diff --git a/.cursor/rules/codegraph.mdc b/.cursor/rules/codegraph.mdc
new file mode 100644
index 0000000..3f23cf6
--- /dev/null
+++ b/.cursor/rules/codegraph.mdc
@@ -0,0 +1,38 @@
+---
+description: CodeGraph MCP usage guide — when to use which tool
+alwaysApply: true
+---
+
+## CodeGraph
+
+This project has a CodeGraph MCP server (`codegraph_*` tools) configured. CodeGraph is a tree-sitter-parsed knowledge graph of every symbol, edge, and file. Reads are sub-millisecond and return structural information grep cannot.
+
+### When to prefer codegraph over native search
+
+Use codegraph for **structural** questions — what calls what, what would break, where is X defined, what is X's signature. Use native grep/read only for **literal text** queries (string contents, comments, log messages) or after you already have a specific file open.
+
+| Question | Tool |
+|---|---|
+| "Where is X defined?" / "Find symbol named X" | `codegraph_search` |
+| "What calls function Y?" | `codegraph_callers` |
+| "What does Y call?" | `codegraph_callees` |
+| "What would break if I changed Z?" | `codegraph_impact` |
+| "Show me Y's signature / source / docstring" | `codegraph_node` |
+| "Give me focused context for a task/area" | `codegraph_context` |
+| "See several related symbols' source at once" | `codegraph_explore` |
+| "What files exist under path/" | `codegraph_files` |
+| "Is the index healthy?" | `codegraph_status` |
+
+### Rules of thumb
+
+- **Answer directly — don't delegate exploration.** For "how does X work" / architecture / trace questions, answer with 2-3 codegraph calls: `codegraph_context` first, then ONE `codegraph_explore` for the source of the symbols it surfaces. Codegraph IS the pre-built index, so spawning a separate file-reading sub-task/agent — or running a grep + read loop — repeats work codegraph already did and costs more for the same answer.
+- **Trust codegraph results.** They come from a full AST parse. Do NOT re-verify them with grep — that's slower, less accurate, and wastes context.
+- **Don't grep first** when looking up a symbol by name. `codegraph_search` is faster and returns kind + location + signature in one call.
+- **Don't chain `codegraph_search` + `codegraph_node`** when you just want context — `codegraph_context` is one call.
+- **Don't loop `codegraph_node` over many symbols** — one `codegraph_explore` call returns several symbols' source grouped in a single capped call, while each separate node/Read call re-reads the whole context and costs far more.
+- **Index lag**: the file watcher debounces ~500ms behind writes; don't re-query immediately after editing a file in the same turn.
+
+### If `.codegraph/` doesn't exist
+
+The MCP server returns "not initialized." Ask the user: *"I notice this project doesn't have CodeGraph initialized. Want me to run `codegraph init -i` to build the index?"*
+
diff --git a/AGENTS.md b/AGENTS.md
new file mode 100644
index 0000000..c618c33
--- /dev/null
+++ b/AGENTS.md
@@ -0,0 +1,44 @@
+# GEO Platform Agents
+
+## Agent Registry
+
+| Agent | Name | Type | Supported Tasks | File |
+|-------|------|------|----------------|------|
+| CitationDetectorAgent | citation_detector | CITATION_DETECTOR | citation_detect, citation_batch | backend/app/agent_framework/agents/citation_detector.py |
+| ContentGeneratorAgent | content_generator | CONTENT_GENERATOR | content_generate, content_regenerate | backend/app/agent_framework/agents/content_generator_agent.py |
+| DeAIAgent | deai_agent | DEAI_AGENT | deai_process | backend/app/agent_framework/agents/deai_agent.py |
+| GEOOptimizerAgent | geo_optimizer | GEO_OPTIMIZER | geo_optimize | backend/app/agent_framework/agents/geo_optimizer.py |
+| MonitorAgent | monitor | PERFORMANCE_TRACKER | monitor_track, monitor_check_single | backend/app/agent_framework/agents/monitor_agent.py |
+| SchemaAdvisorAgent | schema_advisor | SCHEMA_ADVISOR | schema_advise | backend/app/agent_framework/agents/schema_advisor.py |
+| CompetitorAnalyzerAgent | competitor_analyzer | COMPETITOR_ANALYZER | competitor_analyze, competitor_gap_analysis | backend/app/agent_framework/agents/competitor_analyzer.py |
+| TrendAgent | trend_agent | TREND_AGENT | trend_insight, trend_hotspot | backend/app/agent_framework/agents/trend_agent.py |
+
+## Running Agents
+
+### Standalone Mode (No Redis Required)
+```bash
+cd geo/backend
+python3 -m app.agent_framework.standalone [agent_name|all]
+```
+
+### With Redis Queue
+Agents auto-register via Registry when Redis is available. Tasks dispatched via TaskDispatcher.
+
+## Agent Framework Components
+- BaseAgent: Abstract base class (backend/app/agent_framework/base_agent.py)
+- Dispatcher: Task distribution (backend/app/agent_framework/dispatcher.py)
+- Registry: Agent registration (backend/app/agent_framework/registry.py)
+- Protocol: Message types and AgentType enum (backend/app/agent_framework/protocol.py)
+- ConfigManager: Agent configuration (backend/app/agent_framework/config_manager.py)
+
+## New Agent Data Models
+- MonitoringRecord + ContentBaseline (backend/app/models/monitoring.py)
+- SchemaSuggestion (backend/app/models/schema_suggestion.py)
+- CompetitorInsight (backend/app/models/competitor_insight.py)
+- TrendInsight (backend/app/models/trend_insight.py)
+
+## New Agent API Endpoints
+- /api/v1/monitoring - MonitorAgent API
+- /api/v1/competitor - CompetitorAnalyzer API
+- /api/v1/schema - SchemaAdvisor API
+- /api/v1/trends - TrendAgent API
diff --git a/CLAUDE.md b/CLAUDE.md
new file mode 100644
index 0000000..cb11f3e
--- /dev/null
+++ b/CLAUDE.md
@@ -0,0 +1,58 @@
+# GEO Platform - AI Context
+
+## Project Overview
+GEO (Generative Engine Optimization) SaaS platform that helps brands get cited by AI search engines (ChatGPT, Perplexity, DeepSeek, etc.).
+
+## Tech Stack
+- Backend: FastAPI + SQLAlchemy 2.0 (async) + PostgreSQL 15 + Redis 7
+- Frontend: Next.js 14 + React 18 + TypeScript + Tailwind CSS + shadcn/ui
+- AI: Multi-LLM provider (DeepSeek, OpenAI, etc.) + RAG knowledge base
+- Infrastructure: Docker Compose + Prometheus metrics
+
+## Architecture
+- Backend API: 33 route modules under backend/app/api/
+- Agent Framework: 8 agents (CitationDetector, ContentGenerator, DeAIAgent, GEOOptimizer, MonitorAgent, SchemaAdvisor, CompetitorAnalyzer, TrendAgent)
+- Data Models: 35 models under backend/app/models/
+- Services: 80+ service files organized by domain under backend/app/services/
+- Repositories: 12 repository files under backend/app/repositories/ for data access
+- Frontend: 25+ pages under frontend/app/(dashboard)/
+- Frontend API Clients: 27 modules under frontend/lib/api/
+
+## Key Patterns
+- Repository pattern for data access (backend/app/repositories/)
+- Agent Framework: BaseAgent abstract class → concrete agents, Dispatcher + Registry + Redis Queue
+- Shared BrandScoringDataService for V2 5-dimension scoring
+- Content Pipeline: Generate → DeAI → GEO Optimize → HTML Output
+- GEO Workflow: Diagnosis → Strategy → Plan → Content Generation → Monitoring
+
+## API Conventions
+- All routes under /api/v1/ prefix
+- Authentication via JWT Bearer token (get_current_user dependency)
+- Pydantic schemas for request/response validation
+- Async SQLAlchemy ORM with AsyncSession
+
+## Frontend Conventions
+- API clients in frontend/lib/api/ using fetchWithAuth from client.ts
+- Barrel export in frontend/lib/api/index.ts
+- Types co-located with API modules
+- Dashboard pages under frontend/app/(dashboard)/dashboard/
+
+## Environment
+- Backend runs on port 8000
+- Frontend runs on port 3001 (dev)
+- PostgreSQL on port 5432
+- Redis on port 6379
+- .env.example has all required variables
+
+## Common Commands
+- Start backend: cd geo/backend && python3 -m uvicorn app.main:app --reload --port 8000
+- Start frontend: cd geo/frontend && npm run dev
+- Start all agents: cd geo/backend && python3 -m app.agent_framework.standalone all
+- Docker: docker-compose up -d
+
+## Important Notes
+- SEOOptimizer in code actually performs GEO optimization, not traditional SEO
+- content.py (content generation) vs contents.py (content management) - different modules
+- Analytics endpoints all require authentication via _get_org_id or get_current_user
+- CORS: dev mode uses allow_origins=["*"], production must restrict
+- Brand scoring uses V2 5-dimension system: mention_rate, recommendation_rank, sentiment_score, citation_quality, competitive_position
diff --git a/backend/app/agent_framework/agents/__init__.py b/backend/app/agent_framework/agents/__init__.py
index 87cc090..3275afd 100644
--- a/backend/app/agent_framework/agents/__init__.py
+++ b/backend/app/agent_framework/agents/__init__.py
@@ -4,10 +4,12 @@ from .citation_detector import CitationDetectorAgent
from .content_generator_agent import ContentGeneratorAgent
from .deai_agent import DeAIAgent
from .geo_optimizer_agent import GEOOptimizerAgent
+from .trend_agent import TrendAgent
__all__ = [
"CitationDetectorAgent",
"ContentGeneratorAgent",
"DeAIAgent",
"GEOOptimizerAgent",
+ "TrendAgent",
]
diff --git a/backend/app/agent_framework/agents/citation_detector.py b/backend/app/agent_framework/agents/citation_detector.py
index faaf13a..80c18b2 100644
--- a/backend/app/agent_framework/agents/citation_detector.py
+++ b/backend/app/agent_framework/agents/citation_detector.py
@@ -1,9 +1,9 @@
-"""CitationDetector Agent - 将现有 CitationEngine 封装为 Agent"""
-
import logging
-import re
import time
-from datetime import datetime, timezone
+import uuid
+from datetime import datetime, timedelta, timezone
+
+from sqlalchemy import select
from app.agent_framework.base import BaseAgent
from app.agent_framework.protocol import (
@@ -16,19 +16,13 @@ from app.agent_framework.protocol import (
from app.database import AsyncSessionLocal
from app.models.citation_record import CitationRecord
from app.models.query import Query
-from app.workers.citation_engine import CitationEngine
+from app.models.query_task import QueryTask
+from app.services.ai_engine.platform_bridge import execute_single_platform
logger = logging.getLogger(__name__)
class CitationDetectorAgent(BaseAgent):
- """
- 引用检测 Agent:将现有 CitationEngine 封装为 BaseAgent 实现。
-
- 支持的任务类型:
- - citation_detect: 执行完整的引用检测(遍历 query 的所有平台)
- - citation_detect_single: 执行单个平台的引用检测
- """
def __init__(self):
super().__init__(
@@ -36,7 +30,6 @@ class CitationDetectorAgent(BaseAgent):
agent_type=AgentType.CITATION_DETECTOR,
version="1.0.0",
)
- self._engine = CitationEngine()
def get_capabilities(self) -> AgentCapability:
return AgentCapability(
@@ -49,7 +42,6 @@ class CitationDetectorAgent(BaseAgent):
)
async def execute(self, task: TaskMessage) -> TaskResult:
- """执行引用检测任务"""
started_at = datetime.now(timezone.utc)
start_time = time.monotonic()
@@ -94,17 +86,11 @@ class CitationDetectorAgent(BaseAgent):
)
async def _execute_full_detect(self, task: TaskMessage) -> dict:
- """
- 执行完整的引用检测(遍历 query 的所有平台)。
- input_data 需包含: query_id (str)
- """
query_id = task.input_data.get("query_id")
if not query_id:
raise ValueError("input_data must contain 'query_id'")
async with AsyncSessionLocal() as db:
- from sqlalchemy import select
-
stmt = select(Query).where(Query.id == query_id)
result = await db.execute(stmt)
query = result.scalar_one_or_none()
@@ -112,23 +98,75 @@ class CitationDetectorAgent(BaseAgent):
if not query:
raise ValueError(f"Query {query_id} not found")
- # 上报进度:开始检测
await self.report_progress(
task_id=task.task_id,
progress=0.1,
message=f"Starting citation detection for query '{query.keyword}'",
)
- records = await self._engine.execute_query(query, db)
+ records: list[CitationRecord] = []
+ platforms = query.platforms or ["wenxin", "kimi"]
+ brand_aliases = query.brand_aliases or []
+
+ for i, platform_name in enumerate(platforms):
+ progress = 0.1 + 0.8 * (i / len(platforms))
+ await self.report_progress(
+ task_id=task.task_id,
+ progress=progress,
+ message=f"Detecting on platform '{platform_name}' ({i+1}/{len(platforms)})",
+ )
+
+ task_obj = await self._get_or_create_task(db, query.id, platform_name)
+ task_obj.status = "running"
+ task_obj.started_at = datetime.utcnow()
+ task_obj.error_message = None
+ await db.commit()
+
+ try:
+ detect_result = await execute_single_platform(
+ keyword=query.keyword,
+ platform=platform_name,
+ target_brand=query.target_brand,
+ brand_aliases=brand_aliases,
+ )
+
+ record = CitationRecord.from_citation_result(
+ query_id=query.id,
+ platform=platform_name,
+ result=detect_result,
+ )
+ db.add(record)
+ records.append(record)
+
+ task_obj.status = "success"
+ task_obj.completed_at = datetime.utcnow()
+ await db.commit()
+
+ except Exception as e:
+ logger.error(f"平台 {platform_name} 查询失败: {e}")
+ task_obj.status = "failed"
+ task_obj.error_message = str(e)
+ task_obj.completed_at = datetime.utcnow()
+
+ record = CitationRecord.from_citation_result(
+ query_id=query.id,
+ platform=platform_name,
+ result={"cited": False, "raw_response": str(e)},
+ )
+ db.add(record)
+ records.append(record)
+ await db.commit()
+
+ query.last_queried_at = datetime.utcnow()
+ query.next_query_at = self._calculate_next_query_at(query.frequency)
+ await db.commit()
- # 上报进度:检测完成
await self.report_progress(
task_id=task.task_id,
progress=1.0,
message=f"Detection completed: {len(records)} records found",
)
- # 构建输出
record_summaries = []
for r in records:
record_summaries.append({
@@ -148,10 +186,6 @@ class CitationDetectorAgent(BaseAgent):
}
async def _execute_single_detect(self, task: TaskMessage) -> dict:
- """
- 执行单个平台的引用检测。
- input_data 需包含: keyword, platform, target_brand, brand_aliases(optional)
- """
keyword = task.input_data.get("keyword")
platform = task.input_data.get("platform")
target_brand = task.input_data.get("target_brand")
@@ -162,56 +196,118 @@ class CitationDetectorAgent(BaseAgent):
"input_data must contain 'keyword', 'platform', 'target_brand'"
)
- # 上报进度
await self.report_progress(
task_id=task.task_id,
progress=0.2,
message=f"Querying platform '{platform}' with keyword '{keyword}'",
)
- result = await self._engine.execute_single_platform(
+ result = await execute_single_platform(
keyword=keyword,
platform=platform,
target_brand=target_brand,
brand_aliases=brand_aliases,
)
- # 上报进度:完成
await self.report_progress(
task_id=task.task_id,
progress=1.0,
message=f"Single platform detection completed on '{platform}'",
)
- # 清理 raw_response 以避免返回过大的数据
output = {k: v for k, v in result.items() if k != "raw_response"}
return output
- # -----------------------------------------------------------------------
- # 向后兼容:保留原有 CitationEngine 的同步调用接口
- # -----------------------------------------------------------------------
-
async def execute_query_compat(self, query: Query, db) -> list[CitationRecord]:
- """
- 向后兼容方法:供现有 scheduler 继续调用。
- 签名与 CitationEngine.execute_query 完全一致。
- """
- return await self._engine.execute_query(query, db)
+ records: list[CitationRecord] = []
+ platforms = query.platforms or ["wenxin", "kimi"]
+ brand_aliases = query.brand_aliases or []
+
+ for platform_name in platforms:
+ task_obj = await self._get_or_create_task(db, query.id, platform_name)
+ task_obj.status = "running"
+ task_obj.started_at = datetime.utcnow()
+ task_obj.error_message = None
+ await db.commit()
+
+ try:
+ detect_result = await execute_single_platform(
+ keyword=query.keyword,
+ platform=platform_name,
+ target_brand=query.target_brand,
+ brand_aliases=brand_aliases,
+ )
+
+ record = CitationRecord.from_citation_result(
+ query_id=query.id,
+ platform=platform_name,
+ result=detect_result,
+ )
+ db.add(record)
+ records.append(record)
+
+ task_obj.status = "success"
+ task_obj.completed_at = datetime.utcnow()
+ await db.commit()
+
+ except Exception as e:
+ logger.error(f"平台 {platform_name} 查询失败: {e}")
+ task_obj.status = "failed"
+ task_obj.error_message = str(e)
+ task_obj.completed_at = datetime.utcnow()
+
+ record = CitationRecord.from_citation_result(
+ query_id=query.id,
+ platform=platform_name,
+ result={"cited": False, "raw_response": str(e)},
+ )
+ db.add(record)
+ records.append(record)
+ await db.commit()
+
+ query.last_queried_at = datetime.utcnow()
+ query.next_query_at = self._calculate_next_query_at(query.frequency)
+ await db.commit()
+
+ return records
async def execute_single_platform_compat(
self, keyword: str, platform: str, target_brand: str, brand_aliases: list
) -> dict:
- """
- 向后兼容方法:供现有 scheduler 继续调用。
- 签名与 CitationEngine.execute_single_platform 完全一致。
- """
- return await self._engine.execute_single_platform(
+ return await execute_single_platform(
keyword=keyword,
platform=platform,
target_brand=target_brand,
brand_aliases=brand_aliases,
)
- async def close(self):
- """关闭底层引擎"""
- await self._engine.close()
+ async def _get_or_create_task(self, db, query_id: uuid.UUID, platform: str) -> QueryTask:
+ stmt = select(QueryTask).where(
+ QueryTask.query_id == query_id,
+ QueryTask.platform == platform,
+ )
+ result = await db.execute(stmt)
+ task_obj = result.scalar_one_or_none()
+
+ if not task_obj:
+ task_obj = QueryTask(
+ query_id=query_id,
+ platform=platform,
+ status="pending",
+ )
+ db.add(task_obj)
+ await db.commit()
+ await db.refresh(task_obj)
+
+ return task_obj
+
+ @staticmethod
+ def _calculate_next_query_at(frequency: str | None) -> datetime:
+ now = datetime.utcnow()
+ freq_map = {
+ "daily": timedelta(days=1),
+ "weekly": timedelta(days=7),
+ "monthly": timedelta(days=30),
+ }
+ delta = freq_map.get(frequency or "weekly", timedelta(days=7))
+ return now + delta
diff --git a/backend/app/agent_framework/agents/competitor_analyzer.py b/backend/app/agent_framework/agents/competitor_analyzer.py
new file mode 100644
index 0000000..1cfd023
--- /dev/null
+++ b/backend/app/agent_framework/agents/competitor_analyzer.py
@@ -0,0 +1,163 @@
+import logging
+import time
+from datetime import datetime, timezone
+
+from app.agent_framework.base import BaseAgent
+from app.agent_framework.protocol import (
+ AgentCapability,
+ AgentType,
+ TaskMessage,
+ TaskResult,
+ TaskStatus,
+)
+from app.services.llm import LLMFactory, LLMError
+
+logger = logging.getLogger(__name__)
+
+
+class CompetitorAnalyzerAgent(BaseAgent):
+
+ def __init__(self):
+ super().__init__(
+ name="competitor_analyzer",
+ agent_type=AgentType.COMPETITOR_ANALYZER,
+ version="1.0.0",
+ )
+
+ def get_capabilities(self) -> AgentCapability:
+ return AgentCapability(
+ agent_name=self.name,
+ agent_type=self.agent_type,
+ version=self.version,
+ supported_tasks=["competitor_analyze", "competitor_gap_analysis"],
+ max_concurrency=2,
+ description="竞品策略分析Agent:对比品牌与竞品的引用数据,识别差距领域,发现机会点,生成策略建议",
+ )
+
+ async def execute(self, task: TaskMessage) -> TaskResult:
+ started_at = datetime.now(timezone.utc)
+ start_time = time.monotonic()
+
+ try:
+ if task.task_type == "competitor_gap_analysis":
+ output = await self._gap_analysis(task)
+ else:
+ output = await self._analyze(task)
+
+ elapsed = time.monotonic() - start_time
+ return TaskResult(
+ task_id=task.task_id,
+ agent_name=self.name,
+ status=TaskStatus.COMPLETED,
+ output_data=output,
+ error_message=None,
+ started_at=started_at,
+ completed_at=datetime.now(timezone.utc),
+ metrics={
+ "elapsed_seconds": round(elapsed, 2),
+ "task_type": task.task_type,
+ },
+ )
+
+ except LLMError as e:
+ elapsed = time.monotonic() - start_time
+ logger.error(f"CompetitorAnalyzer LLM error on task {task.task_id}: {e}")
+ return TaskResult(
+ task_id=task.task_id,
+ agent_name=self.name,
+ status=TaskStatus.FAILED,
+ output_data=None,
+ error_message=f"LLM调用失败: {e}",
+ started_at=started_at,
+ completed_at=datetime.now(timezone.utc),
+ metrics={
+ "elapsed_seconds": round(elapsed, 2),
+ "task_type": task.task_type,
+ },
+ )
+
+ except Exception as e:
+ elapsed = time.monotonic() - start_time
+ logger.error(f"CompetitorAnalyzer task {task.task_id} failed: {e}")
+ return TaskResult(
+ task_id=task.task_id,
+ agent_name=self.name,
+ status=TaskStatus.FAILED,
+ output_data=None,
+ error_message=str(e),
+ started_at=started_at,
+ completed_at=datetime.now(timezone.utc),
+ metrics={
+ "elapsed_seconds": round(elapsed, 2),
+ "task_type": task.task_type,
+ },
+ )
+
+ async def _analyze(self, task: TaskMessage) -> dict:
+ from app.services.competitor.competitor_analyzer_service import CompetitorAnalyzerService
+
+ input_data = task.input_data
+ brand_id = input_data.get("brand_id")
+ analysis_types = input_data.get("analysis_types")
+ period_days = input_data.get("period_days", 30)
+
+ if not brand_id:
+ raise ValueError("input_data必须包含'brand_id'字段")
+
+ await self.report_progress(
+ task_id=task.task_id,
+ progress=0.1,
+ message="开始竞品策略分析...",
+ )
+
+ service = CompetitorAnalyzerService()
+ result = await service.analyze_competitor(
+ brand_id=brand_id,
+ analysis_types=analysis_types,
+ period_days=period_days,
+ progress_callback=lambda p, m: self.report_progress(
+ task_id=task.task_id, progress=p, message=m,
+ ),
+ )
+
+ await self.report_progress(
+ task_id=task.task_id,
+ progress=1.0,
+ message="竞品策略分析完成",
+ )
+
+ return result
+
+ async def _gap_analysis(self, task: TaskMessage) -> dict:
+ from app.services.competitor.competitor_analyzer_service import CompetitorAnalyzerService
+
+ input_data = task.input_data
+ brand_id = input_data.get("brand_id")
+ period_days = input_data.get("period_days", 30)
+
+ if not brand_id:
+ raise ValueError("input_data必须包含'brand_id'字段")
+
+ await self.report_progress(
+ task_id=task.task_id,
+ progress=0.1,
+ message="开始竞品差距分析...",
+ )
+
+ service = CompetitorAnalyzerService()
+ result = await service.analyze_competitor(
+ brand_id=brand_id,
+ analysis_types=["citation_gap", "platform_coverage", "query_overlap"],
+ period_days=period_days,
+ progress_callback=lambda p, m: self.report_progress(
+ task_id=task.task_id, progress=p, message=m,
+ ),
+ )
+
+ await self.report_progress(
+ task_id=task.task_id,
+ progress=1.0,
+ message="竞品差距分析完成",
+ )
+
+ return result
diff --git a/backend/app/agent_framework/agents/content_generator_agent.py b/backend/app/agent_framework/agents/content_generator_agent.py
index 45d907a..7564fcb 100644
--- a/backend/app/agent_framework/agents/content_generator_agent.py
+++ b/backend/app/agent_framework/agents/content_generator_agent.py
@@ -2,7 +2,6 @@
import json
import logging
-import re
import time
from datetime import datetime, timezone
@@ -16,6 +15,7 @@ from app.agent_framework.protocol import (
TaskStatus,
)
from app.services.llm import LLMFactory, LLMError
+from app.utils.json_extractor import extract_json
logger = logging.getLogger(__name__)
@@ -165,7 +165,7 @@ class ContentGeneratorAgent(BaseAgent):
# 4. 解析JSON输出
try:
- topics = json.loads(self._extract_json(response.content))
+ topics = json.loads(extract_json(response.content))
except json.JSONDecodeError:
topics = [{"title": response.content, "reason": "LLM输出解析失败,返回原始内容"}]
@@ -277,22 +277,3 @@ class ContentGeneratorAgent(BaseAgent):
logger.warning(f"RAG检索失败,跳过知识库上下文: {e}")
return "暂无相关知识库内容"
-
- def _extract_json(self, text: str) -> str:
- """从LLM响应中提取JSON(可能被markdown包裹)"""
- # 尝试提取```json ... ```中的内容
- match = re.search(r'```(?:json)?\s*\n?(.*?)\n?```', text, re.DOTALL)
- if match:
- return match.group(1).strip()
- # 尝试直接找 [ 或 { 开头的JSON
- for i, c in enumerate(text):
- if c in '[{':
- depth = 0
- for j in range(i, len(text)):
- if text[j] in '[{':
- depth += 1
- elif text[j] in ']}':
- depth -= 1
- if depth == 0:
- return text[i : j + 1]
- return text
diff --git a/backend/app/agent_framework/agents/geo_optimizer_agent.py b/backend/app/agent_framework/agents/geo_optimizer_agent.py
index dd05c97..586b972 100644
--- a/backend/app/agent_framework/agents/geo_optimizer_agent.py
+++ b/backend/app/agent_framework/agents/geo_optimizer_agent.py
@@ -2,7 +2,6 @@
import json
import logging
-import re
import time
from datetime import datetime, timezone
@@ -16,6 +15,7 @@ from app.agent_framework.protocol import (
TaskStatus,
)
from app.services.llm import LLMFactory, LLMError
+from app.utils.json_extractor import extract_json
logger = logging.getLogger(__name__)
@@ -155,7 +155,7 @@ class GEOOptimizerAgent(BaseAgent):
# 尝试解析JSON输出
try:
- result = json.loads(self._extract_json(response.content))
+ result = json.loads(extract_json(response.content))
result["usage"] = response.usage
# 上报进度:完成
await self.report_progress(
@@ -178,20 +178,3 @@ class GEOOptimizerAgent(BaseAgent):
"changes": ["LLM输出非标准格式,已返回原始优化结果"],
"usage": response.usage,
}
-
- def _extract_json(self, text: str) -> str:
- """从响应中提取JSON"""
- match = re.search(r'```(?:json)?\s*\n?(.*?)\n?```', text, re.DOTALL)
- if match:
- return match.group(1).strip()
- for i, c in enumerate(text):
- if c == '{':
- depth = 0
- for j in range(i, len(text)):
- if text[j] == '{':
- depth += 1
- elif text[j] == '}':
- depth -= 1
- if depth == 0:
- return text[i : j + 1]
- return text
diff --git a/backend/app/agent_framework/agents/monitor_agent.py b/backend/app/agent_framework/agents/monitor_agent.py
new file mode 100644
index 0000000..c962c43
--- /dev/null
+++ b/backend/app/agent_framework/agents/monitor_agent.py
@@ -0,0 +1,231 @@
+import logging
+import time
+import uuid
+from datetime import datetime, timezone
+
+from sqlalchemy import select
+
+from app.agent_framework.base import BaseAgent
+from app.agent_framework.protocol import (
+ AgentCapability,
+ AgentType,
+ TaskMessage,
+ TaskResult,
+ TaskStatus,
+)
+from app.database import AsyncSessionLocal
+from app.models.monitoring import MonitoringRecord
+from app.models.query import Query
+from app.models.brand import Brand
+from app.services.monitoring.monitor_service import MonitorService
+
+logger = logging.getLogger(__name__)
+
+
+class MonitorAgent(BaseAgent):
+
+ def __init__(self):
+ super().__init__(
+ name="monitor",
+ agent_type=AgentType.PERFORMANCE_TRACKER,
+ version="1.0.0",
+ )
+
+ def get_capabilities(self) -> AgentCapability:
+ return AgentCapability(
+ agent_name=self.name,
+ agent_type=self.agent_type,
+ version=self.version,
+ supported_tasks=["monitor_track", "monitor_check_single"],
+ max_concurrency=3,
+ description="效果追踪Agent:监测品牌引用量、情感、排名变化,生成变化报告",
+ )
+
+ async def execute(self, task: TaskMessage) -> TaskResult:
+ started_at = datetime.now(timezone.utc)
+ start_time = time.monotonic()
+
+ try:
+ task_type = task.task_type
+ if task_type == "monitor_track":
+ output = await self._monitor_track(task)
+ elif task_type == "monitor_check_single":
+ output = await self._monitor_check_single(task)
+ else:
+ raise ValueError(f"不支持的任务类型: {task_type}")
+
+ elapsed = time.monotonic() - start_time
+ return TaskResult(
+ task_id=task.task_id,
+ agent_name=self.name,
+ status=TaskStatus.COMPLETED,
+ output_data=output,
+ error_message=None,
+ started_at=started_at,
+ completed_at=datetime.now(timezone.utc),
+ metrics={
+ "elapsed_seconds": round(elapsed, 2),
+ "task_type": task_type,
+ },
+ )
+
+ except Exception as e:
+ elapsed = time.monotonic() - start_time
+ logger.error(f"MonitorAgent task {task.task_id} failed: {e}")
+ return TaskResult(
+ task_id=task.task_id,
+ agent_name=self.name,
+ status=TaskStatus.FAILED,
+ output_data=None,
+ error_message=str(e),
+ started_at=started_at,
+ completed_at=datetime.now(timezone.utc),
+ metrics={
+ "elapsed_seconds": round(elapsed, 2),
+ "task_type": task.task_type,
+ },
+ )
+
+ async def _monitor_track(self, task: TaskMessage) -> dict:
+ input_data = task.input_data
+ brand_id = input_data.get("brand_id")
+ if not brand_id:
+ raise ValueError("input_data必须包含'brand_id'字段")
+
+ brand_id = uuid.UUID(brand_id)
+
+ await self.report_progress(
+ task_id=task.task_id,
+ progress=0.1,
+ message="开始效果追踪...",
+ )
+
+ async with AsyncSessionLocal() as db:
+ stmt = select(Brand).where(Brand.id == brand_id)
+ result = await db.execute(stmt)
+ brand = result.scalar_one_or_none()
+ if not brand:
+ raise ValueError(f"品牌不存在: {brand_id}")
+
+ queries_stmt = select(Query).where(Query.target_brand == brand.name)
+ queries_result = await db.execute(queries_stmt)
+ queries = list(queries_result.scalars().all())
+
+ if not queries:
+ await self.report_progress(
+ task_id=task.task_id,
+ progress=1.0,
+ message="品牌下无关联查询,效果追踪完成",
+ )
+ return {
+ "brand_id": str(brand_id),
+ "brand_name": brand.name,
+ "total_queries": 0,
+ "reports": [],
+ }
+
+ total_queries = len(queries)
+ reports = []
+ service = MonitorService()
+
+ monitoring_stmt = select(MonitoringRecord).where(
+ MonitoringRecord.brand_id == brand_id,
+ MonitoringRecord.status == "active",
+ )
+ monitoring_result = await db.execute(monitoring_stmt)
+ monitoring_records = list(monitoring_result.scalars().all())
+
+ for idx, record in enumerate(monitoring_records):
+ progress = 0.2 + (0.7 * idx / max(len(monitoring_records), 1))
+ await self.report_progress(
+ task_id=task.task_id,
+ progress=progress,
+ message=f"正在检测第 {idx + 1}/{len(monitoring_records)} 条监测记录...",
+ )
+
+ updated_record = await service.check_and_compare(db, record.id)
+ if updated_record:
+ report = await service.generate_change_report(db, updated_record.id)
+ if report:
+ reports.append(report)
+
+ await self.report_progress(
+ task_id=task.task_id,
+ progress=1.0,
+ message="效果追踪完成",
+ )
+
+ return {
+ "brand_id": str(brand_id),
+ "brand_name": brand.name,
+ "total_queries": total_queries,
+ "checked_records": len(monitoring_records),
+ "reports": reports,
+ }
+
+ async def _monitor_check_single(self, task: TaskMessage) -> dict:
+ input_data = task.input_data
+ brand_id = input_data.get("brand_id")
+ keyword = input_data.get("keyword")
+ platform = input_data.get("platform")
+
+ if not brand_id:
+ raise ValueError("input_data必须包含'brand_id'字段")
+
+ brand_id = uuid.UUID(brand_id)
+
+ await self.report_progress(
+ task_id=task.task_id,
+ progress=0.1,
+ message="开始单关键词检测...",
+ )
+
+ async with AsyncSessionLocal() as db:
+ service = MonitorService()
+
+ await self.report_progress(
+ task_id=task.task_id,
+ progress=0.3,
+ message="正在创建监测记录...",
+ )
+
+ record = await service.create_monitoring_record(
+ db=db,
+ brand_id=brand_id,
+ query_keywords=keyword,
+ platform=platform,
+ check_interval_hours=input_data.get("check_interval_hours", 24),
+ )
+
+ await self.report_progress(
+ task_id=task.task_id,
+ progress=0.6,
+ message="正在执行检测对比...",
+ )
+
+ updated_record = await service.check_and_compare(db, record.id)
+
+ await self.report_progress(
+ task_id=task.task_id,
+ progress=0.8,
+ message="正在生成变化报告...",
+ )
+
+ report = None
+ if updated_record:
+ report = await service.generate_change_report(db, updated_record.id)
+
+ await self.report_progress(
+ task_id=task.task_id,
+ progress=1.0,
+ message="单关键词检测完成",
+ )
+
+ return {
+ "record_id": str(record.id),
+ "brand_id": str(brand_id),
+ "keyword": keyword,
+ "platform": platform,
+ "change_type": updated_record.change_type if updated_record else None,
+ "report": report,
+ }
diff --git a/backend/app/agent_framework/agents/schema_advisor.py b/backend/app/agent_framework/agents/schema_advisor.py
new file mode 100644
index 0000000..518b281
--- /dev/null
+++ b/backend/app/agent_framework/agents/schema_advisor.py
@@ -0,0 +1,398 @@
+import copy
+import json
+import logging
+import time
+from datetime import datetime, timezone
+
+from app.agent_framework.base import BaseAgent
+from app.agent_framework.prompts.schema_advisor import SCHEMA_ADVISOR_TEMPLATE
+from app.agent_framework.protocol import (
+ AgentCapability,
+ AgentType,
+ TaskMessage,
+ TaskResult,
+ TaskStatus,
+)
+from app.services.llm import LLMFactory, LLMError
+from app.utils.json_extractor import extract_json
+
+logger = logging.getLogger(__name__)
+
+SCHEMA_TEMPLATES = {
+ "Organization": {
+ "@context": "https://schema.org",
+ "@type": "Organization",
+ "name": "",
+ "description": "",
+ "url": "",
+ "logo": "",
+ "sameAs": [],
+ "contactPoint": {
+ "@type": "ContactPoint",
+ "contactType": "customer service",
+ "telephone": "",
+ },
+ },
+ "Product": {
+ "@context": "https://schema.org",
+ "@type": "Product",
+ "name": "",
+ "description": "",
+ "brand": {"@type": "Brand", "name": ""},
+ "offers": {
+ "@type": "Offer",
+ "priceCurrency": "CNY",
+ "availability": "https://schema.org/InStock",
+ },
+ },
+ "FAQPage": {
+ "@context": "https://schema.org",
+ "@type": "FAQPage",
+ "mainEntity": [
+ {
+ "@type": "Question",
+ "name": "",
+ "acceptedAnswer": {
+ "@type": "Answer",
+ "text": "",
+ },
+ }
+ ],
+ },
+ "Article": {
+ "@context": "https://schema.org",
+ "@type": "Article",
+ "headline": "",
+ "description": "",
+ "author": {"@type": "Organization", "name": ""},
+ "datePublished": "",
+ "image": "",
+ },
+ "LocalBusiness": {
+ "@context": "https://schema.org",
+ "@type": "LocalBusiness",
+ "name": "",
+ "address": {
+ "@type": "PostalAddress",
+ "streetAddress": "",
+ "addressLocality": "",
+ "addressRegion": "",
+ "postalCode": "",
+ "addressCountry": "CN",
+ },
+ "geo": {
+ "@type": "GeoCoordinates",
+ "latitude": "",
+ "longitude": "",
+ },
+ "telephone": "",
+ "openingHours": "",
+ },
+}
+
+DIMENSION_SCHEMA_MAP = {
+ "schema_marketing": ["Organization", "LocalBusiness"],
+ "entity_clarity": ["Organization", "Product"],
+ "citation_readiness": ["FAQPage", "Article"],
+ "brand_visibility": ["Organization", "Product"],
+ "local_seo": ["LocalBusiness"],
+}
+
+PRIORITY_THRESHOLD = {
+ "high": 30.0,
+ "medium": 60.0,
+}
+
+DIFFICULTY_MAP = {
+ "Organization": "easy",
+ "Product": "medium",
+ "FAQPage": "medium",
+ "Article": "easy",
+ "LocalBusiness": "hard",
+}
+
+
+class SchemaAdvisorAgent(BaseAgent):
+
+ def __init__(self):
+ super().__init__(
+ name="schema_advisor",
+ agent_type=AgentType.SCHEMA_ADVISOR,
+ version="1.0.0",
+ )
+
+ def get_capabilities(self) -> AgentCapability:
+ return AgentCapability(
+ agent_name=self.name,
+ agent_type=self.agent_type,
+ version=self.version,
+ supported_tasks=["schema_advise"],
+ max_concurrency=2,
+ description="Schema优化建议Agent:识别Schema缺失维度,生成JSON-LD结构化数据建议",
+ )
+
+ async def execute(self, task: TaskMessage) -> TaskResult:
+ started_at = datetime.now(timezone.utc)
+ start_time = time.monotonic()
+
+ try:
+ output = await self._advise(task)
+
+ elapsed = time.monotonic() - start_time
+ return TaskResult(
+ task_id=task.task_id,
+ agent_name=self.name,
+ status=TaskStatus.COMPLETED,
+ output_data=output,
+ error_message=None,
+ started_at=started_at,
+ completed_at=datetime.now(timezone.utc),
+ metrics={
+ "elapsed_seconds": round(elapsed, 2),
+ "task_type": task.task_type,
+ },
+ )
+
+ except LLMError as e:
+ elapsed = time.monotonic() - start_time
+ logger.error(f"SchemaAdvisor LLM error on task {task.task_id}: {e}")
+ return TaskResult(
+ task_id=task.task_id,
+ agent_name=self.name,
+ status=TaskStatus.FAILED,
+ output_data=None,
+ error_message=f"LLM调用失败: {e}",
+ started_at=started_at,
+ completed_at=datetime.now(timezone.utc),
+ metrics={
+ "elapsed_seconds": round(elapsed, 2),
+ "task_type": task.task_type,
+ },
+ )
+
+ except Exception as e:
+ elapsed = time.monotonic() - start_time
+ logger.error(f"SchemaAdvisor task {task.task_id} failed: {e}")
+ return TaskResult(
+ task_id=task.task_id,
+ agent_name=self.name,
+ status=TaskStatus.FAILED,
+ output_data=None,
+ error_message=str(e),
+ started_at=started_at,
+ completed_at=datetime.now(timezone.utc),
+ metrics={
+ "elapsed_seconds": round(elapsed, 2),
+ "task_type": task.task_type,
+ },
+ )
+
+ async def _advise(self, task: TaskMessage) -> dict:
+ input_data = task.input_data
+ brand_id = input_data.get("brand_id")
+ diagnosis_data = input_data.get("diagnosis_data", {})
+ brand_info = input_data.get("brand_info", {})
+ focus_dimensions = input_data.get("focus_dimensions")
+
+ if not brand_id:
+ raise ValueError("input_data必须包含'brand_id'字段")
+
+ await self.report_progress(
+ task_id=task.task_id,
+ progress=0.1,
+ message="开始Schema建议分析...",
+ )
+
+ missing_dimensions = self._identify_missing_dimensions(diagnosis_data, focus_dimensions)
+
+ await self.report_progress(
+ task_id=task.task_id,
+ progress=0.3,
+ message=f"识别到{len(missing_dimensions)}个Schema缺失维度...",
+ )
+
+ matched = self._match_templates(missing_dimensions)
+
+ await self.report_progress(
+ task_id=task.task_id,
+ progress=0.5,
+ message="匹配预定义模板完成,开始LLM填充...",
+ )
+
+ filled = await self._fill_with_llm(matched, brand_info)
+
+ await self.report_progress(
+ task_id=task.task_id,
+ progress=0.8,
+ message="LLM填充完成,验证JSON-LD格式...",
+ )
+
+ validated = self._validate_and_sort(filled)
+
+ await self.report_progress(
+ task_id=task.task_id,
+ progress=1.0,
+ message="Schema建议生成完成",
+ )
+
+ return {
+ "brand_id": brand_id,
+ "suggestions": validated,
+ "total": len(validated),
+ }
+
+ def _identify_missing_dimensions(
+ self,
+ diagnosis_data: dict,
+ focus_dimensions: list[str] | None = None,
+ ) -> list[dict]:
+ dimensions = []
+ dimension_scores = diagnosis_data.get("dimensions", {})
+ for dim_name, dim_info in dimension_scores.items():
+ if dim_name not in DIMENSION_SCHEMA_MAP:
+ continue
+ if focus_dimensions and dim_name not in focus_dimensions:
+ continue
+ score = dim_info.get("score", 0) if isinstance(dim_info, dict) else dim_info
+ max_score = dim_info.get("max_score", 100) if isinstance(dim_info, dict) else 100
+ percentage = (score / max_score * 100) if max_score > 0 else 0
+ if percentage < 80:
+ dimensions.append({
+ "dimension": dim_name,
+ "current_score": round(score, 2),
+ "max_score": max_score,
+ "percentage": round(percentage, 2),
+ })
+ if not dimensions and diagnosis_data:
+ overall = diagnosis_data.get("overall_score", 0)
+ if overall < 80:
+ for dim_name in DIMENSION_SCHEMA_MAP:
+ if focus_dimensions and dim_name not in focus_dimensions:
+ continue
+ dimensions.append({
+ "dimension": dim_name,
+ "current_score": 0,
+ "max_score": 100,
+ "percentage": 0,
+ })
+ return dimensions
+
+ def _match_templates(self, missing_dimensions: list[dict]) -> list[dict]:
+ matched = []
+ seen_types = set()
+ for dim in missing_dimensions:
+ schema_types = DIMENSION_SCHEMA_MAP.get(dim["dimension"], [])
+ for schema_type in schema_types:
+ if schema_type in seen_types:
+ continue
+ seen_types.add(schema_type)
+ template = SCHEMA_TEMPLATES.get(schema_type)
+ if template:
+ percentage = dim["percentage"]
+ if percentage < PRIORITY_THRESHOLD["high"]:
+ priority = "high"
+ elif percentage < PRIORITY_THRESHOLD["medium"]:
+ priority = "medium"
+ else:
+ priority = "low"
+ matched.append({
+ "schema_type": schema_type,
+ "priority": priority,
+ "diagnosis_dimensions": {
+ "dimension": dim["dimension"],
+ "current_score": dim["current_score"],
+ "max_score": dim["max_score"],
+ "percentage": dim["percentage"],
+ },
+ "json_ld_template": copy.deepcopy(template),
+ "implementation_difficulty": DIFFICULTY_MAP.get(schema_type, "medium"),
+ })
+ return matched
+
+ async def _fill_with_llm(self, matched: list[dict], brand_info: dict) -> list[dict]:
+ provider = LLMFactory.get_default()
+ results = []
+ for item in matched:
+ schema_type = item["schema_type"]
+ try:
+ variables = {
+ "brand_name": brand_info.get("name", ""),
+ "brand_website": brand_info.get("website", ""),
+ "brand_industry": brand_info.get("industry", ""),
+ "schema_type": schema_type,
+ "diagnosis_data": json.dumps(item.get("diagnosis_dimensions", {}), ensure_ascii=False),
+ "existing_schemas": "无",
+ }
+ messages = SCHEMA_ADVISOR_TEMPLATE.render(variables)
+ response = await provider.chat(
+ messages,
+ temperature=0.3,
+ max_tokens=2048,
+ )
+ filled = json.loads(extract_json(response.content))
+ item["json_ld_filled"] = filled
+ item["estimated_impact"] = self._generate_impact_description(
+ schema_type, item.get("diagnosis_dimensions", {}).get("dimension", "")
+ )
+ except (json.JSONDecodeError, LLMError, ValueError) as e:
+ logger.warning(f"LLM填充Schema {schema_type} 失败: {e}")
+ item["json_ld_filled"] = None
+ item["estimated_impact"] = self._generate_impact_description(
+ schema_type, item.get("diagnosis_dimensions", {}).get("dimension", "")
+ )
+ results.append(item)
+ return results
+
+ def _validate_json_ld(self, json_ld: dict) -> dict:
+ errors = []
+ warnings = []
+
+ if not json_ld:
+ return {"is_valid": False, "errors": ["JSON-LD为空"], "warnings": []}
+
+ if "@context" not in json_ld:
+ errors.append("缺少@context字段")
+
+ if "@type" not in json_ld:
+ errors.append("缺少@type字段")
+
+ if "@context" in json_ld and json_ld["@context"] != "https://schema.org":
+ warnings.append(f"@context值非标准: {json_ld.get('@context')}")
+
+ if "@type" in json_ld and json_ld["@type"] not in SCHEMA_TEMPLATES:
+ warnings.append(f"@type非推荐类型: {json_ld.get('@type')}")
+
+ try:
+ json.dumps(json_ld)
+ except (json.JSONDecodeError, TypeError) as e:
+ errors.append(f"JSON序列化失败: {e}")
+
+ return {
+ "is_valid": len(errors) == 0,
+ "errors": errors,
+ "warnings": warnings,
+ }
+
+ def _validate_and_sort(self, items: list[dict]) -> list[dict]:
+ validated = []
+ for item in items:
+ json_ld_filled = item.get("json_ld_filled")
+ if json_ld_filled:
+ validation = self._validate_json_ld(json_ld_filled)
+ item["validation_errors"] = None if validation["is_valid"] else {"errors": validation["errors"], "warnings": validation["warnings"]}
+ else:
+ item["validation_errors"] = {"errors": ["JSON-LD填充失败"], "warnings": []}
+ validated.append(item)
+ priority_order = {"high": 0, "medium": 1, "low": 2}
+ validated.sort(key=lambda x: priority_order.get(x.get("priority", "medium"), 1))
+ return validated
+
+ def _generate_impact_description(self, schema_type: str, dimension: str) -> str:
+ impacts = {
+ "Organization": "增强品牌实体识别,提升AI搜索引擎对品牌的理解和引用概率",
+ "Product": "提升产品在搜索结果中的富摘要展示,增加点击率和引用率",
+ "FAQPage": "增加FAQ富摘要展示机会,提升在AI回答中的直接引用概率",
+ "Article": "优化文章内容的结构化表达,提升AI搜索引擎的内容理解和引用",
+ "LocalBusiness": "增强本地搜索可见性,提升地理位置相关查询的引用率",
+ }
+ return impacts.get(schema_type, f"提升{dimension}维度的得分和AI引用率")
diff --git a/backend/app/agent_framework/agents/trend_agent.py b/backend/app/agent_framework/agents/trend_agent.py
new file mode 100644
index 0000000..34b510a
--- /dev/null
+++ b/backend/app/agent_framework/agents/trend_agent.py
@@ -0,0 +1,166 @@
+import logging
+import time
+from datetime import datetime, timezone
+
+from app.agent_framework.base import BaseAgent
+from app.agent_framework.protocol import (
+ AgentCapability,
+ AgentType,
+ TaskMessage,
+ TaskResult,
+ TaskStatus,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class TrendAgent(BaseAgent):
+
+ def __init__(self):
+ super().__init__(
+ name="trend_agent",
+ agent_type=AgentType.TREND_AGENT,
+ version="1.0.0",
+ )
+
+ def get_capabilities(self) -> AgentCapability:
+ return AgentCapability(
+ agent_name=self.name,
+ agent_type=self.agent_type,
+ version=self.version,
+ supported_tasks=["trend_insight", "trend_hotspot"],
+ max_concurrency=2,
+ description="趋势洞察Agent:分析品牌引用趋势、识别热点话题、推断变化原因并生成建议",
+ )
+
+ async def execute(self, task: TaskMessage) -> TaskResult:
+ started_at = datetime.now(timezone.utc)
+ start_time = time.monotonic()
+
+ try:
+ if task.task_type == "trend_insight":
+ output = await self._trend_insight(task)
+ elif task.task_type == "trend_hotspot":
+ output = await self._trend_hotspot(task)
+ else:
+ raise ValueError(f"不支持的任务类型: {task.task_type}")
+
+ elapsed = time.monotonic() - start_time
+ return TaskResult(
+ task_id=task.task_id,
+ agent_name=self.name,
+ status=TaskStatus.COMPLETED,
+ output_data=output,
+ error_message=None,
+ started_at=started_at,
+ completed_at=datetime.now(timezone.utc),
+ metrics={
+ "elapsed_seconds": round(elapsed, 2),
+ "task_type": task.task_type,
+ },
+ )
+
+ except Exception as e:
+ elapsed = time.monotonic() - start_time
+ logger.error(f"TrendAgent task {task.task_id} failed: {e}")
+ return TaskResult(
+ task_id=task.task_id,
+ agent_name=self.name,
+ status=TaskStatus.FAILED,
+ output_data=None,
+ error_message=str(e),
+ started_at=started_at,
+ completed_at=datetime.now(timezone.utc),
+ metrics={
+ "elapsed_seconds": round(elapsed, 2),
+ "task_type": task.task_type,
+ },
+ )
+
+ async def _trend_insight(self, task: TaskMessage) -> dict:
+ input_data = task.input_data
+ brand_id = input_data.get("brand_id")
+ days = input_data.get("days", 30)
+ platforms = input_data.get("platforms")
+ keywords = input_data.get("keywords")
+
+ if not brand_id:
+ raise ValueError("input_data必须包含'brand_id'字段")
+
+ await self.report_progress(
+ task_id=task.task_id,
+ progress=0.1,
+ message="开始趋势洞察分析...",
+ )
+
+ from app.database import AsyncSessionLocal
+ from app.services.trend.trend_analyzer_service import TrendAnalyzerService
+
+ async with AsyncSessionLocal() as db:
+ service = TrendAnalyzerService(db)
+
+ await self.report_progress(
+ task_id=task.task_id,
+ progress=0.3,
+ message="获取历史引用数据...",
+ )
+
+ await self.report_progress(
+ task_id=task.task_id,
+ progress=0.5,
+ message="执行时间序列分析...",
+ )
+
+ result = await service.analyze_trends(
+ brand_id=brand_id,
+ days=days,
+ platforms=platforms,
+ keywords=keywords,
+ )
+
+ await self.report_progress(
+ task_id=task.task_id,
+ progress=1.0,
+ message="趋势洞察分析完成",
+ )
+
+ return result
+
+ async def _trend_hotspot(self, task: TaskMessage) -> dict:
+ input_data = task.input_data
+ brand_id = input_data.get("brand_id")
+ days = input_data.get("days", 30)
+
+ if not brand_id:
+ raise ValueError("input_data必须包含'brand_id'字段")
+
+ await self.report_progress(
+ task_id=task.task_id,
+ progress=0.1,
+ message="开始热点话题分析...",
+ )
+
+ from app.database import AsyncSessionLocal
+ from app.services.trend.trend_analyzer_service import TrendAnalyzerService
+
+ async with AsyncSessionLocal() as db:
+ service = TrendAnalyzerService(db)
+
+ await self.report_progress(
+ task_id=task.task_id,
+ progress=0.5,
+ message="检测引用量突增的关键词/话题...",
+ )
+
+ result = await service.get_hotspots(
+ brand_id=brand_id,
+ days=days,
+ )
+
+ await self.report_progress(
+ task_id=task.task_id,
+ progress=1.0,
+ message="热点话题分析完成",
+ )
+
+ return result
diff --git a/backend/app/agent_framework/base.py b/backend/app/agent_framework/base.py
index 93a07cf..8ca954d 100644
--- a/backend/app/agent_framework/base.py
+++ b/backend/app/agent_framework/base.py
@@ -21,6 +21,33 @@ from app.config import settings
logger = logging.getLogger(__name__)
+# Module-level lazy singleton for TaskDispatcher — avoids creating a new
+# dispatcher on every method call while still deferring the import to
+# prevent circular-dependency issues at module-load time.
+_dispatcher_instance = None
+
+
+def _get_dispatcher():
+ """Return a cached TaskDispatcher instance (lazy singleton)."""
+ global _dispatcher_instance
+ if _dispatcher_instance is None:
+ from app.agent_framework.dispatcher import TaskDispatcher
+ _dispatcher_instance = TaskDispatcher(settings.REDIS_URL)
+ return _dispatcher_instance
+
+
+# Module-level lazy singleton for AgentRegistry — same rationale.
+_registry_instance = None
+
+
+def _get_registry():
+ """Return a cached AgentRegistry instance (lazy singleton)."""
+ global _registry_instance
+ if _registry_instance is None:
+ from app.agent_framework.registry import AgentRegistry
+ _registry_instance = AgentRegistry()
+ return _registry_instance
+
class BaseAgent(ABC):
"""所有 Agent 的基类,定义标准生命周期"""
@@ -40,6 +67,10 @@ class BaseAgent(ABC):
def status(self) -> AgentStatus:
return self._status
+ @property
+ def is_distributed(self) -> bool:
+ return self._redis is not None
+
@abstractmethod
async def execute(self, task: TaskMessage) -> TaskResult:
"""执行任务的核心逻辑,子类必须实现"""
@@ -51,48 +82,47 @@ class BaseAgent(ABC):
...
async def start(self):
- """启动 Agent:注册到 Registry,开始监听任务队列"""
logger.info(f"Starting agent '{self.name}' (type={self.agent_type}, version={self.version})")
- # 初始化 Redis 连接
- self._redis = aioredis.from_url(
- settings.REDIS_URL,
- decode_responses=True,
- )
+ try:
+ self._redis = aioredis.from_url(
+ settings.REDIS_URL,
+ decode_responses=True,
+ )
+ await self._redis.ping()
- # 注册到 Registry
- from app.agent_framework.registry import AgentRegistry
+ registry = _get_registry()
+ capability = self.get_capabilities()
+ await registry.register(capability, endpoint=f"agent:{self.name}")
- registry = AgentRegistry()
- capability = self.get_capabilities()
- await registry.register(capability, endpoint=f"agent:{self.name}")
+ self._status = AgentStatus.ONLINE
- # 更新状态
- self._status = AgentStatus.ONLINE
+ capability = self.get_capabilities()
+ max_concurrency = getattr(capability, 'max_concurrency', 1) or 1
+ self._semaphore = asyncio.Semaphore(max_concurrency)
+ logger.info(
+ f"Agent '{self.name}' concurrency limit set to {max_concurrency}"
+ )
- # 根据 capabilities 的 max_concurrency 初始化 Semaphore
- capability = self.get_capabilities()
- max_concurrency = getattr(capability, 'max_concurrency', 1) or 1
- self._semaphore = asyncio.Semaphore(max_concurrency)
- logger.info(
- f"Agent '{self.name}' concurrency limit set to {max_concurrency}"
- )
+ self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
+ self._listen_task = asyncio.create_task(self._listen_for_tasks())
- # 启动心跳
- self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
+ logger.info(f"Agent '{self.name}' started in distributed mode")
+ except Exception as e:
+ self._redis = None
+ self._status = AgentStatus.ONLINE
- # 开始监听任务队列
- self._listen_task = asyncio.create_task(self._listen_for_tasks())
+ capability = self.get_capabilities()
+ max_concurrency = getattr(capability, 'max_concurrency', 1) or 1
+ self._semaphore = asyncio.Semaphore(max_concurrency)
- logger.info(f"Agent '{self.name}' started successfully")
+ logger.warning(f"Agent '{self.name}' started in local mode (Redis unavailable: {e})")
async def stop(self):
- """停止 Agent:注销,停止监听"""
logger.info(f"Stopping agent '{self.name}'")
self._status = AgentStatus.OFFLINE
- # 取消监听任务
if self._listen_task and not self._listen_task.done():
self._listen_task.cancel()
try:
@@ -100,7 +130,6 @@ class BaseAgent(ABC):
except asyncio.CancelledError:
pass
- # 取消心跳任务
if self._heartbeat_task and not self._heartbeat_task.done():
self._heartbeat_task.cancel()
try:
@@ -108,14 +137,10 @@ class BaseAgent(ABC):
except asyncio.CancelledError:
pass
- # 注销
- from app.agent_framework.registry import AgentRegistry
+ if self._redis is not None:
+ registry = _get_registry()
+ await registry.unregister(self.name)
- registry = AgentRegistry()
- await registry.unregister(self.name)
-
- # 关闭 Redis 连接
- if self._redis:
await self._redis.close()
self._redis = None
@@ -123,13 +148,10 @@ class BaseAgent(ABC):
async def heartbeat(self):
"""定期心跳上报"""
- from app.agent_framework.registry import AgentRegistry
-
- registry = AgentRegistry()
+ registry = _get_registry()
await registry.update_heartbeat(self.name)
async def report_progress(self, task_id: str, progress: float, message: str):
- """上报任务进度"""
progress_obj = TaskProgress(
task_id=task_id,
agent_name=self.name,
@@ -138,7 +160,6 @@ class BaseAgent(ABC):
updated_at=datetime.now(timezone.utc),
)
- # 通过 Redis Pub/Sub 发布进度
if self._redis:
try:
await self._redis.publish(
@@ -148,11 +169,11 @@ class BaseAgent(ABC):
except Exception as e:
logger.warning(f"Failed to publish progress for task {task_id}: {e}")
- # 同时更新数据库
- from app.agent_framework.dispatcher import TaskDispatcher
-
- dispatcher = TaskDispatcher(settings.REDIS_URL)
- await dispatcher.handle_progress(progress_obj)
+ try:
+ dispatcher = _get_dispatcher()
+ await dispatcher.handle_progress(progress_obj)
+ except Exception as e:
+ logger.warning(f"Failed to report progress to dispatcher for task {task_id}: {e}")
async def _heartbeat_loop(self):
"""心跳循环"""
@@ -198,7 +219,6 @@ class BaseAgent(ABC):
await self._execute_task(task)
async def _execute_task(self, task: TaskMessage):
- """执行单个任务"""
self._running_tasks.add(task.task_id)
self._status = AgentStatus.BUSY
@@ -209,15 +229,12 @@ class BaseAgent(ABC):
result.started_at = started_at
result.completed_at = datetime.now(timezone.utc)
- # 处理结果
- from app.agent_framework.dispatcher import TaskDispatcher
-
- dispatcher = TaskDispatcher(settings.REDIS_URL)
- await dispatcher.handle_result(result)
+ if self._redis is not None:
+ dispatcher = _get_dispatcher()
+ await dispatcher.handle_result(result)
except Exception as e:
logger.error(f"Agent '{self.name}' task {task.task_id} failed: {e}")
- # 构建失败结果
error_result = TaskResult(
task_id=task.task_id,
agent_name=self.name,
@@ -228,10 +245,9 @@ class BaseAgent(ABC):
completed_at=datetime.now(timezone.utc),
metrics=None,
)
- from app.agent_framework.dispatcher import TaskDispatcher
-
- dispatcher = TaskDispatcher(settings.REDIS_URL)
- await dispatcher.handle_result(error_result)
+ if self._redis is not None:
+ dispatcher = _get_dispatcher()
+ await dispatcher.handle_result(error_result)
finally:
self._running_tasks.discard(task.task_id)
diff --git a/backend/app/agent_framework/prompts/__init__.py b/backend/app/agent_framework/prompts/__init__.py
index 1c87d32..d6cc579 100644
--- a/backend/app/agent_framework/prompts/__init__.py
+++ b/backend/app/agent_framework/prompts/__init__.py
@@ -15,6 +15,7 @@ from .content_generator import CONTENT_GENERATOR_TEMPLATE
from .deai_agent import DEAI_TEMPLATE
from .geo_optimizer import GEO_OPTIMIZER_TEMPLATE
from .rule_checker import RULE_CHECKER_TEMPLATE
+from .schema_advisor import SCHEMA_ADVISOR_TEMPLATE
from .topic_selector import TOPIC_SELECTOR_TEMPLATE
__all__ = [
@@ -25,4 +26,5 @@ __all__ = [
"DEAI_TEMPLATE",
"GEO_OPTIMIZER_TEMPLATE",
"RULE_CHECKER_TEMPLATE",
+ "SCHEMA_ADVISOR_TEMPLATE",
]
diff --git a/backend/app/agent_framework/prompts/schema_advisor.py b/backend/app/agent_framework/prompts/schema_advisor.py
new file mode 100644
index 0000000..a8e1542
--- /dev/null
+++ b/backend/app/agent_framework/prompts/schema_advisor.py
@@ -0,0 +1,70 @@
+from .base_template import PromptSection, PromptTemplate
+
+SCHEMA_ADVISOR_TEMPLATE = PromptTemplate(
+ PromptSection(
+ identity="""你是一位精通Schema.org结构化数据和JSON-LD的技术专家。
+你深刻理解搜索引擎和AI模型(如ChatGPT、Perplexity、Kimi)如何解析和利用结构化数据,
+知道如何通过精准的Schema标记提升品牌在AI搜索结果中的可见性和引用率。
+你生成的JSON-LD严格遵循Schema.org规范,确保可被搜索引擎正确解析。""",
+
+ context="""## 品牌信息
+- 品牌名称:${brand_name}
+- 网站:${brand_website}
+- 行业:${brand_industry}
+
+## 诊断数据
+${diagnosis_data}
+
+## 已有Schema标记
+${existing_schemas}
+
+## 目标Schema类型
+${schema_type}""",
+
+ instructions="""请根据以上品牌信息和诊断数据,为品牌生成完整的JSON-LD结构化数据。
+
+生成要求:
+
+1. 内容填充:
+ - 所有字段必须填充真实、具体的内容,不得留空
+ - 品牌名称、网站等基本信息必须与提供的数据一致
+ - 描述性文本应当专业、准确,体现品牌特色
+
+2. Schema类型特定要求:
+ - Organization: 包含name, description, url, logo, sameAs(社交媒体链接), contactPoint
+ - Product: 包含name, description, brand, offers, aggregateRating(如有)
+ - FAQPage: 生成3-5个与品牌行业相关的高质量FAQ,问题和答案需自然且信息丰富
+ - Article: 包含headline, author, datePublished, description, image
+ - LocalBusiness: 包含name, address(完整地址结构), geo, telephone, openingHours
+
+3. 语言要求:
+ - 所有自然语言内容使用与品牌名称相同的语言
+ - 技术字段(如@type, @context)保持英文
+
+4. 结构完整性:
+ - 必须包含@context和@type
+ - 嵌套对象必须完整,不得省略必要子属性""",
+
+ constraints="""## 约束条件
+- 严格遵循Schema.org规范,不得使用非标准属性
+- @context必须为"https://schema.org"
+- @type必须是Schema.org定义的有效类型
+- 不得编造不存在的品牌信息(如无实际地址,LocalBusiness的address可使用占位结构)
+- FAQ的问题必须是用户真实可能搜索的问题
+- 所有URL字段如无实际值,留空字符串
+- 不得在JSON-LD中包含HTML标签""",
+
+ output_format="""## 输出格式
+请以JSON格式输出填充后的JSON-LD:
+
+```json
+{
+ "@context": "https://schema.org",
+ "@type": "...",
+ "...": "..."
+}
+```
+
+仅输出JSON-LD对象,不要包含任何解释文字。""",
+ )
+)
diff --git a/backend/app/agent_framework/protocol.py b/backend/app/agent_framework/protocol.py
index 624dd0c..82c5e09 100644
--- a/backend/app/agent_framework/protocol.py
+++ b/backend/app/agent_framework/protocol.py
@@ -6,7 +6,6 @@ from enum import Enum
class AgentType(str, Enum):
- """Agent 类型枚举"""
CITATION_DETECTOR = "citation_detector"
CONTENT_GENERATOR = "content_generator"
DEAI_AGENT = "deai_agent"
@@ -14,6 +13,9 @@ class AgentType(str, Enum):
RULE_CHECKER = "rule_checker"
COMPETITOR_ANALYZER = "competitor_analyzer"
PERFORMANCE_TRACKER = "performance_tracker"
+ SCHEMA_ADVISOR = "schema_advisor"
+ MONITOR_AGENT = "monitor_agent"
+ TREND_AGENT = "trend_agent"
class TaskStatus(str, Enum):
diff --git a/backend/app/agent_framework/standalone.py b/backend/app/agent_framework/standalone.py
new file mode 100644
index 0000000..330b2e3
--- /dev/null
+++ b/backend/app/agent_framework/standalone.py
@@ -0,0 +1,62 @@
+import asyncio
+import argparse
+import logging
+import sys
+
+from app.agent_framework.agents.citation_detector import CitationDetectorAgent
+from app.agent_framework.agents.content_generator_agent import ContentGeneratorAgent
+from app.agent_framework.agents.deai_agent import DeAIAgent
+from app.agent_framework.agents.geo_optimizer_agent import GEOOptimizerAgent
+from app.agent_framework.agents.monitor_agent import MonitorAgent
+from app.agent_framework.agents.schema_advisor import SchemaAdvisorAgent
+from app.agent_framework.agents.competitor_analyzer import CompetitorAnalyzerAgent
+from app.agent_framework.agents.trend_agent import TrendAgent
+
+AGENTS = {
+ "citation_detector": CitationDetectorAgent,
+ "content_generator": ContentGeneratorAgent,
+ "deai_agent": DeAIAgent,
+ "geo_optimizer": GEOOptimizerAgent,
+ "monitor": MonitorAgent,
+ "schema_advisor": SchemaAdvisorAgent,
+ "competitor_analyzer": CompetitorAnalyzerAgent,
+ "trend_agent": TrendAgent,
+ "all": None,
+}
+
+
+async def run_agent(name: str):
+ if name == "all":
+ agents = [cls() for cls in AGENTS.values() if cls is not None]
+ else:
+ cls = AGENTS.get(name)
+ if cls is None:
+ print(f"Unknown agent: {name}")
+ sys.exit(1)
+ agents = [cls()]
+
+ for agent in agents:
+ await agent.start()
+
+ print(f"Agent(s) running: {[a.name for a in agents]}")
+ try:
+ await asyncio.Future()
+ except asyncio.CancelledError:
+ pass
+ finally:
+ for agent in agents:
+ await agent.stop()
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Run GEO Agent(s)")
+ parser.add_argument("agent", choices=list(AGENTS.keys()), help="Agent name or 'all'")
+ parser.add_argument("--log-level", default="INFO", help="Logging level")
+ args = parser.parse_args()
+
+ logging.basicConfig(level=getattr(logging, args.log_level.upper()))
+ asyncio.run(run_agent(args.agent))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/backend/app/api/ai_engines.py b/backend/app/api/ai_engines.py
index 4a8f356..a6b8228 100644
--- a/backend/app/api/ai_engines.py
+++ b/backend/app/api/ai_engines.py
@@ -46,6 +46,7 @@ class QueryResultResponse(BaseModel):
brand_context: str | None
competitor_contexts: list[str]
response_time_ms: int
+ timestamp: str
model_config = {"from_attributes": True}
@@ -61,6 +62,7 @@ class CitationRateResponse(BaseModel):
class BatchQueryResponse(BaseModel):
results: list[QueryResultResponse]
citation_rate: CitationRateResponse
+ avg_response_time_ms: float = 0.0
def _result_to_response(r: AIQueryResult) -> QueryResultResponse:
@@ -73,6 +75,7 @@ def _result_to_response(r: AIQueryResult) -> QueryResultResponse:
brand_context=r.brand_context,
competitor_contexts=r.competitor_contexts,
response_time_ms=r.response_time_ms,
+ timestamp=r.timestamp.isoformat(),
)
@@ -97,9 +100,16 @@ async def _execute_batch(
engine_types, query, brand_name, competitor_names
)
citation_rate = service.calculate_citation_rate(results)
+ response_results = [_result_to_response(r) for r in results]
+ avg_response_time = (
+ sum(r.response_time_ms for r in results) / len(results)
+ if results
+ else 0.0
+ )
return BatchQueryResponse(
- results=[_result_to_response(r) for r in results],
+ results=response_results,
citation_rate=CitationRateResponse(**citation_rate),
+ avg_response_time_ms=avg_response_time,
)
diff --git a/backend/app/api/alerts.py b/backend/app/api/alerts.py
index c5fa5a1..f1cf5a7 100644
--- a/backend/app/api/alerts.py
+++ b/backend/app/api/alerts.py
@@ -24,7 +24,7 @@ from app.schemas.alert import (
AlertSettingListResponse,
AlertSettingBulkUpdate,
)
-from app.services.alert_engine import AlertEngine
+from app.services.alert.alert_engine import AlertEngine
logger = logging.getLogger(__name__)
diff --git a/backend/app/api/api_keys.py b/backend/app/api/api_keys.py
index 667ca82..78b56ac 100644
--- a/backend/app/api/api_keys.py
+++ b/backend/app/api/api_keys.py
@@ -6,7 +6,7 @@ from pydantic import BaseModel
from app.api.deps import get_current_user
from app.models.user import User
from app.services.api_key_manager import APIKeyManager, KeySource, KeyStatus
-from app.services.smart_router import ENGINE_COST_PROFILES
+from app.services.llm.smart_router import ENGINE_COST_PROFILES
logger = logging.getLogger(__name__)
diff --git a/backend/app/api/attribution.py b/backend/app/api/attribution.py
new file mode 100644
index 0000000..31a1d6d
--- /dev/null
+++ b/backend/app/api/attribution.py
@@ -0,0 +1,300 @@
+import logging
+import uuid
+from datetime import datetime
+
+from fastapi import APIRouter, Depends, HTTPException, status
+from pydantic import BaseModel
+from sqlalchemy import select
+from sqlalchemy.ext.asyncio import AsyncSession
+
+from app.api.deps import get_current_user
+from app.database import get_db
+from app.models.attribution_record import AttributionRecord
+from app.models.brand import Brand
+from app.models.diagnosis_record import DiagnosisRecord
+from app.models.user import User
+from app.services.attribution.attribution_engine import AttributionEngine
+from app.services.attribution.roi_calculator import ROICalculator
+
+logger = logging.getLogger(__name__)
+
+router = APIRouter()
+
+PLAN_COSTS = {
+ "free": 0.0,
+ "starter": 99.0,
+ "pro": 299.0,
+ "enterprise": 999.0,
+}
+
+
+def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
+ if isinstance(value, uuid.UUID):
+ return value
+ return uuid.UUID(str(value))
+
+
+class StartTrackingRequest(BaseModel):
+ brand_id: str
+ content_id: str | None = None
+
+
+class AttributionResponse(BaseModel):
+ id: str
+ brand_id: str
+ content_id: str | None
+ baseline_score: float
+ current_score: float | None
+ score_delta: float | None
+ status: str
+ roi_percentage: float | None
+ created_at: datetime
+
+ model_config = {"from_attributes": True}
+
+
+class ROIReport(BaseModel):
+ brand_id: str
+ brand_name: str
+ subscription_cost: float
+ current_plan: str
+ total_score_delta: float
+ value_generated: float
+ roi_percentage: float
+ break_even_delta: float
+ tracking_records: list[AttributionResponse]
+ ab_comparison: dict | None
+
+
+class ABComparisonResponse(BaseModel):
+ brand_id: str
+ brand_name: str
+ overall_before: float
+ overall_after: float
+ overall_delta: float
+ dimensions: list[dict]
+
+
+def _record_to_response(record: AttributionRecord) -> AttributionResponse:
+ return AttributionResponse(
+ id=str(record.id),
+ brand_id=str(record.brand_id),
+ content_id=str(record.content_id) if record.content_id else None,
+ baseline_score=record.baseline_score,
+ current_score=record.current_score,
+ score_delta=record.score_delta,
+ status=record.status,
+ roi_percentage=record.roi_percentage,
+ created_at=record.created_at,
+ )
+
+
+async def _get_brand_or_404(
+ brand_id: uuid.UUID,
+ current_user: User,
+ db: AsyncSession,
+) -> Brand:
+ user_uuid = _to_uuid(current_user.id)
+ stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == user_uuid)
+ result = await db.execute(stmt)
+ brand = result.scalar_one_or_none()
+ if not brand:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="品牌不存在",
+ )
+ return brand
+
+
+@router.post("/start", response_model=AttributionResponse)
+async def start_tracking(
+ body: StartTrackingRequest,
+ current_user: User = Depends(get_current_user),
+ db: AsyncSession = Depends(get_db),
+):
+ brand_id = _to_uuid(body.brand_id)
+ brand = await _get_brand_or_404(brand_id, current_user, db)
+
+ content_id = _to_uuid(body.content_id) if body.content_id else None
+
+ engine = AttributionEngine()
+ record = await engine.start_tracking(db, brand.id, content_id, current_user.id)
+ return _record_to_response(record)
+
+
+@router.get("/brand/{brand_id}")
+async def get_brand_attribution(
+ brand_id: uuid.UUID,
+ current_user: User = Depends(get_current_user),
+ db: AsyncSession = Depends(get_db),
+):
+ brand = await _get_brand_or_404(brand_id, current_user, db)
+
+ engine = AttributionEngine()
+ summary = await engine.get_brand_attribution_summary(db, brand.id)
+ summary["records"] = [_record_to_response(r) for r in summary["records"]]
+ return summary
+
+
+@router.get("/{record_id}", response_model=AttributionResponse)
+async def get_attribution_record(
+ record_id: uuid.UUID,
+ current_user: User = Depends(get_current_user),
+ db: AsyncSession = Depends(get_db),
+):
+ stmt = select(AttributionRecord).where(AttributionRecord.id == record_id)
+ result = await db.execute(stmt)
+ record = result.scalar_one_or_none()
+ if not record:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="归因记录不存在",
+ )
+ return _record_to_response(record)
+
+
+@router.post("/{record_id}/check", response_model=AttributionResponse)
+async def check_attribution(
+ record_id: uuid.UUID,
+ current_user: User = Depends(get_current_user),
+ db: AsyncSession = Depends(get_db),
+):
+ engine = AttributionEngine()
+ try:
+ record = await engine.check_attribution(db, record_id)
+ except ValueError:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="归因记录不存在",
+ )
+ return _record_to_response(record)
+
+
+@router.get("/roi/{brand_id}", response_model=ROIReport)
+async def get_roi_report(
+ brand_id: uuid.UUID,
+ current_user: User = Depends(get_current_user),
+ db: AsyncSession = Depends(get_db),
+):
+ brand = await _get_brand_or_404(brand_id, current_user, db)
+
+ engine = AttributionEngine()
+ summary = await engine.get_brand_attribution_summary(db, brand.id)
+
+ user_plan = getattr(current_user, "plan", "free") or "free"
+ subscription_cost = PLAN_COSTS.get(user_plan, 0.0)
+
+ calculator = ROICalculator()
+ roi_data = calculator.calculate_roi(
+ subscription_cost=subscription_cost,
+ score_delta=summary["total_score_delta"],
+ attribution_records=summary["records"],
+ )
+
+ ab_comparison = None
+ baseline_record = (
+ await db.execute(
+ select(DiagnosisRecord)
+ .where(
+ DiagnosisRecord.brand_id == brand.id,
+ DiagnosisRecord.status == "completed",
+ )
+ .order_by(DiagnosisRecord.completed_at.asc())
+ .limit(1)
+ )
+ ).scalar_one_or_none()
+ latest_record = (
+ await db.execute(
+ select(DiagnosisRecord)
+ .where(
+ DiagnosisRecord.brand_id == brand.id,
+ DiagnosisRecord.status == "completed",
+ )
+ .order_by(DiagnosisRecord.completed_at.desc())
+ .limit(1)
+ )
+ ).scalar_one_or_none()
+
+ if baseline_record and latest_record and baseline_record.id != latest_record.id:
+ before_dims = baseline_record.result_json.get("dimensions", []) if baseline_record.result_json else []
+ after_dims = latest_record.result_json.get("dimensions", []) if latest_record.result_json else []
+ before_map = {d.get("name"): d for d in before_dims}
+ after_map = {d.get("name"): d for d in after_dims}
+ ab_comparison = calculator.generate_ab_comparison(
+ before_score=baseline_record.overall_score or 0,
+ after_score=latest_record.overall_score or 0,
+ before_dimensions=before_map,
+ after_dimensions=after_map,
+ )
+
+ return ROIReport(
+ brand_id=str(brand.id),
+ brand_name=brand.name,
+ subscription_cost=subscription_cost,
+ current_plan=user_plan,
+ total_score_delta=summary["total_score_delta"],
+ value_generated=roi_data["value_generated"],
+ roi_percentage=roi_data["roi_percentage"],
+ break_even_delta=roi_data["break_even_delta"],
+ tracking_records=[_record_to_response(r) for r in summary["records"]],
+ ab_comparison=ab_comparison,
+ )
+
+
+@router.get("/ab-comparison/{brand_id}", response_model=ABComparisonResponse)
+async def get_ab_comparison(
+ brand_id: uuid.UUID,
+ current_user: User = Depends(get_current_user),
+ db: AsyncSession = Depends(get_db),
+):
+ brand = await _get_brand_or_404(brand_id, current_user, db)
+
+ baseline_record = (
+ await db.execute(
+ select(DiagnosisRecord)
+ .where(
+ DiagnosisRecord.brand_id == brand.id,
+ DiagnosisRecord.status == "completed",
+ )
+ .order_by(DiagnosisRecord.completed_at.asc())
+ .limit(1)
+ )
+ ).scalar_one_or_none()
+ latest_record = (
+ await db.execute(
+ select(DiagnosisRecord)
+ .where(
+ DiagnosisRecord.brand_id == brand.id,
+ DiagnosisRecord.status == "completed",
+ )
+ .order_by(DiagnosisRecord.completed_at.desc())
+ .limit(1)
+ )
+ ).scalar_one_or_none()
+
+ if not baseline_record or not latest_record:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="暂无诊断数据,无法生成A/B对比",
+ )
+
+ calculator = ROICalculator()
+ before_dims = baseline_record.result_json.get("dimensions", []) if baseline_record.result_json else []
+ after_dims = latest_record.result_json.get("dimensions", []) if latest_record.result_json else []
+ before_map = {d.get("name"): d for d in before_dims}
+ after_map = {d.get("name"): d for d in after_dims}
+ comparison = calculator.generate_ab_comparison(
+ before_score=baseline_record.overall_score or 0,
+ after_score=latest_record.overall_score or 0,
+ before_dimensions=before_map,
+ after_dimensions=after_map,
+ )
+
+ return ABComparisonResponse(
+ brand_id=str(brand.id),
+ brand_name=brand.name,
+ overall_before=comparison["overall_before"],
+ overall_after=comparison["overall_after"],
+ overall_delta=comparison["overall_delta"],
+ dimensions=comparison["dimensions"],
+ )
diff --git a/backend/app/api/brands.py b/backend/app/api/brands.py
index fa704c0..7a2deab 100644
--- a/backend/app/api/brands.py
+++ b/backend/app/api/brands.py
@@ -1,5 +1,6 @@
"""Brands API endpoints."""
import json
+import logging
import uuid
from typing import Annotated
@@ -14,9 +15,12 @@ from app.api.scoring import router as scoring_router
from app.database import get_db
from app.models.user import User
from app.models.brand import Brand
+from app.models.query import Query as QueryModel
from app.schemas.brand import BrandCreate, BrandUpdate, BrandResponse, BrandListResponse
from app.services.cache import get_cache_service, TTL_BRANDS
+logger = logging.getLogger(__name__)
+
router = APIRouter()
# Include competitors router under brands
@@ -127,6 +131,28 @@ async def update_brand(
# Update only provided fields
update_data = brand_data.model_dump(exclude_unset=True)
+
+ # 检测品牌名称变更,同步更新关联 Query 的 target_brand 和 brand_aliases
+ old_name = brand.name
+ new_name = update_data.get("name")
+
+ if new_name and new_name != old_name:
+ queries_stmt = select(QueryModel).where(
+ QueryModel.user_id == current_user.id,
+ QueryModel.target_brand == old_name,
+ )
+ queries_result = await db.execute(queries_stmt)
+ related_queries = queries_result.scalars().all()
+
+ for query in related_queries:
+ query.target_brand = new_name
+ # 如果 brand_aliases 中包含旧名称,也同步更新
+ if query.brand_aliases and old_name in query.brand_aliases:
+ query.brand_aliases = [new_name if a == old_name else a for a in query.brand_aliases]
+
+ if related_queries:
+ logger.info(f"Brand renamed from '{old_name}' to '{new_name}', synced {len(related_queries)} queries")
+
for field, value in update_data.items():
setattr(brand, field, value)
diff --git a/backend/app/api/citations.py b/backend/app/api/citations.py
index b2fb949..beba568 100644
--- a/backend/app/api/citations.py
+++ b/backend/app/api/citations.py
@@ -13,7 +13,7 @@ from app.schemas.citation import (
CitationListResponse,
CitationStatsResponse,
)
-from app.services.citation import (
+from app.services.citation.citation import (
get_citation_stats,
get_citations,
)
diff --git a/backend/app/api/competitor_analysis.py b/backend/app/api/competitor_analysis.py
new file mode 100644
index 0000000..6534927
--- /dev/null
+++ b/backend/app/api/competitor_analysis.py
@@ -0,0 +1,135 @@
+import logging
+import uuid
+
+from fastapi import APIRouter, Depends, HTTPException, Query, status
+from sqlalchemy import select, func
+from sqlalchemy.ext.asyncio import AsyncSession
+
+from app.api.deps import get_current_user
+from app.database import get_db
+from app.models.user import User
+from app.models.brand import Brand
+from app.models.competitor_insight import CompetitorInsight
+from app.schemas.competitor_insight import (
+ CompetitorAnalysisRequest,
+ CompetitorInsightResponse,
+ CompetitorInsightList,
+ CompetitorGapSummary,
+)
+from app.services.competitor.competitor_analyzer_service import CompetitorAnalyzerService
+
+logger = logging.getLogger(__name__)
+
+router = APIRouter()
+
+
+async def _get_brand_if_owned(
+ brand_id: uuid.UUID,
+ current_user: User,
+ db: AsyncSession,
+) -> Brand:
+ stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == current_user.id)
+ result = await db.execute(stmt)
+ brand = result.scalar_one_or_none()
+ if not brand:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="品牌不存在",
+ )
+ return brand
+
+
+@router.post("/analyze", response_model=CompetitorInsightList)
+async def analyze_competitor(
+ request: CompetitorAnalysisRequest,
+ current_user: User = Depends(get_current_user),
+ db: AsyncSession = Depends(get_db),
+):
+ await _get_brand_if_owned(request.brand_id, current_user, db)
+
+ service = CompetitorAnalyzerService()
+ try:
+ result = await service.analyze_competitor(
+ brand_id=request.brand_id,
+ analysis_types=request.analysis_types,
+ period_days=request.period_days or 30,
+ )
+ except ValueError as e:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=str(e),
+ )
+
+ stmt = (
+ select(CompetitorInsight)
+ .where(CompetitorInsight.brand_id == request.brand_id)
+ .order_by(CompetitorInsight.created_at.desc())
+ )
+ db_result = await db.execute(stmt)
+ insights = list(db_result.scalars().all())
+
+ return {"items": insights, "total": len(insights)}
+
+
+@router.get("/brand/{brand_id}", response_model=CompetitorInsightList)
+async def get_brand_insights(
+ brand_id: uuid.UUID,
+ skip: int = Query(0, ge=0),
+ limit: int = Query(20, ge=1, le=100),
+ current_user: User = Depends(get_current_user),
+ db: AsyncSession = Depends(get_db),
+):
+ await _get_brand_if_owned(brand_id, current_user, db)
+
+ count_stmt = select(func.count()).select_from(CompetitorInsight).where(
+ CompetitorInsight.brand_id == brand_id,
+ )
+ count_result = await db.execute(count_stmt)
+ total = count_result.scalar_one()
+
+ stmt = (
+ select(CompetitorInsight)
+ .where(CompetitorInsight.brand_id == brand_id)
+ .order_by(CompetitorInsight.created_at.desc())
+ .offset(skip)
+ .limit(limit)
+ )
+ result = await db.execute(stmt)
+ insights = list(result.scalars().all())
+
+ return {"items": insights, "total": total}
+
+
+@router.get("/{insight_id}", response_model=CompetitorInsightResponse)
+async def get_insight_detail(
+ insight_id: uuid.UUID,
+ current_user: User = Depends(get_current_user),
+ db: AsyncSession = Depends(get_db),
+):
+ stmt = select(CompetitorInsight).where(CompetitorInsight.id == insight_id)
+ result = await db.execute(stmt)
+ insight = result.scalar_one_or_none()
+
+ if not insight:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="洞察不存在",
+ )
+
+ await _get_brand_if_owned(insight.brand_id, current_user, db)
+
+ return insight
+
+
+@router.get("/brand/{brand_id}/gap-summary", response_model=list[CompetitorGapSummary])
+async def get_gap_summary(
+ brand_id: uuid.UUID,
+ current_user: User = Depends(get_current_user),
+ db: AsyncSession = Depends(get_db),
+):
+ brand = await _get_brand_if_owned(brand_id, current_user, db)
+
+ service = CompetitorAnalyzerService()
+ gap_summaries = await service.calculate_gap_score(db, brand_id, brand.name)
+
+ return gap_summaries
diff --git a/backend/app/api/competitors.py b/backend/app/api/competitors.py
index 31b3d22..38eaba9 100644
--- a/backend/app/api/competitors.py
+++ b/backend/app/api/competitors.py
@@ -21,6 +21,7 @@ from app.schemas.competitor import (
CompetitorRecommendationItem,
CompetitorRecommendationResponse,
)
+from app.utils.json_extractor import extract_json
logger = logging.getLogger(__name__)
@@ -363,7 +364,7 @@ async def _get_llm_recommendations(
raise ValueError("API返回空响应")
# 提取JSON
- json_str = _extract_json_from_text(content)
+ json_str = extract_json(content)
data = json.loads(json_str)
recommendations = []
@@ -393,32 +394,6 @@ async def _get_llm_recommendations(
raise
-def _extract_json_from_text(text: str) -> str:
- """从文本中提取JSON字符串。"""
- import re
-
- # 尝试直接解析
- try:
- json.loads(text)
- return text
- except json.JSONDecodeError:
- pass
-
- # 尝试从代码块中提取
- json_pattern = r'```(?:json)?\s*([\s\S]*?)\s*```'
- match = re.search(json_pattern, text)
- if match:
- return match.group(1).strip()
-
- # 尝试找到第一个{到最后一个}之间的内容
- first_brace = text.find('{')
- last_brace = text.rfind('}')
- if first_brace != -1 and last_brace != -1 and last_brace > first_brace:
- return text[first_brace:last_brace + 1]
-
- raise ValueError(f"无法从响应中提取JSON: {text[:200]}")
-
-
def _get_industry_label(industry: str | None) -> str:
"""获取行业中文标签。"""
labels = {
diff --git a/backend/app/api/content.py b/backend/app/api/content.py
index afda162..5549b15 100644
--- a/backend/app/api/content.py
+++ b/backend/app/api/content.py
@@ -1,18 +1,28 @@
-"""内容生产API - 串联Agent Pipeline"""
+"""内容生产API - 串联Agent Pipeline
+
+业务逻辑已委托给 ContentGenerationService,API 层仅负责:
+1. 请求解析与参数校验
+2. 调用服务层
+3. 格式化响应
+"""
import json
import logging
import re
+import uuid
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
-from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.database import get_db
+from app.models.brand import Brand
from app.models.content import Content, ContentVersion
+from app.models.diagnosis_record import DiagnosisRecord
from app.models.user import User
+from app.services.content.content_generation_service import ContentGenerationService
+from app.services.llm import LLMError
logger = logging.getLogger(__name__)
@@ -29,6 +39,7 @@ class ContentGenerateRequest(BaseModel):
brand_description: str = ""
run_deai: bool = True
run_geo: bool = True
+ use_agent_framework: bool = False
class ContentGenerateResponse(BaseModel):
@@ -41,44 +52,6 @@ class ContentGenerateResponse(BaseModel):
pipeline_stages: list[dict] = [] # 每个阶段的执行结果摘要
-async def _get_knowledge_context(
- db: AsyncSession,
- brand_name: str,
- knowledge_base_ids: list[str],
- target_keyword: str,
-) -> str:
- """
- 从知识库检索与查询相关的上下文。
-
- 如果有知识库ID,则调用 RAGService.search 获取相关内容;
- 否则返回空字符串,不影响后续流程。
- """
- if not knowledge_base_ids:
- return ""
-
- try:
- from app.services.knowledge.rag_service import RAGService
- rag_service = RAGService()
- results = await rag_service.search(
- session=db,
- query=f"{brand_name} {target_keyword}" if brand_name else target_keyword,
- knowledge_base_ids=knowledge_base_ids,
- top_k=3,
- )
- if results:
- context_parts = []
- for r in results:
- content = r.get("content", "")
- title = r.get("document_title", "")
- if content:
- context_parts.append(f"[{title}] {content}")
- return "\n".join(context_parts)
- return ""
- except Exception as e:
- logger.warning(f"知识库检索失败,将不使用知识库上下文: {e}")
- return ""
-
-
@router.post("/generate", response_model=ContentGenerateResponse)
async def generate_content(
req: ContentGenerateRequest,
@@ -88,115 +61,132 @@ async def generate_content(
"""
一键生成内容(同步执行Pipeline),结果存入数据库
- 流程:ContentGenerator → DeAI → GEOOptimizer
+ 流程:ContentGenerator -> DeAI -> GEOOptimizer
+ 业务逻辑委托给 ContentGenerationService
"""
- from app.services.llm import LLMError, LLMFactory
- from app.agent_framework.prompts import (
- CONTENT_GENERATOR_TEMPLATE,
- DEAI_TEMPLATE,
- GEO_OPTIMIZER_TEMPLATE,
- )
-
org_id = getattr(current_user, "organization_id", None)
if not org_id:
raise HTTPException(status_code=403, detail="用户未关联组织")
- stages = []
-
try:
- provider = LLMFactory.get_default()
-
- # 获取知识库上下文
- knowledge_context = await _get_knowledge_context(
- db, req.brand_name, req.knowledge_base_ids, req.target_keyword
+ service = ContentGenerationService()
+ result = await service.generate_content(
+ keyword=req.target_keyword,
+ brand_name=req.brand_name,
+ platform=req.target_platform,
+ content_style=req.content_style,
+ word_count=req.word_count,
+ knowledge_base_ids=req.knowledge_base_ids,
+ db=db,
+ user_id=current_user.id,
+ org_id=org_id,
+ run_deai=req.run_deai,
+ run_geo=req.run_geo,
+ use_agent_framework=req.use_agent_framework,
)
- # Stage 1: 内容生成
- gen_variables = {
- "topic_title": req.target_keyword,
- "target_keyword": req.target_keyword,
- "target_platform": req.target_platform,
- "content_angle": "综合分析",
- "content_style": req.content_style,
- "word_count": str(req.word_count),
- "brand_name": req.brand_name,
- "knowledge_context": knowledge_context,
- }
- messages = CONTENT_GENERATOR_TEMPLATE.render(gen_variables)
- response = await provider.chat(messages, temperature=0.7, max_tokens=req.word_count * 2)
- content = response.content
- stages.append({"stage": "content_generation", "status": "success", "word_count": len(content)})
-
- # Stage 2: 去AI化(可选)
- if req.run_deai:
- deai_variables = {
- "original_content": content,
- "target_style": "自然流畅",
- "preserve_structure": "是",
- }
- messages = DEAI_TEMPLATE.render(deai_variables)
- response = await provider.chat(messages, temperature=0.9, max_tokens=len(content) * 2)
- content = response.content
- stages.append({"stage": "deai", "status": "success"})
-
- # Stage 3: GEO优化(可选)
- optimized = content
- seo_score = None
- if req.run_geo:
- geo_variables = {
- "original_content": content,
- "target_keywords": req.target_keyword,
- "target_platform": req.target_platform,
- "optimization_level": "moderate",
- }
- messages = GEO_OPTIMIZER_TEMPLATE.render(geo_variables)
- response = await provider.chat(messages, temperature=0.5, max_tokens=len(content) * 2)
- optimized = response.content
- stages.append({"stage": "geo_optimization", "status": "success"})
-
- # ---- 存入数据库 ----
- content_obj = Content(
- organization_id=org_id,
- title=req.target_keyword,
- content_type="article",
- body=optimized,
- status="draft",
- target_platforms=[req.target_platform] if req.target_platform else [],
- keywords=[req.target_keyword],
- extra_metadata={
- "original_content": content if content != optimized else None,
- "pipeline_stages": stages,
- "seo_score": seo_score,
- "brand_name": req.brand_name,
- "content_style": req.content_style,
- "word_count_target": req.word_count,
- },
- created_by=current_user.id,
- current_version=1,
- )
- db.add(content_obj)
- await db.flush() # get content_obj.id
-
- # 创建版本记录(初始版本)
- version = ContentVersion(
- content_id=content_obj.id,
- version_number=1,
- title=req.target_keyword,
- body=optimized,
- change_summary="Pipeline自动生成",
- created_by=current_user.id,
- )
- db.add(version)
- await db.commit()
- await db.refresh(content_obj)
-
return ContentGenerateResponse(
status="success",
- content=content,
- optimized_content=optimized,
- seo_score=seo_score,
- content_id=str(content_obj.id),
- pipeline_stages=stages,
+ content=result["content"],
+ optimized_content=result["optimized_content"],
+ seo_score=result["seo_score"],
+ content_id=result["content_id"],
+ pipeline_stages=result["pipeline_stages"],
+ )
+
+ except LLMError as e:
+ raise HTTPException(status_code=502, detail=f"LLM调用失败: {str(e)}")
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=f"内容生成异常: {str(e)}")
+
+
+class GEOContentGenerateRequest(BaseModel):
+ brand_id: str
+ target_keywords: list[str]
+ platform: str = "通用"
+ content_style: str = "专业严谨"
+ word_count: int = 2000
+ knowledge_base_ids: list[str] = []
+ run_deai: bool = True
+ run_geo: bool = True
+
+
+class GEOContentGenerateResponse(BaseModel):
+ content_id: Optional[str] = None
+ content: str = ""
+ optimized_content: str = ""
+ seo_score: Optional[int] = None
+ pipeline_stages: list[dict] = []
+
+
+@router.post("/generate-geo", response_model=GEOContentGenerateResponse, status_code=201)
+async def generate_geo_content(
+ req: GEOContentGenerateRequest,
+ db: AsyncSession = Depends(get_db),
+ current_user: User = Depends(get_current_user),
+):
+ org_id = getattr(current_user, "organization_id", None)
+ if not org_id:
+ raise HTTPException(status_code=403, detail="用户未关联组织")
+
+ from sqlalchemy import select
+
+ try:
+ brand_uuid = uuid.UUID(req.brand_id)
+ except ValueError:
+ raise HTTPException(status_code=400, detail=f"Invalid brand_id format: {req.brand_id}")
+
+ brand_stmt = select(Brand).where(Brand.id == brand_uuid)
+ brand_result = await db.execute(brand_stmt)
+ brand = brand_result.scalar_one_or_none()
+ if not brand:
+ raise HTTPException(status_code=404, detail=f"Brand not found: {req.brand_id}")
+
+ diagnosis_context = ""
+ diag_stmt = (
+ select(DiagnosisRecord)
+ .where(DiagnosisRecord.brand_id == brand_uuid, DiagnosisRecord.status == "completed")
+ .order_by(DiagnosisRecord.created_at.desc())
+ )
+ diag_result = await db.execute(diag_stmt)
+ diagnosis = diag_result.scalar_one_or_none()
+ if diagnosis and diagnosis.result_json:
+ result_json = diagnosis.result_json
+ weak_dimensions = []
+ if isinstance(result_json, dict):
+ dimensions = result_json.get("dimensions", {})
+ for dim_name, dim_data in dimensions.items():
+ if isinstance(dim_data, dict) and dim_data.get("score", 100) < 60:
+ weak_dimensions.append(dim_name)
+ if weak_dimensions:
+ diagnosis_context = f"基于诊断结果,以下维度需要重点优化:{', '.join(weak_dimensions)}。请围绕这些维度生成针对性内容。"
+
+ keyword = "、".join(req.target_keywords)
+ if diagnosis_context:
+ keyword = f"{keyword}({diagnosis_context})"
+
+ try:
+ service = ContentGenerationService()
+ result = await service.generate_content(
+ keyword=keyword,
+ brand_name=brand.name,
+ platform=req.platform,
+ content_style=req.content_style,
+ word_count=req.word_count,
+ knowledge_base_ids=req.knowledge_base_ids,
+ db=db,
+ user_id=current_user.id,
+ org_id=org_id,
+ run_deai=req.run_deai,
+ run_geo=req.run_geo,
+ )
+
+ return GEOContentGenerateResponse(
+ content_id=result["content_id"],
+ content=result["content"],
+ optimized_content=result["optimized_content"],
+ seo_score=result["seo_score"],
+ pipeline_stages=result["pipeline_stages"],
)
except LLMError as e:
diff --git a/backend/app/api/dashboard.py b/backend/app/api/dashboard.py
index 70760c9..954bedf 100644
--- a/backend/app/api/dashboard.py
+++ b/backend/app/api/dashboard.py
@@ -1,6 +1,5 @@
"""Dashboard API endpoints."""
import uuid
-from datetime import datetime, timedelta
from fastapi import APIRouter, Depends, Query
from sqlalchemy import select, func, Integer
@@ -19,325 +18,15 @@ from app.schemas.dashboard import (
PlatformScoreItem,
RecentQueryItem,
)
-from app.services.scoring_service import ScoringService, get_health_level
-from app.services.sentiment_service import get_sentiment_service
-from app.schemas.scoring import CitationResult
+from app.services.scoring.scoring_service import get_health_level
+from app.services.scoring.brand_scoring_data_service import (
+ get_brand_scoring_data_service,
+ REQUIRED_PLATFORMS,
+)
from app.services.cache import get_cache_service, TTL_DASHBOARD
router = APIRouter()
-# Required platforms for the dashboard
-REQUIRED_PLATFORMS = [
- "wenxin", # 文心一言
- "kimi", # Kimi
- "tongyi", # 通义千问
- "doubao", # 豆包
- "xinghuo", # 讯飞星火
- "tiangong", # 天工AI
- "qingyan", # 智谱清言
-]
-
-
-async def _get_brand_score_by_platform(
- db: AsyncSession,
- user_id: uuid.UUID,
- brand_id: uuid.UUID,
-) -> dict[str, float]:
- """
- Calculate brand score by platform.
-
- Returns a dict mapping platform key to score (0-100).
- """
- brand_stmt = select(Brand).where(
- Brand.id == brand_id,
- Brand.user_id == user_id,
- )
- brand_result = await db.execute(brand_stmt)
- brand = brand_result.scalar_one_or_none()
-
- if not brand:
- return {platform: 0.0 for platform in REQUIRED_PLATFORMS}
-
- queries_stmt = select(QueryModel).where(
- QueryModel.user_id == user_id,
- QueryModel.target_brand == brand.name,
- )
- queries_result = await db.execute(queries_stmt)
- queries = list(queries_result.scalars().all())
-
- if not queries:
- return {platform: 0.0 for platform in REQUIRED_PLATFORMS}
-
- query_ids = [q.id for q in queries]
-
- platform_scores = {platform: 0.0 for platform in REQUIRED_PLATFORMS}
-
- for platform in REQUIRED_PLATFORMS:
- citation_stmt = select(
- func.count().label("total"),
- func.sum(
- func.cast(
- func.case((CitationRecord.cited == True, 1), else_=0),
- Integer
- )
- ).label("cited")
- ).where(
- CitationRecord.query_id.in_(query_ids),
- CitationRecord.platform == platform,
- )
- result = await db.execute(citation_stmt)
- row = result.one()
-
- total = row.total or 0
- cited = row.cited or 0
-
- if total > 0:
- citation_rate = cited / total
- platform_scores[platform] = round(citation_rate * 100, 2)
- else:
- platform_scores[platform] = 0.0
-
- return platform_scores
-
-
-async def _get_competitor_scores_by_platform(
- db: AsyncSession,
- user_id: uuid.UUID,
- brand_id: uuid.UUID,
-) -> dict[str, float]:
- """
- Calculate average competitor score by platform.
-
- Returns a dict mapping platform key to average competitor score (0-100).
- """
- competitor_stmt = select(Competitor).where(Competitor.brand_id == brand_id)
- competitor_result = await db.execute(competitor_stmt)
- competitors = list(competitor_result.scalars().all())
-
- if not competitors:
- return {platform: 0.0 for platform in REQUIRED_PLATFORMS}
-
- competitor_names = [c.name for c in competitors]
-
- competitor_queries_stmt = select(QueryModel).where(
- QueryModel.user_id == user_id,
- QueryModel.target_brand.in_(competitor_names),
- )
- competitor_queries_result = await db.execute(competitor_queries_stmt)
- competitor_queries = list(competitor_queries_result.scalars().all())
-
- if not competitor_queries:
- return {platform: 0.0 for platform in REQUIRED_PLATFORMS}
-
- competitor_query_ids = [q.id for q in competitor_queries]
-
- platform_scores = {platform: 0.0 for platform in REQUIRED_PLATFORMS}
-
- for platform in REQUIRED_PLATFORMS:
- citation_stmt = select(
- func.count().label("total"),
- func.sum(
- func.cast(
- func.case((CitationRecord.cited == True, 1), else_=0),
- Integer
- )
- ).label("cited")
- ).where(
- CitationRecord.query_id.in_(competitor_query_ids),
- CitationRecord.platform == platform,
- )
- result = await db.execute(citation_stmt)
- row = result.one()
-
- total = row.total or 0
- cited = row.cited or 0
-
- if total > 0:
- citation_rate = cited / total
- platform_scores[platform] = round(citation_rate * 100, 2)
- else:
- platform_scores[platform] = 0.0
-
- return platform_scores
-
-
-async def _calculate_overall_score_v2(
- db: AsyncSession,
- user_id: uuid.UUID,
- brand_id: uuid.UUID,
-) -> tuple[float, float, list[DimensionScoreItem], int, int]:
- """
- Calculate V2 overall score, change from yesterday, dimension scores,
- and competitor position.
-
- Returns: (overall_score, change_from_yesterday, dimensions, ahead_count, behind_count)
- """
- brand_stmt = select(Brand).where(
- Brand.id == brand_id,
- Brand.user_id == user_id,
- )
- brand_result = await db.execute(brand_stmt)
- brand = brand_result.scalar_one_or_none()
-
- if not brand:
- return 0.0, 0.0, [], 0, 0
-
- # Get queries for this brand
- queries_stmt = select(QueryModel).where(
- QueryModel.user_id == user_id,
- QueryModel.target_brand == brand.name,
- )
- queries_result = await db.execute(queries_stmt)
- queries = list(queries_result.scalars().all())
-
- if not queries:
- return 0.0, 0.0, [], 0, 0
-
- query_ids = [q.id for q in queries]
-
- # Get all citations
- citations_stmt = select(CitationRecord).where(
- CitationRecord.query_id.in_(query_ids),
- )
- citations_result = await db.execute(citations_stmt)
- all_citations = list(citations_result.scalars().all())
-
- total_queries = len(all_citations)
- brand_citations = [c for c in all_citations if c.cited]
-
- # Get competitor data
- competitor_stmt = select(Competitor).where(Competitor.brand_id == brand_id)
- competitor_result = await db.execute(competitor_stmt)
- competitors = list(competitor_result.scalars().all())
- competitor_names = [c.name for c in competitors]
-
- # Calculate competitor mentions
- competitor_mentions: dict[str, int] = {}
- for comp_name in competitor_names:
- count = sum(
- 1 for c in all_citations
- if c.cited and c.competitor_brands
- and comp_name in c.competitor_brands
- )
- if count > 0:
- competitor_mentions[comp_name] = count
-
- # Sentiment analysis - prefer persisted sentiment field
- sentiment_service = get_sentiment_service()
- sentiment_counts = {"positive": 0, "neutral": 0, "negative": 0}
- for citation in brand_citations:
- if citation.sentiment and citation.sentiment in ("positive", "neutral", "negative"):
- sentiment_counts[citation.sentiment] += 1
- else:
- content = citation.raw_response or citation.citation_text or ""
- if content.strip():
- try:
- result = await sentiment_service.analyze(
- brand_name=brand.name,
- content=content,
- )
- sentiment_counts[result.sentiment] += 1
- except Exception:
- sentiment_counts["neutral"] += 1
- else:
- sentiment_counts["neutral"] += 1
-
- # Build citation results
- citation_results = [
- CitationResult(
- cited=c.cited,
- position=c.citation_position,
- citation_text=c.citation_text,
- sentiment="neutral", # simplified for dashboard
- confidence=c.confidence or 0.0,
- )
- for c in brand_citations
- ]
-
- # Extract positions
- positions = [c.citation_position for c in brand_citations if c.cited]
-
- # Calculate V2 score
- scoring_service = ScoringService()
- v2_result = scoring_service.calculate_v2(
- mentioned_count=len(brand_citations),
- total_queries=total_queries,
- positions=positions,
- sentiment_counts=sentiment_counts,
- citations=citation_results,
- brand_mentions=len(brand_citations),
- competitor_mentions=competitor_mentions,
- )
-
- # Build dimension items
- dimensions = [
- DimensionScoreItem(
- name=v2_result.mention_rate.name,
- score=round(v2_result.mention_rate.score, 2),
- max_score=v2_result.mention_rate.max_score,
- percentage=round(v2_result.mention_rate.percentage, 2),
- ),
- DimensionScoreItem(
- name=v2_result.recommendation_rank.name,
- score=round(v2_result.recommendation_rank.score, 2),
- max_score=v2_result.recommendation_rank.max_score,
- percentage=round(v2_result.recommendation_rank.percentage, 2),
- ),
- DimensionScoreItem(
- name=v2_result.sentiment_score.name,
- score=round(v2_result.sentiment_score.score, 2),
- max_score=v2_result.sentiment_score.max_score,
- percentage=round(v2_result.sentiment_score.percentage, 2),
- ),
- DimensionScoreItem(
- name=v2_result.citation_quality.name,
- score=round(v2_result.citation_quality.score, 2),
- max_score=v2_result.citation_quality.max_score,
- percentage=round(v2_result.citation_quality.percentage, 2),
- ),
- DimensionScoreItem(
- name=v2_result.competitive_position.name,
- score=round(v2_result.competitive_position.score, 2),
- max_score=v2_result.competitive_position.max_score,
- percentage=round(v2_result.competitive_position.percentage, 2),
- ),
- ]
-
- # Calculate competitor ahead/behind
- ahead_count = sum(
- 1 for count in competitor_mentions.values()
- if len(brand_citations) > count
- )
- behind_count = sum(
- 1 for count in competitor_mentions.values()
- if len(brand_citations) <= count
- )
-
- # Calculate change from yesterday
- today = datetime.now().date()
- yesterday = today - timedelta(days=1)
-
- today_citations = [
- c for c in all_citations
- if c.queried_at.date() == today
- ]
- yesterday_citations = [
- c for c in all_citations
- if c.queried_at.date() == yesterday
- ]
-
- today_cited = sum(1 for c in today_citations if c.cited)
- today_total = len(today_citations)
- today_score = (today_cited / today_total * 100) if today_total > 0 else 0.0
-
- yesterday_cited = sum(1 for c in yesterday_citations if c.cited)
- yesterday_total = len(yesterday_citations)
- yesterday_score = (yesterday_cited / yesterday_total * 100) if yesterday_total > 0 else 0.0
-
- change = round(today_score - yesterday_score, 2)
-
- return v2_result.overall_score, change, dimensions, ahead_count, behind_count
-
@router.get("/stats", response_model=DashboardStatsResponse)
async def get_dashboard_stats(
@@ -357,7 +46,6 @@ async def get_dashboard_stats(
- 最近查询记录
"""
cache = get_cache_service()
- # 如果 brand_id 尚未确定,先查库取第一个品牌
if brand_id is None:
brand_stmt = select(Brand).where(Brand.user_id == current_user.id).limit(1)
brand_result = await db.execute(brand_stmt)
@@ -379,34 +67,67 @@ async def get_dashboard_stats(
total_platforms=7,
)
- # 尝试从缓存读取(TTL: 2 分钟)
cache_key = f"dashboard:stats:{current_user.id}:{brand_id}"
cached = await cache.get_json(cache_key)
if cached is not None:
return cached
- # Get brand name
brand_stmt = select(Brand).where(Brand.id == brand_id)
brand_result = await db.execute(brand_stmt)
brand = brand_result.scalar_one_or_none()
brand_name = brand.name if brand else None
- # Calculate V2 overall score and dimensions
- overall_score, score_change, dimensions, ahead_count, behind_count = (
- await _calculate_overall_score_v2(db, current_user.id, brand_id)
+ scoring_data_service = get_brand_scoring_data_service()
+
+ scoring_data = await scoring_data_service.get_brand_scoring_data(
+ db, current_user.id, brand
)
- # Get platform scores
- platform_scores_dict = await _get_brand_score_by_platform(
+ overall_score = scoring_data.v2_result.overall_score
+ score_change = scoring_data.change_from_yesterday
+
+ dimensions = [
+ DimensionScoreItem(
+ name=scoring_data.v2_result.mention_rate.name,
+ score=round(scoring_data.v2_result.mention_rate.score, 2),
+ max_score=scoring_data.v2_result.mention_rate.max_score,
+ percentage=round(scoring_data.v2_result.mention_rate.percentage, 2),
+ ),
+ DimensionScoreItem(
+ name=scoring_data.v2_result.recommendation_rank.name,
+ score=round(scoring_data.v2_result.recommendation_rank.score, 2),
+ max_score=scoring_data.v2_result.recommendation_rank.max_score,
+ percentage=round(scoring_data.v2_result.recommendation_rank.percentage, 2),
+ ),
+ DimensionScoreItem(
+ name=scoring_data.v2_result.sentiment_score.name,
+ score=round(scoring_data.v2_result.sentiment_score.score, 2),
+ max_score=scoring_data.v2_result.sentiment_score.max_score,
+ percentage=round(scoring_data.v2_result.sentiment_score.percentage, 2),
+ ),
+ DimensionScoreItem(
+ name=scoring_data.v2_result.citation_quality.name,
+ score=round(scoring_data.v2_result.citation_quality.score, 2),
+ max_score=scoring_data.v2_result.citation_quality.max_score,
+ percentage=round(scoring_data.v2_result.citation_quality.percentage, 2),
+ ),
+ DimensionScoreItem(
+ name=scoring_data.v2_result.competitive_position.name,
+ score=round(scoring_data.v2_result.competitive_position.score, 2),
+ max_score=scoring_data.v2_result.competitive_position.max_score,
+ percentage=round(scoring_data.v2_result.competitive_position.percentage, 2),
+ ),
+ ]
+
+ ahead_count = scoring_data.competitor_data.get("ahead_count", 0)
+ behind_count = scoring_data.competitor_data.get("behind_count", 0)
+
+ platform_scores_dict = scoring_data.platform_scores
+
+ competitor_scores_dict = await scoring_data_service.get_competitor_platform_scores(
db, current_user.id, brand_id
)
- # Get competitor platform scores
- competitor_scores_dict = await _get_competitor_scores_by_platform(
- db, current_user.id, brand_id
- )
-
- # Get first competitor name for display
competitor_stmt = select(Competitor).where(
Competitor.brand_id == brand_id
).limit(1)
@@ -424,10 +145,8 @@ async def get_dashboard_stats(
for platform, score in platform_scores_dict.items()
]
- # Count monitored platforms
monitored = sum(1 for s in platform_scores_dict.values() if s > 0)
- # Get recent queries (last 10)
recent_queries_stmt = (
select(QueryModel)
.where(QueryModel.user_id == current_user.id)
@@ -437,7 +156,6 @@ async def get_dashboard_stats(
recent_queries_result = await db.execute(recent_queries_stmt)
recent_queries_list = list(recent_queries_result.scalars().all())
- # Get citation count for each query
recent_queries = []
for query in recent_queries_list:
citation_count_stmt = select(
@@ -460,7 +178,6 @@ async def get_dashboard_stats(
queried_at=query.last_queried_at or query.created_at,
))
- # Health level
health_level = get_health_level(overall_score)
response = DashboardStatsResponse(
@@ -477,7 +194,6 @@ async def get_dashboard_stats(
brand_name=brand_name,
)
- # 将结果写入缓存(TTL: 2 分钟)
await cache.set_json(
cache_key,
response.model_dump(mode="json"),
diff --git a/backend/app/api/detection.py b/backend/app/api/detection.py
index 168e1a0..83c32bf 100644
--- a/backend/app/api/detection.py
+++ b/backend/app/api/detection.py
@@ -16,7 +16,7 @@ from app.schemas.detection_task import (
DetectionTaskUpdate,
DetectionTriggerResponse,
)
-from app.services.detection_scheduler import DetectionSchedulerService, TaskNotFoundError
+from app.services.detection.detection_scheduler import DetectionSchedulerService, TaskNotFoundError
logger = logging.getLogger(__name__)
diff --git a/backend/app/api/diagnosis.py b/backend/app/api/diagnosis.py
index 2a6af11..a00a0fa 100644
--- a/backend/app/api/diagnosis.py
+++ b/backend/app/api/diagnosis.py
@@ -1,6 +1,8 @@
"""诊断API端点 - 提供SEO和GEO诊断功能"""
+import asyncio
import logging
import uuid
+from datetime import UTC, datetime
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, status
@@ -11,13 +13,32 @@ from app.api.deps import get_current_user
from app.database import get_db
from app.models.user import User
from app.models.brand import Brand
-from app.services.seo_diagnosis import SEODiagnosisService
-from app.services.geo_diagnosis import GEODiagnosisService, GEODiagnosisInput
+from app.models.diagnosis_record import DiagnosisRecord
+from app.schemas.diagnosis import (
+ GEODiagnosisHistoryItem,
+ GEODiagnosisHistoryResponse,
+ GEODiagnosisResponse,
+ GEODiagnosisResultResponse,
+ GEODiagnosisTaskResponse,
+ GEODiagnosisTriggerRequest,
+)
+from app.services.diagnosis.data_collector import DataCollectorService
+from app.services.diagnosis.seo_diagnosis import SEODiagnosisService
+from app.services.diagnosis.geo_diagnosis import GEODiagnosisService, GEODiagnosisInput
+from app.utils.health import get_health_level_label
logger = logging.getLogger(__name__)
router = APIRouter()
+_FREE_TIER_DIMENSIONS = {"内容可提取性", "E-E-A-T信号", "引用就绪度"}
+
+
+def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
+ if isinstance(value, uuid.UUID):
+ return value
+ return uuid.UUID(str(value))
+
@router.get("/seo/{brand_id}")
async def get_seo_diagnosis(
@@ -25,24 +46,14 @@ async def get_seo_diagnosis(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
- """
- 获取品牌的SEO诊断结果
-
- 返回5维度SEO诊断:
- - 技术SEO (25分)
- - 页面SEO (20分)
- - 内容质量 (20分)
- - 外链分析 (15分)
- - 用户体验 (20分)
- """
brand = await _get_brand_or_404(brand_id, current_user, db)
-
+
try:
service = SEODiagnosisService()
result = service.diagnose()
-
+
logger.info(f"SEO诊断完成: brand_id={brand_id}, brand={brand.name}, score={result.overall_score}")
-
+
return result.to_dict()
except Exception as e:
logger.error(f"SEO诊断失败: brand_id={brand_id}, error={str(e)}", exc_info=True)
@@ -52,39 +63,158 @@ async def get_seo_diagnosis(
)
-@router.get("/geo/{brand_id}")
-async def get_geo_diagnosis(
+@router.post("/geo/{brand_id}", status_code=status.HTTP_202_ACCEPTED)
+async def trigger_geo_diagnosis(
brand_id: uuid.UUID,
+ body: GEODiagnosisTriggerRequest = GEODiagnosisTriggerRequest(),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
- """
- 获取品牌的GEO诊断结果
-
- 返回6维度GEO诊断:
- - 内容可提取性 (20分)
- - 实体清晰度 (15分)
- - E-E-A-T信号 (20分)
- - Schema标记 (15分)
- - 主题权威 (15分)
- - 引用就绪度 (15分)
- """
brand = await _get_brand_or_404(brand_id, current_user, db)
-
- try:
- input_data = GEODiagnosisInput()
- service = GEODiagnosisService()
- result = service.diagnose(input_data)
-
- logger.info(f"GEO诊断完成: brand_id={brand_id}, brand={brand.name}, score={result.overall_score}")
-
- return result.to_dict()
- except Exception as e:
- logger.error(f"GEO诊断失败: brand_id={brand_id}, error={str(e)}", exc_info=True)
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- detail="GEO诊断服务异常,请稍后重试",
+
+ if not body.force_refresh:
+ existing = await _find_recent_completed(db, brand_id, hours=24)
+ if existing:
+ return GEODiagnosisTaskResponse(
+ task_id=str(existing.id),
+ brand_id=str(brand_id),
+ status="completed",
+ )
+
+ record = DiagnosisRecord(
+ brand_id=brand_id,
+ user_id=_to_uuid(current_user.id),
+ diagnosis_type="geo",
+ status="pending",
+ )
+ db.add(record)
+ await db.commit()
+ await db.refresh(record)
+
+ asyncio.create_task(
+ _run_geo_diagnosis(
+ record_id=record.id,
+ brand_id=brand_id,
+ brand_name=brand.name,
+ brand_aliases=brand.aliases or [],
+ website=brand.website,
+ industry=brand.industry,
+ user_id=_to_uuid(current_user.id),
)
+ )
+
+ return GEODiagnosisTaskResponse(
+ task_id=str(record.id),
+ brand_id=str(brand_id),
+ status="pending",
+ )
+
+
+@router.get("/geo/{brand_id}/result")
+async def get_geo_diagnosis_result(
+ brand_id: uuid.UUID,
+ task_id: str | None = None,
+ current_user: User = Depends(get_current_user),
+ db: AsyncSession = Depends(get_db),
+):
+ await _get_brand_or_404(brand_id, current_user, db)
+
+ if task_id:
+ try:
+ tid = uuid.UUID(task_id)
+ except ValueError:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail="无效的task_id",
+ )
+ stmt = select(DiagnosisRecord).where(
+ DiagnosisRecord.id == tid,
+ DiagnosisRecord.brand_id == brand_id,
+ )
+ else:
+ stmt = (
+ select(DiagnosisRecord)
+ .where(
+ DiagnosisRecord.brand_id == brand_id,
+ DiagnosisRecord.diagnosis_type == "geo",
+ )
+ .order_by(DiagnosisRecord.created_at.desc())
+ .limit(1)
+ )
+
+ result = await db.execute(stmt)
+ record = result.scalar_one_or_none()
+
+ if not record:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="诊断记录不存在",
+ )
+
+ if record.status == "pending" or record.status == "running":
+ return GEODiagnosisResultResponse(
+ task_id=str(record.id),
+ brand_id=str(brand_id),
+ status=record.status,
+ )
+
+ if record.status == "failed":
+ return GEODiagnosisResultResponse(
+ task_id=str(record.id),
+ brand_id=str(brand_id),
+ status="failed",
+ error=record.error_message,
+ )
+
+ user_plan = getattr(current_user, "plan", None) or "free"
+ is_paid = user_plan not in ("free", None)
+ diagnosis_resp = _build_diagnosis_response(record.result_json, is_paid)
+
+ return GEODiagnosisResultResponse(
+ task_id=str(record.id),
+ brand_id=str(brand_id),
+ status="completed",
+ result=diagnosis_resp,
+ )
+
+
+@router.get("/geo/{brand_id}/history")
+async def get_geo_diagnosis_history(
+ brand_id: uuid.UUID,
+ limit: int = 10,
+ current_user: User = Depends(get_current_user),
+ db: AsyncSession = Depends(get_db),
+):
+ await _get_brand_or_404(brand_id, current_user, db)
+
+ stmt = (
+ select(DiagnosisRecord)
+ .where(
+ DiagnosisRecord.brand_id == brand_id,
+ DiagnosisRecord.diagnosis_type == "geo",
+ DiagnosisRecord.status == "completed",
+ )
+ .order_by(DiagnosisRecord.created_at.desc())
+ .limit(limit)
+ )
+ result = await db.execute(stmt)
+ records = result.scalars().all()
+
+ items = [
+ GEODiagnosisHistoryItem(
+ task_id=str(r.id),
+ overall_score=r.overall_score or 0,
+ health_level=r.result_json.get("health_level", "danger") if r.result_json else "danger",
+ created_at=r.created_at.isoformat(),
+ completed_at=r.completed_at.isoformat() if r.completed_at else None,
+ )
+ for r in records
+ ]
+
+ return GEODiagnosisHistoryResponse(
+ brand_id=str(brand_id),
+ history=items,
+ )
@router.get("/combined/{brand_id}")
@@ -93,29 +223,32 @@ async def get_combined_diagnosis(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
- """
- 获取品牌的综合诊断结果
-
- 结合SEO和GEO诊断,返回综合评分和详细诊断结果
- """
brand = await _get_brand_or_404(brand_id, current_user, db)
-
+
try:
seo_service = SEODiagnosisService()
seo_result = seo_service.diagnose()
-
+
+ collector = DataCollectorService(db)
+ collection = await collector.collect(
+ brand_name=brand.name,
+ brand_aliases=brand.aliases or [],
+ website=brand.website,
+ industry=brand.industry,
+ )
+
geo_service = GEODiagnosisService()
- geo_result = geo_service.diagnose(GEODiagnosisInput())
-
+ geo_result = geo_service.diagnose(collection.diagnosis_input)
+
combined_score = round((seo_result.overall_score + geo_result.overall_score) / 2, 2)
-
+
logger.info(
f"综合诊断完成: brand_id={brand_id}, brand={brand.name}, "
f"seo_score={seo_result.overall_score}, "
f"geo_score={geo_result.overall_score}, "
f"combined_score={combined_score}"
)
-
+
return {
"seo_score": seo_result.overall_score,
"geo_score": geo_result.overall_score,
@@ -131,20 +264,131 @@ async def get_combined_diagnosis(
)
+async def _run_geo_diagnosis(
+ record_id: uuid.UUID,
+ brand_id: uuid.UUID,
+ brand_name: str,
+ brand_aliases: list[str],
+ website: str | None,
+ industry: str | None,
+ user_id: uuid.UUID,
+) -> None:
+ from app.database import AsyncSessionLocal
+
+ async with AsyncSessionLocal() as db:
+ try:
+ stmt = select(DiagnosisRecord).where(DiagnosisRecord.id == record_id)
+ result = await db.execute(stmt)
+ record = result.scalar_one_or_none()
+ if not record:
+ return
+
+ record.status = "running"
+ await db.commit()
+
+ collector = DataCollectorService(db)
+ collection = await collector.collect(
+ brand_name=brand_name,
+ brand_aliases=brand_aliases,
+ website=website,
+ industry=industry,
+ )
+
+ geo_service = GEODiagnosisService()
+ diagnosis_result = geo_service.diagnose(collection.diagnosis_input)
+
+ record.status = "completed"
+ record.overall_score = diagnosis_result.overall_score
+ record.result_json = diagnosis_result.to_dict()
+ record.completed_at = datetime.now(UTC)
+ record.collection_metadata = collection.metadata
+ if collection.errors:
+ record.collection_metadata["errors"] = collection.errors
+
+ await db.commit()
+
+ logger.info(
+ f"GEO诊断完成: brand_id={brand_id}, brand={brand_name}, "
+ f"score={diagnosis_result.overall_score}"
+ )
+
+ except Exception as e:
+ logger.error(f"GEO诊断任务失败: record_id={record_id}, error={e}", exc_info=True)
+ try:
+ stmt = select(DiagnosisRecord).where(DiagnosisRecord.id == record_id)
+ result = await db.execute(stmt)
+ record = result.scalar_one_or_none()
+ if record:
+ record.status = "failed"
+ record.error_message = str(e)
+ await db.commit()
+ except Exception:
+ logger.error(f"更新失败状态也失败: record_id={record_id}")
+
+
+def _build_diagnosis_response(result_json: dict | None, is_paid: bool) -> GEODiagnosisResponse:
+ if not result_json:
+ return GEODiagnosisResponse(
+ overall_score=0,
+ health_level="danger",
+ health_level_label=get_health_level_label("danger"),
+ dimensions=[],
+ recommendations=[],
+ is_full_report=is_paid,
+ )
+
+ dimensions = result_json.get("dimensions", [])
+ if not is_paid:
+ dimensions = [d for d in dimensions if d.get("name") in _FREE_TIER_DIMENSIONS]
+
+ recommendations = result_json.get("recommendations", [])
+ if not is_paid:
+ recommendations = [r for r in recommendations if r.get("priority") == "P0"]
+
+ return GEODiagnosisResponse(
+ overall_score=result_json.get("overall_score", 0),
+ health_level=result_json.get("health_level", "danger"),
+ health_level_label=result_json.get("health_level_label", get_health_level_label("danger")),
+ dimensions=dimensions,
+ recommendations=recommendations,
+ is_full_report=is_paid,
+ )
+
+
+async def _find_recent_completed(
+ db: AsyncSession, brand_id: uuid.UUID, hours: int = 24
+) -> DiagnosisRecord | None:
+ from datetime import timedelta
+
+ cutoff = datetime.now(UTC) - timedelta(hours=hours)
+ stmt = (
+ select(DiagnosisRecord)
+ .where(
+ DiagnosisRecord.brand_id == brand_id,
+ DiagnosisRecord.status == "completed",
+ DiagnosisRecord.completed_at >= cutoff,
+ )
+ .order_by(DiagnosisRecord.completed_at.desc())
+ .limit(1)
+ )
+ result = await db.execute(stmt)
+ return result.scalar_one_or_none()
+
+
async def _get_brand_or_404(
brand_id: uuid.UUID,
current_user: User,
db: AsyncSession,
) -> Brand:
- """获取品牌或抛出404异常"""
- stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == current_user.id)
+ user_uuid = _to_uuid(current_user.id)
+ stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == user_uuid)
result = await db.execute(stmt)
brand = result.scalar_one_or_none()
-
+
if not brand:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="品牌不存在",
)
-
+
return brand
diff --git a/backend/app/api/distribution.py b/backend/app/api/distribution.py
index f01d92c..4361e68 100644
--- a/backend/app/api/distribution.py
+++ b/backend/app/api/distribution.py
@@ -4,6 +4,7 @@ import uuid
from datetime import datetime
from fastapi import APIRouter, Depends, HTTPException, status
+from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
@@ -25,7 +26,9 @@ from app.schemas.distribution import (
)
from app.services.distribution.formatter import ContentFormatter
from app.services.distribution.platform_rules import PLATFORM_RULES, PlatformRuleEngine
+from app.services.distribution.publish_engine import PublishEngine
from app.services.distribution.publish_strategy import PublishStrategyService
+from app.services.distribution.publishers.base import PublishResult
router = APIRouter()
@@ -235,3 +238,72 @@ async def create_schedule(
tips=tips,
created_at=schedule.created_at.strftime("%Y-%m-%d %H:%M:%S"),
)
+
+
+class PublishRequest(BaseModel):
+ content_id: str = Field(min_length=1)
+ platforms: list[str] = Field(min_length=1)
+
+
+class PublishStatusItem(BaseModel):
+ platform: str
+ status: str
+ article_url: str | None = None
+ published_at: str | None = None
+
+
+class PublishStatusResponse(BaseModel):
+ platforms: list[PublishStatusItem]
+
+
+class PublishResponse(BaseModel):
+ results: list[PublishResult]
+
+
+_publish_engine = PublishEngine()
+
+
+@router.post("/publish", response_model=PublishResponse)
+async def publish_content(
+ req: PublishRequest,
+ db: AsyncSession = Depends(get_db),
+ current_user: User = Depends(get_current_user),
+):
+ org_id = current_user.organization_id
+ if not org_id:
+ raise HTTPException(status_code=403, detail="用户未关联组织")
+
+ try:
+ results = await _publish_engine.publish_content(
+ content_id=req.content_id,
+ platforms=req.platforms,
+ db=db,
+ user_id=str(current_user.id),
+ org_id=str(org_id),
+ )
+ except ValueError as e:
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
+
+ return PublishResponse(results=results)
+
+
+@router.get("/publish/{content_id}/status", response_model=PublishStatusResponse)
+async def get_publish_status(
+ content_id: str,
+ db: AsyncSession = Depends(get_db),
+ current_user: User = Depends(get_current_user),
+):
+ try:
+ status_list = await _publish_engine.get_publish_status(
+ content_id=content_id,
+ db=db,
+ )
+ except ValueError:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail=f"Content not found: {content_id}",
+ )
+
+ return PublishStatusResponse(
+ platforms=[PublishStatusItem(**s) for s in status_list]
+ )
diff --git a/backend/app/api/health_score.py b/backend/app/api/health_score.py
new file mode 100644
index 0000000..0b9d6c1
--- /dev/null
+++ b/backend/app/api/health_score.py
@@ -0,0 +1,144 @@
+import hashlib
+import logging
+
+from fastapi import APIRouter, Depends, Query
+from sqlalchemy import select
+from sqlalchemy.ext.asyncio import AsyncSession
+
+from app.database import get_db
+from app.models.brand import Brand
+from app.schemas.health_score import (
+ HealthScoreDimension,
+ HealthScoreRecommendation,
+ HealthScoreResponse,
+)
+from app.services.cache import get_cache_service
+from app.services.diagnosis.data_collector import DataCollectorService
+from app.services.diagnosis.geo_diagnosis import GEODiagnosisService
+from app.utils.health import get_health_level, get_health_level_label
+
+logger = logging.getLogger(__name__)
+
+router = APIRouter()
+
+_FREE_TIER_DIMENSIONS = {"内容可提取性", "E-E-A-T信号", "引用就绪度"}
+
+_CACHE_TTL = 86400
+
+
+def _build_default_response(brand_name: str) -> HealthScoreResponse:
+ return HealthScoreResponse(
+ brand_name=brand_name,
+ overall_score=0.0,
+ health_level="danger",
+ health_level_label=get_health_level_label("danger"),
+ dimensions=[
+ HealthScoreDimension(
+ name=d,
+ score=0.0,
+ max_score=0.0,
+ percentage=0.0,
+ status="fail",
+ )
+ for d in sorted(_FREE_TIER_DIMENSIONS)
+ ],
+ recommendations=[],
+ is_full_report=False,
+ cached=False,
+ )
+
+
+@router.get("/health-score", response_model=HealthScoreResponse)
+async def get_public_health_score(
+ brand: str = Query(..., min_length=1),
+ competitors: str = Query(default=""),
+ db: AsyncSession = Depends(get_db),
+):
+ cache = get_cache_service()
+ cache_key = f"health_score:{hashlib.md5(brand.lower().encode()).hexdigest()}"
+
+ cached_data = await cache.get_json(cache_key)
+ if cached_data is not None:
+ cached_data["cached"] = True
+ return cached_data
+
+ brand_name = brand.strip()
+ competitor_list = [c.strip() for c in competitors.split(",") if c.strip()][:3]
+
+ brand_aliases: list[str] = []
+ website: str | None = None
+ industry: str | None = None
+
+ stmt = select(Brand).where(Brand.name == brand_name).limit(1)
+ result = await db.execute(stmt)
+ brand_record = result.scalar_one_or_none()
+
+ if brand_record:
+ brand_aliases = brand_record.aliases or []
+ website = brand_record.website
+ industry = brand_record.industry
+
+ try:
+ collector = DataCollectorService(db)
+ collection = await collector.collect(
+ brand_name=brand_name,
+ brand_aliases=brand_aliases,
+ website=website,
+ industry=industry,
+ )
+
+ geo_service = GEODiagnosisService()
+ diagnosis_result = geo_service.diagnose(collection.diagnosis_input)
+ except Exception as e:
+ logger.error(f"健康分数据采集失败: brand={brand_name}, error={e}", exc_info=True)
+ default_resp = _build_default_response(brand_name)
+ await cache.set_json(cache_key, default_resp.model_dump(mode="json"), expire=_CACHE_TTL)
+ return default_resp
+
+ all_dimensions = diagnosis_result.to_dict().get("dimensions", [])
+ free_dimensions = [d for d in all_dimensions if d.get("name") in _FREE_TIER_DIMENSIONS]
+
+ overall_score = round(sum(d.get("score", 0.0) for d in free_dimensions), 2)
+ health_level = get_health_level(overall_score)
+
+ dimension_responses = [
+ HealthScoreDimension(
+ name=d.get("name", ""),
+ score=round(d.get("score", 0.0), 2),
+ max_score=d.get("max_score", 0.0),
+ percentage=round(d.get("percentage", 0.0), 2),
+ status=d.get("status", "fail"),
+ )
+ for d in free_dimensions
+ ]
+
+ all_recommendations = diagnosis_result.to_dict().get("recommendations", [])
+ free_recommendations = [
+ r for r in all_recommendations
+ if r.get("priority") == "P0" and r.get("dimension") in _FREE_TIER_DIMENSIONS
+ ]
+
+ recommendation_responses = [
+ HealthScoreRecommendation(
+ priority=r.get("priority", "P0"),
+ dimension=r.get("dimension", ""),
+ title=r.get("title", ""),
+ description=r.get("description", ""),
+ )
+ for r in free_recommendations
+ ]
+
+ response = HealthScoreResponse(
+ brand_name=brand_name,
+ overall_score=overall_score,
+ health_level=health_level,
+ health_level_label=get_health_level_label(health_level),
+ dimensions=dimension_responses,
+ recommendations=recommendation_responses,
+ is_full_report=False,
+ cached=False,
+ )
+
+ await cache.set_json(cache_key, response.model_dump(mode="json"), expire=_CACHE_TTL)
+
+ return response
diff --git a/backend/app/api/monitoring.py b/backend/app/api/monitoring.py
new file mode 100644
index 0000000..69f411c
--- /dev/null
+++ b/backend/app/api/monitoring.py
@@ -0,0 +1,207 @@
+import uuid
+
+from fastapi import APIRouter, Depends, HTTPException, Query, status
+from sqlalchemy import select
+from sqlalchemy.ext.asyncio import AsyncSession
+
+from app.api.deps import get_current_user
+from app.database import get_db
+from app.models.user import User
+from app.models.brand import Brand
+from app.models.monitoring import MonitoringRecord
+from app.schemas.monitoring import (
+ MonitoringRecordCreate,
+ MonitoringRecordResponse,
+ MonitoringRecordList,
+ MonitoringChangeReport,
+ MonitoringStatusUpdate,
+)
+from app.services.monitoring.monitor_service import MonitorService
+
+router = APIRouter()
+
+
+async def _get_brand_with_access(
+ brand_id: uuid.UUID,
+ db: AsyncSession,
+ current_user: User,
+) -> Brand:
+ stmt = select(Brand).where(
+ Brand.id == brand_id,
+ Brand.user_id == current_user.id,
+ )
+ result = await db.execute(stmt)
+ brand = result.scalar_one_or_none()
+
+ if not brand:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="品牌不存在",
+ )
+ return brand
+
+
+def _record_to_response(record: MonitoringRecord) -> MonitoringRecordResponse:
+ return MonitoringRecordResponse(
+ id=record.id,
+ brand_id=record.brand_id,
+ content_id=record.content_id,
+ query_keywords=record.query_keywords,
+ platform=record.platform,
+ baseline_citation_count=record.baseline_citation_count,
+ baseline_sentiment=record.baseline_sentiment,
+ baseline_rank=record.baseline_rank,
+ current_citation_count=record.current_citation_count,
+ current_sentiment=record.current_sentiment,
+ current_rank=record.current_rank,
+ change_type=record.change_type,
+ change_details=record.change_details,
+ check_interval_hours=record.check_interval_hours,
+ last_checked_at=record.last_checked_at,
+ next_check_at=record.next_check_at,
+ status=record.status,
+ created_at=record.created_at,
+ updated_at=record.updated_at,
+ )
+
+
+@router.post("/tasks", response_model=MonitoringRecordResponse)
+async def create_monitoring_task(
+ request: MonitoringRecordCreate,
+ current_user: User = Depends(get_current_user),
+ db: AsyncSession = Depends(get_db),
+):
+ await _get_brand_with_access(request.brand_id, db, current_user)
+
+ service = MonitorService()
+ record = await service.create_monitoring_record(
+ db=db,
+ brand_id=request.brand_id,
+ content_id=request.content_id,
+ query_keywords=request.query_keywords,
+ platform=request.platform,
+ check_interval_hours=request.check_interval_hours,
+ )
+ return _record_to_response(record)
+
+
+@router.get("/brand/{brand_id}", response_model=MonitoringRecordList)
+async def get_brand_monitoring(
+ brand_id: uuid.UUID,
+ skip: int = Query(0, ge=0),
+ limit: int = Query(20, ge=1, le=100),
+ current_user: User = Depends(get_current_user),
+ db: AsyncSession = Depends(get_db),
+):
+ await _get_brand_with_access(brand_id, db, current_user)
+
+ service = MonitorService()
+ records, total = await service.get_brand_monitoring(
+ db=db,
+ brand_id=brand_id,
+ skip=skip,
+ limit=limit,
+ )
+ return MonitoringRecordList(
+ records=[_record_to_response(r) for r in records],
+ total=total,
+ )
+
+
+@router.get("/{record_id}/report", response_model=MonitoringChangeReport)
+async def get_change_report(
+ record_id: uuid.UUID,
+ current_user: User = Depends(get_current_user),
+ db: AsyncSession = Depends(get_db),
+):
+ stmt = select(MonitoringRecord).where(MonitoringRecord.id == record_id)
+ result = await db.execute(stmt)
+ record = result.scalar_one_or_none()
+
+ if not record:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="监测记录不存在",
+ )
+
+ await _get_brand_with_access(record.brand_id, db, current_user)
+
+ service = MonitorService()
+ report = await service.generate_change_report(db, record_id)
+ if not report:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="变化报告不存在",
+ )
+
+ return MonitoringChangeReport(
+ monitoring_record_id=uuid.UUID(report["monitoring_record_id"]),
+ brand_id=uuid.UUID(report["brand_id"]),
+ change_type=report["change_type"],
+ change_details=report["change_details"],
+ baseline=report["baseline"],
+ current=report["current"],
+ recommendations=report["recommendations"],
+ )
+
+
+@router.put("/{record_id}/status", response_model=MonitoringRecordResponse)
+async def update_monitoring_status(
+ record_id: uuid.UUID,
+ status_update: MonitoringStatusUpdate,
+ current_user: User = Depends(get_current_user),
+ db: AsyncSession = Depends(get_db),
+):
+ valid_statuses = {"active", "paused", "completed"}
+ if status_update.status not in valid_statuses:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=f"无效的状态值,支持: {', '.join(valid_statuses)}",
+ )
+
+ stmt = select(MonitoringRecord).where(MonitoringRecord.id == record_id)
+ result = await db.execute(stmt)
+ record = result.scalar_one_or_none()
+
+ if not record:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="监测记录不存在",
+ )
+
+ await _get_brand_with_access(record.brand_id, db, current_user)
+
+ record.status = status_update.status
+ await db.commit()
+ await db.refresh(record)
+
+ return _record_to_response(record)
+
+
+@router.post("/{record_id}/check", response_model=MonitoringRecordResponse)
+async def trigger_manual_check(
+ record_id: uuid.UUID,
+ current_user: User = Depends(get_current_user),
+ db: AsyncSession = Depends(get_db),
+):
+ stmt = select(MonitoringRecord).where(MonitoringRecord.id == record_id)
+ result = await db.execute(stmt)
+ record = result.scalar_one_or_none()
+
+ if not record:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="监测记录不存在",
+ )
+
+ await _get_brand_with_access(record.brand_id, db, current_user)
+
+ service = MonitorService()
+ updated_record = await service.check_and_compare(db, record_id)
+ if not updated_record:
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail="检测执行失败",
+ )
+
+ return _record_to_response(updated_record)
diff --git a/backend/app/api/onboarding.py b/backend/app/api/onboarding.py
index 471e9cf..181e87a 100644
--- a/backend/app/api/onboarding.py
+++ b/backend/app/api/onboarding.py
@@ -20,15 +20,24 @@ from app.database import get_db
from app.models.user import User
from app.models.brand import Brand
from app.models.competitor import Competitor
-from app.models.citation_record import CitationRecord
from app.models.query import Query as QueryModel
-from app.services.scoring_service import ScoringService
from app.schemas.brand import BrandCreate, BrandResponse
+from app.services.diagnosis.data_collector import DataCollectorService
+from app.services.diagnosis.geo_diagnosis import GEODiagnosisService
+from app.utils.health import get_health_level, get_health_level_label
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/onboarding", tags=["onboarding"])
+_FREE_TIER_DIMENSIONS = {"内容可提取性", "E-E-A-T信号", "引用就绪度"}
+
+
+def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
+ if isinstance(value, uuid.UUID):
+ return value
+ return uuid.UUID(str(value))
+
# ------------------------------------------------------------------
# Request / Response schemas
@@ -60,23 +69,43 @@ class CompetitorRecommendationSimpleResponse(BaseModel):
recommendations: list[CompetitorRecommendationSimple]
+class HealthDimensionItem(BaseModel):
+ name: str
+ score: float
+ max_score: float
+ percentage: float
+ status: str
+
+
+class HealthRecommendationItem(BaseModel):
+ priority: str
+ dimension: str
+ title: str
+ description: str
+
+
class HealthReportResponse(BaseModel):
- """初始健康评分报告"""
brand_id: str
brand_name: str
overall_score: float
+ health_level: str
+ health_level_label: str
platform_scores: dict
strengths: list[str]
weaknesses: list[str]
competitor_scores: list[dict]
+ dimensions: list[HealthDimensionItem] = []
+ recommendations: list[HealthRecommendationItem] = []
+ is_full_report: bool = False
class ActionSuggestion(BaseModel):
- """行动建议项"""
title: str
description: str
- priority: str # high / medium / low
- action_type: str # e.g. coverage, keyword, sentiment, platform
+ priority: str
+ action_type: str
+ is_paid_action: bool = False
+ action_button_text: str = ""
class ActionSuggestionsResponse(BaseModel):
@@ -105,7 +134,7 @@ async def get_onboarding_status(
- completed=True 且 brand_id 有值 → 已完成
- completed=False, current_step=1 → 需要创建品牌
"""
- stmt = select(Brand).where(Brand.user_id == current_user.id)
+ stmt = select(Brand).where(Brand.user_id == _to_uuid(current_user.id))
result = await db.execute(stmt)
brand = result.scalar_one_or_none()
@@ -144,7 +173,7 @@ async def create_onboarding_brand(
)
brand = Brand(
- user_id=current_user.id,
+ user_id=_to_uuid(current_user.id),
name=full_brand_data.name,
aliases=full_brand_data.aliases,
website=full_brand_data.website,
@@ -156,6 +185,38 @@ async def create_onboarding_brand(
await db.commit()
await db.refresh(brand)
+ # 自动创建默认查询词(检查 max_queries 限制)
+ try:
+ current_query_count_stmt = select(func.count()).select_from(QueryModel).where(
+ QueryModel.user_id == current_user.id
+ )
+ current_query_count_result = await db.execute(current_query_count_stmt)
+ current_query_count = current_query_count_result.scalar_one()
+
+ max_queries = getattr(current_user, "max_queries", 3) # 默认免费版 3 个
+
+ if current_query_count < max_queries:
+ default_query = QueryModel(
+ user_id=current_user.id,
+ keyword=f"{brand.name} 推荐",
+ target_brand=brand.name,
+ brand_aliases=brand.aliases or [],
+ platforms=brand.platforms or ["wenxin", "kimi"],
+ frequency=brand.frequency or "weekly",
+ status="active",
+ )
+ db.add(default_query)
+ await db.commit()
+ await db.refresh(default_query)
+ logger.info(f"Auto-created default query for brand '{brand.name}'")
+ else:
+ logger.info(
+ f"Skipped auto-creating default query for brand '{brand.name}': "
+ f"query limit reached ({current_query_count}/{max_queries})"
+ )
+ except Exception as e:
+ logger.warning(f"Failed to auto-create default query for brand '{brand.name}': {e}")
+
return brand
@@ -171,8 +232,7 @@ async def get_onboarding_competitor_recommendations(
复用 brands/competitors 中的推荐逻辑,
支持 LLM 智能推荐和规则推荐两种模式。
"""
- # 验证品牌归属
- stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == current_user.id)
+ stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == _to_uuid(current_user.id))
result = await db.execute(stmt)
brand = result.scalar_one_or_none()
@@ -182,7 +242,6 @@ async def get_onboarding_competitor_recommendations(
detail="品牌不存在",
)
- # 获取已有竞品名称(排除)
existing_stmt = select(Competitor.name).where(Competitor.brand_id == brand_id)
existing_result = await db.execute(existing_stmt)
existing_names = [row[0] for row in existing_result.all()]
@@ -228,14 +287,8 @@ async def get_onboarding_health_report(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
- """
- 获取品牌初始健康评分报告。
-
- 基于 citation_records 表统计品牌的引用数据,
- 如果没有引用数据则返回初始化状态(overall_score: 0)。
- """
- # 验证品牌归属
- stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == current_user.id)
+ user_uuid = _to_uuid(current_user.id)
+ stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == user_uuid)
result = await db.execute(stmt)
brand = result.scalar_one_or_none()
@@ -245,7 +298,55 @@ async def get_onboarding_health_report(
detail="品牌不存在",
)
- # 查询与品牌关联的 queries
+ user_plan = getattr(current_user, "plan", None) or "free"
+ is_paid = user_plan not in ("free", None)
+
+ try:
+ collector = DataCollectorService(db)
+ collection = await collector.collect(
+ brand_name=brand.name,
+ brand_aliases=brand.aliases or [],
+ website=brand.website,
+ industry=brand.industry,
+ )
+
+ geo_service = GEODiagnosisService()
+ diagnosis_result = geo_service.diagnose(collection.diagnosis_input)
+ except Exception as e:
+ logger.error(f"Onboarding健康报告采集失败: brand_id={brand_id}, error={e}", exc_info=True)
+ return HealthReportResponse(
+ brand_id=str(brand.id),
+ brand_name=brand.name,
+ overall_score=0.0,
+ health_level="danger",
+ health_level_label=get_health_level_label("danger"),
+ platform_scores={},
+ strengths=["品牌已创建,等待数据采集"],
+ weaknesses=["数据采集失败,请稍后重试"],
+ competitor_scores=[],
+ dimensions=[
+ HealthDimensionItem(name=d, score=0.0, max_score=0.0, percentage=0.0, status="fail")
+ for d in sorted(_FREE_TIER_DIMENSIONS)
+ ],
+ recommendations=[],
+ is_full_report=is_paid,
+ )
+
+ result_dict = diagnosis_result.to_dict()
+ all_dimensions = result_dict.get("dimensions", [])
+ all_recommendations = result_dict.get("recommendations", [])
+
+ if is_paid:
+ filtered_dimensions = all_dimensions
+ filtered_recommendations = all_recommendations
+ else:
+ filtered_dimensions = [d for d in all_dimensions if d.get("name") in _FREE_TIER_DIMENSIONS]
+ filtered_recommendations = [r for r in all_recommendations if r.get("priority") == "P0"]
+
+ overall_score = round(diagnosis_result.overall_score, 2)
+ health_level = get_health_level(overall_score)
+
+ platform_scores: dict[str, float] = {}
queries_stmt = select(QueryModel).where(
QueryModel.user_id == current_user.id,
QueryModel.target_brand == brand.name,
@@ -253,138 +354,89 @@ async def get_onboarding_health_report(
queries_result = await db.execute(queries_stmt)
queries = list(queries_result.scalars().all())
- # 没有查询数据 → 返回初始化状态
- if not queries:
- return HealthReportResponse(
- brand_id=str(brand.id),
- brand_name=brand.name,
- overall_score=0.0,
- platform_scores={},
- strengths=["品牌已创建,等待数据采集"],
- weaknesses=["尚无AI平台引用数据,需等待查询执行"],
- competitor_scores=[],
+ if queries:
+ from app.models.citation_record import CitationRecord
+ query_ids = [q.id for q in queries]
+ citations_stmt = select(CitationRecord).where(
+ CitationRecord.query_id.in_(query_ids),
)
+ citations_result = await db.execute(citations_stmt)
+ citations = list(citations_result.scalars().all())
- query_ids = [q.id for q in queries]
+ platforms_seen: dict[str, dict] = {}
+ for c in citations:
+ p = c.platform or "unknown"
+ if p not in platforms_seen:
+ platforms_seen[p] = {"total": 0, "cited": 0}
+ platforms_seen[p]["total"] += 1
+ if c.cited:
+ platforms_seen[p]["cited"] += 1
- # 获取引用记录
- citations_stmt = select(CitationRecord).where(
- CitationRecord.query_id.in_(query_ids),
- )
- citations_result = await db.execute(citations_stmt)
- citations = list(citations_result.scalars().all())
+ for p, data in platforms_seen.items():
+ rate = (data["cited"] / data["total"] * 100) if data["total"] > 0 else 0.0
+ platform_scores[p] = round(rate, 2)
- total = len(citations)
- cited = [c for c in citations if c.cited]
-
- # 计算各平台评分
- platform_scores: dict[str, float] = {}
- platforms_seen: dict[str, dict] = {} # {platform: {total, cited}}
-
- for c in citations:
- p = c.platform or "unknown"
- if p not in platforms_seen:
- platforms_seen[p] = {"total": 0, "cited": 0}
- platforms_seen[p]["total"] += 1
- if c.cited:
- platforms_seen[p]["cited"] += 1
-
- for p, data in platforms_seen.items():
- rate = (data["cited"] / data["total"] * 100) if data["total"] > 0 else 0.0
- platform_scores[p] = round(rate, 2)
-
- # 使用 ScoringService 计算 overall_score
- scoring_service = ScoringService()
- sentiment_counts = {"positive": 0, "neutral": 0, "negative": 0}
- for c in cited:
- sentiment = c.sentiment or "neutral"
- if sentiment in sentiment_counts:
- sentiment_counts[sentiment] += 1
-
- from app.schemas.scoring import CitationResult
- citation_results = [
- CitationResult(
- cited=c.cited,
- position=c.citation_position,
- citation_text=c.citation_text,
- sentiment=c.sentiment or "neutral",
- confidence=c.confidence or 0.0,
- )
- for c in cited
- ]
- positions = [c.citation_position for c in cited if c.cited]
-
- # 获取竞品信息
- competitor_stmt = select(Competitor).where(Competitor.brand_id == brand_id)
- competitor_result = await db.execute(competitor_stmt)
- competitors = list(competitor_result.scalars().all())
- competitor_names = [c.name for c in competitors]
- competitor_mentions: dict[str, int] = {}
- for comp_name in competitor_names:
- count = sum(
- 1 for c in citations
- if c.cited and c.competitor_brands and comp_name in c.competitor_brands
- )
- if count > 0:
- competitor_mentions[comp_name] = count
-
- v2_result = scoring_service.calculate_v2(
- mentioned_count=len(cited),
- total_queries=total,
- positions=positions,
- sentiment_counts=sentiment_counts,
- citations=citation_results,
- brand_mentions=len(cited),
- competitor_mentions=competitor_mentions,
- )
-
- # 生成 strengths/weaknesses
strengths = []
weaknesses = []
- if total == 0:
- strengths.append("品牌已创建")
- weaknesses.append("尚无引用数据")
- else:
- mention_rate = len(cited) / total * 100 if total > 0 else 0
- if mention_rate >= 50:
- strengths.append(f"提及率较高 ({round(mention_rate, 1)}%)")
- else:
- weaknesses.append(f"提及率偏低 ({round(mention_rate, 1)}%)")
+ for d in filtered_dimensions:
+ pct = d.get("percentage", 0)
+ if pct >= 60:
+ strengths.append(f"{d.get('name', '')}表现良好 ({round(pct, 1)}%)")
+ elif pct > 0:
+ weaknesses.append(f"{d.get('name', '')}有待提升 ({round(pct, 1)}%)")
- for p, score in platform_scores.items():
- if score >= 60:
- strengths.append(f"{p} 平台表现良好 ({score}%)")
- elif score > 0:
- weaknesses.append(f"{p} 平台覆盖率不足 ({score}%)")
+ if not strengths:
+ strengths.append("已有初步诊断数据")
+ if not weaknesses:
+ weaknesses.append("暂无明显短板")
- if sentiment_counts["positive"] > sentiment_counts["negative"]:
- strengths.append("情感倾向正面")
- elif sentiment_counts["negative"] > sentiment_counts["positive"]:
- weaknesses.append("情感倾向偏负面")
+ competitor_stmt = select(Competitor).where(Competitor.brand_id == brand_id)
+ competitor_result = await db.execute(competitor_stmt)
+ competitors = list(competitor_result.scalars().all())
- if not strengths:
- strengths.append("已有初步引用数据")
- if not weaknesses:
- weaknesses.append("暂无明显短板")
-
- # 竞品评分
competitor_scores = []
- for comp_name, mentions in competitor_mentions.items():
- comp_score = round(mentions / total * 100, 2) if total > 0 else 0.0
+ for comp in competitors:
competitor_scores.append({
- "name": comp_name,
- "score": comp_score,
+ "name": comp.name,
+ "score": 0.0,
+ "is_leading": False,
})
+ dimension_items = [
+ HealthDimensionItem(
+ name=d.get("name", ""),
+ score=round(d.get("score", 0.0), 2),
+ max_score=d.get("max_score", 0.0),
+ percentage=round(d.get("percentage", 0.0), 2),
+ status=d.get("status", "fail"),
+ )
+ for d in filtered_dimensions
+ ]
+
+ recommendation_items = [
+ HealthRecommendationItem(
+ priority=r.get("priority", "P0"),
+ dimension=r.get("dimension", ""),
+ title=r.get("title", ""),
+ description=r.get("description", ""),
+ )
+ for r in filtered_recommendations
+ ]
+
return HealthReportResponse(
brand_id=str(brand.id),
brand_name=brand.name,
- overall_score=round(v2_result.overall_score, 2),
+ overall_score=overall_score,
+ health_level=health_level,
+ health_level_label=get_health_level_label(health_level),
platform_scores=platform_scores,
strengths=strengths,
weaknesses=weaknesses,
competitor_scores=competitor_scores,
+ dimensions=dimension_items,
+ recommendations=recommendation_items,
+ is_full_report=is_paid,
)
@@ -394,21 +446,21 @@ async def get_onboarding_action_suggestions(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
- """
- 根据健康报告生成行动建议(基于规则引擎,不需要 LLM)。
- """
- # 先获取健康报告数据(复用逻辑)
report = await get_onboarding_health_report(brand_id, current_user, db)
+ user_plan = getattr(current_user, "plan", None) or "free"
+ is_paid = user_plan not in ("free", None)
+
suggestions = []
- # 规则引擎:基于评分和平台数据生成建议
if report.overall_score < 20:
suggestions.append(ActionSuggestion(
title="提升 AI 平台覆盖率",
description=f"当前综合评分仅 {report.overall_score},品牌在AI搜索中几乎未被提及。建议增加查询词覆盖面,让AI平台更频繁地引用品牌。",
priority="high",
action_type="coverage",
+ is_paid_action=False,
+ action_button_text="设置查询词",
))
if report.overall_score < 50:
@@ -417,9 +469,10 @@ async def get_onboarding_action_suggestions(
description="品牌在关键查询词下的提及率偏低,建议调整查询关键词策略,聚焦行业核心术语。",
priority="high",
action_type="keyword",
+ is_paid_action=False,
+ action_button_text="优化关键词",
))
- # 平台维度建议
for platform, score in report.platform_scores.items():
if score < 30:
suggestions.append(ActionSuggestion(
@@ -427,28 +480,32 @@ async def get_onboarding_action_suggestions(
description=f"品牌在 {platform} 平台的引用率仅为 {score}%,需要针对性优化该平台的内容策略。",
priority="medium",
action_type="platform",
+ is_paid_action=False,
+ action_button_text=f"优化{platform}",
))
- # 情感维度建议
- if "情感倾向偏负面" in report.weaknesses:
- suggestions.append(ActionSuggestion(
- title="改善品牌情感倾向",
- description="AI平台对品牌的情感评价偏负面,建议发布正面品牌内容、优化品牌描述以改善情感得分。",
- priority="medium",
- action_type="sentiment",
- ))
+ for dim in report.dimensions:
+ if dim.percentage < 40 and dim.name not in _FREE_TIER_DIMENSIONS:
+ suggestions.append(ActionSuggestion(
+ title=f"提升{dim.name}得分",
+ description=f"{dim.name}当前得分仅 {round(dim.percentage, 1)}%,解锁详细诊断和优化方案。",
+ priority="medium",
+ action_type="dimension",
+ is_paid_action=True,
+ action_button_text="升级解锁",
+ ))
- # 竞品对比建议
for comp in report.competitor_scores:
- if comp["score"] > report.overall_score:
+ if comp.get("score", 0) > report.overall_score:
suggestions.append(ActionSuggestion(
title=f"应对竞品 {comp['name']} 威胁",
description=f"竞品 {comp['name']} 评分 ({comp['score']}) 高于本品牌 ({report.overall_score}),建议分析竞品优势领域并制定差异化策略。",
priority="high",
action_type="competitive",
+ is_paid_action=True,
+ action_button_text="查看竞品分析",
))
- # 如果没有引用数据,给出基础建议
if report.overall_score == 0:
suggestions = [
ActionSuggestion(
@@ -456,28 +513,45 @@ async def get_onboarding_action_suggestions(
description="品牌尚无查询数据,建议首先设置与品牌最相关的核心查询词,让系统开始数据采集。",
priority="high",
action_type="keyword",
+ is_paid_action=False,
+ action_button_text="设置查询词",
),
ActionSuggestion(
title="添加竞品对比",
description="添加主要竞品以便进行对比分析,了解品牌在市场中的定位。",
priority="medium",
action_type="coverage",
+ is_paid_action=False,
+ action_button_text="添加竞品",
),
ActionSuggestion(
- title="完善品牌信息",
- description="补充品牌别名、网站、行业等详细信息,有助于提升AI平台识别率。",
+ title="解锁完整6维度诊断",
+ description="免费版仅展示3个核心维度,升级后可查看完整6维度诊断报告和深度优化方案。",
priority="medium",
- action_type="brand_info",
+ action_type="upgrade",
+ is_paid_action=True,
+ action_button_text="升级Pro",
),
]
- # 确保至少有1条建议
if not suggestions:
suggestions.append(ActionSuggestion(
title="持续监测品牌表现",
description="品牌表现良好,建议持续监测并保持当前策略。",
priority="low",
action_type="monitor",
+ is_paid_action=False,
+ action_button_text="查看Dashboard",
+ ))
+
+ if not is_paid:
+ suggestions.append(ActionSuggestion(
+ title="升级Pro解锁完整诊断",
+ description="免费版仅展示3个核心维度和P0建议。升级Pro可获取完整6维度诊断、深度竞品分析和AI优化方案。",
+ priority="low",
+ action_type="upgrade",
+ is_paid_action=True,
+ action_button_text="升级Pro",
))
return ActionSuggestionsResponse(suggestions=suggestions)
@@ -496,8 +570,7 @@ async def complete_onboarding(
User 模型当前没有 onboarding_completed 专用字段,
品牌的创建即代表 onboarding 完成。
"""
- # 验证品牌归属
- stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == current_user.id)
+ stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == _to_uuid(current_user.id))
result = await db.execute(stmt)
brand = result.scalar_one_or_none()
@@ -507,8 +580,6 @@ async def complete_onboarding(
detail="品牌不存在",
)
- # 品牌已创建即代表 onboarding 完成,无需额外字段更新
- # 后续如需专用字段,可通过 alembic 迁移添加 user.onboarding_completed
logger.info(f"User {current_user.id} completed onboarding with brand {brand_id}")
return OnboardingCompleteResponse(success=True)
\ No newline at end of file
diff --git a/backend/app/api/payments.py b/backend/app/api/payments.py
new file mode 100644
index 0000000..94005ac
--- /dev/null
+++ b/backend/app/api/payments.py
@@ -0,0 +1,274 @@
+import logging
+import uuid
+from datetime import datetime, timezone
+
+from fastapi import APIRouter, Depends, HTTPException, Request, status
+from pydantic import BaseModel
+from sqlalchemy import select
+from sqlalchemy.ext.asyncio import AsyncSession
+
+from app.api.deps import get_current_user
+from app.database import get_db
+from app.models.payment_order import PaymentOrder as PaymentOrderModel
+from app.models.user import User
+from app.services.payment import get_payment_gateway
+from app.services.subscription import PLANS, subscribe
+
+logger = logging.getLogger(__name__)
+
+router = APIRouter(prefix="/api/v1/payments", tags=["支付"])
+
+
+class CreateOrderRequest(BaseModel):
+ plan: str
+ payment_provider: str = "wechat"
+
+
+class CreateOrderResponse(BaseModel):
+ order_id: str
+ pay_url: str
+ amount: float
+ currency: str = "CNY"
+ status: str = "pending"
+
+
+class OrderStatusResponse(BaseModel):
+ order_id: str
+ status: str
+ plan: str
+ amount: float
+ payment_provider: str
+ payment_id: str | None = None
+ created_at: str | None = None
+ paid_at: str | None = None
+
+
+class RefundRequest(BaseModel):
+ reason: str = ""
+
+
+@router.post("/orders", response_model=CreateOrderResponse, status_code=status.HTTP_201_CREATED)
+async def create_payment_order(
+ request: CreateOrderRequest,
+ db: AsyncSession = Depends(get_db),
+ current_user: User = Depends(get_current_user),
+):
+ plan_data = PLANS.get(request.plan)
+ if plan_data is None:
+ raise HTTPException(
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
+ detail=f"无效的套餐: {request.plan}",
+ )
+
+ if plan_data["price"] == 0:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail="免费套餐无需支付",
+ )
+
+ order_id = uuid.uuid4()
+ amount = plan_data["price"]
+
+ gateway = get_payment_gateway(request.payment_provider)
+ payment_order = await gateway.create_order(
+ order_id=str(order_id),
+ amount=amount,
+ description=f"GEO平台{plan_data['name']}订阅",
+ user_id=current_user.id,
+ plan=request.plan,
+ )
+
+ db_order = PaymentOrderModel(
+ id=order_id,
+ user_id=current_user.id,
+ plan=request.plan,
+ amount=amount,
+ payment_provider=request.payment_provider,
+ status="pending",
+ pay_url=payment_order.pay_url,
+ )
+ db.add(db_order)
+ await db.commit()
+
+ return CreateOrderResponse(
+ order_id=str(order_id),
+ pay_url=payment_order.pay_url,
+ amount=amount,
+ status="pending",
+ )
+
+
+@router.post("/callback/wechat")
+async def wechat_pay_callback(request: Request):
+ body = await request.form()
+ request_data = dict(body)
+
+ gateway = get_payment_gateway("wechat")
+ callback = await gateway.verify_callback(request_data)
+
+ return await _handle_payment_callback(request_data, callback, "wechat")
+
+
+@router.post("/callback/alipay")
+async def alipay_callback(request: Request):
+ body = await request.form()
+ request_data = dict(body)
+
+ gateway = get_payment_gateway("alipay")
+ callback = await gateway.verify_callback(request_data)
+
+ return await _handle_payment_callback(request_data, callback, "alipay")
+
+
+async def _handle_payment_callback(request_data: dict, callback, provider: str):
+ from app.database import AsyncSessionLocal
+
+ async with AsyncSessionLocal() as db:
+ try:
+ result = await _process_callback(db, callback, provider)
+ await db.commit()
+ return result
+ except Exception as e:
+ logger.error(f"[PaymentCallback] 处理回调异常: {e}", exc_info=True)
+ await db.rollback()
+ if provider == "wechat":
+ return _wechat_fail_response()
+ return "fail"
+
+
+async def _process_callback(db: AsyncSession, callback, provider: str):
+ stmt = select(PaymentOrderModel).where(
+ PaymentOrderModel.id == uuid.UUID(callback.order_id)
+ )
+ result = await db.execute(stmt)
+ order = result.scalar_one_or_none()
+
+ if order is None:
+ logger.warning(f"[PaymentCallback] 订单不存在: order_id={callback.order_id}")
+ if provider == "wechat":
+ return _wechat_fail_response()
+ return "fail"
+
+ if callback.status == "success":
+ order.status = "paid"
+ order.payment_id = callback.payment_id
+ order.callback_data = callback.raw_data
+ order.paid_at = datetime.now(timezone.utc)
+
+ await subscribe(db, order.user_id, order.plan)
+
+ logger.info(
+ f"[PaymentCallback] 支付成功: order_id={callback.order_id}, "
+ f"plan={order.plan}, provider={provider}"
+ )
+ else:
+ order.status = "failed"
+ order.callback_data = callback.raw_data
+
+ if provider == "wechat":
+ return _wechat_success_response()
+ return "success"
+
+
+def _wechat_success_response():
+ from fastapi.responses import Response
+ return Response(content="
| + + |
| + + |
| + + |
| + + |
Test
", + body_text="Test", + ) + result = mock_email_service.send_email(msg) + assert result.success is True + assert result.message_id is not None + assert result.message_id.startswith("sim_") + assert result.error is None + + def test_mock_mode_simulate_flag(self, mock_email_service): + assert mock_email_service.simulate_mode is True + + def test_production_mode_would_use_smtp(self): + service = EmailService( + simulate_mode=False, + smtp_host="localhost", + smtp_port=587, + smtp_user="test", + smtp_password="test", + ) + assert service.simulate_mode is False + + +class TestEmailTemplates: + def test_geo_weekly_report_template_exists(self): + template_path = TEMPLATES_DIR / "geo_weekly_report.html" + assert template_path.exists() + + def test_renewal_reminder_template_exists(self): + template_path = TEMPLATES_DIR / "renewal_reminder.html" + assert template_path.exists() + + def test_trial_expiring_template_exists(self): + template_path = TEMPLATES_DIR / "trial_expiring.html" + assert template_path.exists() + + def test_welcome_template_exists(self): + template_path = TEMPLATES_DIR / "welcome.html" + assert template_path.exists() + + def test_geo_weekly_report_renders_with_context(self, email_scheduler): + template_html = email_scheduler._load_template("geo_weekly_report.html") + context = { + "user_name": "测试用户", + "score_change": "+5", + "current_score": "78", + "previous_score": "73", + "top_improved": "内容质量 (+12%)", + "top_declined": "品牌权威 (-3%)", + "suggestions": "建议增加技术白皮书", + "report_link": "https://example.com", + "year": "2026", + } + rendered = email_scheduler._render_template(template_html, context) + assert "测试用户" in rendered + assert "+5" in rendered + assert "78" in rendered + assert "内容质量 (+12%)" in rendered + + def test_renewal_reminder_renders_with_context(self, email_scheduler): + template_html = email_scheduler._load_template("renewal_reminder.html") + context = { + "user_name": "续费用户", + "plan_name": "专业版", + "end_date": "2026-06-30", + "days_remaining": "7", + "plan_price": "599", + "renew_link": "https://example.com", + "year": "2026", + } + rendered = email_scheduler._render_template(template_html, context) + assert "续费用户" in rendered + assert "专业版" in rendered + assert "7" in rendered + assert "599" in rendered + + def test_trial_expiring_renders_with_context(self, email_scheduler): + template_html = email_scheduler._load_template("trial_expiring.html") + context = { + "user_name": "试用用户", + "days_remaining": "3", + "upgrade_link": "https://example.com", + "year": "2026", + } + rendered = email_scheduler._render_template(template_html, context) + assert "试用用户" in rendered + assert "3" in rendered + + def test_welcome_renders_with_context(self, email_scheduler): + template_html = email_scheduler._load_template("welcome.html") + context = { + "user_name": "新用户", + "dashboard_link": "https://example.com", + "diagnosis_link": "https://example.com/diagnosis", + "help_link": "https://example.com/help", + "year": "2026", + } + rendered = email_scheduler._render_template(template_html, context) + assert "新用户" in rendered + assert "3" in rendered + + +class TestSendTemplateEmail: + def test_send_template_email_with_welcome(self, mock_email_service): + result = mock_email_service.send_template_email( + to="test@example.com", + subject="欢迎", + template_name="welcome.html", + context={ + "user_name": "测试", + "dashboard_link": "https://example.com", + "diagnosis_link": "https://example.com/d", + "help_link": "https://example.com/h", + "year": "2026", + }, + ) + assert result.success is True + + def test_send_template_email_invalid_template(self, mock_email_service): + with pytest.raises(ValueError, match="模板文件不存在"): + mock_email_service.send_template_email( + to="test@example.com", + subject="Test", + template_name="nonexistent.html", + context={}, + ) + + def test_send_template_email_renders_context(self, mock_email_service): + result = mock_email_service.send_template_email( + to="test@example.com", + subject="周报", + template_name="geo_weekly_report.html", + context={ + "user_name": "渲染测试", + "score_change": "+10", + "current_score": "85", + "previous_score": "75", + "top_improved": "测试提升", + "top_declined": "测试下降", + "suggestions": "测试建议", + "report_link": "https://example.com", + "year": "2026", + }, + ) + assert result.success is True diff --git a/backend/tests/test_api/test_health_score_contract.py b/backend/tests/test_api/test_health_score_contract.py new file mode 100644 index 0000000..5d91974 --- /dev/null +++ b/backend/tests/test_api/test_health_score_contract.py @@ -0,0 +1,329 @@ +import hashlib +import uuid +from unittest.mock import AsyncMock, patch, MagicMock + +import pytest +import pytest_asyncio +from httpx import ASGITransport, AsyncClient +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.pool import StaticPool + +from app.api.deps import get_current_user, get_db +from app.database import Base +from app.main import app +from app.models.brand import Brand +from app.models.user import User +from app.services.auth import hash_password +from app.services.diagnosis.geo_diagnosis import GEODiagnosisInput, GEODiagnosisResult + + +def _make_user( + user_id: str | None = None, + email: str = "test@example.com", + plan: str = "free", +) -> User: + uid = user_id or str(uuid.uuid4()) + user = User( + id=uid, + email=email, + password=hash_password("Test@123456"), + firstName="Test", + lastName="User", + isActive=True, + emailVerified=True, + ) + user.plan = plan + user.max_queries = 50 if plan != "free" else 5 + return user + + +def _to_uuid(value: str | uuid.UUID) -> uuid.UUID: + if isinstance(value, uuid.UUID): + return value + return uuid.UUID(str(value)) + + +@pytest_asyncio.fixture +async def async_engine(): + engine = create_async_engine( + "sqlite+aiosqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + yield engine + await engine.dispose() + + +@pytest_asyncio.fixture +async def async_session(async_engine): + maker = async_sessionmaker( + async_engine, + class_=AsyncSession, + expire_on_commit=False, + autoflush=False, + autocommit=False, + ) + async with maker() as session: + yield session + + +@pytest_asyncio.fixture +async def test_user(async_session): + user = _make_user(plan="free") + async_session.add(user) + await async_session.commit() + await async_session.refresh(user) + return user + + +@pytest_asyncio.fixture +async def test_brand(async_session, test_user): + brand = Brand( + id=uuid.uuid4(), + user_id=_to_uuid(test_user.id), + name="TestBrand", + aliases=["TB", "Test Brand"], + website="https://testbrand.com", + industry="technology", + platforms=["wenxin", "kimi"], + frequency="weekly", + status="active", + ) + async_session.add(brand) + await async_session.commit() + await async_session.refresh(brand) + return brand + + +@pytest_asyncio.fixture +async def public_client(async_session): + async def override_get_db(): + yield async_session + + app.dependency_overrides[get_db] = override_get_db + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + yield client + + app.dependency_overrides.clear() + + +class TestHealthScorePublicContract: + @pytest.mark.asyncio + async def test_returns_200_with_brand(self, public_client, test_brand): + mock_cache = AsyncMock() + mock_cache.get_json = AsyncMock(return_value=None) + mock_cache.set_json = AsyncMock() + + mock_result = MagicMock() + mock_result.to_dict.return_value = { + "overall_score": 45.0, + "health_level": "pass", + "health_level_label": "及格", + "dimensions": [ + {"name": "内容可提取性", "score": 12.0, "max_score": 20.0, "percentage": 60.0, "status": "pass"}, + {"name": "E-E-A-T信号", "score": 10.0, "max_score": 20.0, "percentage": 50.0, "status": "pass"}, + {"name": "引用就绪度", "score": 8.0, "max_score": 15.0, "percentage": 53.3, "status": "pass"}, + {"name": "实体清晰度", "score": 8.0, "max_score": 15.0, "percentage": 53.3, "status": "pass"}, + {"name": "Schema标记", "score": 5.0, "max_score": 15.0, "percentage": 33.3, "status": "fail"}, + {"name": "主题权威", "score": 6.0, "max_score": 15.0, "percentage": 40.0, "status": "pass"}, + ], + "recommendations": [ + {"priority": "P0", "dimension": "内容可提取性", "title": "添加结构化数据", "description": "建议添加Schema标记"}, + ], + } + + with patch("app.api.health_score.get_cache_service", return_value=mock_cache), \ + patch("app.api.health_score.DataCollectorService") as MockCollector, \ + patch("app.api.health_score.GEODiagnosisService") as MockDiagService: + mock_collector = MockCollector.return_value + mock_collection = MagicMock() + mock_collection.diagnosis_input = GEODiagnosisInput() + mock_collector.collect = AsyncMock(return_value=mock_collection) + + mock_diag = MockDiagService.return_value + mock_diag.diagnose = MagicMock(return_value=mock_result) + + response = await public_client.get( + "/api/v1/public/health-score", + params={"brand": "TestBrand"}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["brand_name"] == "TestBrand" + assert data["overall_score"] > 0 + assert data["health_level"] in ("excellent", "good", "pass", "danger") + assert data["is_full_report"] is False + assert len(data["dimensions"]) == 3 + dim_names = {d["name"] for d in data["dimensions"]} + assert "内容可提取性" in dim_names + assert "E-E-A-T信号" in dim_names + assert "引用就绪度" in dim_names + + @pytest.mark.asyncio + async def test_returns_cached_result(self, public_client): + cached_data = { + "brand_name": "CachedBrand", + "overall_score": 55.0, + "health_level": "pass", + "health_level_label": "及格", + "dimensions": [ + {"name": "内容可提取性", "score": 15.0, "max_score": 20.0, "percentage": 75.0, "status": "good"}, + {"name": "E-E-A-T信号", "score": 12.0, "max_score": 20.0, "percentage": 60.0, "status": "pass"}, + {"name": "引用就绪度", "score": 10.0, "max_score": 15.0, "percentage": 66.7, "status": "pass"}, + ], + "recommendations": [], + "is_full_report": False, + "cached": False, + } + + mock_cache = AsyncMock() + mock_cache.get_json = AsyncMock(return_value=cached_data) + + with patch("app.api.health_score.get_cache_service", return_value=mock_cache): + response = await public_client.get( + "/api/v1/public/health-score", + params={"brand": "CachedBrand"}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["cached"] is True + assert data["brand_name"] == "CachedBrand" + + @pytest.mark.asyncio + async def test_missing_brand_param_returns_422(self, public_client): + response = await public_client.get("/api/v1/public/health-score") + assert response.status_code == 422 + + @pytest.mark.asyncio + async def test_collection_failure_returns_default(self, public_client): + mock_cache = AsyncMock() + mock_cache.get_json = AsyncMock(return_value=None) + mock_cache.set_json = AsyncMock() + + with patch("app.api.health_score.get_cache_service", return_value=mock_cache), \ + patch("app.api.health_score.DataCollectorService") as MockCollector: + mock_collector = MockCollector.return_value + mock_collector.collect = AsyncMock(side_effect=Exception("采集失败")) + + response = await public_client.get( + "/api/v1/public/health-score", + params={"brand": "UnknownBrand"}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["brand_name"] == "UnknownBrand" + assert data["overall_score"] == 0.0 + assert data["health_level"] == "danger" + + @pytest.mark.asyncio + async def test_competitors_param_accepted(self, public_client): + mock_cache = AsyncMock() + mock_cache.get_json = AsyncMock(return_value=None) + mock_cache.set_json = AsyncMock() + + mock_result = MagicMock() + mock_result.to_dict.return_value = { + "overall_score": 30.0, + "health_level": "pass", + "health_level_label": "及格", + "dimensions": [ + {"name": "内容可提取性", "score": 10.0, "max_score": 20.0, "percentage": 50.0, "status": "pass"}, + {"name": "E-E-A-T信号", "score": 8.0, "max_score": 20.0, "percentage": 40.0, "status": "fail"}, + {"name": "引用就绪度", "score": 5.0, "max_score": 15.0, "percentage": 33.3, "status": "fail"}, + {"name": "实体清晰度", "score": 5.0, "max_score": 15.0, "percentage": 33.3, "status": "fail"}, + {"name": "Schema标记", "score": 3.0, "max_score": 15.0, "percentage": 20.0, "status": "fail"}, + {"name": "主题权威", "score": 4.0, "max_score": 15.0, "percentage": 26.7, "status": "fail"}, + ], + "recommendations": [], + } + + with patch("app.api.health_score.get_cache_service", return_value=mock_cache), \ + patch("app.api.health_score.DataCollectorService") as MockCollector, \ + patch("app.api.health_score.GEODiagnosisService") as MockDiagService: + mock_collector = MockCollector.return_value + mock_collection = MagicMock() + mock_collection.diagnosis_input = GEODiagnosisInput() + mock_collector.collect = AsyncMock(return_value=mock_collection) + + mock_diag = MockDiagService.return_value + mock_diag.diagnose = MagicMock(return_value=mock_result) + + response = await public_client.get( + "/api/v1/public/health-score", + params={"brand": "SomeBrand", "competitors": "CompA,CompB,CompC"}, + ) + + assert response.status_code == 200 + + @pytest.mark.asyncio + async def test_only_p0_recommendations_returned(self, public_client): + mock_cache = AsyncMock() + mock_cache.get_json = AsyncMock(return_value=None) + mock_cache.set_json = AsyncMock() + + mock_result = MagicMock() + mock_result.to_dict.return_value = { + "overall_score": 30.0, + "health_level": "pass", + "health_level_label": "及格", + "dimensions": [ + {"name": "内容可提取性", "score": 10.0, "max_score": 20.0, "percentage": 50.0, "status": "pass"}, + {"name": "E-E-A-T信号", "score": 8.0, "max_score": 20.0, "percentage": 40.0, "status": "fail"}, + {"name": "引用就绪度", "score": 5.0, "max_score": 15.0, "percentage": 33.3, "status": "fail"}, + ], + "recommendations": [ + {"priority": "P0", "dimension": "内容可提取性", "title": "P0建议", "description": "紧急"}, + {"priority": "P1", "dimension": "E-E-A-T信号", "title": "P1建议", "description": "重要"}, + {"priority": "P2", "dimension": "引用就绪度", "title": "P2建议", "description": "一般"}, + ], + } + + with patch("app.api.health_score.get_cache_service", return_value=mock_cache), \ + patch("app.api.health_score.DataCollectorService") as MockCollector, \ + patch("app.api.health_score.GEODiagnosisService") as MockDiagService: + mock_collector = MockCollector.return_value + mock_collection = MagicMock() + mock_collection.diagnosis_input = GEODiagnosisInput() + mock_collector.collect = AsyncMock(return_value=mock_collection) + + mock_diag = MockDiagService.return_value + mock_diag.diagnose = MagicMock(return_value=mock_result) + + response = await public_client.get( + "/api/v1/public/health-score", + params={"brand": "TestBrand"}, + ) + + assert response.status_code == 200 + data = response.json() + assert len(data["recommendations"]) == 1 + assert data["recommendations"][0]["priority"] == "P0" + + @pytest.mark.asyncio + async def test_no_auth_required(self, public_client): + mock_cache = AsyncMock() + mock_cache.get_json = AsyncMock(return_value={ + "brand_name": "Test", + "overall_score": 50.0, + "health_level": "pass", + "health_level_label": "及格", + "dimensions": [], + "recommendations": [], + "is_full_report": False, + "cached": False, + }) + + with patch("app.api.health_score.get_cache_service", return_value=mock_cache): + response = await public_client.get( + "/api/v1/public/health-score", + params={"brand": "Test"}, + ) + + assert response.status_code == 200 diff --git a/backend/tests/test_api/test_onboarding_contract.py b/backend/tests/test_api/test_onboarding_contract.py new file mode 100644 index 0000000..f9fd9d5 --- /dev/null +++ b/backend/tests/test_api/test_onboarding_contract.py @@ -0,0 +1,453 @@ +import uuid +from unittest.mock import AsyncMock, patch + +import pytest +import pytest_asyncio +from httpx import ASGITransport, AsyncClient +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.pool import StaticPool + +from app.api.deps import get_current_user, get_db +from app.database import Base +from app.main import app +from app.models.brand import Brand +from app.models.user import User +from app.services.auth import hash_password + + +def _make_user( + user_id: str | None = None, + email: str = "test@example.com", + plan: str = "free", +) -> User: + uid = user_id or str(uuid.uuid4()) + user = User( + id=uid, + email=email, + password=hash_password("Test@123456"), + firstName="Test", + lastName="User", + isActive=True, + emailVerified=True, + ) + user.plan = plan + user.max_queries = 50 if plan != "free" else 5 + return user + + +def _to_uuid(value: str | uuid.UUID) -> uuid.UUID: + if isinstance(value, uuid.UUID): + return value + return uuid.UUID(str(value)) + + +@pytest_asyncio.fixture +async def async_engine(): + engine = create_async_engine( + "sqlite+aiosqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + yield engine + await engine.dispose() + + +@pytest_asyncio.fixture +async def async_session(async_engine): + maker = async_sessionmaker( + async_engine, + class_=AsyncSession, + expire_on_commit=False, + autoflush=False, + autocommit=False, + ) + async with maker() as session: + yield session + + +@pytest_asyncio.fixture +async def test_user(async_session): + user = _make_user(plan="free") + async_session.add(user) + await async_session.commit() + await async_session.refresh(user) + return user + + +@pytest_asyncio.fixture +async def paid_user(async_session): + user = _make_user(email="paid@example.com", plan="pro") + async_session.add(user) + await async_session.commit() + await async_session.refresh(user) + return user + + +@pytest_asyncio.fixture +async def test_brand(async_session, test_user): + brand = Brand( + id=uuid.uuid4(), + user_id=_to_uuid(test_user.id), + name="TestBrand", + aliases=["TB"], + website="https://testbrand.com", + industry="technology", + platforms=["wenxin", "kimi"], + frequency="weekly", + status="active", + ) + async_session.add(brand) + await async_session.commit() + await async_session.refresh(brand) + return brand + + +@pytest_asyncio.fixture +async def paid_brand(async_session, paid_user): + brand = Brand( + id=uuid.uuid4(), + user_id=_to_uuid(paid_user.id), + name="PaidBrand", + aliases=["PB"], + website="https://paidbrand.com", + industry="technology", + platforms=["wenxin", "kimi"], + frequency="weekly", + status="active", + ) + async_session.add(brand) + await async_session.commit() + await async_session.refresh(brand) + return brand + + +@pytest_asyncio.fixture +async def async_client(async_session, test_user): + async def override_get_db(): + yield async_session + + async def override_get_current_user(): + return test_user + + app.dependency_overrides[get_db] = override_get_db + app.dependency_overrides[get_current_user] = override_get_current_user + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + yield client + + app.dependency_overrides.clear() + + +@pytest_asyncio.fixture +async def paid_client(async_session, paid_user): + async def override_get_db(): + yield async_session + + async def override_get_current_user(): + return paid_user + + app.dependency_overrides[get_db] = override_get_db + app.dependency_overrides[get_current_user] = override_get_current_user + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + yield client + + app.dependency_overrides.clear() + + +class TestOnboardingStatusContract: + @pytest.mark.asyncio + async def test_status_incomplete_no_brand(self, async_client): + response = await async_client.get("/api/v1/onboarding/status") + assert response.status_code == 200 + data = response.json() + assert data["completed"] is False + assert data["brand_id"] is None + assert data["current_step"] == 1 + + @pytest.mark.asyncio + async def test_status_complete_with_brand(self, async_client, test_brand): + response = await async_client.get("/api/v1/onboarding/status") + assert response.status_code == 200 + data = response.json() + assert data["completed"] is True + assert data["brand_id"] == str(test_brand.id) + + +class TestOnboardingHealthReportContract: + @pytest.mark.asyncio + async def test_health_report_returns_diagnosis_data(self, async_client, test_brand): + with patch("app.api.onboarding.DataCollectorService") as mock_collector_cls: + mock_collector = AsyncMock() + mock_collector_cls.return_value = mock_collector + + from app.services.diagnosis.data_collector import DataCollectionResult + from app.services.diagnosis.geo_diagnosis import GEODiagnosisInput, GEODiagnosisService + + input_data = GEODiagnosisInput( + has_direct_answer=True, + has_brand_definition=True, + answer_ownership_rate=0.3, + ) + mock_collector.collect.return_value = DataCollectionResult( + diagnosis_input=input_data, + ) + + service = GEODiagnosisService() + result = service.diagnose(input_data) + + with patch("app.api.onboarding.GEODiagnosisService", return_value=service): + response = await async_client.get( + f"/api/v1/onboarding/health-report/{test_brand.id}" + ) + + assert response.status_code == 200 + data = response.json() + assert data["brand_id"] == str(test_brand.id) + assert data["brand_name"] == "TestBrand" + assert "overall_score" in data + assert "health_level" in data + assert "health_level_label" in data + assert "dimensions" in data + assert "recommendations" in data + assert "is_full_report" in data + assert isinstance(data["dimensions"], list) + assert isinstance(data["recommendations"], list) + + @pytest.mark.asyncio + async def test_health_report_brand_not_found(self, async_client): + response = await async_client.get( + f"/api/v1/onboarding/health-report/{uuid.uuid4()}" + ) + assert response.status_code == 404 + + @pytest.mark.asyncio + async def test_health_report_free_user_3_dimensions( + self, async_client, test_brand + ): + with patch("app.api.onboarding.DataCollectorService") as mock_collector_cls: + mock_collector = AsyncMock() + mock_collector_cls.return_value = mock_collector + + from app.services.diagnosis.data_collector import DataCollectionResult + from app.services.diagnosis.geo_diagnosis import GEODiagnosisInput, GEODiagnosisService + + input_data = GEODiagnosisInput( + has_direct_answer=True, + has_brand_definition=True, + has_author_bio=True, + author_credentials_complete=0.8, + has_organization=True, + content_depth_score=0.7, + answer_ownership_rate=0.3, + ) + mock_collector.collect.return_value = DataCollectionResult( + diagnosis_input=input_data, + ) + + service = GEODiagnosisService() + result = service.diagnose(input_data) + + with patch("app.api.onboarding.GEODiagnosisService", return_value=service): + response = await async_client.get( + f"/api/v1/onboarding/health-report/{test_brand.id}" + ) + + data = response.json() + assert data["is_full_report"] is False + dim_names = {d["name"] for d in data["dimensions"]} + assert len(dim_names) == 3 + assert "内容可提取性" in dim_names + assert "E-E-A-T信号" in dim_names + assert "引用就绪度" in dim_names + + @pytest.mark.asyncio + async def test_health_report_paid_user_6_dimensions( + self, paid_client, paid_brand + ): + with patch("app.api.onboarding.DataCollectorService") as mock_collector_cls: + mock_collector = AsyncMock() + mock_collector_cls.return_value = mock_collector + + from app.services.diagnosis.data_collector import DataCollectionResult + from app.services.diagnosis.geo_diagnosis import GEODiagnosisInput, GEODiagnosisService + + input_data = GEODiagnosisInput( + has_direct_answer=True, + has_brand_definition=True, + has_author_bio=True, + author_credentials_complete=0.8, + has_organization=True, + content_depth_score=0.7, + answer_ownership_rate=0.3, + ) + mock_collector.collect.return_value = DataCollectionResult( + diagnosis_input=input_data, + ) + + service = GEODiagnosisService() + result = service.diagnose(input_data) + + with patch("app.api.onboarding.GEODiagnosisService", return_value=service): + response = await paid_client.get( + f"/api/v1/onboarding/health-report/{paid_brand.id}" + ) + + data = response.json() + assert data["is_full_report"] is True + assert len(data["dimensions"]) == 6 + + @pytest.mark.asyncio + async def test_health_report_collection_failure_fallback( + self, async_client, test_brand + ): + with patch("app.api.onboarding.DataCollectorService") as mock_collector_cls: + mock_collector = AsyncMock() + mock_collector_cls.return_value = mock_collector + mock_collector.collect.side_effect = Exception("DB error") + + response = await async_client.get( + f"/api/v1/onboarding/health-report/{test_brand.id}" + ) + + assert response.status_code == 200 + data = response.json() + assert data["overall_score"] == 0 + assert data["health_level"] == "danger" + assert len(data["dimensions"]) == 3 + + +class TestOnboardingActionSuggestionsContract: + @pytest.mark.asyncio + async def test_suggestions_have_paid_action_field( + self, async_client, test_brand + ): + with patch("app.api.onboarding.DataCollectorService") as mock_collector_cls: + mock_collector = AsyncMock() + mock_collector_cls.return_value = mock_collector + + from app.services.diagnosis.data_collector import DataCollectionResult + from app.services.diagnosis.geo_diagnosis import GEODiagnosisInput, GEODiagnosisService + + input_data = GEODiagnosisInput( + has_direct_answer=True, + has_brand_definition=True, + answer_ownership_rate=0.3, + ) + mock_collector.collect.return_value = DataCollectionResult( + diagnosis_input=input_data, + ) + + service = GEODiagnosisService() + result = service.diagnose(input_data) + + with patch("app.api.onboarding.GEODiagnosisService", return_value=service): + response = await async_client.get( + f"/api/v1/onboarding/action-suggestions/{test_brand.id}" + ) + + assert response.status_code == 200 + data = response.json() + assert "suggestions" in data + assert isinstance(data["suggestions"], list) + for s in data["suggestions"]: + assert "is_paid_action" in s + assert "action_button_text" in s + assert isinstance(s["is_paid_action"], bool) + + @pytest.mark.asyncio + async def test_free_user_gets_upgrade_suggestion( + self, async_client, test_brand + ): + with patch("app.api.onboarding.DataCollectorService") as mock_collector_cls: + mock_collector = AsyncMock() + mock_collector_cls.return_value = mock_collector + + from app.services.diagnosis.data_collector import DataCollectionResult + from app.services.diagnosis.geo_diagnosis import GEODiagnosisInput, GEODiagnosisService + + input_data = GEODiagnosisInput( + has_direct_answer=True, + has_brand_definition=True, + answer_ownership_rate=0.3, + ) + mock_collector.collect.return_value = DataCollectionResult( + diagnosis_input=input_data, + ) + + service = GEODiagnosisService() + result = service.diagnose(input_data) + + with patch("app.api.onboarding.GEODiagnosisService", return_value=service): + response = await async_client.get( + f"/api/v1/onboarding/action-suggestions/{test_brand.id}" + ) + + data = response.json() + upgrade_suggestions = [ + s for s in data["suggestions"] if s["action_type"] == "upgrade" + ] + assert len(upgrade_suggestions) >= 1 + assert any(s["is_paid_action"] for s in upgrade_suggestions) + + @pytest.mark.asyncio + async def test_zero_score_suggestions_include_upgrade( + self, async_client, test_brand + ): + with patch("app.api.onboarding.DataCollectorService") as mock_collector_cls: + mock_collector = AsyncMock() + mock_collector_cls.return_value = mock_collector + mock_collector.collect.side_effect = Exception("No data") + + response = await async_client.get( + f"/api/v1/onboarding/action-suggestions/{test_brand.id}" + ) + + assert response.status_code == 200 + data = response.json() + upgrade_suggestions = [ + s for s in data["suggestions"] if s["action_type"] == "upgrade" + ] + assert len(upgrade_suggestions) >= 1 + + +class TestOnboardingCreateBrandContract: + @pytest.mark.asyncio + async def test_create_brand_success(self, async_client): + response = await async_client.post( + "/api/v1/onboarding/brand", + json={"name": "NewBrand", "industry": "tech"}, + ) + assert response.status_code == 201 + data = response.json() + assert data["name"] == "NewBrand" + + @pytest.mark.asyncio + async def test_create_brand_short_name_rejected(self, async_client): + response = await async_client.post( + "/api/v1/onboarding/brand", + json={"name": "A"}, + ) + assert response.status_code == 422 + + +class TestOnboardingCompleteContract: + @pytest.mark.asyncio + async def test_complete_onboarding(self, async_client, test_brand): + response = await async_client.post( + f"/api/v1/onboarding/complete/{test_brand.id}" + ) + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + + @pytest.mark.asyncio + async def test_complete_onboarding_brand_not_found(self, async_client): + response = await async_client.post( + f"/api/v1/onboarding/complete/{uuid.uuid4()}" + ) + assert response.status_code == 404 diff --git a/backend/tests/test_api/test_payment_contract.py b/backend/tests/test_api/test_payment_contract.py new file mode 100644 index 0000000..36a0c1c --- /dev/null +++ b/backend/tests/test_api/test_payment_contract.py @@ -0,0 +1,372 @@ +import uuid +from unittest.mock import AsyncMock, patch, MagicMock + +import pytest +import pytest_asyncio +from fastapi import Depends, FastAPI +from httpx import ASGITransport, AsyncClient +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.pool import StaticPool + +from app.api.deps import get_current_user, get_db +from app.database import Base +from app.main import app +from app.middleware.subscription_enforcement import SubscriptionEnforcement +from app.models.payment_order import PaymentOrder as PaymentOrderModel +from app.models.user import User +from app.services.auth import hash_password +from app.services.payment.base import PaymentCallback + + +def _make_user( + user_id: str | None = None, + email: str = "test@example.com", + plan: str = "free", +) -> User: + uid = user_id or str(uuid.uuid4()) + user = User( + id=uid, + email=email, + password=hash_password("Test@123456"), + firstName="Test", + lastName="User", + isActive=True, + emailVerified=True, + ) + user.plan = plan + user.max_queries = 50 if plan != "free" else 5 + return user + + +@pytest_asyncio.fixture +async def async_engine(): + engine = create_async_engine( + "sqlite+aiosqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + yield engine + await engine.dispose() + + +@pytest_asyncio.fixture +async def async_session(async_engine): + maker = async_sessionmaker( + async_engine, + class_=AsyncSession, + expire_on_commit=False, + autoflush=False, + autocommit=False, + ) + async with maker() as session: + yield session + + +@pytest_asyncio.fixture +async def free_user(async_session): + user = _make_user(plan="free") + async_session.add(user) + await async_session.commit() + await async_session.refresh(user) + return user + + +@pytest_asyncio.fixture +async def pro_user(async_session): + uid = str(uuid.uuid4()) + user = _make_user(user_id=uid, email="pro@example.com", plan="pro") + async_session.add(user) + await async_session.commit() + await async_session.refresh(user) + return user + + +@pytest_asyncio.fixture +async def enterprise_user(async_session): + uid = str(uuid.uuid4()) + user = _make_user(user_id=uid, email="enterprise@example.com", plan="enterprise") + async_session.add(user) + await async_session.commit() + await async_session.refresh(user) + return user + + +@pytest_asyncio.fixture +async def client_with_free_user(async_session, free_user): + async def override_get_db(): + yield async_session + + async def override_get_current_user(): + return free_user + + app.dependency_overrides[get_db] = override_get_db + app.dependency_overrides[get_current_user] = override_get_current_user + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + yield client + + app.dependency_overrides.clear() + + +@pytest_asyncio.fixture +async def client_with_pro_user(async_session, pro_user): + async def override_get_db(): + yield async_session + + async def override_get_current_user(): + return pro_user + + app.dependency_overrides[get_db] = override_get_db + app.dependency_overrides[get_current_user] = override_get_current_user + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + yield client + + app.dependency_overrides.clear() + + +@pytest_asyncio.fixture +async def client_with_enterprise_user(async_session, enterprise_user): + async def override_get_db(): + yield async_session + + async def override_get_current_user(): + return enterprise_user + + app.dependency_overrides[get_db] = override_get_db + app.dependency_overrides[get_current_user] = override_get_current_user + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + yield client + + app.dependency_overrides.clear() + + +class TestCreatePaymentOrder: + @pytest.mark.asyncio + async def test_create_order_returns_201_with_order_id_and_pay_url(self, client_with_free_user): + response = await client_with_free_user.post( + "/api/v1/payments/orders", + json={"plan": "starter", "payment_provider": "wechat"}, + ) + assert response.status_code == 201 + data = response.json() + assert "order_id" in data + assert "pay_url" in data + assert data["amount"] == 199 + assert data["status"] == "pending" + assert data["currency"] == "CNY" + + @pytest.mark.asyncio + async def test_create_order_with_invalid_plan_returns_422(self, client_with_free_user): + response = await client_with_free_user.post( + "/api/v1/payments/orders", + json={"plan": "invalid_plan", "payment_provider": "wechat"}, + ) + assert response.status_code == 422 + + @pytest.mark.asyncio + async def test_create_order_for_free_plan_returns_400(self, client_with_free_user): + response = await client_with_free_user.post( + "/api/v1/payments/orders", + json={"plan": "free", "payment_provider": "wechat"}, + ) + assert response.status_code == 400 + + @pytest.mark.asyncio + async def test_create_order_pro_plan(self, client_with_free_user): + response = await client_with_free_user.post( + "/api/v1/payments/orders", + json={"plan": "pro", "payment_provider": "alipay"}, + ) + assert response.status_code == 201 + data = response.json() + assert data["amount"] == 599 + + @pytest.mark.asyncio + async def test_create_order_enterprise_plan(self, client_with_free_user): + response = await client_with_free_user.post( + "/api/v1/payments/orders", + json={"plan": "enterprise", "payment_provider": "wechat"}, + ) + assert response.status_code == 201 + data = response.json() + assert data["amount"] == 1999 + + +class TestWechatCallback: + @pytest.mark.asyncio + async def test_wechat_callback_updates_order_and_activates_subscription( + self, async_session, free_user + ): + order_id = uuid.uuid4() + order = PaymentOrderModel( + id=order_id, + user_id=free_user.id, + plan="starter", + amount=199, + payment_provider="wechat", + status="pending", + pay_url=f"mock://pay/{order_id}", + ) + async_session.add(order) + await async_session.commit() + + callback = PaymentCallback( + order_id=str(order_id), + payment_id=f"wx_pay_{order_id}", + amount=199, + status="success", + raw_data={"out_trade_no": str(order_id), "result_code": "SUCCESS"}, + ) + + from app.api.payments import _process_callback + result = await _process_callback(async_session, callback, "wechat") + await async_session.commit() + + await async_session.refresh(order) + assert order.status == "paid" + assert order.payment_id is not None + + +class TestAlipayCallback: + @pytest.mark.asyncio + async def test_alipay_callback_updates_order_and_activates_subscription( + self, async_session, free_user + ): + order_id = uuid.uuid4() + order = PaymentOrderModel( + id=order_id, + user_id=free_user.id, + plan="pro", + amount=599, + payment_provider="alipay", + status="pending", + pay_url=f"mock://pay/{order_id}", + ) + async_session.add(order) + await async_session.commit() + + callback = PaymentCallback( + order_id=str(order_id), + payment_id=f"ali_pay_{order_id}", + amount=599, + status="success", + raw_data={"out_trade_no": str(order_id), "trade_status": "TRADE_SUCCESS"}, + ) + + from app.api.payments import _process_callback + result = await _process_callback(async_session, callback, "alipay") + await async_session.commit() + + await async_session.refresh(order) + assert order.status == "paid" + assert order.payment_id is not None + + +class TestQueryOrder: + @pytest.mark.asyncio + async def test_query_order_returns_current_status(self, client_with_free_user, async_session, free_user): + order_id = uuid.uuid4() + order = PaymentOrderModel( + id=order_id, + user_id=free_user.id, + plan="starter", + amount=199, + payment_provider="wechat", + status="pending", + pay_url=f"mock://pay/{order_id}", + ) + async_session.add(order) + await async_session.commit() + + response = await client_with_free_user.get(f"/api/v1/payments/orders/{order_id}") + assert response.status_code == 200 + data = response.json() + assert data["order_id"] == str(order_id) + assert data["status"] == "pending" + assert data["plan"] == "starter" + assert data["amount"] == 199 + + @pytest.mark.asyncio + async def test_query_nonexistent_order_returns_404(self, client_with_free_user): + fake_id = uuid.uuid4() + response = await client_with_free_user.get(f"/api/v1/payments/orders/{fake_id}") + assert response.status_code == 404 + + +class TestSubscriptionEnforcement: + @pytest.mark.asyncio + async def test_free_user_accessing_pro_only_endpoint_gets_403(self, client_with_free_user): + test_app = FastAPI() + + @test_app.get("/test/pro-only") + async def pro_only(user: User = Depends(SubscriptionEnforcement.require_plan("pro", "enterprise"))): + return {"message": "ok"} + + async def override_get_current_user(): + return _make_user(plan="free") + + test_app.dependency_overrides[get_current_user] = override_get_current_user + + transport = ASGITransport(app=test_app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.get("/test/pro-only") + + assert response.status_code == 403 + data = response.json() + assert data["detail"]["required_plan"] == "pro" + assert data["detail"]["current_plan"] == "free" + + @pytest.mark.asyncio + async def test_pro_user_accessing_pro_only_endpoint_succeeds(self, client_with_pro_user): + test_app = FastAPI() + + @test_app.get("/test/pro-only-2") + async def pro_only(user: User = Depends(SubscriptionEnforcement.require_plan("pro", "enterprise"))): + return {"message": "ok"} + + async def override_get_current_user(): + return _make_user(plan="pro") + + test_app.dependency_overrides[get_current_user] = override_get_current_user + + transport = ASGITransport(app=test_app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.get("/test/pro-only-2") + + assert response.status_code == 200 + + +class TestQuotaCheck: + @pytest.mark.asyncio + async def test_quota_check_returns_remaining_usage_info(self, client_with_free_user): + test_app = FastAPI() + + @test_app.get("/test/quota") + async def check_quota(quota=Depends(SubscriptionEnforcement.check_quota("queries"))): + return quota + + async def override_get_current_user(): + return _make_user(plan="free") + + async def override_get_db(): + yield AsyncMock() + + test_app.dependency_overrides[get_current_user] = override_get_current_user + test_app.dependency_overrides[get_db] = override_get_db + + transport = ASGITransport(app=test_app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.get("/test/quota") + + assert response.status_code == 200 + data = response.json() + assert data["resource"] == "queries" + assert data["plan"] == "free" + assert "remaining" in data diff --git a/tests/test_content_generation.py b/backend/tests/test_content_pipeline/test_content_generation.py similarity index 100% rename from tests/test_content_generation.py rename to backend/tests/test_content_pipeline/test_content_generation.py diff --git a/backend/tests/test_infrastructure/__init__.py b/backend/tests/test_infrastructure/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/tests/test_infrastructure/test_middleware_unified.py b/backend/tests/test_infrastructure/test_middleware_unified.py new file mode 100644 index 0000000..f1fc1f1 --- /dev/null +++ b/backend/tests/test_infrastructure/test_middleware_unified.py @@ -0,0 +1,76 @@ +"""验证 middleware/ 统一导入结构。 + +合并策略: +- app/middleware/ 作为统一目录 +- monitoring/metrics.py (Prometheus指标定义) -> middleware/prometheus_metrics.py +- monitoring/middleware.py (MonitoringMiddleware) -> 合并到 middleware/metrics.py +- monitoring/agent_hooks.py -> middleware/agent_hooks.py +- monitoring/llm_metrics.py -> middleware/llm_metrics.py +- monitoring/ 目录已删除 +""" +import pytest + + +class TestUnifiedMiddlewareImports: + """测试统一后的 app.middleware 导入路径是否可用。""" + + def test_import_logging_filter(self): + from app.middleware.logging_filter import APIKeyFilter + assert APIKeyFilter is not None + + def test_import_logging_middleware(self): + from app.middleware.logging_middleware import RequestLoggingMiddleware + assert RequestLoggingMiddleware is not None + + def test_import_rate_limit(self): + from app.middleware.rate_limit import RateLimitMiddleware + assert RateLimitMiddleware is not None + + def test_import_request_id(self): + from app.middleware.request_id import RequestIdMiddleware, REQUEST_ID_HEADER + assert RequestIdMiddleware is not None + assert REQUEST_ID_HEADER == "X-Request-ID" + + def test_import_metrics_middleware(self): + from app.middleware.metrics import MetricsMiddleware + assert MetricsMiddleware is not None + + def test_import_monitoring_middleware_from_metrics(self): + from app.middleware.metrics import MonitoringMiddleware + assert MonitoringMiddleware is not None + + def test_import_prometheus_metrics(self): + from app.middleware.prometheus_metrics import ( + API_REQUESTS_TOTAL, + API_REQUEST_DURATION_SECONDS, + API_REQUESTS_IN_PROGRESS, + AGENT_EXECUTIONS_TOTAL, + AGENT_EXECUTION_DURATION_SECONDS, + AGENT_RUNNING_TASKS, + LLM_REQUESTS_TOTAL, + LLM_REQUEST_DURATION_SECONDS, + LLM_TOKENS_TOTAL, + LLM_COST_ESTIMATED, + BRAND_COUNT, + QUERY_COUNT_TOTAL, + CONTENT_GENERATED_TOTAL, + CITATION_DETECTED_TOTAL, + SERVICE_INFO, + ) + assert API_REQUESTS_TOTAL is not None + assert SERVICE_INFO is not None + + def test_import_agent_hooks(self): + from app.middleware.agent_hooks import agent_execution_context, record_agent_execution + assert agent_execution_context is not None + assert record_agent_execution is not None + + def test_import_llm_metrics(self): + from app.middleware.llm_metrics import get_llm_metrics, LLMMetricsWrapper + assert get_llm_metrics is not None + assert LLMMetricsWrapper is not None + + def test_import_excluded_paths(self): + from app.middleware.metrics import _SKIP_PATHS + assert isinstance(_SKIP_PATHS, set) + assert "/health" in _SKIP_PATHS diff --git a/tests/test_monitoring.py b/backend/tests/test_infrastructure/test_monitoring.py similarity index 99% rename from tests/test_monitoring.py rename to backend/tests/test_infrastructure/test_monitoring.py index 4e34d6e..a8c6639 100644 --- a/tests/test_monitoring.py +++ b/backend/tests/test_infrastructure/test_monitoring.py @@ -12,7 +12,7 @@ import pytest import pytest_asyncio from prometheus_client import REGISTRY -from app.monitoring.metrics import ( +from app.middleware.prometheus_metrics import ( API_REQUESTS_TOTAL, AGENT_EXECUTIONS_TOTAL, LLM_REQUESTS_TOTAL, diff --git a/tests/test_business_flow.py b/backend/tests/test_integration/test_business_flow.py similarity index 81% rename from tests/test_business_flow.py rename to backend/tests/test_integration/test_business_flow.py index a2da714..a782e7e 100644 --- a/tests/test_business_flow.py +++ b/backend/tests/test_integration/test_business_flow.py @@ -7,14 +7,8 @@ import pytest_asyncio from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from app.main import app -from app.api.deps import get_current_user -from app.database import get_db from app.models.citation_record import CitationRecord -from app.models.query import Query from app.models.query_task import QueryTask -from app.models.user import User -from app.services.auth import hash_password # --------------------------------------------------------------------------- @@ -22,14 +16,8 @@ from app.services.auth import hash_password # --------------------------------------------------------------------------- @pytest_asyncio.fixture async def user_a(test_session: AsyncSession): - """Create a real user in the test DB.""" - user = User( - email="user_a@example.com", - password_hash=hash_password("password123"), - name="User A", - plan="free", - max_queries=5, - ) + from tests.fixtures.auth import _make_user + user = _make_user(email="user_a@example.com", plan="free") test_session.add(user) await test_session.commit() await test_session.refresh(user) @@ -38,14 +26,8 @@ async def user_a(test_session: AsyncSession): @pytest_asyncio.fixture async def user_b(test_session: AsyncSession): - """Create another real user in the test DB.""" - user = User( - email="user_b@example.com", - password_hash=hash_password("password456"), - name="User B", - plan="free", - max_queries=5, - ) + from tests.fixtures.auth import _make_user + user = _make_user(email="user_b@example.com", plan="free") test_session.add(user) await test_session.commit() await test_session.refresh(user) @@ -53,24 +35,21 @@ async def user_b(test_session: AsyncSession): @pytest_asyncio.fixture -async def auth_client_a(async_client, user_a): - """Login user_a and return client with auth headers.""" - response = await async_client.post( +async def auth_client_a(plain_client, user_a): + response = await plain_client.post( "/api/v1/auth/login", - json={"email": "user_a@example.com", "password": "password123"}, + json={"email": "user_a@example.com", "password": "Test@123456"}, ) assert response.status_code == 200 token = response.json()["access_token"] - # Return a small helper or just the headers return {"Authorization": f"Bearer {token}"} @pytest_asyncio.fixture -async def auth_client_b(async_client, user_b): - """Login user_b and return auth headers.""" - response = await async_client.post( +async def auth_client_b(plain_client, user_b): + response = await plain_client.post( "/api/v1/auth/login", - json={"email": "user_b@example.com", "password": "password456"}, + json={"email": "user_b@example.com", "password": "Test@123456"}, ) assert response.status_code == 200 token = response.json()["access_token"] @@ -81,9 +60,8 @@ async def auth_client_b(async_client, user_b): # 1. 完整流程:注册 -> 登录 -> 创建查询词 -> 查看列表 # --------------------------------------------------------------------------- @pytest.mark.asyncio -async def test_full_user_flow(async_client, override_get_db): - # Register - reg_resp = await async_client.post( +async def test_full_user_flow(plain_client, override_get_db): + reg_resp = await plain_client.post( "/api/v1/auth/register", json={"email": "flow@example.com", "password": "flowpass", "name": "Flow User"}, ) @@ -92,7 +70,7 @@ async def test_full_user_flow(async_client, override_get_db): assert reg_data["email"] == "flow@example.com" # Login - login_resp = await async_client.post( + login_resp = await plain_client.post( "/api/v1/auth/login", json={"email": "flow@example.com", "password": "flowpass"}, ) @@ -101,7 +79,7 @@ async def test_full_user_flow(async_client, override_get_db): headers = {"Authorization": f"Bearer {token}"} # Create query - create_resp = await async_client.post( + create_resp = await plain_client.post( "/api/v1/queries/", headers=headers, json={ @@ -117,7 +95,7 @@ async def test_full_user_flow(async_client, override_get_db): assert query_data["target_brand"] == "FlowBrand" # List queries - list_resp = await async_client.get("/api/v1/queries/", headers=headers) + list_resp = await plain_client.get("/api/v1/queries/", headers=headers) assert list_resp.status_code == 200 list_data = list_resp.json() assert list_data["total"] == 1 @@ -129,11 +107,11 @@ async def test_full_user_flow(async_client, override_get_db): # 2. 查询词生命周期:创建 -> 更新 -> 暂停 -> 恢复 -> 删除 # --------------------------------------------------------------------------- @pytest.mark.asyncio -async def test_query_lifecycle(async_client, override_get_db, auth_client_a): +async def test_query_lifecycle(plain_client, override_get_db, auth_client_a): headers = auth_client_a # Create - create_resp = await async_client.post( + create_resp = await plain_client.post( "/api/v1/queries/", headers=headers, json={ @@ -147,7 +125,7 @@ async def test_query_lifecycle(async_client, override_get_db, auth_client_a): query_id = create_resp.json()["id"] # Update - update_resp = await async_client.put( + update_resp = await plain_client.put( f"/api/v1/queries/{query_id}", headers=headers, json={"keyword": "updated keyword", "frequency": "daily"}, @@ -157,7 +135,7 @@ async def test_query_lifecycle(async_client, override_get_db, auth_client_a): assert update_resp.json()["frequency"] == "daily" # Pause (update status to paused) - pause_resp = await async_client.put( + pause_resp = await plain_client.put( f"/api/v1/queries/{query_id}", headers=headers, json={"status": "paused"}, @@ -166,7 +144,7 @@ async def test_query_lifecycle(async_client, override_get_db, auth_client_a): assert pause_resp.json()["status"] == "paused" # Resume (update status back to active) - resume_resp = await async_client.put( + resume_resp = await plain_client.put( f"/api/v1/queries/{query_id}", headers=headers, json={"status": "active"}, @@ -175,14 +153,14 @@ async def test_query_lifecycle(async_client, override_get_db, auth_client_a): assert resume_resp.json()["status"] == "active" # Delete - delete_resp = await async_client.delete( + delete_resp = await plain_client.delete( f"/api/v1/queries/{query_id}", headers=headers, ) assert delete_resp.status_code == 204 # Verify deletion - get_resp = await async_client.get(f"/api/v1/queries/{query_id}", headers=headers) + get_resp = await plain_client.get(f"/api/v1/queries/{query_id}", headers=headers) assert get_resp.status_code == 404 @@ -190,12 +168,12 @@ async def test_query_lifecycle(async_client, override_get_db, auth_client_a): # 3. 查询数量限制:free 用户最多 5 个 # --------------------------------------------------------------------------- @pytest.mark.asyncio -async def test_query_limit_free_user(async_client, override_get_db, auth_client_a): +async def test_query_limit_free_user(plain_client, override_get_db, auth_client_a): headers = auth_client_a # Create 5 queries (limit for free plan) for i in range(5): - resp = await async_client.post( + resp = await plain_client.post( "/api/v1/queries/", headers=headers, json={ @@ -208,7 +186,7 @@ async def test_query_limit_free_user(async_client, override_get_db, auth_client_ assert resp.status_code == 201, f"Failed to create query {i}" # 6th query should be rejected - resp = await async_client.post( + resp = await plain_client.post( "/api/v1/queries/", headers=headers, json={ @@ -227,12 +205,12 @@ async def test_query_limit_free_user(async_client, override_get_db, auth_client_ # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_citation_stats_correctness( - async_client, override_get_db, auth_client_a, test_session + plain_client, override_get_db, auth_client_a, test_session ): headers = auth_client_a # Create a query - create_resp = await async_client.post( + create_resp = await plain_client.post( "/api/v1/queries/", headers=headers, json={ @@ -277,7 +255,7 @@ async def test_citation_stats_correctness( await test_session.commit() # Call stats API - stats_resp = await async_client.get("/api/v1/citations/stats", headers=headers) + stats_resp = await plain_client.get("/api/v1/citations/stats", headers=headers) assert stats_resp.status_code == 200 stats = stats_resp.json() @@ -300,12 +278,12 @@ async def test_citation_stats_correctness( # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_export_csv( - async_client, override_get_db, auth_client_a, test_session + plain_client, override_get_db, auth_client_a, test_session ): headers = auth_client_a # Create query - create_resp = await async_client.post( + create_resp = await plain_client.post( "/api/v1/queries/", headers=headers, json={ @@ -331,7 +309,7 @@ async def test_export_csv( await test_session.commit() # Export CSV - export_resp = await async_client.get( + export_resp = await plain_client.get( f"/api/v1/reports/export/csv?query_id={query_id}", headers=headers, ) @@ -347,12 +325,12 @@ async def test_export_csv( # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_run_now_creates_query_task( - async_client, override_get_db, auth_client_a, test_session + plain_client, override_get_db, auth_client_a, test_session ): headers = auth_client_a # Create an active query - create_resp = await async_client.post( + create_resp = await plain_client.post( "/api/v1/queries/", headers=headers, json={ @@ -366,7 +344,7 @@ async def test_run_now_creates_query_task( query_id = create_resp.json()["id"] # Trigger run-now - run_resp = await async_client.post( + run_resp = await plain_client.post( f"/api/v1/queries/{query_id}/run-now", headers=headers, ) @@ -391,13 +369,13 @@ async def test_run_now_creates_query_task( # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_permission_isolation( - async_client, override_get_db, auth_client_a, auth_client_b + plain_client, override_get_db, auth_client_a, auth_client_b ): headers_a = auth_client_a headers_b = auth_client_b # User A creates a query - create_resp = await async_client.post( + create_resp = await plain_client.post( "/api/v1/queries/", headers=headers_a, json={ @@ -411,14 +389,14 @@ async def test_permission_isolation( query_id = create_resp.json()["id"] # User B tries to access User A's query - get_resp = await async_client.get( + get_resp = await plain_client.get( f"/api/v1/queries/{query_id}", headers=headers_b, ) assert get_resp.status_code == 404 # User B tries to update User A's query - put_resp = await async_client.put( + put_resp = await plain_client.put( f"/api/v1/queries/{query_id}", headers=headers_b, json={"keyword": "hacked"}, @@ -426,14 +404,14 @@ async def test_permission_isolation( assert put_resp.status_code == 404 # User B tries to delete User A's query - del_resp = await async_client.delete( + del_resp = await plain_client.delete( f"/api/v1/queries/{query_id}", headers=headers_b, ) assert del_resp.status_code == 404 # User B tries to run-now User A's query - run_resp = await async_client.post( + run_resp = await plain_client.post( f"/api/v1/queries/{query_id}/run-now", headers=headers_b, ) diff --git a/backend/tests/test_integration/test_monetization_flow.py b/backend/tests/test_integration/test_monetization_flow.py new file mode 100644 index 0000000..1644942 --- /dev/null +++ b/backend/tests/test_integration/test_monetization_flow.py @@ -0,0 +1,191 @@ +import uuid +from unittest.mock import AsyncMock, patch + +import pytest + +from tests.fixtures.auth import _make_user, _to_uuid + + +class TestMonetizationFlow: + @pytest.mark.asyncio + async def test_full_monetization_flow(self, async_client, async_session): + user = _make_user(email="monetization@example.com", plan="free") + async_session.add(user) + await async_session.commit() + await async_session.refresh(user) + + from app.api.deps import get_current_user, get_db + from app.main import app + + async def override_get_db(): + yield async_session + + async def override_get_current_user(): + return user + + app.dependency_overrides[get_db] = override_get_db + app.dependency_overrides[get_current_user] = override_get_current_user + + try: + brand_resp = await async_client.post( + "/api/v1/onboarding/brand", + json={"name": "MonoBrand", "industry": "technology"}, + ) + assert brand_resp.status_code == 201 + brand_data = brand_resp.json() + brand_id = brand_data["id"] + assert brand_data["name"] == "MonoBrand" + + onboarding_resp = await async_client.post( + f"/api/v1/onboarding/complete/{brand_id}" + ) + assert onboarding_resp.status_code == 200 + assert onboarding_resp.json()["success"] is True + + with patch("app.api.onboarding.DataCollectorService") as mock_collector_cls: + mock_collector = AsyncMock() + mock_collector_cls.return_value = mock_collector + + from app.services.diagnosis.data_collector import DataCollectionResult + from app.services.diagnosis.geo_diagnosis import GEODiagnosisInput, GEODiagnosisService + + input_data = GEODiagnosisInput( + has_direct_answer=True, + has_brand_definition=True, + has_author_bio=True, + author_credentials_complete=0.8, + has_organization=True, + content_depth_score=0.7, + answer_ownership_rate=0.3, + ) + mock_collector.collect.return_value = DataCollectionResult( + diagnosis_input=input_data, + ) + + service = GEODiagnosisService() + result = service.diagnose(input_data) + + with patch("app.api.onboarding.GEODiagnosisService", return_value=service): + health_resp = await async_client.get( + f"/api/v1/onboarding/health-report/{brand_id}" + ) + + assert health_resp.status_code == 200 + health_data = health_resp.json() + assert health_data["brand_id"] == brand_id + assert "overall_score" in health_data + assert "dimensions" in health_data + assert health_data["is_full_report"] is False + + order_resp = await async_client.post( + "/api/v1/payments/orders", + json={"plan": "pro", "payment_provider": "wechat"}, + ) + assert order_resp.status_code == 201 + order_data = order_resp.json() + order_id = order_data["order_id"] + assert "pay_url" in order_data + assert order_data["amount"] == 599 + assert order_data["status"] == "pending" + + from app.models.payment_order import PaymentOrder as PaymentOrderModel + from app.services.payment.base import PaymentCallback + + from sqlalchemy import select + stmt = select(PaymentOrderModel).where( + PaymentOrderModel.id == uuid.UUID(order_id) + ) + db_result = await async_session.execute(stmt) + order = db_result.scalar_one() + + callback = PaymentCallback( + order_id=str(order.id), + payment_id=f"wx_pay_{order.id}", + amount=599, + status="success", + raw_data={"out_trade_no": str(order.id), "result_code": "SUCCESS"}, + ) + + from app.api.payments import _process_callback + await _process_callback(async_session, callback, "wechat") + await async_session.commit() + + await async_session.refresh(order) + assert order.status == "paid" + assert order.payment_id is not None + + await async_session.refresh(user) + assert user.plan == "pro" + assert user.max_queries == 50 + + with patch("app.api.onboarding.DataCollectorService") as mock_collector_cls: + mock_collector = AsyncMock() + mock_collector_cls.return_value = mock_collector + + from app.services.diagnosis.data_collector import DataCollectionResult + from app.services.diagnosis.geo_diagnosis import GEODiagnosisInput, GEODiagnosisService + + input_data = GEODiagnosisInput( + has_direct_answer=True, + has_brand_definition=True, + has_author_bio=True, + author_credentials_complete=0.8, + has_organization=True, + content_depth_score=0.7, + answer_ownership_rate=0.3, + ) + mock_collector.collect.return_value = DataCollectionResult( + diagnosis_input=input_data, + ) + + service = GEODiagnosisService() + result = service.diagnose(input_data) + + with patch("app.api.onboarding.GEODiagnosisService", return_value=service): + paid_health_resp = await async_client.get( + f"/api/v1/onboarding/health-report/{brand_id}" + ) + + assert paid_health_resp.status_code == 200 + paid_health_data = paid_health_resp.json() + assert paid_health_data["is_full_report"] is True + assert len(paid_health_data["dimensions"]) == 6 + + from app.models.diagnosis_record import DiagnosisRecord + + diag_record = DiagnosisRecord( + brand_id=uuid.UUID(brand_id), + user_id=_to_uuid(user.id), + diagnosis_type="geo", + status="completed", + overall_score=result.overall_score, + result_json=result.to_dict(), + ) + async_session.add(diag_record) + await async_session.commit() + await async_session.refresh(diag_record) + + attr_resp = await async_client.post( + "/api/v1/attribution/start", + json={"brand_id": brand_id}, + ) + assert attr_resp.status_code == 200 + attr_data = attr_resp.json() + assert "id" in attr_data + assert attr_data["brand_id"] == brand_id + assert attr_data["baseline_score"] == result.overall_score + assert attr_data["status"] == "tracking" + + roi_resp = await async_client.get( + f"/api/v1/attribution/roi/{brand_id}" + ) + assert roi_resp.status_code == 200 + roi_data = roi_resp.json() + assert "roi_percentage" in roi_data + assert "value_generated" in roi_data + assert "subscription_cost" in roi_data + assert roi_data["brand_name"] == "MonoBrand" + assert roi_data["current_plan"] == "pro" + + finally: + app.dependency_overrides.clear() diff --git a/backend/tests/test_models/test_citation_record.py b/backend/tests/test_models/test_citation_record.py new file mode 100644 index 0000000..8fff54f --- /dev/null +++ b/backend/tests/test_models/test_citation_record.py @@ -0,0 +1,151 @@ +"""Tests for CitationRecord.from_citation_result() factory method""" +import uuid + +import pytest + +from app.models.citation_record import CitationRecord + + +class TestFromCitationResultBasicFields: + """基本字段映射""" + + def test_basic_fields_mapped(self): + query_id = uuid.uuid4() + result = { + "cited": True, + "position": 2, + "citation_text": "Brand X is recommended", + "competitor_brands": ["Brand Y", "Brand Z"], + "raw_response": "Some raw response", + "confidence": 0.95, + "match_type": "direct", + } + record = CitationRecord.from_citation_result( + query_id=query_id, + platform="kimi", + result=result, + ) + assert record.query_id == query_id + assert record.platform == "kimi" + assert record.cited is True + assert record.citation_position == 2 + assert record.citation_text == "Brand X is recommended" + assert record.competitor_brands == ["Brand Y", "Brand Z"] + assert record.confidence == 0.95 + assert record.match_type == "direct" + + def test_missing_optional_fields_use_defaults(self): + query_id = uuid.uuid4() + result = {"cited": False} + record = CitationRecord.from_citation_result( + query_id=query_id, + platform="wenxin", + result=result, + ) + assert record.cited is False + assert record.citation_position is None + assert record.citation_text is None + assert record.competitor_brands == [] + assert record.raw_response == "" + assert record.confidence is None + assert record.match_type is None + + +class TestFromCitationResultSourceFields: + """引用源分析字段""" + + def test_source_fields_populated(self): + query_id = uuid.uuid4() + result = { + "cited": True, + "data_source": "ai_platform", + "source_urls": ["https://example.com"], + "source_titles": ["Example Title"], + "citation_contexts": ["context snippet"], + "ai_response_text": "AI said something", + } + record = CitationRecord.from_citation_result( + query_id=query_id, + platform="kimi", + result=result, + ) + assert record.data_source == "ai_platform" + assert record.source_urls == ["https://example.com"] + assert record.source_titles == ["Example Title"] + assert record.citation_contexts == ["context snippet"] + assert record.ai_response_text == "AI said something" + + def test_source_fields_default_to_none(self): + query_id = uuid.uuid4() + result = {"cited": False} + record = CitationRecord.from_citation_result( + query_id=query_id, + platform="kimi", + result=result, + ) + assert record.data_source is None + assert record.source_urls is None + assert record.source_titles is None + assert record.citation_contexts is None + assert record.ai_response_text == "" + + +class TestFromCitationResultSanitization: + """raw_response 和 ai_response_text 应该被清理""" + + def test_raw_response_sanitized(self): + query_id = uuid.uuid4() + result = { + "cited": True, + "raw_response": "Hello\x00World\x08!", + } + record = CitationRecord.from_citation_result( + query_id=query_id, + platform="kimi", + result=result, + ) + # NULL 字节和 \x08 应该被移除 + assert "\x00" not in record.raw_response + assert "\x08" not in record.raw_response + assert "Hello" in record.raw_response + assert "World" in record.raw_response + + def test_ai_response_text_sanitized(self): + query_id = uuid.uuid4() + result = { + "cited": True, + "ai_response_text": "Text\x00with\x0bcontrol\x1fchars", + } + record = CitationRecord.from_citation_result( + query_id=query_id, + platform="kimi", + result=result, + ) + assert "\x00" not in record.ai_response_text + assert "\x0b" not in record.ai_response_text + assert "\x1f" not in record.ai_response_text + + def test_newline_preserved_in_sanitization(self): + query_id = uuid.uuid4() + result = { + "cited": True, + "raw_response": "Line1\nLine2\tTab\rReturn", + } + record = CitationRecord.from_citation_result( + query_id=query_id, + platform="kimi", + result=result, + ) + assert "\n" in record.raw_response + assert "\t" in record.raw_response + assert "\r" in record.raw_response + + def test_none_raw_response_becomes_empty_string(self): + query_id = uuid.uuid4() + result = {"cited": True, "raw_response": None} + record = CitationRecord.from_citation_result( + query_id=query_id, + platform="kimi", + result=result, + ) + assert record.raw_response == "" diff --git a/backend/tests/test_platform_adapters.py b/backend/tests/test_platform_adapters.py index 85f2c94..5b6ad67 100644 --- a/backend/tests/test_platform_adapters.py +++ b/backend/tests/test_platform_adapters.py @@ -10,13 +10,10 @@ AI平台适配器测试 - 验证各平台适配器是否正常工作 import pytest from unittest.mock import Mock, patch, AsyncMock, MagicMock -import sys -sys.path.insert(0, '/Users/Chiguyong/Code/Fischer/geo/backend') - -from app.workers.platforms.kimi import KimiAdapter -from app.workers.platforms.wenxin import WenxinAdapter -from app.workers.platforms.doubao import DoubaoAdapter +from app.services.ai_engine.kimi import KimiAdapter +from app.services.ai_engine.wenxin import WenxinAdapter +from app.services.ai_engine.doubao import DoubaoAdapter from app.workers.citation_extractor import ( extract_markdown_links, extract_urls_with_context, diff --git a/backend/tests/test_services/test_brand_citation_llm.py b/backend/tests/test_services/test_brand_citation_llm.py new file mode 100644 index 0000000..ae720d0 --- /dev/null +++ b/backend/tests/test_services/test_brand_citation_llm.py @@ -0,0 +1,221 @@ +"""Tests for BrandCitationLLMService - replacing LLMAdapter""" +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from app.services.llm.brand_citation_service import BrandCitationLLMService, BRAND_CITATION_PROMPT +from app.services.llm.base import LLMError + + +class TestBrandCitationLLMServiceInit: + """服务初始化测试""" + + def test_init_with_default_factory(self): + service = BrandCitationLLMService() + assert service is not None + assert service._provider_name is None + assert service._model is None + + def test_init_with_custom_provider(self): + service = BrandCitationLLMService(provider_name="deepseek", model="deepseek-chat") + assert service._provider_name == "deepseek" + assert service._model == "deepseek-chat" + + +class TestBrandCitationLLMServiceQuery: + """品牌引用查询测试""" + + @pytest.mark.asyncio + async def test_query_brand_citation_cited(self): + service = BrandCitationLLMService() + mock_provider = AsyncMock() + mock_provider.chat.return_value = MagicMock( + content='{"cited": true, "position": 1, "citation_text": "Brand X is great", "sentiment": "positive", "confidence": 0.9}' + ) + with patch.object(service, '_get_provider', return_value=mock_provider): + with patch('app.services.llm.brand_citation_service.settings') as mock_settings: + mock_settings.ENABLE_LLM = True + result = await service.query_brand_citation( + keyword="AI搜索", + brand_name="Brand X", + brand_aliases=["BX"] + ) + assert result.cited is True + assert result.position == 1 + assert result.confidence == 0.9 + + @pytest.mark.asyncio + async def test_query_brand_citation_not_cited(self): + service = BrandCitationLLMService() + mock_provider = AsyncMock() + mock_provider.chat.return_value = MagicMock( + content='{"cited": false, "position": null, "citation_text": null, "sentiment": "neutral", "confidence": 0.8}' + ) + with patch.object(service, '_get_provider', return_value=mock_provider): + with patch('app.services.llm.brand_citation_service.settings') as mock_settings: + mock_settings.ENABLE_LLM = True + result = await service.query_brand_citation( + keyword="AI搜索", + brand_name="Brand X", + brand_aliases=[] + ) + assert result.cited is False + assert result.position is None + assert result.citation_text is None + + @pytest.mark.asyncio + async def test_query_brand_citation_with_markdown_json(self): + """Test that markdown-wrapped JSON is handled via extract_json""" + service = BrandCitationLLMService() + mock_provider = AsyncMock() + mock_provider.chat.return_value = MagicMock( + content='```json\n{"cited": true, "position": 2, "citation_text": "test", "sentiment": "positive", "confidence": 0.7}\n```' + ) + with patch.object(service, '_get_provider', return_value=mock_provider): + with patch('app.services.llm.brand_citation_service.settings') as mock_settings: + mock_settings.ENABLE_LLM = True + result = await service.query_brand_citation( + keyword="test", + brand_name="Brand X", + brand_aliases=[] + ) + assert result.cited is True + assert result.position == 2 + + @pytest.mark.asyncio + async def test_query_brand_citation_sentiment_positive(self): + """测试正面情感""" + service = BrandCitationLLMService() + mock_provider = AsyncMock() + mock_provider.chat.return_value = MagicMock( + content='{"cited": true, "position": 2, "citation_text": "YYY品牌产品质量非常好", "sentiment": "positive", "confidence": 0.92}' + ) + with patch.object(service, '_get_provider', return_value=mock_provider): + with patch('app.services.llm.brand_citation_service.settings') as mock_settings: + mock_settings.ENABLE_LLM = True + result = await service.query_brand_citation( + keyword="AI搜索", + brand_name="YYY", + brand_aliases=[] + ) + assert result.sentiment == "positive" + + @pytest.mark.asyncio + async def test_query_brand_citation_sentiment_negative(self): + """测试负面情感""" + service = BrandCitationLLMService() + mock_provider = AsyncMock() + mock_provider.chat.return_value = MagicMock( + content='{"cited": true, "position": 3, "citation_text": "ZZZ品牌存在质量问题", "sentiment": "negative", "confidence": 0.88}' + ) + with patch.object(service, '_get_provider', return_value=mock_provider): + with patch('app.services.llm.brand_citation_service.settings') as mock_settings: + mock_settings.ENABLE_LLM = True + result = await service.query_brand_citation( + keyword="AI搜索", + brand_name="ZZZ", + brand_aliases=[] + ) + assert result.sentiment == "negative" + + @pytest.mark.asyncio + async def test_query_brand_citation_invalid_sentiment_defaults_neutral(self): + """测试无效sentiment值默认为neutral""" + service = BrandCitationLLMService() + mock_provider = AsyncMock() + mock_provider.chat.return_value = MagicMock( + content='{"cited": true, "position": 1, "citation_text": "test", "sentiment": "unknown", "confidence": 0.5}' + ) + with patch.object(service, '_get_provider', return_value=mock_provider): + with patch('app.services.llm.brand_citation_service.settings') as mock_settings: + mock_settings.ENABLE_LLM = True + result = await service.query_brand_citation( + keyword="test", + brand_name="Test", + brand_aliases=[] + ) + assert result.sentiment == "neutral" + + @pytest.mark.asyncio + async def test_query_brand_citation_confidence_clamped(self): + """测试confidence值被钳制到0.0-1.0范围""" + service = BrandCitationLLMService() + mock_provider = AsyncMock() + mock_provider.chat.return_value = MagicMock( + content='{"cited": true, "position": 1, "citation_text": "test", "sentiment": "neutral", "confidence": 1.5}' + ) + with patch.object(service, '_get_provider', return_value=mock_provider): + with patch('app.services.llm.brand_citation_service.settings') as mock_settings: + mock_settings.ENABLE_LLM = True + result = await service.query_brand_citation( + keyword="test", + brand_name="Test", + brand_aliases=[] + ) + assert result.confidence == 1.0 + + +class TestBrandCitationLLMServicePrompt: + """Prompt构建测试""" + + def test_build_prompt(self): + service = BrandCitationLLMService() + prompt = service._build_prompt("AI搜索", "Brand X", ["BX"]) + assert "Brand X" in prompt + assert "BX" in prompt + assert "AI搜索" in prompt + + def test_build_prompt_no_aliases(self): + service = BrandCitationLLMService() + prompt = service._build_prompt("AI搜索", "Brand X", []) + assert "Brand X" in prompt + assert "AI搜索" in prompt + assert "无" in prompt + + def test_brand_citation_prompt_constant_exists(self): + """验证BRAND_CITATION_PROMPT常量存在且包含关键占位符""" + assert "{keyword}" in BRAND_CITATION_PROMPT + assert "{brand_name}" in BRAND_CITATION_PROMPT + assert "{brand_aliases}" in BRAND_CITATION_PROMPT + + +class TestBrandCitationLLMServiceErrors: + """错误处理测试""" + + @pytest.mark.asyncio + async def test_llm_disabled_raises_error(self): + service = BrandCitationLLMService() + with patch('app.services.llm.brand_citation_service.settings') as mock_settings: + mock_settings.ENABLE_LLM = False + with pytest.raises(LLMError, match="LLM"): + await service.query_brand_citation("test", "Brand X", []) + + @pytest.mark.asyncio + async def test_llm_disabled_error_contains_config_guidance(self): + """LLM禁用时错误信息应包含配置指引""" + service = BrandCitationLLMService() + with patch('app.services.llm.brand_citation_service.settings') as mock_settings: + mock_settings.ENABLE_LLM = False + with pytest.raises(LLMError) as exc_info: + await service.query_brand_citation("test", "Brand X", []) + error_msg = str(exc_info.value) + assert "ENABLE_LLM" in error_msg + + @pytest.mark.asyncio + async def test_missing_required_fields_raises_error(self): + """响应缺少必需字段时抛出异常""" + service = BrandCitationLLMService() + mock_provider = AsyncMock() + mock_provider.chat.return_value = MagicMock( + content='{"invalid": "response"}' + ) + with patch.object(service, '_get_provider', return_value=mock_provider): + with patch('app.services.llm.brand_citation_service.settings') as mock_settings: + mock_settings.ENABLE_LLM = True + with pytest.raises((LLMError, ValueError)): + await service.query_brand_citation("test", "Brand X", []) + + @pytest.mark.asyncio + async def test_no_mock_result_method(self): + """确保不存在_get_mock_result方法(与旧LLMAdapter的区别)""" + service = BrandCitationLLMService() + assert not hasattr(service, "_get_mock_result") diff --git a/backend/tests/test_services/test_citation_pattern.py b/backend/tests/test_services/test_citation_pattern.py index 25e5923..99d7a39 100644 --- a/backend/tests/test_services/test_citation_pattern.py +++ b/backend/tests/test_services/test_citation_pattern.py @@ -3,7 +3,7 @@ from datetime import UTC, datetime import pytest from app.services.ai_engine.base import AIQueryResult, CitationInfo, EngineType -from app.services.citation_pattern import ( +from app.services.citation.citation_pattern import ( AuthoritySignalAnalyzer, CitationFormatAnalyzer, CitationPattern, diff --git a/backend/tests/test_services/test_content_generation.py b/backend/tests/test_services/test_content_generation.py new file mode 100644 index 0000000..73b354b --- /dev/null +++ b/backend/tests/test_services/test_content_generation.py @@ -0,0 +1,420 @@ +"""Tests for ContentGenerationService - extracted from api/content.py + +TDD RED phase: tests for the service that extracts the 3-stage +content generation flow (generate -> de-AI -> GEO optimize) +out of the API handler. +""" +import uuid +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from app.services.llm.base import LLMError, LLMResponse + + +def _make_llm_response(content: str) -> LLMResponse: + """Helper to create LLMResponse objects for mocking.""" + return LLMResponse(content=content, model="test-model") + + +class TestContentGenerationService: + """ContentGenerationService unit tests.""" + + @pytest.mark.asyncio + async def test_generate_content_basic_three_stages(self): + """Test basic 3-stage content generation (generate -> de-AI -> GEO optimize).""" + from app.services.content.content_generation_service import ( + ContentGenerationService, + ) + + service = ContentGenerationService() + + mock_provider = AsyncMock() + mock_provider.chat.side_effect = [ + _make_llm_response("Raw generated content"), + _make_llm_response("De-AIed content"), + _make_llm_response("GEO optimized content"), + ] + + with patch.object(service, "_get_provider", return_value=mock_provider): + result = await service.generate_content( + keyword="AI搜索", + brand_name="Brand X", + platform="wenxin", + content_style="专业严谨", + word_count=2000, + run_deai=True, + run_geo=True, + ) + + assert result is not None + assert result["content"] == "De-AIed content" + assert result["optimized_content"] == "GEO optimized content" + assert result["pipeline_stages"] is not None + assert len(result["pipeline_stages"]) == 3 + assert result["pipeline_stages"][0]["stage"] == "content_generation" + assert result["pipeline_stages"][1]["stage"] == "deai" + assert result["pipeline_stages"][2]["stage"] == "geo_optimization" + + @pytest.mark.asyncio + async def test_generate_content_skip_deai(self): + """Test generation with run_deai=False skips the de-AI stage.""" + from app.services.content.content_generation_service import ( + ContentGenerationService, + ) + + service = ContentGenerationService() + + mock_provider = AsyncMock() + mock_provider.chat.side_effect = [ + _make_llm_response("Raw generated content"), + _make_llm_response("GEO optimized content"), + ] + + with patch.object(service, "_get_provider", return_value=mock_provider): + result = await service.generate_content( + keyword="test", + brand_name="Brand X", + platform="wenxin", + run_deai=False, + run_geo=True, + ) + + assert result["content"] == "Raw generated content" + assert result["optimized_content"] == "GEO optimized content" + assert len(result["pipeline_stages"]) == 2 + stage_names = [s["stage"] for s in result["pipeline_stages"]] + assert "deai" not in stage_names + + @pytest.mark.asyncio + async def test_generate_content_skip_geo(self): + """Test generation with run_geo=False skips the GEO optimization stage.""" + from app.services.content.content_generation_service import ( + ContentGenerationService, + ) + + service = ContentGenerationService() + + mock_provider = AsyncMock() + mock_provider.chat.side_effect = [ + _make_llm_response("Raw generated content"), + _make_llm_response("De-AIed content"), + ] + + with patch.object(service, "_get_provider", return_value=mock_provider): + result = await service.generate_content( + keyword="test", + brand_name="Brand X", + platform="wenxin", + run_deai=True, + run_geo=False, + ) + + assert result["content"] == "De-AIed content" + assert result["optimized_content"] == "De-AIed content" + assert len(result["pipeline_stages"]) == 2 + stage_names = [s["stage"] for s in result["pipeline_stages"]] + assert "geo_optimization" not in stage_names + + @pytest.mark.asyncio + async def test_generate_content_skip_both_stages(self): + """Test generation with both run_deai=False and run_geo=False.""" + from app.services.content.content_generation_service import ( + ContentGenerationService, + ) + + service = ContentGenerationService() + + mock_provider = AsyncMock() + mock_provider.chat.side_effect = [ + _make_llm_response("Raw generated content"), + ] + + with patch.object(service, "_get_provider", return_value=mock_provider): + result = await service.generate_content( + keyword="test", + brand_name="Brand X", + platform="wenxin", + run_deai=False, + run_geo=False, + ) + + assert result["content"] == "Raw generated content" + assert result["optimized_content"] == "Raw generated content" + assert len(result["pipeline_stages"]) == 1 + + @pytest.mark.asyncio + async def test_generate_content_with_knowledge_context(self): + """Test that knowledge_context is passed through to the generation prompt.""" + from app.services.content.content_generation_service import ( + ContentGenerationService, + ) + + service = ContentGenerationService() + + mock_provider = AsyncMock() + mock_provider.chat.side_effect = [ + _make_llm_response("Raw content"), + _make_llm_response("De-AIed"), + _make_llm_response("Optimized"), + ] + + with patch.object(service, "_get_provider", return_value=mock_provider): + result = await service.generate_content( + keyword="test keyword", + brand_name="Brand X", + platform="wenxin", + knowledge_context="Some knowledge base content", + ) + + # The first call to provider.chat should include the knowledge context + first_call_args = mock_provider.chat.call_args_list[0] + messages = first_call_args[0][0] + # Verify knowledge context appears in the rendered messages + all_content = " ".join(str(m) for m in messages) + assert "Some knowledge base content" in all_content + + @pytest.mark.asyncio + async def test_generate_content_saves_to_database(self): + """Test that generated content is saved to database when db and user_id are provided.""" + from app.services.content.content_generation_service import ( + ContentGenerationService, + ) + + service = ContentGenerationService() + + mock_provider = AsyncMock() + mock_provider.chat.side_effect = [ + _make_llm_response("Raw content"), + _make_llm_response("De-AIed"), + _make_llm_response("Optimized"), + ] + mock_db = AsyncMock() + + with patch.object(service, "_get_provider", return_value=mock_provider): + result = await service.generate_content( + keyword="test", + brand_name="Brand X", + platform="wenxin", + db=mock_db, + user_id="test-user-123", + org_id=str(uuid.uuid4()), + ) + + # Verify db.add was called (Content + ContentVersion) + assert mock_db.add.call_count == 2 + mock_db.commit.assert_called_once() + + @pytest.mark.asyncio + async def test_generate_content_no_db_no_save(self): + """Test that when db is not provided, no database operations occur.""" + from app.services.content.content_generation_service import ( + ContentGenerationService, + ) + + service = ContentGenerationService() + + mock_provider = AsyncMock() + mock_provider.chat.side_effect = [ + _make_llm_response("Raw"), + _make_llm_response("De-AIed"), + _make_llm_response("Optimized"), + ] + + with patch.object(service, "_get_provider", return_value=mock_provider): + result = await service.generate_content( + keyword="test", + brand_name="Brand X", + platform="wenxin", + ) + + # No content_id when db is not provided + assert result.get("content_id") is None + + @pytest.mark.asyncio + async def test_generate_content_llm_error_propagates(self): + """Test that LLMError from the provider propagates correctly.""" + from app.services.content.content_generation_service import ( + ContentGenerationService, + ) + + service = ContentGenerationService() + + mock_provider = AsyncMock() + mock_provider.chat.side_effect = LLMError( + "API rate limit", provider="openai" + ) + + with patch.object(service, "_get_provider", return_value=mock_provider): + with pytest.raises(LLMError, match="API rate limit"): + await service.generate_content( + keyword="test", + brand_name="Brand X", + platform="wenxin", + ) + + @pytest.mark.asyncio + async def test_get_knowledge_context_with_ids(self): + """Test _get_knowledge_context retrieves context when knowledge_base_ids are provided.""" + from app.services.content.content_generation_service import ( + ContentGenerationService, + ) + + service = ContentGenerationService() + mock_db = AsyncMock() + + mock_rag = MagicMock() + mock_rag.search = AsyncMock(return_value=[ + {"content": "Knowledge chunk 1", "document_title": "Doc A"}, + {"content": "Knowledge chunk 2", "document_title": "Doc B"}, + ]) + + with patch( + "app.services.knowledge.rag_service.RAGService", + return_value=mock_rag, + ): + context = await service._get_knowledge_context( + db=mock_db, + brand_name="Brand X", + knowledge_base_ids=["kb-1"], + target_keyword="test keyword", + ) + + assert "Knowledge chunk 1" in context + assert "Knowledge chunk 2" in context + assert "Doc A" in context + + @pytest.mark.asyncio + async def test_get_knowledge_context_empty_ids(self): + """Test _get_knowledge_context returns empty string when no IDs provided.""" + from app.services.content.content_generation_service import ( + ContentGenerationService, + ) + + service = ContentGenerationService() + mock_db = AsyncMock() + + context = await service._get_knowledge_context( + db=mock_db, + brand_name="Brand X", + knowledge_base_ids=[], + target_keyword="test", + ) + + assert context == "" + + @pytest.mark.asyncio + async def test_get_knowledge_context_rag_failure_returns_empty(self): + """Test _get_knowledge_context returns empty string when RAG search fails.""" + from app.services.content.content_generation_service import ( + ContentGenerationService, + ) + + service = ContentGenerationService() + mock_db = AsyncMock() + + mock_rag = MagicMock() + mock_rag.search = AsyncMock(side_effect=Exception("RAG service down")) + + with patch( + "app.services.knowledge.rag_service.RAGService", + return_value=mock_rag, + ): + context = await service._get_knowledge_context( + db=mock_db, + brand_name="Brand X", + knowledge_base_ids=["kb-1"], + target_keyword="test", + ) + + assert context == "" + + @pytest.mark.asyncio + async def test_generate_content_passes_correct_prompt_variables(self): + """Test that the service passes correct variables to each prompt template stage.""" + from app.services.content.content_generation_service import ( + ContentGenerationService, + ) + + service = ContentGenerationService() + + mock_provider = AsyncMock() + mock_provider.chat.side_effect = [ + _make_llm_response("Raw content"), + _make_llm_response("De-AIed content"), + _make_llm_response("Optimized content"), + ] + + with patch.object(service, "_get_provider", return_value=mock_provider): + with patch( + "app.services.content.content_generation_service.CONTENT_GENERATOR_TEMPLATE" + ) as mock_gen_template, patch( + "app.services.content.content_generation_service.DEAI_TEMPLATE" + ) as mock_deai_template, patch( + "app.services.content.content_generation_service.GEO_OPTIMIZER_TEMPLATE" + ) as mock_geo_template: + mock_gen_template.render.return_value = [ + {"role": "user", "content": "gen prompt"} + ] + mock_deai_template.render.return_value = [ + {"role": "user", "content": "deai prompt"} + ] + mock_geo_template.render.return_value = [ + {"role": "user", "content": "geo prompt"} + ] + + await service.generate_content( + keyword="AI搜索", + brand_name="Brand X", + platform="微信公众号", + content_style="专业严谨", + word_count=3000, + knowledge_context="Some context", + ) + + # Verify CONTENT_GENERATOR_TEMPLATE.render called with correct variables + gen_call_kwargs = mock_gen_template.render.call_args[0][0] + assert gen_call_kwargs["topic_title"] == "AI搜索" + assert gen_call_kwargs["target_keyword"] == "AI搜索" + assert gen_call_kwargs["target_platform"] == "微信公众号" + assert gen_call_kwargs["content_style"] == "专业严谨" + assert gen_call_kwargs["word_count"] == "3000" + assert gen_call_kwargs["brand_name"] == "Brand X" + assert gen_call_kwargs["knowledge_context"] == "Some context" + + # Verify DEAI_TEMPLATE.render called with the generated content + deai_call_kwargs = mock_deai_template.render.call_args[0][0] + assert deai_call_kwargs["original_content"] == "Raw content" + + # Verify GEO_OPTIMIZER_TEMPLATE.render called with de-AIed content + geo_call_kwargs = mock_geo_template.render.call_args[0][0] + assert geo_call_kwargs["original_content"] == "De-AIed content" + assert geo_call_kwargs["target_keywords"] == "AI搜索" + assert geo_call_kwargs["target_platform"] == "微信公众号" + + @pytest.mark.asyncio + async def test_generate_content_default_parameters(self): + """Test that default parameter values are applied correctly.""" + from app.services.content.content_generation_service import ( + ContentGenerationService, + ) + + service = ContentGenerationService() + + mock_provider = AsyncMock() + mock_provider.chat.side_effect = [ + _make_llm_response("Raw"), + _make_llm_response("De-AIed"), + _make_llm_response("Optimized"), + ] + + with patch.object(service, "_get_provider", return_value=mock_provider): + result = await service.generate_content( + keyword="test", + brand_name="Brand X", + ) + + # Defaults: platform="通用", content_style="专业严谨", word_count=2000 + # run_deai=True, run_geo=True + assert result is not None + assert len(result["pipeline_stages"]) == 3 diff --git a/backend/tests/test_services/test_data_collector.py b/backend/tests/test_services/test_data_collector.py new file mode 100644 index 0000000..d17435c --- /dev/null +++ b/backend/tests/test_services/test_data_collector.py @@ -0,0 +1,381 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +import pytest_asyncio +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.pool import StaticPool + +from app.database import Base +from app.models.brand import Brand +from app.models.citation_record import CitationRecord +from app.models.query import Query +from app.models.user import User +from app.services.auth import hash_password +from app.services.diagnosis.data_collector import DataCollectorService, DataCollectionResult +from app.services.diagnosis.geo_diagnosis import GEODiagnosisInput + + +@pytest_asyncio.fixture +async def async_engine(): + engine = create_async_engine( + "sqlite+aiosqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + yield engine + await engine.dispose() + + +@pytest_asyncio.fixture +async def async_session(async_engine): + maker = async_sessionmaker( + async_engine, + class_=AsyncSession, + expire_on_commit=False, + autoflush=False, + autocommit=False, + ) + async with maker() as session: + yield session + + +class TestWebsiteSignalParsing: + def test_parse_html_with_schema_org(self): + service = DataCollectorService.__new__(DataCollectorService) + service._db = None + + html = """ + + + + +TestBrand是一家专注于技术创新的公司,为企业提供智能化解决方案。
+Hello world
" + signals = service._parse_html_signals(html) + + assert signals["has_organization"] is False + assert signals["has_product"] is False + assert signals["has_qa_headings"] is False + + def test_parse_html_article_schema(self): + service = DataCollectorService.__new__(DataCollectorService) + service._db = None + + html = """ + + + + + + """ + signals = service._parse_html_signals(html) + + assert signals["has_article"] is True + assert signals["has_faq"] is True + assert signals["has_breadcrumb"] is True + + +class TestSignalApplication: + def test_apply_ai_signals(self): + service = DataCollectorService.__new__(DataCollectorService) + service._db = None + + inp = GEODiagnosisInput() + ai_data = { + "aor": 0.4, + "accuracy": 0.85, + "sov": 0.25, + "competitor_gap": 0.15, + "total_responses": 10, + "cited_count": 4, + "accurate_count": 3, + "has_author_bio": True, + "author_credentials_complete": 0.7, + "has_data_sources": True, + } + service._apply_ai_signals(inp, ai_data) + + assert inp.answer_ownership_rate == 0.4 + assert inp.citation_accuracy == 0.85 + assert inp.ai_sov == 0.25 + assert inp.competitor_gap == 0.15 + assert inp.total_ai_responses == 10 + assert inp.brand_mention_count == 4 + assert inp.accurate_citation_count == 3 + assert inp.has_author_bio is True + assert inp.author_credentials_complete == 0.7 + assert inp.has_data_sources is True + + def test_apply_citation_signals(self): + service = DataCollectorService.__new__(DataCollectorService) + service._db = None + + inp = GEODiagnosisInput() + citation_data = { + "aor": 0.5, + "accuracy": 0.9, + "sov": 0.3, + "competitor_gap": 0.1, + "total_responses": 20, + "cited_count": 10, + "accurate_count": 9, + "has_certifications": True, + "certification_count": 3, + "has_expert_endorsements": True, + "endorsement_count": 5, + "content_depth_score": 0.8, + "topic_coverage_ratio": 0.7, + "entity_consistency_score": 0.85, + "cluster_completeness": 0.6, + "total_content_count": 20, + "topic_cluster_count": 8, + } + service._apply_citation_signals(inp, citation_data) + + assert inp.answer_ownership_rate == 0.5 + assert inp.citation_accuracy == 0.9 + assert inp.has_certifications is True + assert inp.certification_count == 3 + assert inp.content_depth_score == 0.8 + + def test_apply_website_signals(self): + service = DataCollectorService.__new__(DataCollectorService) + service._db = None + + inp = GEODiagnosisInput() + website_data = { + "has_direct_answer": True, + "has_qa_headings": True, + "has_structured_data": True, + "has_internal_links": True, + "has_freshness_info": True, + "has_brand_definition": True, + "has_target_audience": True, + "has_unique_value": True, + "has_organization": True, + "has_product": True, + "has_article": True, + "has_faq": False, + "has_howto": False, + "has_breadcrumb": False, + } + service._apply_website_signals(inp, website_data) + + assert inp.has_direct_answer is True + assert inp.has_qa_headings is True + assert inp.has_organization is True + assert inp.has_product is True + assert inp.has_article is True + assert inp.has_faq is False + + def test_signals_merge_max_values(self): + service = DataCollectorService.__new__(DataCollectorService) + service._db = None + + inp = GEODiagnosisInput() + service._apply_ai_signals(inp, {"aor": 0.3, "accuracy": 0.7}) + service._apply_citation_signals(inp, {"aor": 0.5, "accuracy": 0.9}) + + assert inp.answer_ownership_rate == 0.5 + assert inp.citation_accuracy == 0.9 + + +class TestDataCollectorIntegration: + @pytest.mark.asyncio + async def test_collect_with_no_data_sources(self, async_session): + service = DataCollectorService(async_session) + + with patch.object( + service, "_collect_ai_platform_signals", new_callable=AsyncMock + ) as mock_ai, patch.object( + service, "_collect_citation_record_signals", new_callable=AsyncMock + ) as mock_cite, patch.object( + service, "_collect_website_signals", new_callable=AsyncMock + ) as mock_web: + mock_ai.return_value = { + "aor": 0.0, + "accuracy": 0.0, + "sov": 0.0, + "competitor_gap": 0.5, + "total_responses": 0, + "cited_count": 0, + "accurate_count": 0, + "metadata": {}, + } + mock_cite.return_value = { + "total_responses": 0, + "cited_count": 0, + "accurate_count": 0, + "aor": 0.0, + "accuracy": 0.0, + "sov": 0.0, + "competitor_gap": 0.0, + "metadata": {"records_found": 0}, + } + mock_web.return_value = {"metadata": {"skipped": True, "reason": "no_website"}} + + result = await service.collect(brand_name="UnknownBrand") + + assert isinstance(result, DataCollectionResult) + assert isinstance(result.diagnosis_input, GEODiagnosisInput) + assert result.diagnosis_input.has_industry_classification is False + + @pytest.mark.asyncio + async def test_collect_with_industry(self, async_session): + service = DataCollectorService(async_session) + + with patch.object( + service, "_collect_ai_platform_signals", new_callable=AsyncMock + ) as mock_ai, patch.object( + service, "_collect_citation_record_signals", new_callable=AsyncMock + ) as mock_cite, patch.object( + service, "_collect_website_signals", new_callable=AsyncMock + ) as mock_web: + mock_ai.return_value = { + "aor": 0.0, + "accuracy": 0.0, + "sov": 0.0, + "competitor_gap": 0.5, + "total_responses": 0, + "cited_count": 0, + "accurate_count": 0, + "metadata": {}, + } + mock_cite.return_value = { + "total_responses": 0, + "cited_count": 0, + "accurate_count": 0, + "aor": 0.0, + "accuracy": 0.0, + "sov": 0.0, + "competitor_gap": 0.0, + "metadata": {"records_found": 0}, + } + mock_web.return_value = {"metadata": {"skipped": True}} + + result = await service.collect( + brand_name="TestBrand", industry="technology" + ) + + assert result.diagnosis_input.has_industry_classification is True + + @pytest.mark.asyncio + async def test_collect_produces_nonzero_with_website_signals( + self, async_session + ): + service = DataCollectorService(async_session) + + with patch.object( + service, "_collect_ai_platform_signals", new_callable=AsyncMock + ) as mock_ai, patch.object( + service, "_collect_citation_record_signals", new_callable=AsyncMock + ) as mock_cite, patch.object( + service, "_collect_website_signals", new_callable=AsyncMock + ) as mock_web: + mock_ai.return_value = { + "aor": 0.2, + "accuracy": 0.6, + "sov": 0.1, + "competitor_gap": 0.3, + "total_responses": 5, + "cited_count": 1, + "accurate_count": 0, + "has_author_bio": True, + "author_credentials_complete": 0.5, + "has_data_sources": False, + "metadata": {}, + } + mock_cite.return_value = { + "total_responses": 0, + "cited_count": 0, + "accurate_count": 0, + "aor": 0.0, + "accuracy": 0.0, + "sov": 0.0, + "competitor_gap": 0.0, + "metadata": {"records_found": 0}, + } + mock_web.return_value = { + "has_direct_answer": True, + "has_qa_headings": True, + "has_structured_data": True, + "has_internal_links": True, + "has_freshness_info": True, + "has_brand_definition": True, + "has_target_audience": True, + "has_unique_value": True, + "has_organization": True, + "has_product": True, + "has_article": True, + "has_faq": False, + "has_howto": False, + "has_breadcrumb": False, + "metadata": {"url": "https://test.com"}, + } + + result = await service.collect( + brand_name="TestBrand", + website="https://test.com", + industry="technology", + ) + + from app.services.diagnosis.geo_diagnosis import GEODiagnosisService + + geo_service = GEODiagnosisService() + diagnosis = geo_service.diagnose(result.diagnosis_input) + + assert diagnosis.overall_score > 0 + assert len(diagnosis.dimensions) == 6 + + @pytest.mark.asyncio + async def test_collect_handles_channel_failure(self, async_session): + service = DataCollectorService(async_session) + + with patch.object( + service, "_collect_ai_platform_signals", new_callable=AsyncMock + ) as mock_ai, patch.object( + service, "_collect_citation_record_signals", new_callable=AsyncMock + ) as mock_cite, patch.object( + service, "_collect_website_signals", new_callable=AsyncMock + ) as mock_web: + mock_ai.side_effect = Exception("AI platform unavailable") + mock_cite.return_value = { + "total_responses": 0, + "cited_count": 0, + "accurate_count": 0, + "aor": 0.0, + "accuracy": 0.0, + "sov": 0.0, + "competitor_gap": 0.0, + "metadata": {"records_found": 0}, + } + mock_web.return_value = {"metadata": {"skipped": True}} + + result = await service.collect(brand_name="TestBrand") + + assert len(result.errors) >= 1 + assert any("ai_platform" in e for e in result.errors) diff --git a/backend/tests/test_services/test_detection_scheduler.py b/backend/tests/test_services/test_detection_scheduler.py index 6638cac..d56025b 100644 --- a/backend/tests/test_services/test_detection_scheduler.py +++ b/backend/tests/test_services/test_detection_scheduler.py @@ -132,7 +132,7 @@ class TestDetectionTaskModel: class TestDetectionSchedulerService: @pytest.mark.asyncio async def test_create_task(self, async_session, test_brand, test_user): - from app.services.detection_scheduler import DetectionSchedulerService + from app.services.detection.detection_scheduler import DetectionSchedulerService service = DetectionSchedulerService() task_data = { @@ -153,7 +153,7 @@ class TestDetectionSchedulerService: @pytest.mark.asyncio async def test_update_task(self, async_session, test_brand, test_user): from app.models.detection_task import DetectionTask - from app.services.detection_scheduler import DetectionSchedulerService + from app.services.detection.detection_scheduler import DetectionSchedulerService task = DetectionTask( brand_id=test_brand.id, @@ -182,7 +182,7 @@ class TestDetectionSchedulerService: @pytest.mark.asyncio async def test_delete_task(self, async_session, test_brand, test_user): from app.models.detection_task import DetectionTask - from app.services.detection_scheduler import DetectionSchedulerService + from app.services.detection.detection_scheduler import DetectionSchedulerService task = DetectionTask( brand_id=test_brand.id, @@ -207,7 +207,7 @@ class TestDetectionSchedulerService: @pytest.mark.asyncio async def test_get_tasks(self, async_session, test_brand, test_user): from app.models.detection_task import DetectionTask - from app.services.detection_scheduler import DetectionSchedulerService + from app.services.detection.detection_scheduler import DetectionSchedulerService for i in range(3): task = DetectionTask( @@ -228,7 +228,7 @@ class TestDetectionSchedulerService: @pytest.mark.asyncio async def test_trigger_task(self, async_session, test_brand, test_user): from app.models.detection_task import DetectionTask - from app.services.detection_scheduler import DetectionSchedulerService + from app.services.detection.detection_scheduler import DetectionSchedulerService task = DetectionTask( brand_id=test_brand.id, @@ -252,7 +252,7 @@ class TestDetectionSchedulerService: @pytest.mark.asyncio async def test_frequency_validation_hourly(self, async_session, test_brand, test_user): - from app.services.detection_scheduler import DetectionSchedulerService + from app.services.detection.detection_scheduler import DetectionSchedulerService service = DetectionSchedulerService() task_data = { @@ -267,7 +267,7 @@ class TestDetectionSchedulerService: @pytest.mark.asyncio async def test_frequency_validation_daily(self, async_session, test_brand, test_user): - from app.services.detection_scheduler import DetectionSchedulerService + from app.services.detection.detection_scheduler import DetectionSchedulerService service = DetectionSchedulerService() task_data = { @@ -282,7 +282,7 @@ class TestDetectionSchedulerService: @pytest.mark.asyncio async def test_frequency_validation_weekly(self, async_session, test_brand, test_user): - from app.services.detection_scheduler import DetectionSchedulerService + from app.services.detection.detection_scheduler import DetectionSchedulerService service = DetectionSchedulerService() task_data = { @@ -297,7 +297,7 @@ class TestDetectionSchedulerService: @pytest.mark.asyncio async def test_frequency_validation_invalid(self, async_session, test_brand, test_user): - from app.services.detection_scheduler import DetectionSchedulerService + from app.services.detection.detection_scheduler import DetectionSchedulerService service = DetectionSchedulerService() task_data = { @@ -312,7 +312,7 @@ class TestDetectionSchedulerService: @pytest.mark.asyncio async def test_execute_task_flow(self, async_session, test_brand, test_user): from app.models.detection_task import DetectionTask - from app.services.detection_scheduler import DetectionSchedulerService + from app.services.detection.detection_scheduler import DetectionSchedulerService task = DetectionTask( brand_id=test_brand.id, @@ -353,7 +353,7 @@ class TestDetectionSchedulerService: @pytest.mark.asyncio async def test_delete_task_not_found(self, async_session, test_user): - from app.services.detection_scheduler import DetectionSchedulerService + from app.services.detection.detection_scheduler import DetectionSchedulerService service = DetectionSchedulerService() result = await service.delete_task(uuid.uuid4(), test_user.id, async_session) @@ -361,7 +361,7 @@ class TestDetectionSchedulerService: @pytest.mark.asyncio async def test_update_task_not_found(self, async_session, test_user): - from app.services.detection_scheduler import DetectionSchedulerService, TaskNotFoundError + from app.services.detection.detection_scheduler import DetectionSchedulerService, TaskNotFoundError service = DetectionSchedulerService() with pytest.raises(TaskNotFoundError): @@ -369,7 +369,7 @@ class TestDetectionSchedulerService: @pytest.mark.asyncio async def test_get_tasks_empty(self, async_session, test_brand, test_user): - from app.services.detection_scheduler import DetectionSchedulerService + from app.services.detection.detection_scheduler import DetectionSchedulerService service = DetectionSchedulerService() tasks = await service.get_tasks(test_brand.id, test_user.id, async_session) diff --git a/backend/tests/test_services/test_geo_diagnosis.py b/backend/tests/test_services/test_geo_diagnosis.py index a513a4d..a3cb49d 100644 --- a/backend/tests/test_services/test_geo_diagnosis.py +++ b/backend/tests/test_services/test_geo_diagnosis.py @@ -4,7 +4,7 @@ GEO诊断服务单元测试 测试6大维度诊断逻辑、评分算法、推荐生成和服务类 """ import pytest -from app.services.geo_diagnosis import ( +from app.services.diagnosis.geo_diagnosis import ( GEODiagnosisService, GEODiagnosisInput, diagnose_content_extractability, diff --git a/backend/tests/test_services/test_imports.py b/backend/tests/test_services/test_imports.py new file mode 100644 index 0000000..f3ee6c5 --- /dev/null +++ b/backend/tests/test_services/test_imports.py @@ -0,0 +1,174 @@ +""" +Import verification tests for services reorganization. + +These tests verify that all import paths work correctly after +moving service files into subdirectories. +""" +import pytest + + +# ============================================================ +# New import path tests (should work after reorganization) +# ============================================================ + +class TestDiagnosisImports: + """Verify diagnosis subdirectory imports.""" + + def test_geo_diagnosis_import(self): + from app.services.diagnosis.geo_diagnosis import GEODiagnosisService + assert GEODiagnosisService is not None + + def test_geo_diagnosis_dataclasses(self): + from app.services.diagnosis.geo_diagnosis import ( + GEODiagnosisInput, + GEODiagnosisResult, + GEODimensionScore, + GEORecommendation, + DiagnosisItem, + ) + assert GEODiagnosisInput is not None + assert GEODiagnosisResult is not None + + def test_seo_diagnosis_import(self): + from app.services.diagnosis.seo_diagnosis import SEODiagnosisService + assert SEODiagnosisService is not None + + def test_seo_diagnosis_dataclasses(self): + from app.services.diagnosis.seo_diagnosis import ( + SEODiagnosisResult, + SEODimensionScore, + SEORecommendation, + DiagnosisStatus, + DimensionName, + ) + assert SEODiagnosisResult is not None + assert DiagnosisStatus is not None + + +class TestScoringImports: + """Verify scoring subdirectory imports.""" + + def test_scoring_service_import(self): + from app.services.scoring.scoring_service import ScoringService + assert ScoringService is not None + + def test_scoring_dataclasses(self): + from app.services.scoring.scoring_service import ( + ScoringResultV2, + DimensionScore, + ) + assert ScoringResultV2 is not None + assert DimensionScore is not None + + def test_get_health_level_reexport(self): + """scoring_service re-exports get_health_level from app.utils.health""" + from app.services.scoring.scoring_service import get_health_level + assert callable(get_health_level) + + +class TestAlertImports: + """Verify alert subdirectory imports.""" + + def test_alert_engine_import(self): + from app.services.alert.alert_engine import AlertEngine + assert AlertEngine is not None + + def test_alert_context_import(self): + from app.services.alert.alert_engine import AlertContext + assert AlertContext is not None + + +class TestCitationImports: + """Verify citation subdirectory imports.""" + + def test_citation_import(self): + from app.services.citation.citation import get_citations + assert callable(get_citations) + + def test_citation_stats_import(self): + from app.services.citation.citation import get_citation_stats + assert callable(get_citation_stats) + + def test_citation_pattern_import(self): + from app.services.citation.citation_pattern import CitationPatternEngine + assert CitationPatternEngine is not None + + def test_citation_pattern_dataclasses(self): + from app.services.citation.citation_pattern import ( + CitationPattern, + PatternAnalysisReport, + ) + assert CitationPattern is not None + assert PatternAnalysisReport is not None + + +class TestLLMImports: + """Verify llm subdirectory imports (new additions).""" + + def test_smart_router_import(self): + from app.services.llm.smart_router import SmartRouter + assert SmartRouter is not None + + def test_engine_selector_import(self): + from app.services.llm.engine_selector import EngineSelector + assert EngineSelector is not None + + def test_smart_router_dataclasses(self): + from app.services.llm.smart_router import ( + CostTier, + EngineCostProfile, + ENGINE_COST_PROFILES, + ) + assert CostTier is not None + assert ENGINE_COST_PROFILES is not None + + def test_llm_init_includes_new_exports(self): + """Verify llm/__init__.py includes SmartRouter and EngineSelector.""" + from app.services.llm import SmartRouter, EngineSelector + assert SmartRouter is not None + assert EngineSelector is not None + + +class TestDetectionImports: + """Verify detection subdirectory imports.""" + + def test_detection_scheduler_import(self): + from app.services.detection.detection_scheduler import DetectionSchedulerService + assert DetectionSchedulerService is not None + + def test_task_not_found_error_import(self): + from app.services.detection.detection_scheduler import TaskNotFoundError + assert TaskNotFoundError is not None + + +class TestAdvisorImports: + """Verify advisor subdirectory imports.""" + + def test_optimization_advisor_import(self): + from app.services.advisor.optimization_advisor import generate_suggestions + assert callable(generate_suggestions) + + def test_advisor_dataclasses(self): + from app.services.advisor.optimization_advisor import ( + SuggestionItem, + BrandAnalysisContext, + ) + assert SuggestionItem is not None + assert BrandAnalysisContext is not None + + +class TestAnalysisImports: + """Verify analysis subdirectory imports.""" + + def test_sentiment_service_import(self): + from app.services.analysis.sentiment_service import SentimentAnalysisService + assert SentimentAnalysisService is not None + + def test_sentiment_result_import(self): + from app.services.analysis.sentiment_service import SentimentResult + assert SentimentResult is not None + + def test_get_sentiment_service_import(self): + from app.services.analysis.sentiment_service import get_sentiment_service + assert callable(get_sentiment_service) + diff --git a/tests/test_knowledge_enhanced.py b/backend/tests/test_services/test_knowledge_enhanced.py similarity index 98% rename from tests/test_knowledge_enhanced.py rename to backend/tests/test_services/test_knowledge_enhanced.py index 0fc372e..27bd703 100644 --- a/tests/test_knowledge_enhanced.py +++ b/backend/tests/test_services/test_knowledge_enhanced.py @@ -292,14 +292,14 @@ class TestMarkdownParser: """测试解析带标题的Markdown""" parser = MarkdownParser() - content = b"""# 这是一个标题 + content = """# 这是一个标题 这是文档内容。 ## 子标题 更多内容。 -""" +""".encode("utf-8") doc = await parser.parse(content) @@ -312,8 +312,8 @@ class TestMarkdownParser: """测试解析不带标题的Markdown""" parser = MarkdownParser() - content = b"""这是文档内容,没有标题。 -""" + content = """这是文档内容,没有标题。 +""".encode("utf-8") doc = await parser.parse(content) @@ -329,10 +329,10 @@ class TestTextParser: """测试使用第一行作为标题""" parser = TextParser() - content = b"""这是第一行标题 + content = """这是第一行标题 这是第二行内容 这是第三行内容 -""" +""".encode("utf-8") doc = await parser.parse(content) diff --git a/tests/test_knowledge_graph.py b/backend/tests/test_services/test_knowledge_graph.py similarity index 100% rename from tests/test_knowledge_graph.py rename to backend/tests/test_services/test_knowledge_graph.py diff --git a/tests/test_llm_provider.py b/backend/tests/test_services/test_llm_provider.py similarity index 100% rename from tests/test_llm_provider.py rename to backend/tests/test_services/test_llm_provider.py diff --git a/tests/test_platform_rules.py b/backend/tests/test_services/test_platform_rules.py similarity index 100% rename from tests/test_platform_rules.py rename to backend/tests/test_services/test_platform_rules.py diff --git a/tests/test_rag_service.py b/backend/tests/test_services/test_rag_service.py similarity index 100% rename from tests/test_rag_service.py rename to backend/tests/test_services/test_rag_service.py diff --git a/backend/tests/test_services/test_scoring_service.py b/backend/tests/test_services/test_scoring_service.py index 2f2f8d5..09bd90a 100644 --- a/backend/tests/test_services/test_scoring_service.py +++ b/backend/tests/test_services/test_scoring_service.py @@ -1,5 +1,5 @@ import pytest -from app.services.scoring_service import ( +from app.services.scoring.scoring_service import ( ScoringService, calculate_mention_rate_score, calculate_sov_score, diff --git a/backend/tests/test_services/test_seo_diagnosis.py b/backend/tests/test_services/test_seo_diagnosis.py index c8ed525..08aed11 100644 --- a/backend/tests/test_services/test_seo_diagnosis.py +++ b/backend/tests/test_services/test_seo_diagnosis.py @@ -2,7 +2,7 @@ SEO诊断服务单元测试 """ import pytest -from app.services.seo_diagnosis import ( +from app.services.diagnosis.seo_diagnosis import ( SEODiagnosisService, SEODiagnosisResult, SEODimensionScore, diff --git a/backend/tests/test_services/test_smart_router_and_usage.py b/backend/tests/test_services/test_smart_router_and_usage.py index 1f8b994..961fa28 100644 --- a/backend/tests/test_services/test_smart_router_and_usage.py +++ b/backend/tests/test_services/test_smart_router_and_usage.py @@ -1,6 +1,6 @@ import pytest -from app.services.smart_router import ( +from app.services.llm.smart_router import ( ENGINE_COST_PROFILES, CostTier, EngineCostProfile, diff --git a/backend/tests/test_services/test_smart_router_key_integration.py b/backend/tests/test_services/test_smart_router_key_integration.py index 621abd5..05da280 100644 --- a/backend/tests/test_services/test_smart_router_key_integration.py +++ b/backend/tests/test_services/test_smart_router_key_integration.py @@ -1,8 +1,8 @@ import pytest from app.services.api_key_manager import APIKeyManager, KeySource -from app.services.smart_router import ENGINE_COST_PROFILES, CostTier, SmartRouter -from app.services.engine_selector import EngineSelector +from app.services.llm.smart_router import ENGINE_COST_PROFILES, CostTier, SmartRouter +from app.services.llm.engine_selector import EngineSelector class TestSmartRouterWithKeyManager: diff --git a/tests/test_topic_templates.py b/backend/tests/test_services/test_topic_templates.py similarity index 100% rename from tests/test_topic_templates.py rename to backend/tests/test_services/test_topic_templates.py diff --git a/backend/tests/test_utils/__init__.py b/backend/tests/test_utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/tests/test_utils/test_health.py b/backend/tests/test_utils/test_health.py new file mode 100644 index 0000000..de02fa0 --- /dev/null +++ b/backend/tests/test_utils/test_health.py @@ -0,0 +1,66 @@ +"""Tests for app.utils.health — get_health_level / get_health_level_label""" +import pytest + +from app.utils.health import get_health_level, get_health_level_label + + +# ── get_health_level ────────────────────────────────────────── + + +class TestGetHealthLevel: + """根据评分返回健康等级字符串""" + + def test_excellent_lower_bound(self): + assert get_health_level(80) == "excellent" + + def test_excellent_high(self): + assert get_health_level(100) == "excellent" + + def test_good_upper_bound(self): + assert get_health_level(79) == "good" + + def test_good_lower_bound(self): + assert get_health_level(60) == "good" + + def test_pass_upper_bound(self): + assert get_health_level(59) == "pass" + + def test_pass_lower_bound(self): + assert get_health_level(40) == "pass" + + def test_danger_upper_bound(self): + assert get_health_level(39) == "danger" + + def test_danger_zero(self): + assert get_health_level(0) == "danger" + + def test_negative_score(self): + assert get_health_level(-10) == "danger" + + def test_float_score(self): + assert get_health_level(80.5) == "excellent" + + +# ── get_health_level_label ──────────────────────────────────── + + +class TestGetHealthLevelLabel: + """根据等级返回中文标签""" + + @pytest.mark.parametrize( + "level, label", + [ + ("excellent", "优秀"), + ("good", "良好"), + ("pass", "及格"), + ("danger", "危险"), + ], + ) + def test_known_levels(self, level, label): + assert get_health_level_label(level) == label + + def test_unknown_level(self): + assert get_health_level_label("unknown") == "未知" + + def test_empty_string(self): + assert get_health_level_label("") == "未知" diff --git a/backend/tests/test_utils/test_json_extractor.py b/backend/tests/test_utils/test_json_extractor.py new file mode 100644 index 0000000..3fec83d --- /dev/null +++ b/backend/tests/test_utils/test_json_extractor.py @@ -0,0 +1,88 @@ +"""Tests for app.utils.json_extractor — extract_json""" +import json + +import pytest + +from app.utils.json_extractor import extract_json + + +class TestExtractJsonPlainObject: + """纯 JSON 对象""" + + def test_simple_object(self): + text = '{"key": "value"}' + result = extract_json(text) + assert json.loads(result) == {"key": "value"} + + def test_nested_object(self): + text = '{"a": {"b": [1, 2]}}' + result = extract_json(text) + assert json.loads(result) == {"a": {"b": [1, 2]}} + + +class TestExtractJsonMarkdownCodeBlock: + """JSON 包裹在 markdown 代码块中""" + + def test_json_code_block(self): + text = '```json\n{"key": "value"}\n```' + result = extract_json(text) + assert json.loads(result) == {"key": "value"} + + def test_plain_code_block(self): + text = '```\n{"key": "value"}\n```' + result = extract_json(text) + assert json.loads(result) == {"key": "value"} + + def test_code_block_with_surrounding_text(self): + text = 'Here is the result:\n```json\n{"key": "value"}\n```\nDone.' + result = extract_json(text) + assert json.loads(result) == {"key": "value"} + + +class TestExtractJsonArray: + """JSON 数组""" + + def test_simple_array(self): + text = "[1, 2, 3]" + result = extract_json(text) + assert json.loads(result) == [1, 2, 3] + + def test_array_in_code_block(self): + text = "```json\n[1, 2, 3]\n```" + result = extract_json(text) + assert json.loads(result) == [1, 2, 3] + + def test_array_with_surrounding_text(self): + text = "Result: [1, 2, 3] done" + result = extract_json(text) + assert json.loads(result) == [1, 2, 3] + + +class TestExtractJsonWithSurroundingText: + """JSON 被周围文本包裹""" + + def test_object_with_surrounding_text(self): + text = 'Here is the result: {"key": "value"} done' + result = extract_json(text) + assert json.loads(result) == {"key": "value"} + + def test_nested_with_surrounding_text(self): + text = 'Output: {"a": {"b": [1, 2]}} end' + result = extract_json(text) + assert json.loads(result) == {"a": {"b": [1, 2]}} + + +class TestExtractJsonInvalidInput: + """无效输入应抛出 ValueError""" + + def test_no_json_at_all(self): + with pytest.raises(ValueError, match="无法从响应中提取JSON"): + extract_json("This is just plain text with no JSON") + + def test_empty_string(self): + with pytest.raises(ValueError, match="无法从响应中提取JSON"): + extract_json("") + + def test_unclosed_brace(self): + with pytest.raises(ValueError, match="无法从响应中提取JSON"): + extract_json('{"key": "value"') diff --git a/backend/tests/test_workers/test_llm_adapter.py b/backend/tests/test_workers/test_llm_adapter.py index 4c0deab..4dee5af 100644 --- a/backend/tests/test_workers/test_llm_adapter.py +++ b/backend/tests/test_workers/test_llm_adapter.py @@ -1,38 +1,38 @@ +"""品牌引用LLM服务测试 - 迁移自 test_llm_adapter.py + +原 LLMAdapter (System 3) 已被 BrandCitationLLMService (System 1) 替代。 +本文件保留了所有原有的测试场景,使用新的服务接口。 +""" import pytest from unittest.mock import AsyncMock, patch, MagicMock -from app.workers.llm_adapter import LLMAdapter, LLMAdapterError, BRAND_CITATION_PROMPT + +from app.services.llm.brand_citation_service import BrandCitationLLMService, BRAND_CITATION_PROMPT +from app.services.llm.base import LLMError -class TestLLMAdapter: - """LLM适配器测试""" +class TestBrandCitationLLMService: + """品牌引用LLM服务测试(原 LLMAdapter 测试迁移)""" @pytest.fixture - def llm_adapter(self): - """创建LLM适配器实例""" - return LLMAdapter() + def service(self): + """创建BrandCitationLLMService实例""" + return BrandCitationLLMService() @pytest.mark.asyncio - async def test_llm_adapter_cited_brand(self, llm_adapter): + async def test_cited_brand(self, service): """测试检测到品牌引用""" - mock_response = { - "cited": True, - "position": 1, - "citation_text": "XXX是一款非常优秀的品牌产品", - "sentiment": "positive", - "confidence": 0.95 - } - - with patch("app.workers.llm_adapter.settings") as mock_settings: - mock_settings.ENABLE_LLM = True - with patch.object(llm_adapter, '_call_deepseek', new_callable=AsyncMock) as mock_call: - mock_call.return_value = mock_response - - result = await llm_adapter.query_brand_citation( + mock_provider = AsyncMock() + mock_provider.chat.return_value = MagicMock( + content='{"cited": true, "position": 1, "citation_text": "XXX是一款非常优秀的品牌产品", "sentiment": "positive", "confidence": 0.95}' + ) + with patch.object(service, '_get_provider', return_value=mock_provider): + with patch('app.services.llm.brand_citation_service.settings') as mock_settings: + mock_settings.ENABLE_LLM = True + result = await service.query_brand_citation( keyword="AI搜索", brand_name="XXX", brand_aliases=["品牌别名1", "品牌别名2"] ) - assert result.cited is True assert result.position == 1 assert result.citation_text == "XXX是一款非常优秀的品牌产品" @@ -40,135 +40,85 @@ class TestLLMAdapter: assert result.confidence == 0.95 @pytest.mark.asyncio - async def test_llm_adapter_not_cited(self, llm_adapter): + async def test_not_cited(self, service): """测试未检测到品牌引用""" - mock_response = { - "cited": False, - "position": None, - "citation_text": None, - "sentiment": "neutral", - "confidence": 0.90 - } - - with patch("app.workers.llm_adapter.settings") as mock_settings: - mock_settings.ENABLE_LLM = True - with patch.object(llm_adapter, '_call_deepseek', new_callable=AsyncMock) as mock_call: - mock_call.return_value = mock_response - - result = await llm_adapter.query_brand_citation( + mock_provider = AsyncMock() + mock_provider.chat.return_value = MagicMock( + content='{"cited": false, "position": null, "citation_text": null, "sentiment": "neutral", "confidence": 0.90}' + ) + with patch.object(service, '_get_provider', return_value=mock_provider): + with patch('app.services.llm.brand_citation_service.settings') as mock_settings: + mock_settings.ENABLE_LLM = True + result = await service.query_brand_citation( keyword="AI搜索", brand_name="YYY", brand_aliases=[] ) - assert result.cited is False assert result.position is None assert result.citation_text is None assert result.sentiment == "neutral" @pytest.mark.asyncio - async def test_llm_adapter_sentiment_positive(self, llm_adapter): + async def test_sentiment_positive(self, service): """测试正面情感""" - mock_response = { - "cited": True, - "position": 2, - "citation_text": "YYY品牌产品质量非常好,用户口碑极佳", - "sentiment": "positive", - "confidence": 0.92 - } - - with patch("app.workers.llm_adapter.settings") as mock_settings: - mock_settings.ENABLE_LLM = True - with patch.object(llm_adapter, '_call_deepseek', new_callable=AsyncMock) as mock_call: - mock_call.return_value = mock_response - - result = await llm_adapter.query_brand_citation( + mock_provider = AsyncMock() + mock_provider.chat.return_value = MagicMock( + content='{"cited": true, "position": 2, "citation_text": "YYY品牌产品质量非常好,用户口碑极佳", "sentiment": "positive", "confidence": 0.92}' + ) + with patch.object(service, '_get_provider', return_value=mock_provider): + with patch('app.services.llm.brand_citation_service.settings') as mock_settings: + mock_settings.ENABLE_LLM = True + result = await service.query_brand_citation( keyword="AI搜索", brand_name="YYY", brand_aliases=[] ) - assert result.sentiment == "positive" @pytest.mark.asyncio - async def test_llm_adapter_sentiment_negative(self, llm_adapter): + async def test_sentiment_negative(self, service): """测试负面情感""" - mock_response = { - "cited": True, - "position": 3, - "citation_text": "ZZZ品牌存在质量问题,遭到用户投诉", - "sentiment": "negative", - "confidence": 0.88 - } - - with patch("app.workers.llm_adapter.settings") as mock_settings: - mock_settings.ENABLE_LLM = True - with patch.object(llm_adapter, '_call_deepseek', new_callable=AsyncMock) as mock_call: - mock_call.return_value = mock_response - - result = await llm_adapter.query_brand_citation( + mock_provider = AsyncMock() + mock_provider.chat.return_value = MagicMock( + content='{"cited": true, "position": 3, "citation_text": "ZZZ品牌存在质量问题,遭到用户投诉", "sentiment": "negative", "confidence": 0.88}' + ) + with patch.object(service, '_get_provider', return_value=mock_provider): + with patch('app.services.llm.brand_citation_service.settings') as mock_settings: + mock_settings.ENABLE_LLM = True + result = await service.query_brand_citation( keyword="AI搜索", brand_name="ZZZ", brand_aliases=[] ) - assert result.sentiment == "negative" @pytest.mark.asyncio - async def test_llm_adapter_api_error_retry(self, llm_adapter): - """测试API错误重试""" - mock_success_response = { - "cited": True, - "position": 1, - "citation_text": "测试文本", - "sentiment": "neutral", - "confidence": 0.90 - } - - with patch("app.workers.llm_adapter.settings") as mock_settings: - mock_settings.ENABLE_LLM = True - with patch.object(llm_adapter, '_call_deepseek', new_callable=AsyncMock) as mock_call: - mock_call.side_effect = [ - Exception("API调用失败"), - Exception("API调用失败"), - mock_success_response - ] - - result = await llm_adapter.query_brand_citation( - keyword="AI搜索", - brand_name="测试品牌", - brand_aliases=[] - ) - - assert result.cited is True - assert mock_call.call_count == 3 - - @pytest.mark.asyncio - async def test_llm_adapter_parse_error(self, llm_adapter): + async def test_parse_error(self, service): """测试响应解析错误""" - with patch("app.workers.llm_adapter.settings") as mock_settings: - mock_settings.ENABLE_LLM = True - with patch.object(llm_adapter, '_call_deepseek', new_callable=AsyncMock) as mock_call: - mock_call.return_value = {"invalid": "response"} - - with pytest.raises(LLMAdapterError) as exc_info: - await llm_adapter.query_brand_citation( + mock_provider = AsyncMock() + mock_provider.chat.return_value = MagicMock( + content='{"invalid": "response"}' + ) + with patch.object(service, '_get_provider', return_value=mock_provider): + with patch('app.services.llm.brand_citation_service.settings') as mock_settings: + mock_settings.ENABLE_LLM = True + with pytest.raises(LLMError) as exc_info: + await service.query_brand_citation( keyword="AI搜索", brand_name="测试品牌", brand_aliases=[] ) - error_msg = str(exc_info.value) assert "响应缺少必需字段" in error_msg or "解析响应失败" in error_msg - def test_build_prompt(self, llm_adapter): + def test_build_prompt(self, service): """测试Prompt构建""" - prompt = llm_adapter._build_prompt( + prompt = service._build_prompt( keyword="AI搜索", brand_name="测试品牌", brand_aliases=["别名1", "别名2"] ) - assert "AI搜索" in prompt assert "测试品牌" in prompt assert "别名1" in prompt diff --git a/backend/tests/test_workers/test_llm_adapter_no_mock.py b/backend/tests/test_workers/test_llm_adapter_no_mock.py index b596f56..82cdd65 100644 --- a/backend/tests/test_workers/test_llm_adapter_no_mock.py +++ b/backend/tests/test_workers/test_llm_adapter_no_mock.py @@ -1,25 +1,30 @@ +"""验证BrandCitationLLMService不再返回Mock数据,而是抛出明确错误 + +迁移自 test_llm_adapter_no_mock.py,原测试验证 LLMAdapter 不返回 Mock 数据。 +现在验证 BrandCitationLLMService (System 1) 的相同行为。 +""" import pytest -from unittest.mock import AsyncMock, patch, PropertyMock +from unittest.mock import AsyncMock, patch, MagicMock -from app.workers.llm_adapter import LLMAdapter, LLMAdapterError +from app.services.llm.brand_citation_service import BrandCitationLLMService +from app.services.llm.base import LLMError -class TestLLMAdapterNoMock: - """验证LLMAdapter不再返回Mock数据,而是抛出明确错误""" +class TestBrandCitationLLMServiceNoMock: + """验证BrandCitationLLMService不返回Mock数据,而是抛出明确错误""" @pytest.fixture - def adapter(self): - return LLMAdapter() + def service(self): + return BrandCitationLLMService() @pytest.mark.asyncio - async def test_enable_llm_false_raises_error(self, adapter): - """ENABLE_LLM=False时必须抛出LLMAdapterError,而非返回Mock数据""" - with patch("app.workers.llm_adapter.settings") as mock_settings: + async def test_enable_llm_false_raises_error(self, service): + """ENABLE_LLM=False时必须抛出LLMError,而非返回Mock数据""" + with patch('app.services.llm.brand_citation_service.settings') as mock_settings: mock_settings.ENABLE_LLM = False - mock_settings.DEEPSEEK_API_KEY = "test-key" - with pytest.raises(LLMAdapterError) as exc_info: - await adapter.query_brand_citation( + with pytest.raises(LLMError) as exc_info: + await service.query_brand_citation( keyword="AI搜索", brand_name="测试品牌", brand_aliases=["别名1"], @@ -30,45 +35,18 @@ class TestLLMAdapterNoMock: assert "未启用" in error_msg @pytest.mark.asyncio - async def test_enable_llm_true_no_api_key_raises_error(self, adapter): - """ENABLE_LLM=True但无API Key时必须抛出LLMAdapterError""" - adapter.api_key = None + async def test_enable_llm_true_calls_provider(self, service): + """ENABLE_LLM=True时正常调用Provider""" + mock_provider = AsyncMock() + mock_provider.chat.return_value = MagicMock( + content='{"cited": true, "position": 1, "citation_text": "测试引用", "sentiment": "positive", "confidence": 0.95}' + ) - with patch("app.workers.llm_adapter.settings") as mock_settings: - mock_settings.ENABLE_LLM = True + with patch.object(service, '_get_provider', return_value=mock_provider): + with patch('app.services.llm.brand_citation_service.settings') as mock_settings: + mock_settings.ENABLE_LLM = True - with pytest.raises(LLMAdapterError) as exc_info: - await adapter.query_brand_citation( - keyword="AI搜索", - brand_name="测试品牌", - brand_aliases=[], - ) - - error_msg = str(exc_info.value) - assert "API Key" in error_msg or "DEEPSEEK_API_KEY" in error_msg - - @pytest.mark.asyncio - async def test_enable_llm_true_with_key_calls_api(self, adapter): - """ENABLE_LLM=True且有Key时正常调用API""" - mock_response = { - "cited": True, - "position": 1, - "citation_text": "测试引用", - "sentiment": "positive", - "confidence": 0.95, - } - - with patch("app.workers.llm_adapter.settings") as mock_settings: - mock_settings.ENABLE_LLM = True - mock_settings.OPENAI_API_KEY = None - mock_settings.DEEPSEEK_API_KEY = "sk-test-key" - - with patch.object( - adapter, "_call_deepseek", new_callable=AsyncMock - ) as mock_call: - mock_call.return_value = mock_response - - result = await adapter.query_brand_citation( + result = await service.query_brand_citation( keyword="AI搜索", brand_name="测试品牌", brand_aliases=[], @@ -79,43 +57,23 @@ class TestLLMAdapterNoMock: assert result.sentiment == "positive" def test_get_mock_result_method_removed(self): - """_get_mock_result方法必须已被删除""" - assert not hasattr(LLMAdapter, "_get_mock_result"), ( + """_get_mock_result方法必须已被删除(旧LLMAdapter遗留)""" + assert not hasattr(BrandCitationLLMService, "_get_mock_result"), ( "_get_mock_result方法仍然存在,必须删除" ) @pytest.mark.asyncio - async def test_error_message_user_friendly(self, adapter): + async def test_error_message_user_friendly(self, service): """错误信息必须对用户友好,包含配置指引""" - with patch("app.workers.llm_adapter.settings") as mock_settings: + with patch('app.services.llm.brand_citation_service.settings') as mock_settings: mock_settings.ENABLE_LLM = False - mock_settings.DEEPSEEK_API_KEY = "test-key" - with pytest.raises(LLMAdapterError) as exc_info: - await adapter.query_brand_citation( + with pytest.raises(LLMError) as exc_info: + await service.query_brand_citation( keyword="AI搜索", brand_name="测试品牌", brand_aliases=[], ) error_msg = str(exc_info.value) - assert "ENABLE_LLM=True" in error_msg - assert "DEEPSEEK_API_KEY" in error_msg - - @pytest.mark.asyncio - async def test_no_api_key_error_message_user_friendly(self, adapter): - """无API Key时错误信息必须包含配置指引""" - adapter.api_key = None - - with patch("app.workers.llm_adapter.settings") as mock_settings: - mock_settings.ENABLE_LLM = True - - with pytest.raises(LLMAdapterError) as exc_info: - await adapter.query_brand_citation( - keyword="AI搜索", - brand_name="测试品牌", - brand_aliases=[], - ) - - error_msg = str(exc_info.value) - assert "DEEPSEEK_API_KEY" in error_msg + assert "ENABLE_LLM" in error_msg diff --git a/docs/01-项目概览/architecture.md b/docs/01-项目概览/architecture.md index 25d0dc0..379a31c 100644 --- a/docs/01-项目概览/architecture.md +++ b/docs/01-项目概览/architecture.md @@ -70,6 +70,9 @@ │ ├───────────┬───────────┬───────────┬────────────┤ │ │ │ Citation │ Content │ DeAI │ GEO │ │ │ │ Detector │ Generator │ Agent │ Optimizer │ │ +│ ├───────────┼───────────┼───────────┼────────────┤ │ +│ │ Monitor │ Schema │Competitor│ Trend │ │ +│ │ Agent │ Advisor │ Analyzer │ Agent │ │ │ └───────────┴───────────┴───────────┴────────────┘ │ │ │ │ ┌─────────────────────────────────────────────────┐ │ @@ -89,11 +92,47 @@ | ContentGenerator | 内容生成 | 主题、规则库、品牌素材 | GEO优化内容 | | DeAIAgent | 去AI化 | AI生成内容 | 自然化内容 | | GEOOptimizer | GEO优化 | 原始内容、关键词策略 | 优化后内容 | +| MonitorAgent | 效果追踪 | 品牌ID、监测配置 | 监测记录、趋势数据 | +| SchemaAdvisor | Schema建议 | 品牌内容、行业类型 | Schema标记建议 | +| CompetitorAnalyzer | 竞品分析 | 品牌、竞品列表 | 竞品洞察报告 | +| TrendAgent | 趋势洞察 | 行业、关键词 | 趋势洞察数据 | **注意:** 当前项目中的`SEOOptimizer`实际执行的是GEO优化(内容结构化、实体优化),而非传统SEO优化(技术SEO)。建议在后续版本中明确区分: - **SEOOptimizer** → 技术SEO优化(网站技术层面) - **GEOOptimizer** → 内容实体优化(AI引用层面) +## GEO 业务闭环 + +``` +诊断 → 策略 → 方案 → 内容生成 → 效果追踪 + ↑ │ + └────────────────────────────────────┘ +``` + +**闭环流程说明:** + +| 阶段 | 描述 | 核心Agent/服务 | +|------|------|----------------| +| 诊断 | 品牌GEO现状评估,识别优化机会 | CitationDetector, GEOOptimizer | +| 策略 | 基于诊断结果制定GEO优化策略 | StrategyService, CompetitorAnalyzer | +| 方案 | 生成可执行的GEO优化方案 | StrategyService, SchemaAdvisor | +| 内容生成 | 按方案生成GEO优化内容 | ContentGenerator, DeAIAgent | +| 效果追踪 | 持续监测优化效果,驱动下一轮迭代 | MonitorAgent, TrendAgent | + +## 品牌评分体系 + +品牌评分V2采用5维度评分模型,全面衡量品牌在AI搜索引擎中的表现: + +| 维度 | 字段名 | 描述 | 评分范围 | +|------|--------|------|----------| +| 提及率 | mention_rate | 品牌在AI响应中的被提及频率 | 0-100 | +| 推荐排名 | recommendation_rank | 品牌在AI推荐中的排名位置 | 0-100 | +| 情感评分 | sentiment_score | AI对品牌的情感倾向评分 | 0-100 | +| 引用质量 | citation_quality | 品牌被引用内容的质量与权威性 | 0-100 | +| 竞争位置 | competitive_position | 品牌相对竞品的综合竞争位势 | 0-100 | + +**综合评分计算:** 综合品牌评分 = 加权平均(mention_rate, recommendation_rank, sentiment_score, citation_quality, competitive_position) + ## 内容生成Pipeline ``` @@ -118,4 +157,4 @@ API请求 → Prometheus指标 → Grafana可视化 ## 数据库设计 -核心表:users, organizations, brands, competitors, queries, citations, alerts, contents, knowledge_bases, knowledge_entities, knowledge_relations +核心表:users, organizations, brands, competitors, queries, citations, alerts, contents, knowledge_bases, knowledge_entities, knowledge_relations, geo_plans, geo_plan_actions, monitoring_records, content_baselines, schema_suggestions, competitor_insights, trend_insights diff --git a/docs/01-项目概览/changelog.md b/docs/01-项目概览/changelog.md index 4169335..b5adb21 100644 --- a/docs/01-项目概览/changelog.md +++ b/docs/01-项目概览/changelog.md @@ -1,5 +1,39 @@ # 变更日志 +## v2.1.0 (2026-05-31) + +### 新增功能 +- 4个新Agent: MonitorAgent, SchemaAdvisor, CompetitorAnalyzer, TrendAgent +- GEO方案自动生成系统 +- 竞品引导提示组件 +- 品牌评分V2体系 (mention_rate, recommendation_rank, sentiment_score, citation_quality, competitive_position) + +### 架构优化 +- BrandScoringDataService共享服务提取 +- N+1查询消除 +- Repository模式数据访问层完善 + +### Bug修复 +- 前端API参数不匹配修复 +- agentsApi死代码清理 +- CORS配置修正 +- 前端环境变量端口修正 + +### 新增API +- /api/v1/strategy +- /api/v1/monitoring +- /api/v1/competitor +- /api/v1/schema +- /api/v1/trends + +### 新增前端API客户端 +- monitoring.ts +- competitor-analysis.ts +- schema-advisor.ts +- trends.ts + +--- + ## v2.0.0 (当前版本) ### 新增功能 diff --git a/docs/01-项目概览/tech-stack.md b/docs/01-项目概览/tech-stack.md index d5cceda..c2158a9 100644 --- a/docs/01-项目概览/tech-stack.md +++ b/docs/01-项目概览/tech-stack.md @@ -52,13 +52,59 @@ geo/ │ │ ├── agent_framework/ # Agent框架 │ │ ├── models/ # 数据模型 │ │ ├── schemas/ # Pydantic模型 -│ │ ├── services/ # 业务逻辑 -│ │ ├── workers/ # 异步任务 -│ │ └── monitoring/ # 监控模块 +│ │ ├── services/ # 业务逻辑 +│ │ │ ├── ai_engine/ # AI引擎服务 +│ │ │ ├── llm/ # LLM服务 +│ │ │ ├── knowledge/ # 知识库服务 +│ │ │ ├── content/ # 内容服务 +│ │ │ ├── distribution/ # 分发服务 +│ │ │ ├── diagnosis/ # 诊断服务 +│ │ │ ├── citation/ # 引用服务 +│ │ │ ├── competitor/ # 竞品服务 +│ │ │ ├── monitoring/ # 监测服务 +│ │ │ ├── trend/ # 趋势服务 +│ │ │ ├── schema/ # Schema服务 +│ │ │ ├── scoring/ # 评分服务 +│ │ │ ├── strategy/ # 策略服务 +│ │ │ ├── alert/ # 告警服务 +│ │ │ ├── advisor/ # 顾问服务 +│ │ │ ├── analysis/ # 分析服务 +│ │ │ ├── detection/ # 检测服务 +│ │ │ └── analytics/ # 分析统计服务 +│ │ ├── repositories/ # 数据访问层 │ └── requirements.txt ├── frontend/ # Next.js 前端 │ ├── app/ # 页面 │ ├── components/ # 组件 -│ └── lib/ # 工具函数 +│ └── lib/ +│ ├── api/ # API客户端 +│ │ ├── analytics.ts +│ │ ├── api-keys.ts +│ │ ├── auth.ts +│ │ ├── brands.ts +│ │ ├── citations.ts +│ │ ├── clients.ts +│ │ ├── competitor-analysis.ts +│ │ ├── contents.ts +│ │ ├── dashboard.ts +│ │ ├── detection.ts +│ │ ├── diagnosis.ts +│ │ ├── distribution.ts +│ │ ├── health.ts +│ │ ├── image.ts +│ │ ├── knowledge.ts +│ │ ├── lifecycle.ts +│ │ ├── monitoring.ts +│ │ ├── onboarding.ts +│ │ ├── organization.ts +│ │ ├── platform-rules.ts +│ │ ├── queries.ts +│ │ ├── reports.ts +│ │ ├── schema-advisor.ts +│ │ ├── strategy.ts +│ │ ├── suggestions.ts +│ │ ├── trends.ts +│ │ └── usage.ts +│ └── ... # 其他工具函数 └── docs/ # 文档 ``` diff --git a/docs/02-模块说明/agent-framework.md b/docs/02-模块说明/agent-framework.md index 07fc8c3..f258e90 100644 --- a/docs/02-模块说明/agent-framework.md +++ b/docs/02-模块说明/agent-framework.md @@ -98,6 +98,65 @@ GEO平台的AI Agent框架是系统的核心智能层,采用模块化、可插 | 输入 | 原始内容、关键词策略 | | 输出 | 优化后的内容 | +### MonitorAgent (效果追踪Agent) + +| 属性 | 说明 | +|------|------| +| 文件 | `backend/app/agent_framework/agents/monitor_agent.py` | +| 职责 | 定期检测品牌在AI平台的引用变化,生成变化报告 | +| 输入 | 品牌名称、AI平台列表、检测周期 | +| 输出 | 引用变化报告(新增引用、丢失引用、变化趋势) | + +**报告类型**: +- `daily_report` - 每日引用变化快报 +- `weekly_report` - 周度引用趋势报告 +- `alert` - 异常变化告警(引用大幅下降等) + +### SchemaAdvisorAgent (Schema建议Agent) + +| 属性 | 说明 | +|------|------| +| 文件 | `backend/app/agent_framework/agents/schema_advisor_agent.py` | +| 职责 | 基于品牌诊断数据生成JSON-LD结构化数据建议 | +| 输入 | 品牌诊断数据、当前Schema配置、行业类型 | +| 输出 | JSON-LD结构化数据建议方案 | + +**建议类型**: +- `schema_add` - 新增Schema建议 +- `schema_modify` - 现有Schema修改建议 +- `schema_validate` - Schema合规性验证结果 + +### CompetitorAnalyzerAgent (竞品分析Agent) + +| 属性 | 说明 | +|------|------| +| 文件 | `backend/app/agent_framework/agents/competitor_analyzer_agent.py` | +| 职责 | 5维度分析竞品(引用量、引用质量、引用场景、内容策略、机会发现) | +| 输入 | 品牌名称、竞品列表、分析维度 | +| 输出 | 竞品分析报告 | + +**分析维度**: +- `citation_volume` - 引用量对比 +- `citation_quality` - 引用质量评估 +- `citation_scenario` - 引用场景分布 +- `content_strategy` - 内容策略分析 +- `opportunity_discovery` - 机会发现与建议 + +### TrendAgent (趋势洞察Agent) + +| 属性 | 说明 | +|------|------| +| 文件 | `backend/app/agent_framework/agents/trend_agent.py` | +| 职责 | 识别关键词和平台的引用趋势(上升/下降/平稳/热点) | +| 输入 | 关键词列表、AI平台列表、时间范围 | +| 输出 | 趋势分析报告 | + +**趋势类型**: +- `rising` - 上升趋势 +- `declining` - 下降趋势 +- `stable` - 平稳趋势 +- `hotspot` - 热点趋势 + ## 通信协议 基于数据库+Redis Queue的异步通信: diff --git a/docs/02-模块说明/agent-protocol.md b/docs/02-模块说明/agent-protocol.md index 924fa66..2960758 100644 --- a/docs/02-模块说明/agent-protocol.md +++ b/docs/02-模块说明/agent-protocol.md @@ -90,6 +90,9 @@ class AgentStatus(str, Enum): | RULE_CHECKER | 规则检查Agent | 内容合规审核 | | COMPETITOR_ANALYZER | 竞品分析Agent | 竞品数据收集分析 | | PERFORMANCE_TRACKER | 性能追踪Agent | 追踪内容表现 | +| SCHEMA_ADVISOR | Schema建议Agent | 生成JSON-LD结构化数据建议 | +| MONITOR_AGENT | 效果追踪Agent | 检测品牌引用变化,生成变化报告 | +| TREND_AGENT | 趋势洞察Agent | 识别关键词和平台的引用趋势 | ## 任务类型 diff --git a/docs/04-API文档/README.md b/docs/04-API文档/README.md index 6b7b682..d1c7e65 100644 --- a/docs/04-API文档/README.md +++ b/docs/04-API文档/README.md @@ -1,6 +1,6 @@ # API文档 -本目录包含GEO平台的API接口文档。 +本目录包含GEO平台的API接口文档,共33个API模块。 ## 文档内容 @@ -8,8 +8,33 @@ - [品牌API](./brands.md) - 品牌管理接口 - [内容API](./content.md) - 内容生成接口 - [知识库API](./knowledge.md) - 知识库接口 -- [诊断API](./diagnosis.md) - SEO/GEO诊断接口 ✅ 新增 +- [诊断API](./diagnosis.md) - SEO/GEO诊断接口 - [健康检查](./health.md) - 系统健康检查接口 +- [监测优化API](./analytics.md) - 监测数据分析接口 +- [告警通知API](./alerts.md) - 告警通知管理接口 +- [仪表盘API](./dashboard.md) - 仪表盘数据接口 +- [查询词API](./queries.md) - 查询词管理接口 +- [引用数据API](./citations.md) - 引用数据管理接口 +- [报告API](./reports.md) - 报告生成接口 +- [Agent管理API](./agents.md) - Agent管理接口 +- [生命周期API](./lifecycle.md) - 生命周期管理接口 +- [内容管理API](./contents.md) - 内容管理接口 +- [客户管理API](./clients.md) - 客户管理接口 +- [内容分发API](./distribution.md) - 内容分发接口 +- [优化建议API](./suggestions.md) - 优化建议接口 +- [引导流程API](./onboarding.md) - 引导流程接口 +- [平台规则API](./platform-rules.md) - 平台规则管理接口 +- [图片生成API](./image.md) - 图片生成接口 +- [组织管理API](./organization.md) - 组织管理接口 +- [检测任务API](./detection.md) - 检测任务接口 +- [GEO方案API](./strategy.md) - GEO方案生成接口 +- [AI引擎查询API](./ai-engines.md) - AI引擎查询接口 +- [效果追踪API](./monitoring.md) - 效果追踪监测接口 +- [竞品分析API](./competitor.md) - 竞品分析接口 +- [Schema建议API](./schema.md) - Schema标记建议接口 +- [趋势洞察API](./trends.md) - 趋势洞察接口 +- [API Key管理API](./api-keys.md) - API Key管理接口 +- [用量追踪API](./usage.md) - 用量追踪接口 ## 诊断模块API diff --git a/docs/05-部署运维/docker.md b/docs/05-部署运维/docker.md index fa19d3d..4750b1e 100644 --- a/docs/05-部署运维/docker.md +++ b/docs/05-部署运维/docker.md @@ -23,8 +23,12 @@ cp .env.example .env # 编辑.env配置必要参数 ``` +> **说明**:Docker Compose 默认读取项目根目录下的 `.env` 文件加载环境变量。所有配置项均可通过 `.env` 文件统一管理,无需在 `docker-compose.yml` 中硬编码。 + ### 3. 启动服务 +> **注意**:开发模式使用 `uvicorn --reload` 启动,支持热重载;生产环境必须移除 `--reload` 参数,使用 `uvicorn app.main:app --host 0.0.0.0 --port 8000` 启动。 + ```bash # 开发环境 docker-compose up -d @@ -56,10 +60,24 @@ docker-compose logs -f ## 数据持久化 - volumes: - - postgres_data:/var/lib/postgresql/data - - redis_data:/data - - ./uploads:/app/uploads +通过Volume挂载确保容器重建后数据不丢失: + +```yaml +volumes: + postgres_data:/var/lib/postgresql/data + redis_data:/data + ./uploads:/app/uploads + ./logs:/app/logs + ./backups:/app/backups +``` + +| 挂载路径 | 说明 | +|----------|------| +| `postgres_data` | PostgreSQL数据目录 | +| `redis_data` | Redis持久化数据 | +| `./uploads` | 用户上传文件 | +| `./logs` | 应用日志 | +| `./backups` | 数据库备份 | ## 健康检查 diff --git a/docs/05-部署运维/environment.md b/docs/05-部署运维/environment.md index 6cd9517..884764a 100644 --- a/docs/05-部署运维/environment.md +++ b/docs/05-部署运维/environment.md @@ -32,6 +32,8 @@ | `SECRET_KEY` | 应用密钥 | `your-secret-key` | | `CORS_ORIGINS` | 允许的跨域源 | `http://localhost:3000` | +> **注意**:生产环境必须配置具体域名,不能使用通配符`*` + ### 认证 | 变量 | 说明 | 示例 | @@ -60,6 +62,29 @@ | `QDRANT_HOST` | Qdrant地址 | `localhost` | | `QDRANT_PORT` | Qdrant端口 | `6333` | +### AI平台API密钥 + +| 变量 | 说明 | 示例 | +|------|------|------| +| `CHATGPT_API_KEY` | ChatGPT API密钥 | `sk-...` | +| `DEEPSEEK_API_KEY` | DeepSeek API密钥 | `sk-...` | +| `QWEN_API_KEY` | 通义千问API密钥 | `sk-...` | +| `KIMI_API_KEY` | Kimi API密钥 | `sk-...` | +| `DOUBAO_API_KEY` | 豆包API密钥 | `sk-...` | +| `GEMINI_API_KEY` | Gemini API密钥 | `AI...` | +| `PERPLEXITY_API_KEY` | Perplexity API密钥 | `pplx-...` | +| `YUANBAO_API_KEY` | 元宝API密钥 | `sk-...` | +| `WENXIN_API_KEY` | 文心一言API密钥 | `sk-...` | +| `ZHIPU_API_KEY` | 智谱AI API密钥 | `...` | +| `TONGYI_API_KEY` | 通义API密钥 | `sk-...` | + +### 功能开关 + +| 变量 | 说明 | 默认值 | +|------|------|------| +| `ENABLE_LLM` | 启用LLM增强功能 | `true` | +| `ENABLE_AGENT_FRAMEWORK` | 启用Agent框架 | `true` | + ## 开发环境示例 ```env diff --git a/docs/brainstorms/2026-05-31-geo-next-phase-core-flow-repair-requirements.md b/docs/brainstorms/2026-05-31-geo-next-phase-core-flow-repair-requirements.md new file mode 100644 index 0000000..4f4ec1c --- /dev/null +++ b/docs/brainstorms/2026-05-31-geo-next-phase-core-flow-repair-requirements.md @@ -0,0 +1,204 @@ +--- +date: "2026-05-31" +topic: "geo-next-phase-core-flow-repair" +--- + +## Summary + +GEO 平台下一阶段的核心目标是实施变现闭环(诊断修复→免费获客→付费转化→执行闭环→效果归因),同时用 API 契约测试驱动核心功能修复,辅以少量 E2E 烟雾测试验证关键用户路径。原有的全面质量保障计划(E2E 覆盖、性能基线、安全扫描)降级为上线后工作。 + +--- + +## Problem Frame + +GEO 平台尚未上线,存在三个致命问题阻断变现闭环: + +1. **诊断系统形同虚设** — `backend/app/api/diagnosis.py` 第 75-77 行用 `GEODiagnosisInput()` 空参调用,所有字段默认 False/0,永远返回 0 分。产品核心价值为零。 +2. **内容 Pipeline 只格式化不生成** — `backend/app/services/content/content_pipeline.py` 只有 4 个阶段(RuleValidator→SensitiveFilter→SEOOptimizer→HTMLGenerator),没有 AI 内容生成阶段,核心付费功能缺失。 +3. **支付是模拟的** — `backend/app/services/subscription.py` 中 `payment_method="模拟支付"`,没有真实支付网关和功能限制中间件,用户无付费动力。 + +此外,Onboarding 缺少付费墙触发点、分发没有实际发布集成、监控没有归因逻辑、邮件服务未接入业务系统——各模块存在但互不连通。 + +与此同时,当前的质量保障计划(11 个 IU)将大量投入放在 E2E 测试、性能基线、安全扫描上——这些对未上线且核心功能断裂的产品来说 ROI 不高。核心矛盾:**测试体系是为稳定产品服务的,但产品本身还不稳定。** + +--- + +## Key Decisions + +**变现闭环优先于全面测试体系** — 产品核心价值为零(诊断永远 0 分),此时投入 E2E 测试、性能基线、安全扫描无法产生实际价值。先修通变现闭环,让产品可用且有收入,再建设全面质量保障。 + +**测试角色从"质量保障"转变为"行为契约"** — 不再是"写测试防止回归",而是"用测试定义正确行为,驱动修复"。为变现闭环的每个新/修改 API 编写契约测试,定义期望行为,驱动实现修复。 + +**API 契约测试为主,E2E 烟雾测试为辅** — API 测试反馈快(秒级)、定位精确、维护成本低;E2E 只覆盖 1-2 个最关键路径作为安全网。原 QA 计划的 3 个完整 E2E 套件(U3/U4/U5)缩减为 1-2 个烟雾测试。 + +**获客优先于数据护城河** — 行业基准数据是长期壁垒,但在没有验证付费意愿之前投入风险过高。先用免费 GEO 健康分验证需求,再积累数据资产。 + +**半自动执行优先于全自动** — 品牌方不会信任 AI 完全自主发布内容。执行闭环采用"AI 生成→人工审核→确认发布"模式。 + +**中国 AI 平台生态为差异化核心** — 海外 SEO/GEO 工具无法覆盖文心、Kimi、通义、豆包等中国 AI 平台,这是天然护城河。 + +**性能基线、安全扫描、全面 E2E 推迟到上线后** — 这些投入在有用户之前无法产生实际价值。上线后根据真实数据和用户反馈决定优先级。 + +--- + +## Requirements + +### 变现闭环核心修复 + +R1. 修复诊断系统:实现自动数据采集(AI 平台查询+网站爬取+CitationRecord 分析),让 `GEODiagnosisService.diagnose()` 接受真实数据输入,产出非零有差异的评分 + +R2. 添加 AI 内容生成阶段:在 ContentPipeline 前端添加 Stage 0(AI Generator),基于诊断结果+RAG 知识库+LLM 生成初稿,后续阶段不变 + +R3. 接入真实支付:微信支付+支付宝双通道,支付 Webhook 处理订阅状态更新,功能限制中间件按套餐控制使用量 + +### 获客与转化 + +R4. 建设免费 GEO 健康分公开页面:输入品牌名→30 秒出报告(综合评分+3 核心维度+3 竞品对比),无需注册,结果缓存 24 小时 + +R5. 重设计 Onboarding 流程:第一步为"查看 GEO 健康分"而非填表,在查看详细建议/执行优化/查看归因报告时嵌入升级提示 + +R6. 执行闭环:内容分发集成微信(半自动)/知乎(API)/头条(API),发布流程"AI 生成→人工审核→确认发布" + +R7. 效果归因系统:追踪已发布内容的 AI 引用变化,2-4 周归因窗口,ROI 计算+A/B 对比报告 + +### API 契约测试 + +R8. 为诊断 API 编写契约测试:POST /api/v1/diagnosis 触发异步诊断,GET /api/v1/diagnosis/{id}/result 返回非零评分 + +R9. 为公开健康分 API 编写契约测试:GET /api/v1/public/health-score?brand=XXX 无需认证,返回综合评分+3 维度+竞品对比 + +R10. 为内容生成 API 编写契约测试:POST /api/v1/content/generate 基于诊断+RAG 生成内容 + +R11. 为支付 API 编写契约测试:创建订单→支付回调→订阅激活的完整链路 + +R12. 编写跨步骤集成测试:品牌创建→诊断→方案→内容生成→效果追踪的完整链路 + +### E2E 烟雾测试 + +R13. 编写 1 个获客路径 E2E 烟雾测试:访问健康分页面→输入品牌名→看到报告→点击注册 + +R14. 编写 1 个核心流程 E2E 烟雾测试:登录→创建品牌→触发诊断→查看诊断结果 + +### 测试基础设施(延续) + +R15. 完成后端测试目录统一验证(原 QA 计划 U1) + +R16. 建立共享 fixture 体系(原 QA 计划 U2):API 契约测试完成后迁移到共享 fixture + +--- + +## Key Flows + +- F1. 变现闭环主流程 + - **Trigger:** 品牌市场人员访问平台 + - **Actors:** 品牌市场人员, 平台系统 + - **Steps:** (1) 输入品牌名查看免费 GEO 健康分 (2) 看到震撼的评分和竞品对比 (3) 注册查看详细建议 (4) 升级付费版执行优化 (5) AI 生成优化内容 (6) 人工审核确认发布 (7) 2-4 周后查看效果归因报告 (8) 续费 + - **Outcome:** 用户完成从"不知道 GEO"到"付费执行优化"到"续费"的完整闭环 + +- F2. API 契约驱动修复流程 + - **Trigger:** 开始修复某个断裂环节 + - **Actors:** 开发者 + - **Steps:** (1) 编写该环节的 API 契约测试,定义期望的输入/输出 (2) 运行测试,确认失败(红色) (3) 修复后端实现 (4) 运行测试,确认通过(绿色) (5) 编写下一个环节的契约测试 (6) 重复直到全链路通过 + - **Outcome:** 核心流程每步都有契约保护,断裂点被精确定位和修复 + +--- + +## Acceptance Examples + +- AE1. **诊断系统修复** — Covers R1, R8 + - **Given:** 品牌已创建 + - **When:** 调用诊断 API 并传入品牌 ID + - **Then:** 返回非零评分,且不同品牌评分有差异;免费版返回 3 核心维度,付费版返回全部 6 维度 + +- AE2. **免费健康分获客** — Covers R4, R9 + - **Given:** 未注册用户访问健康分页面 + - **When:** 输入品牌名 + - **Then:** 30 秒内生成包含综合评分、3 维度、3 竞品对比的报告;24 小时内重复查询返回缓存 + +- AE3. **AI 内容生成** — Covers R2, R10 + - **Given:** 付费用户有诊断结果和知识库 + - **When:** 触发内容生成 + - **Then:** 基于诊断问题+RAG 上下文生成针对性优化内容,内容通过 Pipeline 校验 + +- AE4. **支付闭环** — Covers R3, R11 + - **Given:** 用户选择升级套餐 + - **When:** 完成支付 + - **Then:** 订阅状态更新为 active,功能限制中间件解锁对应功能 + +- AE5. **E2E 烟雾测试** — Covers R13, R14 + - **Given:** 前后端服务运行 + - **When:** 执行 Playwright E2E 烟雾测试 + - **Then:** 获客路径和核心流程路径通过 + +--- + +## Success Criteria + +- 诊断系统产出真实非零评分,不同品牌有差异 +- 免费 GEO 健康分页面可公开访问,30 秒内出报告 +- AI 内容生成基于诊断结果产出可用内容 +- 支付闭环可走通(创建订单→支付→激活→功能解锁) +- 核心流程 API 契约测试全部通过 +- 2 个 E2E 烟雾测试通过 + +--- + +## Scope Boundaries + +**Deferred for later:** + +- 全面 E2E 测试覆盖(原 QA 计划 U3/U4/U5 → 上线后逐步扩展) +- 性能基线测试(原 QA 计划 U7 → 上线后有真实负载后建立) +- CI 安全扫描(原 QA 计划 U8 → 上线后集成) +- Alembic 迁移验证(原 QA 计划 U8 → 上线后添加) +- 前端组件测试(原 QA 计划 U10 → 上线后扩展) +- 行业 GEO 基准数据积累 +- API 市场开放和第三方集成 +- 白标报告和专属客户成功经理 +- 多语言和国际 AI 平台支持 + +**Outside this product's identity:** + +- 传统 SEO 工具功能(关键词排名、外链分析等) +- 广告投放和营销自动化 +- 社交媒体管理 +- 基础设施级别的渗透测试 + +--- + +## Dependencies / Assumptions + +- 变现闭环实施计划已存在:`docs/plans/2026-05-31-003-feat-geo-monetization-closed-loop-plan.md` +- 诊断自动数据采集依赖现有 AI 平台适配器(`backend/app/services/ai_engine/`) +- AI 内容生成依赖 LLM API Key 可用 +- 支付集成需要微信支付和支付宝商户号 +- 知乎/头条 API 发布需要 OAuth 授权和 API Key +- API 契约测试使用 httpx AsyncClient + 内存 SQLite,不依赖外部服务 +- E2E 烟雾测试需要前后端服务同时运行 +- 假设品牌方看到 GEO 健康分后会意识到问题并产生付费意愿(未验证,需 MVP 验证) + +--- + +## Outstanding Questions + +**Deferred to Planning:** + +- 效果归因的具体技术实现方案(基于引用检测还是基于评分变化) +- 付费版定价是否需要调整(当前 ¥199/599/1999 是否匹配价值感知) +- 微信公众号 OAuth 授权流程的具体实现方案 +- 知乎和头条 API 的申请和审核周期 +- 诊断自动数据采集的网站爬取可能遇到反爬机制的降级方案 + +--- + +## Sources / Research + +- 变现闭环需求文档:`docs/brainstorms/2026-05-31-geo-platform-monetization-closed-loop-requirements.md` +- 变现闭环实施计划:`docs/plans/2026-05-31-003-feat-geo-monetization-closed-loop-plan.md` +- 质量保障需求文档:`docs/brainstorms/2026-05-31-geo-platform-next-phase-quality-requirements.md` +- 质量保障实施计划:`docs/plans/2026-05-31-002-test-quality-assurance-system-plan.md` +- 诊断系统根因:`backend/app/api/diagnosis.py` 第 75-77 行(空参调用) +- 内容 Pipeline:`backend/app/services/content/content_pipeline.py`(4 阶段,无 AI 生成) +- 订阅服务:`backend/app/services/subscription.py`(模拟支付) +- 现有后端测试:`backend/tests/` (~90 个文件) +- 现有 E2E 测试:`frontend/e2e/tests/` (7 个文件) diff --git a/docs/brainstorms/2026-05-31-geo-platform-acceptance-audit-requirements.md b/docs/brainstorms/2026-05-31-geo-platform-acceptance-audit-requirements.md new file mode 100644 index 0000000..8e78178 --- /dev/null +++ b/docs/brainstorms/2026-05-31-geo-platform-acceptance-audit-requirements.md @@ -0,0 +1,184 @@ +--- +date: "2026-05-31" +topic: "geo-platform-acceptance-audit" +--- + +## Summary + +GEO 平台全项目验收审计方案,定义功能完整性、安全性、性能、代码质量四个维度的验收标准与检查清单,包含已发现 42 项问题的分级修复计划,以及四类项目文档的更新规范,确保平台达到可交付质量水平。 + +## Problem Frame + +GEO 平台经过多轮迭代开发,已实现 33 个 API 路由模块、8 个 Agent、35 个数据模型、25+ 前端页面。但快速迭代带来了技术债务:4 个新 Agent 的前端 API 客户端尚未创建、部分 API 存在认证绕过风险、N+1 查询影响性能、API 文档覆盖率仅 15%。在交付前需要系统性的验收审计,识别并修复阻塞问题,同步更新文档使项目状态与代码一致。 + +--- + +## Key Decisions + +**审计维度覆盖全部四项而非选择性覆盖** — 功能完整性、安全性、性能、代码质量相互关联,部分安全漏洞隐藏在功能实现中,性能问题影响功能可用性,选择性覆盖会遗漏跨维度问题。 + +**Critical/High 问题作为验收阻塞项** — 20 个 Critical + High 问题必须在验收通过前修复,Medium/Low 问题可记录为技术债务在后续迭代处理。 + +**文档更新纳入验收流程** — 文档与代码不同步是当前最大运维风险(API 文档覆盖率仅 15%),文档更新不是附加任务而是验收的必要条件。 + +**自审计而非第三方审计** — 本方案定位为交付前自审计,确保基本质量门槛;渗透测试等深度安全审计作为后续独立项目规划。 + +--- + +## Requirements + +### 功能完整性 + +R1. 所有 33 个 API 路由模块必须有对应的前端 API 客户端模块,且接口签名与后端一致 + +R2. 8 个 Agent 均可通过 standalone 模式独立启动并执行其声明的 supported_tasks + +R3. 用户注册→登录→Onboarding→品牌创建→诊断→方案生成→内容生成→效果追踪的完整业务闭环可端到端走通 + +R4. 4 个新 Agent(MonitorAgent、SchemaAdvisor、CompetitorAnalyzer、TrendAgent)的前端 API 客户端模块必须创建,覆盖其全部 API 端点 + +R5. 前端所有页面的 API 调用参数(字段名、类型、顺序)与后端路由定义一致,无运行时 400/500 错误 + +R6. GEO 方案自动生成流程(诊断→策略→方案→行动)完整可用,竞品为可选输入 + +R7. 品牌评分(5 维度 V2 评分体系)在品牌详情页和仪表盘均可正确展示 + +### 安全性 + +R8. 所有业务 API 端点(除 `/health`、`/ready`、`/auth/login`、`/auth/register` 外)必须要求认证,不存在认证绕过路径 + +R9. analytics 路由的认证检查必须与业务路由一致,不能因中间件顺序或装饰器缺失而跳过 + +R10. API Key 通过加密存储(key_encryption.py),密钥不在日志、响应体、前端代码中明文暴露 + +R11. CORS 配置在生产环境必须限制 `allow_origins` 为具体域名,不能使用通配符 `*` + +R12. 用户输入必须经过 Pydantic Schema 校验,不存在未校验的请求参数直接进入数据库查询 + +R13. SQL 查询使用 SQLAlchemy ORM 参数化查询,不存在字符串拼接 SQL + +### 性能 + +R14. 品牌评分数据获取逻辑提取为共享服务(BrandScoringDataService),消除 dashboard.py 和 strategy.py 中的重复查询 + +R15. 消除 N+1 查询模式:品牌列表页、仪表盘、策略页等高频访问路径不得出现循环内单条查询 + +R16. AI 引擎批量查询使用 `queryBatch` 端点而非逐条查询 + +R17. Redis 缓存层对热点数据(品牌评分、仪表盘统计)生效,缓存命中率作为可观测指标 + +### 代码质量 + +R18. 前端 `agentsApi.enable`/`agentsApi.disable` 死代码必须移除或替换为实际可用的实现 + +R19. 前端 `onboardingApi.getCompetitorRecommendations` 参数签名必须与后端一致 + +R20. `models/__init__.py` 中标注为"重构后遗留"的模型导入必须验证其与数据库迁移的一致性 + +R21. `SEOOptimizer` 命名与实际功能(GEO 优化)不一致的问题必须在代码或文档中明确标注 + +R22. `content.py`(内容生产)与 `contents.py`(内容管理)的命名混淆必须在 API 文档中明确区分 + +R23. 前端 `lib/api/index.ts` 聚合导出必须覆盖所有 API 客户端模块,消除遗漏 + +--- + +## Key Flows + +- F1. 验收审计执行流程 + - **Trigger:** 审计方案文档确认后启动 + - **Actors:** 开发者、审计执行者 + - **Steps:** (1) 按维度执行检查清单 (2) 记录每项的通过/不通过状态 (3) Critical/High 不通过项进入修复流程 (4) 修复后回归验证 (5) 全部阻塞项通过后进入文档更新阶段 (6) 文档更新完成后签发验收报告 + - **Outcome:** 验收通过/不通过判定 + 问题修复记录 + 更新后的文档 + +- F2. 问题修复流程 + - **Trigger:** 审计检查发现 Critical 或 High 问题 + - **Actors:** 开发者 + - **Steps:** (1) 问题分级确认 (2) 按优先级排序(Critical > High > Medium > Low)(3) 修复实现 (4) 回归验证(确认修复未引入新问题)(5) 更新审计检查清单状态 + - **Outcome:** 问题状态从"不通过"变为"通过" + +- F3. 文档更新流程 + - **Trigger:** 代码修复完成后 + - **Actors:** 开发者 + - **Steps:** (1) 更新项目文档(README、架构图、API 文档)(2) 更新 AI 上下文文件(CLAUDE.md、AGENTS.md)(3) 更新设计文档(模块说明、流程图)(4) 更新部署文档(环境变量、Docker 配置)(5) 交叉验证文档与代码一致性 + - **Outcome:** 四类文档与代码实现完全一致 + +--- + +## Acceptance Examples + +- AE1. **认证绕过验证** — Covers R8, R9 + - **Given:** 未认证的 HTTP 请求 + - **When:** 请求 `/api/v1/analytics` 端点 + - **Then:** 返回 401 Unauthorized,不返回业务数据 + +- AE2. **新 Agent 前端可用性** — Covers R4 + - **Given:** 前端 API 客户端模块 `monitoring.ts`、`competitor-analysis.ts`、`schema-advisor.ts`、`trends.ts` 已创建 + - **When:** 前端页面调用这些模块的方法 + - **Then:** 请求正确到达后端对应端点,参数类型和顺序匹配,无 400/500 错误 + +- AE3. **N+1 查询消除** — Covers R14, R15 + - **Given:** 仪表盘页面加载 + - **When:** 后端处理 `/api/v1/dashboard` 请求 + - **Then:** 品牌评分数据通过共享服务一次查询获取,SQL 日志中不出现循环内单条查询模式 + +- AE4. **CORS 生产配置** — Covers R11 + - **Given:** 生产环境部署 + - **When:** CORS 中间件初始化 + - **Then:** `allow_origins` 不包含通配符 `*`,仅允许配置的域名 + +- AE5. **业务闭环端到端** — Covers R3 + - **Given:** 新注册用户完成 Onboarding + - **When:** 依次执行诊断→查看诊断报告→生成 GEO 方案→执行内容生成→查看效果追踪 + - **Then:** 每一步均可成功完成,数据在步骤间正确传递 + +--- + +## Success Criteria + +- 全部 6 个 Critical 问题修复并通过回归验证 +- 全部 14 个 High 问题修复并通过回归验证 +- API 文档覆盖率从 15% 提升至 80% 以上(至少覆盖 27/33 个路由模块) +- AI 上下文文件(CLAUDE.md / AGENTS.md)创建完成,准确反映项目当前架构 +- 四类文档更新完成且与代码实现一致 +- 端到端业务闭环无阻塞错误 + +--- + +## Scope Boundaries + +**Deferred for later:** + +- 自动化 E2E 测试框架搭建 +- 性能压测与负载测试方案 +- 第三方安全审计(渗透测试) +- CI/CD 流水线完善 +- 前端单元测试和集成测试覆盖 + +**Outside this product's identity:** + +- 基础设施级别的安全审计(网络层、K8s 配置) +- 用户体验走查和可用性测试 +- 多语言/国际化验证 + +--- + +## Dependencies / Assumptions + +- 后端服务可正常启动(PostgreSQL + Redis 可用) +- 前端开发服务器可正常启动 +- 至少一个 LLM Provider 的 API Key 可用(用于验证 Agent 功能) +- 当前审计基于代码静态分析 + 架构审查,不包含运行时动态分析 +- 数据库迁移状态与 models 定义一致(需在验收时验证) + +--- + +## Sources / Research + +- 审计扫描结果:Critical 6、High 14、Medium 14、Low 8 共 42 项问题 +- 项目架构文档:`docs/01-项目概览/architecture.md` +- Agent 框架协议:`docs/02-模块说明/agent-protocol.md` +- GEO 工作流分析:`docs/02-模块说明/geo-workflow-analysis.md` +- 新 Agent 实现计划:`docs/plans/2026-05-28-001-feat-geo-platform-new-agents-plan.md` +- 环境变量模板:`.env.example` +- Docker 部署配置:`docker-compose.yml` diff --git a/docs/brainstorms/2026-05-31-geo-platform-monetization-closed-loop-requirements.md b/docs/brainstorms/2026-05-31-geo-platform-monetization-closed-loop-requirements.md new file mode 100644 index 0000000..43dc50b --- /dev/null +++ b/docs/brainstorms/2026-05-31-geo-platform-monetization-closed-loop-requirements.md @@ -0,0 +1,127 @@ +--- +date: 2026-05-31 +topic: geo-platform-monetization-closed-loop +--- + +## Summary + +将现有GEO平台改造为以「免费GEO健康分」为获客入口、以「诊断→建议→执行→效果归因」为付费闭环的商业模式,同时聚焦中国AI平台生态构建差异化壁垒。 + +## Problem Frame + +品牌方市场团队目前完全没有关注过自己在AI搜索引擎中的表现。当ChatGPT、Kimi、DeepSeek等AI平台回答用户提问时,品牌是否被引用、引用是否正面、排名如何——这些直接影响品牌认知的数据,品牌方一无所知。现有的GEO平台已经具备了6维度诊断、内容生成、知识库等能力,但缺少三个关键环节:让目标客户意识到问题的获客机制、从诊断到效果验证的执行闭环、以及证明优化有效的效果归因。没有这三者,平台无法完成商业变现闭环。 + +## Key Decisions + +**获客优先于数据护城河。** 行业基准数据(方案C)是长期壁垒,但在没有验证付费意愿之前投入,风险过高。先用免费GEO健康分验证需求,再积累数据资产。 + +**半自动执行优先于全自动。** 品牌方不会信任AI完全自主发布内容。执行闭环采用"AI生成→人工审核→确认发布"模式,而非全自动发布。 + +**中国AI平台生态为差异化核心。** 海外SEO工具(Ahrefs/Semrush)和GEO工具(Profound/peecAI)无法覆盖文心、Kimi、通义、豆包等中国AI平台,这是天然护城河。 + +**免费诊断的"震撼值"驱动转化。** 免费GEO健康分必须足够震撼(竞品对比、行业排名、具体问题),才能让不知道GEO的品牌方产生紧迫感和付费意愿。 + +**免费诊断采用频率+深度双限制控制成本。** 免费版限制诊断频率(每个品牌每天1次,详细报告每7天1次)和诊断深度(免费版只看3个核心维度和3个竞品,付费版看全部6维度和更多竞品),在保持震撼值的同时控制API调用成本。 + +**三平台同步推进内容分发集成。** 微信公众号、知乎、头条三个平台同步推进发布集成,微信公众号采用半自动模式(生成内容+复制粘贴引导),知乎和头条利用较开放的API实现更自动化的发布。 + +## Actors + +- A1. **品牌市场人员** — 核心用户,需要了解品牌在AI搜索中的表现并采取行动 +- A2. **品牌决策者** — 查看GEO健康分报告后决定是否付费的管理层 +- A3. **平台系统** — 自动执行诊断、监控、归因等后台任务 + +## Requirements + +### 获客层:GEO健康分驱动增长 + +- R1. 提供无需注册的即时GEO健康分诊断,输入品牌名即可在30秒内生成可视化报告 +- R2. 免费报告包含:综合GEO评分、3个核心维度得分、与3个竞品的对比、最严重的3个问题概要;频率限制为每个品牌每天1次,详细报告每7天1次 +- R2b. 付费版解锁全部6维度诊断、更多竞品对比、无频率限制 +- R3. 详细修复建议和执行方案需注册后查看,注册免费 +- R4. 执行修复建议(内容生成、知识库搭建等)需升级付费版 +- R5. 诊断报告支持生成可分享链接和PDF,便于品牌市场人员向决策者汇报 + +### 转化层:从认知到付费 + +- R6. 重设计Onboarding流程,第一步为"查看你的GEO健康分"而非填表 +- R7. 注册后自动触发GEO变化周报邮件订阅,作为持续触达和转化手段 +- R8. 付费版升级触发点嵌入在以下场景:查看详细建议时、执行优化操作时、查看效果归因报告时 +- R9. 提供GEO评分提升承诺和效果保障机制,降低付费决策门槛 + +### 执行闭环层:诊断→建议→执行→效果归因 + +- R10. 打通内容分发到实际发布的链路,同步支持微信公众号(半自动:生成内容+复制粘贴引导)、知乎(API发布)、头条(API发布)三个平台 +- R11. 发布流程采用"AI生成→人工审核→确认发布"的半自动模式 +- R12. 建设效果归因系统,追踪已发布内容是否被AI搜索引擎引用,计算GEO ROI +- R13. 提供优化前后的A/B对比报告,量化GEO优化的实际效果 +- R14. 设定合理的归因时间窗口(建议2-4周),并在此期间持续监控变化趋势 + +### 差异化层:中国AI平台生态壁垒 + +- R15. 持续覆盖中国主流AI平台(文心、Kimi、通义、豆包、元宝、清言、天工等),确保海外工具无法替代 +- R16. 针对中国AI平台的引用偏好建立规则库,持续更新各平台的引用模式变化 +- R17. 知识库针对GEO场景深度优化,自动从品牌官网、产品文档中提取AI搜索引擎偏好的信息结构 + +## Key Flows + +- F1. 免费获客转化流 + - **Trigger:** 品牌市场人员访问平台首页 + - **Actors:** A1, A3 + - **Steps:** 输入品牌名 → 系统即时生成GEO健康分报告 → 展示综合评分和竞品对比 → 提示"查看详细修复建议"需注册 → 注册后查看建议概要 → 执行建议需升级付费版 + - **Outcome:** 用户完成从"不知道GEO"到"付费执行优化"的转化 + +- F2. 执行闭环流 + - **Trigger:** 用户查看GEO优化建议并决定执行 + - **Actors:** A1, A3 + - **Steps:** 选择优化建议 → AI生成优化内容 → 人工审核确认 → 确认发布到目标平台 → 系统开始追踪该内容的AI引用情况 → 2-4周后生成效果归因报告 + - **Outcome:** 用户看到GEO优化的量化效果,形成续费动力 + +- F3. 效果归因流 + - **Trigger:** 内容发布后系统自动启动追踪 + - **Actors:** A3 + - **Steps:** 记录发布时的基线数据 → 定期检查AI搜索引擎是否引用该内容 → 计算引用变化和GEO评分变化 → 生成归因报告(优化前vs优化后) → 通知用户查看效果 + - **Outcome:** 用户获得GEO ROI的量化证明 + +## Acceptance Examples + +- AE1. **Covers R1, R2, R5.** Given 一个未注册用户访问平台, When 输入品牌名"某某科技", Then 30秒内生成包含综合评分、6维度得分、竞品对比和可分享链接的GEO健康分报告 +- AE2. **Covers R3, R4.** Given 一个已注册免费用户查看GEO健康分报告, When 点击"查看详细修复建议", Then 显示建议概要;When 点击"执行优化", Then 提示需升级付费版 +- AE3. **Covers R10, R11, R12.** Given 一个付费用户选择执行某条GEO优化建议, When AI生成优化内容后, Then 进入人工审核环节;When 用户确认发布, Then 内容发布到目标平台且系统开始追踪AI引用变化 +- AE4. **Covers R13, R14.** Given 一条优化内容已发布3周, When 系统检测到该内容被Kimi引用, Then 生成归因报告显示优化前后的GEO评分变化和引用数据对比 + +## Success Criteria + +- 免费GEO健康分报告的完成率(输入品牌名到查看报告)> 60% +- 从免费报告到注册的转化率 > 15% +- 从注册到付费的转化率 > 5% +- 付费用户中能看到效果归因报告的比例 > 70% +- 付费用户月续费率 > 85% + +## Scope Boundaries + +**Deferred for later:** +- 行业GEO基准数据积累和行业白皮书发布 +- API市场开放和第三方集成 +- 白标报告和专属客户成功经理功能 +- 多语言和国际AI平台支持 + +**Outside this product's identity:** +- 传统SEO工具功能(关键词排名、外链分析等) +- 广告投放和营销自动化 +- 社交媒体管理 + +## Dependencies / Assumptions + +- 假设品牌方看到GEO健康分后会意识到问题并产生付费意愿(未验证,需通过MVP验证) +- 假设GEO优化效果可在2-4周内被观测到(需验证,不同AI平台更新周期不同) +- 假设中国主流内容平台(微信公众号、知乎、头条)的API开放程度足以支持半自动发布(需验证) +- 依赖现有6维度GEO诊断系统的准确性和稳定性 +- 依赖现有8个Agent框架的可扩展性 + +## Outstanding Questions + +**Deferred to Planning:** +- 效果归因的具体技术实现方案(基于引用检测还是基于评分变化) +- 付费版定价是否需要调整(当前¥199/599/1999是否匹配价值感知) +- GEO变化周报的邮件发送基础设施选型 diff --git a/docs/brainstorms/2026-05-31-geo-platform-next-phase-quality-requirements.md b/docs/brainstorms/2026-05-31-geo-platform-next-phase-quality-requirements.md new file mode 100644 index 0000000..0826d79 --- /dev/null +++ b/docs/brainstorms/2026-05-31-geo-platform-next-phase-quality-requirements.md @@ -0,0 +1,167 @@ +--- +date: "2026-05-31" +topic: "geo-platform-next-phase-quality" +--- + +## Summary + +GEO 平台下一阶段质量保障体系,采用双轨并行策略:轨道一统一测试基础设施(合并分裂的测试目录、完善 CI 流水线、建立共享 fixture 体系),轨道二同步推进核心业务 E2E 测试和四项专项测试(性能基准、安全扫描、数据库迁移验证、Agent 端到端测试),确保核心业务流程尽早获得回归保护。 + +## Problem Frame + +GEO 平台经过多轮迭代已具备 33 个 API、8 个 Agent、25+ 前端页面,但测试覆盖严重不足:后端两套测试目录分裂导致 17 个测试被 CI 忽略,前端仅 13 个单元测试且零组件测试,E2E 仅覆盖登录流程,核心业务(品牌创建→诊断→方案→内容→监控)完全无回归保护。CI 流水线不运行 E2E 测试,无安全扫描,无性能基线。在用户开始使用前,必须建立质量保障体系防止功能回归。 + +--- + +## Key Decisions + +**双轨并行而非分层递进** — 基础设施统一和 E2E 测试编写同步推进,因为两者几乎不涉及相同文件,并行没有实质冲突。E2E 测试初期用独立 setup,后续迁移到共享 fixture 是低成本重构。如果等基础设施全部就绪再写 E2E,核心流程的回归保护会延迟数周。 + +**E2E 测试仅覆盖 Chromium** — 先在一个浏览器上稳定运行,跨浏览器扩展作为后续迭代。Playwright 已配置 3 浏览器但当前 E2E 用例太少,扩展浏览器覆盖的 ROI 不高。 + +**性能测试先建基线再设阈值** — 在没有历史数据时设定 SLA 容易过严或过松。第一轮只采集数据建立基线,第二轮根据基线设定阈值和告警。 + +**安全扫描集成到 CI 而非独立流程** — bandit 和 npm audit 作为 CI 步骤运行,在 PR 级别就暴露问题,而不是等到发布前才扫描。 + +--- + +## Requirements + +### 轨道一:测试基础设施 + +R1. 将 `geo/tests/` 下的 17 个测试迁移至 `backend/tests/` 对应子目录,合并两套 conftest.py 的 fixture,删除旧目录 + +R2. CI 中 `pytest tests/` 命令能发现并运行全部后端测试(迁移后应 ≥77 个) + +R3. 将 `pytest-cov` 正式加入 `backend/requirements.txt`,不再依赖 CI 中临时安装 + +R4. 建立共享 fixture 体系:数据库会话(含自动 rollback)、认证 mock(JWT token + auth headers)、测试用户创建、httpx AsyncClient + +R5. 前端 vitest 覆盖范围扩展至 `lib/api/` 全部 27 个模块和关键页面组件 + +R6. 更新 `docs/03-开发指南/testing.md` 使其与实际目录结构、CI 配置、fixture 体系一致 + +### 轨道二:核心业务 E2E 测试 + +R7. 编写用户注册→登录→Onboarding→品牌创建的完整 E2E 测试 + +R8. 编写诊断→查看诊断报告→生成 GEO 方案的 E2E 测试 + +R9. 编写内容生成→查看内容→效果追踪的 E2E 测试 + +R10. E2E 测试在 CI 中运行(至少 Chromium),使用 PostgreSQL 和 Redis service container + +R11. E2E 测试失败时自动截图和录制视频,便于排查 + +### 专项测试 + +R12. 为 5-10 个高频 API 端点建立性能基线(p50/p95/p99 响应时间),首轮只采集不设阈值 + +R13. CI 中集成 bandit(Python 安全扫描)和 npm audit(Node.js 依赖安全检查),PR 级别阻断高危漏洞 + +R14. CI 中添加 Alembic 迁移验证步骤:`alembic upgrade head` 在空数据库上成功执行 + +R15. 编写 Agent 框架端到端测试:任务创建→分发→执行→结果查询的完整链路(至少覆盖 CitationDetector 和 ContentGenerator) + +### 前端组件测试 + +R16. 为 5 个关键页面组件编写 React Testing Library 测试:Dashboard、品牌详情、诊断页、策略页、内容编辑器 + +R17. 为 4 个新 Agent 的前端 API 客户端模块编写单元测试:monitoring.ts、competitor-analysis.ts、schema-advisor.ts、trends.ts + +--- + +## Key Flows + +- F1. 双轨并行执行流程 + - **Trigger:** 需求文档确认后启动 + - **Actors:** 开发者 + - **Steps:** (1) 轨道一:统一测试目录→合并 fixture→修复 CI→扩展 vitest 覆盖 (2) 轨道二:编写核心 E2E→集成 CI E2E 步骤→编写专项测试 (3) 两条轨道完成后汇合:E2E 测试迁移到共享 fixture (4) 更新文档 + - **Outcome:** 完整的质量保障体系,CI 全量运行,核心流程有回归保护 + +- F2. E2E 测试编写流程 + - **Trigger:** 开始编写某个业务流程的 E2E 测试 + - **Actors:** 开发者 + - **Steps:** (1) 定义用户旅程步骤 (2) 编写 Playwright 测试用独立 setup (3) 本地验证通过 (4) 提交并在 CI 中验证 (5) 后续迁移到共享 fixture + - **Outcome:** 可在 CI 中稳定运行的 E2E 测试 + +--- + +## Acceptance Examples + +- AE1. **测试目录统一** — Covers R1, R2 + - **Given:** `geo/tests/` 下有 17 个测试文件 + - **When:** 迁移完成并运行 `cd backend && pytest tests/` + - **Then:** 全部测试被发现且通过,`geo/tests/` 目录已删除 + +- AE2. **核心业务 E2E** — Covers R7, R8, R9 + - **Given:** 后端服务 + PostgreSQL + Redis 运行 + - **When:** 执行 Playwright E2E 测试 + - **Then:** 注册→登录→Onboarding→品牌创建→诊断→方案→内容生成→效果追踪 全流程通过 + +- AE3. **CI 安全扫描** — Covers R13 + - **Given:** PR 中引入了有已知漏洞的依赖 + - **When:** CI 运行 PR Check + - **Then:** bandit 或 npm audit 报告高危漏洞,PR 被标记为检查失败 + +- AE4. **性能基线** — Covers R12 + - **Given:** 性能测试首次运行 + - **When:** 对 `/api/v1/dashboard/stats` 发送 100 次请求 + - **Then:** 采集 p50/p95/p99 响应时间并保存为基线数据,不设阈值不阻断 + +- AE5. **迁移验证** — Covers R14 + - **Given:** 新增了 Alembic 迁移脚本 + - **When:** CI 在空数据库上执行 `alembic upgrade head` + - **Then:** 迁移成功执行,无报错 + +--- + +## Success Criteria + +- 后端全部测试(≥77 个)在 CI 中通过,无目录分裂 +- 核心业务 E2E 测试(≥3 条用户旅程)在 CI 中稳定运行 +- 前端 vitest 覆盖率从当前极低水平提升至 lib/api/ 模块 80%+ +- CI 流水线包含:lint + 单元测试 + E2E 测试 + 安全扫描 + 迁移验证 +- 5-10 个高频 API 端点有性能基线数据 +- Agent 框架端到端测试覆盖至少 2 个 Agent +- testing.md 与实际项目结构一致 + +--- + +## Scope Boundaries + +**Deferred for later:** + +- 跨浏览器 E2E 测试(Firefox / WebKit) +- 覆盖率报告上传第三方服务(Codecov / Coveralls) +- 测试数据工厂(factory-boy / faker)— 当前用 fixture 足够 +- PR 评论中显示覆盖率变化 +- 生产环境监控告警集成 + +**Outside this product's identity:** + +- 新业务功能开发 +- UI/UX 改进和设计优化 +- 基础设施级别的渗透测试 +- 移动端适配测试 + +--- + +## Dependencies / Assumptions + +- CI 环境(GitHub Actions)支持 PostgreSQL 和 Redis service container +- E2E 测试需要后端服务可启动(所有依赖可用) +- Agent E2E 测试需要至少一个 LLM Provider 的 API Key 可用(或使用 mock) +- 性能基线数据需要在相对稳定的环境下采集,避免 CI 共享 runner 的噪声 +- 当前 `geo/tests/conftest.py` 中的 fixture(async_client、auth_token 等)可正确迁移 + +--- + +## Sources / Research + +- 现有测试配置:`backend/pyproject.toml`、`frontend/vitest.config.ts`、`frontend/playwright.config.ts` +- CI 配置:`.github/workflows/ci.yml`、`.github/workflows/pr-check.yml` +- 测试策略文档:`docs/03-开发指南/testing.md` +- 现有 E2E 测试:`frontend/e2e/tests/`(7 个文件) +- 现有后端测试:`backend/tests/`(~60 个)+ `geo/tests/`(17 个) +- 现有前端测试:`frontend/__tests__/`(13 个) diff --git a/docs/plans/2026-05-28-001-feat-geo-platform-new-agents-plan.md b/docs/plans/2026-05-28-001-feat-geo-platform-new-agents-plan.md new file mode 100644 index 0000000..7ec47f4 --- /dev/null +++ b/docs/plans/2026-05-28-001-feat-geo-platform-new-agents-plan.md @@ -0,0 +1,384 @@ +# GEO 平台新增 Agent 实现计划 + +**Status:** Draft +**Type:** feat +**Created:** 2026-05-28 +**Updated:** 2026-05-28 + +## Summary + +为 GEO 平台新增 4 个专业 Agent,构建完整的监测-分析-优化闭环: + +1. **MonitorAgent** — 效果追踪:定期检测内容发布后的引用变化 +2. **SchemaAdvisor** — Schema优化:根据诊断结果推荐具体的 Schema 标记方案 +3. **CompetitorAnalyzer** — 竞品分析:深度分析竞品内容策略、引用模式 +4. **TrendAgent** — 趋势洞察:分析行业 AI 搜索趋势、热点话题 + +--- + +## Problem Frame + +当前 GEO 平台已有 4 个 Agent(CitationDetector、ContentGenerator、DeAIAgent、GEOOptimizer),覆盖检测→生成→去AI化→优化链路。但缺少: + +- **效果追踪**:内容发布后无法自动追踪引用变化 +- **Schema 优化**:诊断发现 Schema 问题但没有 Agent 提供具体方案 +- **竞品深度分析**:仅做引用检测,缺乏策略层面的竞品分析 +- **趋势洞察**:无法感知行业 AI 搜索趋势和热点话题 + +这 4 个 Agent 的缺失导致 GEO 平台只能发现问题,无法主动建议行动方向。 + +--- + +## Scope Boundaries + +### In Scope +- 4 个新 Agent 的完整实现 +- Agent 注册和调度集成 +- 相应的数据模型和 API +- 前端展示(可选,取决于 API 完备性) + +### Out of Scope +- 不实现前端页面(只提供 API,供现有页面调用) +- 不实现自动执行闭环(内容发布后的自动追踪需要分发系统配合) +- TrendAgent 的第三方数据源集成(仅使用现有 AI 引擎返回的引用数据) + +--- + +## Key Technical Decisions + +### KTD-1: MonitorAgent 触发方式 + +**Decision:** 采用「定时任务 + 手动触发」双模式 + +**Rationale:** 自动追踪需要内容分发系统配合,目前分发系统仅支持手动发布。定时任务模式可覆盖已有内容的手动追踪需求。手动触发作为 API 入口,可被前端页面直接调用。 + +**Alternatives considered:** +- 仅定时任务:无法满足用户即时查询需求 +- 仅手动触发:无法实现自动化效果追踪 + +### KTD-2: SchemaAdvisor 生成策略 + +**Decision:** 采用「规则模板 + LLM 增强」混合模式 + +**Rationale:** Schema 标记有明确的规范(JSON-LD 格式),规则模板可覆盖 80% 常见场景。LLM 用于生成描述性内容(如 FAQ 的 questions/answers)和处理边界情况。 + +**Alternatives considered:** +- 纯规则:无法处理动态内容场景 +- 纯 LLM:Schema 格式要求严格,纯 LLM 生成可能不合规 + +### KTD-3: CompetitorAnalyzer 数据来源 + +**Decision:** 复用现有 CitationDetector 的检测数据 + +**Rationale:** 竞品分析依赖品牌和竞品的引用数据,现有 CitationDetector 已实现全量检测能力。新 Agent 只需在检测结果基础上做策略分析,避免重复建设。 + +**Alternatives considered:** +- 独立检测通道:增加 API 调用成本和数据冗余 +- 仅分析历史数据:无法获取最新竞品动态 + +### KTD-4: TrendAgent 趋势识别 + +**Decision:** 基于「品牌领域关键词 + AI 引擎引用模式」识别趋势 + +**Rationale:** TrendAgent 的定位是「GEO 趋势洞察」而非「通用舆情分析」。通过分析品牌核心关键词在 AI 引擎中的引用模式变化,可以识别用户关注度趋势,无需引入外部数据源。 + +**Alternatives considered:** +- 接入第三方舆情 API:成本高、数据不精准 +- 基于社交媒体趋势:与 GEO 关联度低 + +--- + +## High-Level Technical Design + +### 4 Agent 在现有架构中的位置 + +``` +┌──────────────────────────────────────────────────────────────────┐ +│ Agent Framework (Dispatcher) │ +└──────────────────────────────────────────────────────────────────┘ + │ + ┌───────────┬───────────┼───────────┬───────────┐ + ▼ ▼ ▼ ▼ ▼ +┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ +│Citation │ │Monitor │ │Schema │ │Competitor│ │Trend │ +│Detector │ │Agent 🆕 │ │Advisor 🆕│ │Analyzer 🆕│ │Agent 🆕 │ +│ │ │ │ │ │ │ │ │ │ +│引用检测 │ │效果追踪 │ │Schema │ │竞品策略 │ │趋势洞察 ││ +└────┬─────┘ └────┬─────┘ └────┬─────┘ └────┬─────┘ └────┬─────┘ + │ │ │ │ │ + ▼ ▼ ▼ ▼ ▼ + CitationRecord MonitoringRecord SchemaSuggestion CompetitorInsight TrendInsight + │ │ │ │ │ + ▼ ▼ ▼ ▼ ▼ + ScoringService GeoPlanGenerator DiagnosisReport StrategyPage Dashboard +``` + +### MonitorAgent 数据流 + +``` +内容发布 → 记录监测任务 → 定时执行检测 → 对比基线数据 → 生成变化报告 → 更新评分 + │ + ┌───────┴───────┐ + │ 变化类型判断 │ + ├───────────────┤ + │ + 引用量增加 │ → 正面反馈 + │ - 引用量下降 │ → 告警 + 建议 + │ = 无变化 │ → 持续观察 + └───────────────┘ +``` + +### SchemaAdvisor 建议生成 + +``` +诊断数据 → 识别 Schema 缺失维度 → 匹配模板 → 生成 JSON-LD → LLM 填充内容 → 验证格式 + │ + ┌───────┴───────┐ + │ 格式验证通过 │ → 返回建议 + │ 格式验证失败 │ → 回退到规则 + └───────────────┘ +``` + +--- + +## Implementation Units + +### U1. MonitorAgent — 效果追踪 Agent + +**Goal:** 实现内容发布后引用变化的自动追踪和对比分析 + +**Requirements:** R1(效果追踪) + +**Files:** +- `app/agent_framework/agents/monitor_agent.py` (新增) +- `app/models/monitoring.py` (新增) +- `app/schemas/monitoring.py` (新增) +- `app/services/monitoring/monitor_service.py` (新增) +- `app/api/monitoring.py` (新增) +- `tests/test_monitor_agent.py` (新增) + +**Approach:** +1. 数据模型:`MonitoringRecord` 记录监测任务,`ContentBaseline` 存储发布时的基线数据 +2. Agent 逻辑:接收监测任务 → 提取关键词和平台 → 调用 CitationService 检测 → 对比基线 → 生成变化报告 +3. 变化判断:基于引用量、情感倾向、引用排名三个维度计算变化幅度 +4. API:提供创建监测任务、查询监测历史、获取变化报告的端点 + +**Test scenarios:** +- Happy path: 内容发布后创建监测任务,定时检测对比基线,生成正面/负面/无变化报告 +- Edge case: 基线数据为空时,使用首次检测作为基线 +- Error path: 检测服务不可用时,记录失败状态并重试 +- Integration: 监测任务触发时,通过 Dispatcher 正确分发给 MonitorAgent + +--- + +### U2. SchemaAdvisor — Schema 优化建议 Agent + +**Goal:** 根据诊断结果生成具体的 Schema 标记方案 + +**Requirements:** R2(Schema 优化) + +**Files:** +- `app/agent_framework/agents/schema_advisor.py` (新增) +- `app/models/schema_suggestion.py` (新增) +- `app/schemas/schema_suggestion.py` (新增) +- `app/services/schema/schema_advisor_service.py` (新增) +- `app/api/schema_advisor.py` (新增) +- `tests/test_schema_advisor.py` (新增) + +**Approach:** +1. 诊断数据输入:接收 GEO 诊断结果,识别 Schema 缺失维度 +2. 模板匹配:5 种 Schema 类型 × 3 种场景 = 15 个预定义模板 + - Organization: 品牌主页 + - Product: 产品详情页 + - FAQPage: 常见问题页 + - Article: 博客/新闻页 + - LocalBusiness: 本地商户页 +3. LLM 增强:使用 LLM 生成 FAQ 的 questions/answers、Product 描述等自然语言内容 +4. 格式验证:使用 JSON-LD 解析器验证生成结果的语法正确性 +5. 优先级排序:基于诊断得分和实施难度排序建议 + +**Test scenarios:** +- Happy path: 诊断发现 FAQPage 缺失,生成完整的 FAQPage JSON-LD +- Edge case: LLM 生成内容格式错误,回退到规则模板 +- Error path: 模板不存在时,返回"暂不支持该场景"的友好提示 +- Format validation: 生成的 JSON-LD 通过 schema.org 验证 + +--- + +### U3. CompetitorAnalyzer — 竞品策略分析 Agent + +**Goal:** 在引用检测数据基础上,深度分析竞品的内容策略和引用模式 + +**Requirements:** R3(竞品分析) + +**Files:** +- `app/agent_framework/agents/competitor_analyzer.py` (新增) +- `app/models/competitor_insight.py` (新增) +- `app/schemas/competitor_insight.py` (新增) +- `app/services/competitor/competitor_analyzer_service.py` (新增) +- `app/api/competitor_analysis.py` (新增) +- `tests/test_competitor_analyzer.py` (新增) + +**Approach:** +1. 数据聚合:聚合品牌和所有竞品的引用检测数据 +2. 维度分析: + - 引用量对比:品牌 vs 竞品在各平台的引用次数 + - 引用质量对比:正面/中性/负面引用比例 + - 引用场景分析:竞品在哪些查询词下被引用 + - 内容策略推断:竞品引用内容的类型(产品介绍/评测/新闻等) +3. Gap 识别:对比发现品牌在哪些维度落后于竞品 +4. 机会发现:竞品未被引用的场景,即品牌的潜在机会 +5. 策略建议:基于分析结果生成"缩小差距"和"差异化竞争"两类建议 + +**Test scenarios:** +- Happy path: 品牌 vs 3个竞品,生成完整的竞品分析报告 +- Edge case: 竞品数据不足(少于5次引用),标记为"数据不足"并说明 +- Error path: 竞品数据获取失败,生成部分报告并提示缺失数据 +- Content strategy: 识别竞品的内容类型分布,给出差异化建议 + +--- + +### U4. TrendAgent — 趋势洞察 Agent + +**Goal:** 分析品牌领域关键词在 AI 引擎中的引用模式变化,识别趋势 + +**Requirements:** R4(趋势洞察) + +**Files:** +- `app/agent_framework/agents/trend_agent.py` (新增) +- `app/models/trend_insight.py` (新增) +- `app/schemas/trend_insight.py` (新增) +- `app/services/trend/trend_analyzer_service.py` (新增) +- `app/api/trends.py` (新增) +- `tests/test_trend_agent.py` (新增) + +**Approach:** +1. 时间序列分析:聚合品牌关键词在过去 N 天的引用数据(按天/周粒度) +2. 趋势识别算法: + - 上升趋势:引用量环比增长 > 20% + - 下降趋势:引用量环比下降 > 20% + - 平稳趋势:变化率在 ±20% 以内 +3. 热点发现:引用量突增的关键词/话题 +4. 平台差异:同一趋势在不同 AI 引擎的差异表现 +5. 原因推断:结合情感分析推断趋势变化的可能原因(正面事件/负面事件/行业热点) + +**Test scenarios:** +- Happy path: 分析过去30天数据,识别3个上升趋势、2个下降趋势、5个热点话题 +- Edge case: 数据不足30天,使用可用数据进行趋势分析并注明 +- Error path: 引用数据获取失败,返回"趋势分析需要至少7天数据"提示 +- Platform comparison: 识别同一趋势在文心 vs Kimi 的差异 + +--- + +### U5. Agent 注册和调度集成 + +**Goal:** 将 4 个新 Agent 注册到 Agent Framework,支持 Dispatcher 统一调度 + +**Requirements:** 所有 R1-R4 + +**Dependencies:** U1, U2, U3, U4 + +**Files:** +- `app/agent_framework/registry.py` (修改) +- `app/agent_framework/protocol.py` (修改) + +**Approach:** +1. 在 `protocol.py` 中定义 4 种新任务类型: + - `monitor_track` — 效果追踪 + - `schema_advise` — Schema 建议 + - `competitor_analyze` — 竞品分析 + - `trend_insight` — 趋势洞察 +2. 在 `registry.py` 中注册 4 个 Agent 类 +3. 验证 Dispatcher 可正确分发任务到对应 Agent + +**Test scenarios:** +- Happy path: Dispatcher 分发 `monitor_track` 任务到 MonitorAgent +- Happy path: Dispatcher 分发 `schema_advise` 任务到 SchemaAdvisor +- Edge case: 未知任务类型返回"Agent not found"错误 + +--- + +### U6. 数据库迁移 + +**Goal:** 创建新 Agent 所需的数据库表 + +**Dependencies:** U1, U2, U3, U4 + +**Files:** +- `alembic/versions/xxxx_add_new_agent_tables.py` (新增) + +**Approach:** +1. 生成 Alembic 迁移脚本 +2. 创建 5 张新表: + - `monitoring_records` — 监测记录 + - `content_baselines` — 内容基线 + - `schema_suggestions` — Schema 建议 + - `competitor_insights` — 竞品洞察 + - `trend_insights` — 趋势洞察 +3. 执行迁移并验证 + +**Test scenarios:** +- Happy path: 迁移脚本成功执行,5张表创建完成 +- Rollback: 执行 downgrade 后,表被正确删除 + +--- + +## Open Questions + +| # | Question | Resolution | +|---|----------|------------| +| 1 | MonitorAgent 的定时检测间隔是多少? | 建议默认 24 小时,支持用户配置(1h/6h/24h/7d) | +| 2 | SchemaAdvisor 是否需要验证生成的 Schema 在目标网站上可正常解析? | 否,验证 JSON-LD 语法即可,网站解析验证超出当前范围 | +| 3 | TrendAgent 历史数据保留多久? | 建议保留 90 天,超出后归档或删除 | +| 4 | 4 个新 Agent 是否需要独立启动入口(类似 standalone.py)? | 否,通过 API 触发即可,无独立常驻需求 | + +--- + +## System-Wide Impact + +- **数据库**:新增 5 张表,需要 Alembic 迁移 +- **API 层**:新增 4 个路由模块(monitoring.py, schema_advisor.py, competitor_analysis.py, trends.py) +- **Agent Framework**:新增 4 种任务类型注册 +- **前端**:可选,取决于是否需要页面展示(当前计划仅提供 API) + +--- + +## Deferred to Follow-Up Work + +1. **前端监测页面**:MonitorAgent 的效果追踪报告需要页面展示 +2. **自动执行闭环**:内容发布后自动创建监测任务(依赖分发系统) +3. **告警集成**:监测到负面变化时自动触发告警通知 +4. **定时调度**:TrendAgent 的周期性趋势报告(依赖 APScheduler 集成) + +--- + +## Verification + +### 编译验证 +```bash +cd geo/backend && python -m py_compile app/agent_framework/agents/monitor_agent.py app/agent_framework/agents/schema_advisor.py app/agent_framework/agents/competitor_analyzer.py app/agent_framework/agents/trend_agent.py +``` + +### API 验证 +```bash +# MonitorAgent +curl -X POST http://localhost:8000/api/v1/monitoring/tasks -H "Authorization: Bearer $TOKEN" -d '{"content_id": "xxx", "brand_id": "yyy"}' + +# SchemaAdvisor +curl -X POST http://localhost:8000/api/v1/schema/advise -H "Authorization: Bearer $TOKEN" -d '{"brand_id": "yyy"}' + +# CompetitorAnalyzer +curl -X POST http://localhost:8000/api/v1/competitor/analyze -H "Authorization: Bearer $TOKEN" -d '{"brand_id": "yyy"}' + +# TrendAgent +curl -X POST http://localhost:8000/api/v1/trends/insight -H "Authorization: Bearer $TOKEN" -d '{"brand_id": "yyy", "days": 30}' +``` + +### Agent 调度验证 +```bash +# 通过 Agent Framework 调度 +curl -X POST http://localhost:8000/api/v1/agents/tasks -H "Authorization: Bearer $TOKEN" -d '{ + "agent_type": "monitor", + "task_type": "monitor_track", + "params": {"content_id": "xxx", "brand_id": "yyy"} +}' +``` diff --git a/docs/plans/2026-05-31-002-test-quality-assurance-system-plan.md b/docs/plans/2026-05-31-002-test-quality-assurance-system-plan.md new file mode 100644 index 0000000..e004d52 --- /dev/null +++ b/docs/plans/2026-05-31-002-test-quality-assurance-system-plan.md @@ -0,0 +1,332 @@ +--- +title: "test: GEO Platform Quality Assurance System" +type: test +status: active +date: "2026-05-31" +origin: docs/brainstorms/2026-05-31-geo-platform-next-phase-quality-requirements.md +--- + +## Summary + +GEO 平台质量保障体系建设,采用双轨并行策略:轨道一统一测试基础设施(合并分裂目录、完善 CI、建立共享 fixture),轨道二同步推进核心业务 E2E 测试和四项专项测试(性能基线、安全扫描、迁移验证、Agent E2E),确保核心业务流程尽早获得回归保护。 + +## Problem Frame + +GEO 平台经过多轮迭代已具备 33 个 API、8 个 Agent、25+ 前端页面,但测试覆盖严重不足:后端两套测试目录分裂导致 17 个测试被 CI 忽略,前端仅 13 个单元测试且零组件测试,E2E 仅覆盖登录流程,核心业务完全无回归保护。CI 不运行 E2E 测试,无安全扫描,无性能基线。在用户开始使用前,必须建立质量保障体系防止功能回归。 + +--- + +## Requirements + +### 测试基础设施 + +R1. 将 `geo/tests/` 下的 17 个测试迁移至 `backend/tests/` 对应子目录,合并两套 conftest.py 的 fixture,删除旧目录 + +R2. CI 中 `pytest tests/` 命令能发现并运行全部后端测试 + +R3. 将 `pytest-cov` 正式加入 `backend/requirements.txt` + +R4. 建立共享 fixture 体系:数据库会话(含自动 rollback)、认证 mock(JWT token + auth headers)、测试用户创建、httpx AsyncClient + +R5. 前端 vitest 覆盖范围扩展至 `lib/api/` 全部 27 个模块和关键页面组件 + +R6. 更新 `docs/03-开发指南/testing.md` 使其与实际目录结构、CI 配置、fixture 体系一致 + +### 核心业务 E2E 测试 + +R7. 编写用户注册→登录→Onboarding→品牌创建的完整 E2E 测试 + +R8. 编写诊断→查看诊断报告→生成 GEO 方案的 E2E 测试 + +R9. 编写内容生成→查看内容→效果追踪的 E2E 测试 + +R10. E2E 测试在 CI 中运行(至少 Chromium),使用 PostgreSQL 和 Redis service container + +R11. E2E 测试失败时自动截图和录制视频 + +### 专项测试 + +R12. 为 5-10 个高频 API 端点建立性能基线(p50/p95/p99 响应时间),首轮只采集不设阈值 + +R13. CI 中集成 bandit(Python 安全扫描)和 npm audit(Node.js 依赖安全检查),PR 级别阻断高危漏洞 + +R14. CI 中添加 Alembic 迁移验证步骤 + +R15. 编写 Agent 框架端到端测试:任务创建→分发→执行→结果查询的完整链路(至少覆盖 CitationDetector 和 ContentGenerator) + +### 前端组件测试 + +R16. 为 5 个关键页面组件编写 React Testing Library 测试 + +R17. 为 4 个新 Agent 的前端 API 客户端模块编写单元测试 + +--- + +## Key Technical Decisions + +KTD1. **双轨并行而非分层递进** — 基础设施统一和 E2E 测试编写同步推进,两者几乎不涉及相同文件,并行没有实质冲突。E2E 测试初期用独立 setup,后续迁移到共享 fixture 是低成本重构。 + +KTD2. **E2E 测试仅覆盖 Chromium** — 先在一个浏览器上稳定运行,跨浏览器扩展作为后续迭代。Playwright 已配置 3 浏览器但当前 E2E 用例太少,扩展浏览器覆盖的 ROI 不高。 + +KTD3. **性能测试先建基线再设阈值** — 在没有历史数据时设定 SLA 容易过严或过松。第一轮只采集数据建立基线,第二轮根据基线设定阈值和告警。 + +KTD4. **Agent E2E 测试优先使用真实 LLM 调用** — 可使用 LLM API Key 进行真实调用测试,确保 Agent 在实际 LLM 响应下的行为正确。CI 中通过环境变量注入 API Key,无 Key 时降级为 mock 模式。 + +KTD5. **安全扫描集成到 CI 而非独立流程** — bandit 和 npm audit 作为 CI 步骤运行,在 PR 级别就暴露问题。 + +--- + +## Implementation Units + +### U1. 统一后端测试目录 + +- **Goal:** 消除测试目录分裂,确保 CI 能发现并运行全部后端测试 +- **Requirements:** R1, R2, R3 +- **Dependencies:** none +- **Files:** + - `geo/tests/` (17 files to migrate) + - `backend/tests/conftest.py` (merge fixtures) + - `geo/tests/conftest.py` (merge then delete) + - `backend/requirements.txt` (add pytest-cov) + - `.github/workflows/ci.yml` (verify pytest discovers all tests) +- **Approach:** 逐个迁移 `geo/tests/` 下的测试文件到 `backend/tests/` 对应子目录(test_api/, test_services/ 等)。合并两套 conftest.py:保留 `backend/tests/conftest.py` 的内存 SQLite 引擎,从 `geo/tests/conftest.py` 迁入 async_client、auth_token、auth_headers、override_get_current_user 等 fixture。删除 `geo/tests/` 目录。将 pytest-cov 加入 requirements.txt。 +- **Patterns to follow:** `backend/tests/conftest.py` 现有 fixture 模式(async_engine, async_session, test_user) +- **Test scenarios:** + - `cd backend && pytest tests/` 发现并运行全部测试(≥77 个) + - 迁移后的测试功能与迁移前一致 + - CI 中 pytest 命令无需修改即可运行全量测试 +- **Verification:** `cd backend && pytest tests/ --tb=short` 全部通过,`geo/tests/` 目录已不存在 + +### U2. 建立共享 fixture 体系 + +- **Goal:** 为后端集成测试和 E2E 测试提供可复用的测试基础设施 +- **Requirements:** R4 +- **Dependencies:** U1 +- **Files:** + - `backend/tests/conftest.py` (enhance with shared fixtures) + - `backend/tests/fixtures/` (new directory for modular fixtures) + - `backend/tests/fixtures/auth.py` (authentication fixtures) + - `backend/tests/fixtures/database.py` (database fixtures) + - `backend/tests/fixtures/client.py` (httpx client fixtures) + - `backend/tests/fixtures/brands.py` (brand and competitor test data) +- **Approach:** 在 `backend/tests/fixtures/` 下按领域拆分 fixture 模块。auth.py 提供 override_get_current_user、auth_token、auth_headers。database.py 提供 async_engine、async_session(含自动 rollback)。client.py 提供 async_client(httpx AsyncClient with app)。brands.py 提供预创建的测试品牌和竞品数据。conftest.py 通过 pytest plugin 机制自动加载 fixtures/ 下所有模块。 +- **Patterns to follow:** `geo/tests/conftest.py` 中已有的 fixture 实现(async_client, auth_token, override_get_current_user) +- **Test scenarios:** + - 使用 auth_headers fixture 的测试能成功调用需要认证的 API + - 使用 async_client fixture 的测试能发送 HTTP 请求到 FastAPI 应用 + - 数据库 fixture 在测试结束后自动 rollback,不污染数据库 + - 多个测试并行运行时 fixture 互不干扰 +- **Verification:** 使用新 fixture 编写 2-3 个示例测试并全部通过 + +### U3. 核心业务 E2E — 用户注册到品牌创建 + +- **Goal:** 覆盖用户从注册到品牌创建的完整 Onboarding 流程 +- **Requirements:** R7, R11 +- **Dependencies:** none (独立 setup,不依赖 U2) +- **Files:** + - `frontend/e2e/tests/onboarding.spec.ts` (new) + - `frontend/e2e/fixtures/auth.ts` (new, independent setup) + - `frontend/playwright.config.ts` (verify screenshot/video config) +- **Approach:** 编写 Playwright 测试覆盖:注册新用户→登录→进入 Onboarding→填写品牌名称→添加竞品→选择平台→查看健康报告→完成引导。使用独立 setup(直接调用 API 创建测试数据),后续可迁移到共享 fixture。确保 playwright.config.ts 中 screenshot:on-failure 和 video:retain-on-failure 已配置。 +- **Patterns to follow:** `frontend/e2e/tests/login.spec.ts` 现有 E2E 测试模式 +- **Test scenarios:** + - Covers AE1 (from origin): 注册→登录→Onboarding→品牌创建全流程通过 + - 注册时输入无效邮箱显示错误提示 + - 品牌名称为空时无法提交 + - 不添加竞品也能完成 Onboarding + - 已注册用户登录后直接跳转 Dashboard +- **Verification:** `cd frontend && npx playwright test onboarding` 通过 + +### U4. 核心业务 E2E — 诊断到 GEO 方案 + +- **Goal:** 覆盖诊断→方案生成的核心业务闭环 +- **Requirements:** R8, R11 +- **Dependencies:** U3 (需要已登录用户和品牌) +- **Files:** + - `frontend/e2e/tests/diagnosis-strategy.spec.ts` (new) +- **Approach:** 在 U3 创建的品牌基础上,编写:进入诊断页面→触发诊断→查看诊断报告→点击"基于诊断制定 GEO 方案"→查看方案详情。诊断和方案生成可能需要较长时间,使用适当的等待策略。 +- **Patterns to follow:** `frontend/e2e/tests/login.spec.ts` 现有 E2E 测试模式 +- **Test scenarios:** + - Covers AE2 (from origin): 诊断→报告→方案生成全流程通过 + - 诊断页面正确显示品牌评分 + - 方案生成后显示行动项列表 + - 方案行动项状态可更新 +- **Verification:** `cd frontend && npx playwright test diagnosis-strategy` 通过 + +### U5. 核心业务 E2E — 内容生成到效果追踪 + +- **Goal:** 覆盖内容生成→查看→效果追踪的内容工坊流程 +- **Requirements:** R9, R11 +- **Dependencies:** U3 (需要已登录用户和品牌) +- **Files:** + - `frontend/e2e/tests/content-monitoring.spec.ts` (new) +- **Approach:** 编写:进入内容工坊→输入关键词→生成内容→查看生成结果→进入效果追踪页面→查看监测数据。内容生成可能需要 LLM 调用,使用较长超时或 mock。 +- **Patterns to follow:** `frontend/e2e/tests/login.spec.ts` 现有 E2E 测试模式 +- **Test scenarios:** + - Covers AE3 (from origin): 内容生成→查看→效果追踪全流程通过 + - 内容生成页面正确显示生成状态 + - 生成完成后内容可查看 + - 效果追踪页面显示监测数据 +- **Verification:** `cd frontend && npx playwright test content-monitoring` 通过 + +### U6. CI 集成 E2E 测试 + +- **Goal:** 将 E2E 测试纳入 CI 流水线 +- **Requirements:** R10 +- **Dependencies:** U3, U4, U5 +- **Files:** + - `.github/workflows/ci.yml` (add E2E step) + - `.github/workflows/pr-check.yml` (add E2E step) +- **Approach:** 在 CI 中添加 E2E 测试步骤:启动 PostgreSQL 和 Redis service container → 启动后端服务 → 启动前端服务 → 运行 Playwright 测试。仅运行 Chromium 项目以控制时间。E2E 步骤放在单元测试之后,允许失败但不阻塞(初期),稳定后改为阻塞。 +- **Patterns to follow:** 现有 CI 中 PostgreSQL 和 Redis service container 配置 +- **Test scenarios:** + - PR 提交时 CI 自动运行 E2E 测试 + - E2E 测试失败时 CI 标记为失败 + - 截图和视频作为 artifact 上传 +- **Verification:** 提交一个测试 PR,观察 CI 中 E2E 步骤是否正确运行 + +### U7. 性能基线测试 + +- **Goal:** 为高频 API 端点建立性能基线数据 +- **Requirements:** R12 +- **Dependencies:** U2 +- **Files:** + - `backend/tests/performance/` (new directory) + - `backend/tests/performance/__init__.py` (new) + - `backend/tests/performance/test_api_baseline.py` (new) + - `backend/tests/performance/conftest.py` (new) +- **Approach:** 使用 httpx AsyncClient 对 5-10 个高频端点(dashboard/stats, brands, queries, citations, content/generate, strategy/generate, monitoring/brand/{id}, analytics/overview, diagnosis, ai-engines/query)发送多次请求,采集 p50/p95/p99 响应时间,输出为 JSON 基线文件。首轮只采集不设阈值。使用 pytest-benchmark 或自定义计时逻辑。 +- **Patterns to follow:** `backend/tests/` 现有测试结构 +- **Test scenarios:** + - Covers AE4 (from origin): 对 dashboard/stats 发送 100 次请求采集性能数据 + - 基线数据以 JSON 格式保存 + - 性能测试可在本地和 CI 中运行 + - 首轮不设阈值不阻断 +- **Verification:** `cd backend && pytest tests/performance/ --tb=short` 运行并输出基线数据 + +### U8. CI 安全扫描和迁移验证 + +- **Goal:** 在 CI 中集成安全扫描和数据库迁移验证 +- **Requirements:** R13, R14 +- **Dependencies:** none +- **Files:** + - `.github/workflows/ci.yml` (add security scan and migration steps) + - `.github/workflows/pr-check.yml` (add security scan and migration steps) +- **Approach:** 在 CI 中添加:1) `pip install bandit && bandit -r backend/app/ -ll` 扫描 Python 代码安全问题;2) `cd frontend && npm audit --audit-level=high` 检查 Node.js 依赖漏洞;3) `cd backend && alembic upgrade head` 在空数据库上验证迁移。安全扫描发现高危问题时 CI 失败。 +- **Patterns to follow:** 现有 CI 步骤结构 +- **Test scenarios:** + - Covers AE3 (from origin): PR 中引入有漏洞依赖时 CI 报告失败 + - Covers AE5 (from origin): 新增迁移脚本时 CI 验证迁移可执行 + - bandit 扫描 Python 代码中的安全问题 + - npm audit 检查前端依赖漏洞 + - alembic upgrade head 在空数据库上成功执行 +- **Verification:** 提交测试 PR,观察 CI 中安全扫描和迁移验证步骤 + +### U9. Agent 框架 E2E 测试 + +- **Goal:** 测试 Agent 框架的完整任务调度链路 +- **Requirements:** R15 +- **Dependencies:** U2 +- **Files:** + - `backend/tests/test_agent_framework/` (new or extend existing) + - `backend/tests/test_agent_framework/test_e2e_citation.py` (new) + - `backend/tests/test_agent_framework/test_e2e_content.py` (new) +- **Approach:** 测试完整链路:创建 Agent 任务→Dispatcher 分发→Agent 执行→结果写入数据库→查询结果。至少覆盖 CitationDetector(citation_detect 任务)和 ContentGenerator(content_generate 任务)。优先使用真实 LLM API Key 调用,CI 中通过环境变量注入;无 Key 时降级为 mock 模式使用预定义响应。 +- **Patterns to follow:** `backend/tests/test_agent_framework/` 现有 Agent 测试模式 +- **Test scenarios:** + - CitationDetector: 创建 citation_detect 任务→执行→结果包含引用数据 + - ContentGenerator: 创建 content_generate 任务→执行→结果包含生成内容 + - 任务状态从 pending→running→completed 正确流转 + - 任务失败时状态变为 failed 并记录错误信息 + - 无 Redis 时 Agent 以 standalone 模式运行 +- **Verification:** `cd backend && pytest tests/test_agent_framework/test_e2e_*.py --tb=short` 通过 + +### U10. 前端组件和 API 客户端测试 + +- **Goal:** 扩展前端测试覆盖至关键组件和新 API 客户端模块 +- **Requirements:** R16, R17 +- **Dependencies:** none +- **Files:** + - `frontend/__tests__/components/` (new directory) + - `frontend/__tests__/components/dashboard.test.tsx` (new) + - `frontend/__tests__/components/brand-detail.test.tsx` (new) + - `frontend/__tests__/components/diagnosis.test.tsx` (new) + - `frontend/__tests__/components/strategy.test.tsx` (new) + - `frontend/__tests__/components/content-editor.test.tsx` (new) + - `frontend/__tests__/api/monitoring.test.ts` (new) + - `frontend/__tests__/api/competitor-analysis.test.ts` (new) + - `frontend/__tests__/api/schema-advisor.test.ts` (new) + - `frontend/__tests__/api/trends.test.ts` (new) + - `frontend/vitest.config.ts` (extend coverage include) +- **Approach:** 使用 React Testing Library 为 5 个关键页面组件编写渲染和交互测试。使用 vitest 为 4 个新 API 客户端模块编写单元测试(mock fetchWithAuth,验证正确的 URL 和参数传递)。扩展 vitest.config.ts 的 coverage.include 配置。 +- **Patterns to follow:** `frontend/__tests__/` 现有测试模式(hooks, stores, lib) +- **Test scenarios:** + - Dashboard 组件正确渲染评分和平台数据 + - 品牌详情页正确显示 GEO 评分维度 + - 诊断页触发诊断后显示结果 + - 策略页显示 GEO 方案和行动项 + - 内容编辑器正确渲染和提交 + - monitoring.ts API 客户端调用正确的端点 + - competitor-analysis.ts API 客户端传递正确的参数 + - schema-advisor.ts API 客户端处理响应 + - trends.ts API 客户端处理查询参数 +- **Verification:** `cd frontend && npm run test:ci` 通过,覆盖率报告显示 lib/api/ 模块覆盖 + +### U11. 更新测试策略文档 + +- **Goal:** 使文档与实际项目结构和 CI 配置一致 +- **Requirements:** R6 +- **Dependencies:** U1, U2, U6, U8 +- **Files:** + - `docs/03-开发指南/testing.md` (update) +- **Approach:** 更新 testing.md 中的目录结构描述(test_api/, test_services/ 等按领域分类),更新 fixture 体系说明(fixtures/ 模块化 fixture),更新 CI 配置示例(包含 E2E 步骤、安全扫描、迁移验证),添加 E2E 测试编写指南。 +- **Patterns to follow:** 现有 testing.md 文档风格 +- **Test scenarios:** + - Test expectation: none — documentation update +- **Verification:** 文档内容与实际项目结构一致 + +--- + +## Scope Boundaries + +**Deferred for later:** + +- 跨浏览器 E2E 测试(Firefox / WebKit) +- 覆盖率报告上传第三方服务(Codecov / Coveralls) +- 测试数据工厂(factory-boy / faker) +- PR 评论中显示覆盖率变化 +- 生产环境监控告警集成 + +**Outside this product's identity:** + +- 新业务功能开发 +- UI/UX 改进和设计优化 +- 基础设施级别的渗透测试 +- 移动端适配测试 + +### Deferred to Follow-Up Work + +- E2E 测试迁移到共享 fixture(U3/U4/U5 完成后,将独立 setup 替换为 U2 的共享 fixture) +- 性能基线设定阈值和告警(U7 采集基线后,根据数据设定合理阈值) +- 跨浏览器 E2E 扩展(稳定 Chromium 后再扩展) + +--- + +## Risks & Dependencies + +- **E2E 测试稳定性** — 依赖后端服务启动和数据库初始化,CI 环境可能比本地更不稳定。缓解:E2E 初期允许失败不阻塞,稳定后再改为阻塞。 +- **LLM 依赖** — 内容生成和诊断 E2E 测试可能需要 LLM 调用。缓解:使用 mock 或较长超时,Agent E2E 使用 mock LLM。 +- **测试目录迁移风险** — 旧测试可能依赖 `geo/tests/conftest.py` 的特定 fixture,迁移后可能需要调整 import 路径。缓解:逐个迁移并验证。 +- **CI 时间增长** — 添加 E2E、安全扫描、迁移验证会延长 CI 运行时间。缓解:E2E 仅 Chromium,安全扫描仅高危阻断,迁移验证快速执行。 + +--- + +## Sources & Research + +- 现有测试配置:`backend/pyproject.toml`, `frontend/vitest.config.ts`, `frontend/playwright.config.ts` +- CI 配置:`.github/workflows/ci.yml`, `.github/workflows/pr-check.yml` +- 测试策略文档:`docs/03-开发指南/testing.md` +- 现有 E2E 测试:`frontend/e2e/tests/` (7 files) +- 现有后端测试:`backend/tests/` (~60 files) + `geo/tests/` (17 files) +- 现有前端测试:`frontend/__tests__/` (13 files) +- Origin document: `docs/brainstorms/2026-05-31-geo-platform-next-phase-quality-requirements.md` diff --git a/docs/plans/2026-05-31-003-feat-geo-monetization-closed-loop-plan.md b/docs/plans/2026-05-31-003-feat-geo-monetization-closed-loop-plan.md new file mode 100644 index 0000000..0fac9d0 --- /dev/null +++ b/docs/plans/2026-05-31-003-feat-geo-monetization-closed-loop-plan.md @@ -0,0 +1,558 @@ +--- +title: "feat: GEO Platform Monetization Closed Loop" +type: feat +status: active +date: "2026-05-31" +origin: docs/brainstorms/2026-05-31-geo-next-phase-core-flow-repair-requirements.md +secondary-origin: docs/brainstorms/2026-05-31-geo-platform-monetization-closed-loop-requirements.md +--- + +## Summary + +修复 GEO 诊断系统(当前空输入=0 分)、建设免费 GEO 健康分获客入口、接入真实支付与功能限制、添加 AI 内容生成能力、打通内容分发与效果归因闭环,实现从获客到续费的完整商业变现链路。同时用 API 契约测试驱动核心功能修复,辅以少量 E2E 烟雾测试验证关键用户路径。 + +## Problem Frame + +GEO 平台已具备 8 个 Agent、6 维度诊断框架、内容 Pipeline、知识库 RAG 等模块,但三个致命问题阻断了变现闭环:(1) 诊断系统使用空输入调用,永远返回 0 分,产品核心价值为零;(2) 内容 Pipeline 只做格式化不生成内容,核心付费功能缺失;(3) 支付是模拟的,没有功能限制,用户无付费动力。此外,Onboarding 缺少付费墙触发点、分发没有实际发布集成、监控没有归因逻辑、邮件服务未接入业务系统——各模块存在但互不连通。 + +## Requirements + +Origin: `docs/brainstorms/2026-05-31-geo-next-phase-core-flow-repair-requirements.md` + +- R1. 修复诊断系统:实现自动数据采集,让诊断产出非零有差异的评分 → U1 +- R2. 添加 AI 内容生成阶段:在 ContentPipeline 前端添加 Stage 0 → U5 +- R3. 接入真实支付:微信支付+支付宝+功能限制中间件 → U4 +- R4. 建设免费 GEO 健康分公开页面 → U2 +- R5. 重设计 Onboarding 流程,嵌入付费墙触发点 → U3 +- R6. 执行闭环:内容分发集成微信/知乎/头条 → U5 +- R7. 效果归因系统 → U6 +- R8. 为诊断 API 编写契约测试 → U1 +- R9. 为公开健康分 API 编写契约测试 → U2 +- R10. 为内容生成 API 编写契约测试 → U5 +- R11. 为支付 API 编写契约测试 → U4 +- R12. 编写跨步骤集成测试 → U8 +- R13. 编写获客路径 E2E 烟雾测试 → U9 +- R14. 编写核心流程 E2E 烟雾测试 → U9 +- R15. 完成后端测试目录统一验证 → U8 +- R16. 建立共享 fixture 体系 → U8 + +## Key Technical Decisions + +KTD-1. **诊断自动数据采集采用"AI 平台查询+CitationRecord 分析"双通道方案(V1 跳过网站爬取)。** 通过调用现有 AI 平台适配器查询品牌关键词分析引用情况,通过分析已有 CitationRecord 数据聚合引用指标。网站爬取通道留待 V2 迭代,降低首版复杂度。AI 平台查询填充"引用就绪度"维度,CitationRecord 分析填充"E-E-A-T/主题权威"维度,品牌官网 URL 解析填充"内容可提取性/Schema 标记"维度(基于简单 HTTP GET + HTML 解析,非完整爬取)。 + +KTD-2. **免费 GEO 健康分页面为独立公开页面,不依赖现有 Dashboard。** 新建 `/health-score` 路由,无需认证即可访问。结果缓存 24 小时(按品牌名+竞品哈希),避免重复查询消耗 API 额度。 + +KTD-3. **支付集成采用微信支付+支付宝双通道。** 使用官方 SDK,后端实现支付回调 Webhook 处理订阅状态更新。 + +KTD-4. **AI 内容生成集成到现有 ContentPipeline 前端。** 在 Stage 1(RuleValidator) 之前添加 Stage 0(AI Generator),基于诊断结果+RAG 知识库+LLM 生成初稿,后续阶段不变。 + +KTD-5. **效果归因采用"内容发布时间窗口+引用变化关联"方案。** 发布内容时记录时间戳和基线引用数据,2-4 周内定期检查引用变化。 + +KTD-6. **微信公众号采用半自动模式。** 生成内容+格式化+一键复制,用户手动粘贴到公众号编辑器。知乎和头条利用 API 实现自动发布。 + +KTD-7. **API 契约测试驱动修复,E2E 仅做烟雾测试。** 每个核心 API 先写契约测试定义期望行为,再修复实现。E2E 只覆盖 2 个最关键路径(获客+核心流程),不追求全面覆盖。 + +## High-Level Technical Design + +```mermaid +flowchart TB + subgraph Phase1["Phase 1: 核心价值修复 + 契约测试"] + A[品牌名输入] --> B[自动数据采集] + B --> B1[AI平台查询] + B --> B2[CitationRecord分析] + B1 & B2 --> C[GEO诊断引擎] + C --> D[6维度评分] + D --> T1[诊断API契约测试] + end + + subgraph Phase2["Phase 2: 获客入口 + 契约测试"] + D --> E[免费健康分页面] + E --> F{查看详细建议?} + F -->|免费概要| G[注册] + F -->|完整分析| G + G --> H{执行优化?} + H -->|升级付费| I[订阅支付] + E --> T2[健康分API契约测试] + end + + subgraph Phase3["Phase 3: 执行闭环 + 契约测试"] + I --> J[AI内容生成] + J --> K[人工审核] + K --> L[发布到平台] + L --> M[效果归因追踪] + M --> N[ROI报告] + N --> O[续费/升级] + J --> T3[内容生成API契约测试] + I --> T4[支付API契约测试] + end + + subgraph Phase4["Phase 4: 集成验证"] + T1 & T2 & T3 & T4 --> T5[跨步骤集成测试] + T5 --> T6[E2E烟雾测试] + end +``` + +## Implementation Units + +### U1. GEO 诊断自动数据采集与修复 + +**Goal:** 修复当前诊断系统(空输入=0 分),实现自动数据采集填充 GEODiagnosisInput,让诊断产出真实有价值的分数。 + +**Requirements:** R1, R8 + +**Dependencies:** 无 + +**Files:** +- `backend/app/services/diagnosis/geo_diagnosis.py` — 修改诊断服务,添加自动数据采集 +- `backend/app/services/diagnosis/data_collector.py` — 新建自动数据采集服务 +- `backend/app/api/diagnosis.py` — 修改 API 端点,支持自动采集+异步诊断 +- `backend/app/schemas/diagnosis.py` — 新建诊断请求/响应 schema +- `backend/app/models/diagnosis_record.py` — 新建诊断历史记录模型 +- `backend/tests/test_api/test_diagnosis_contract.py` — 新建诊断 API 契约测试 +- `backend/tests/test_services/test_data_collector.py` — 新建数据采集服务测试 + +**Approach:** +1. 新建 `DataCollectorService`,两个采集通道(V1): + - AI 平台查询通道:复用现有 `ai_engine` 适配器,查询品牌相关关键词,分析引用情况填充"引用就绪度"维度 + - 历史数据通道:从 CitationRecord 聚合已有引用数据,填充"E-E-A-T/主题权威"维度 + - 品牌官网解析:简单 HTTP GET + HTML 解析(检查 Schema 标记、标题层级等),填充"内容可提取性/Schema 标记"维度 +2. 修改 `GEODiagnosisService.diagnose()` 接受 `DataCollectorOutput` 作为输入 +3. 修改诊断 API 端点:`POST /api/v1/diagnosis/geo/{brand_id}` 触发异步诊断,`GET /api/v1/diagnosis/geo/{brand_id}/result` 轮询结果 +4. 诊断结果持久化到 `diagnosis_records` 表,支持历史对比 +5. 免费版返回综合分+3 核心维度,付费版返回全部 6 维度 +6. **先写契约测试**:定义诊断 API 的期望行为(非零评分、维度结构、免费/付费差异),再修复实现 + +**Execution note:** 先写诊断 API 契约测试(红色),再实现 DataCollectorService 和修改诊断服务(绿色)。 + +**Patterns to follow:** +- 复用 `backend/app/services/ai_engine/` 下的平台适配器 +- 复用 `backend/app/workers/citation_engine.py` 的查询执行模式 +- 异步任务模式参考 `backend/app/services/detection/detection_scheduler.py` +- 契约测试参考 `backend/tests/conftest.py` 中的 async_client fixture + +**Test scenarios:** +- 输入有效品牌名,自动采集数据并返回非零诊断分数 +- 输入不存在的品牌名,返回低分但非零(基于 AI 平台查询结果) +- 免费用户请求诊断,只返回 3 个核心维度 +- 付费用户请求诊断,返回全部 6 维度 +- 同一品牌 24 小时内重复请求,返回缓存结果 +- 诊断结果持久化,第二次诊断可展示历史对比 +- 契约测试:POST /api/v1/diagnosis/geo/{brand_id} 返回 202 + task_id +- 契约测试:GET /api/v1/diagnosis/geo/{brand_id}/result 返回非零 overall_score + +**Verification:** 调用 `POST /api/v1/diagnosis/geo/{brand_id}` 返回真实非零评分,且不同品牌评分有差异。契约测试全部通过。 + +--- + +### U2. 免费 GEO 健康分公开页面 + +**Goal:** 建设无需注册即可访问的 GEO 健康分页面,作为获客和市场教育的核心入口。 + +**Requirements:** R4, R9 + +**Dependencies:** U1 + +**Files:** +- `frontend/app/(public)/health-score/page.tsx` — 新建公开健康分页面 +- `frontend/app/(public)/health-score/components/ScoreCard.tsx` — 评分卡片组件 +- `frontend/app/(public)/health-score/components/CompetitorComparison.tsx` — 竞品对比组件 +- `frontend/app/(public)/health-score/components/ShareButton.tsx` — 分享按钮组件 +- `frontend/lib/api/health-score.ts` — 健康分 API 客户端 +- `backend/app/api/health_score.py` — 新建公开健康分 API(无需认证) +- `backend/app/schemas/health_score.py` — 健康分响应 schema +- `backend/tests/test_api/test_health_score_contract.py` — 新建健康分 API 契约测试 + +**Approach:** +1. 新建公开路由 `/health-score`,无需登录即可访问 +2. 页面核心交互:输入品牌名 → 30 秒内展示 GEO 健康分报告 +3. 报告内容:综合评分(大数字+健康等级色标)、3 核心维度雷达图、3 竞品对比柱状图、最严重 3 个问题概要 +4. "查看详细修复建议"按钮 → 触发注册弹窗 +5. 注册后自动关联本次诊断结果到用户账户 +6. 支持生成可分享链接(含品牌名参数)和 PDF 下载 +7. 后端:`GET /api/v1/public/health-score?brand=XXX` 无需认证,返回免费版诊断结果 +8. 结果缓存 24 小时(Redis),避免重复查询 + +**Patterns to follow:** +- 前端页面参考 `frontend/app/(dashboard)/onboarding/Step4HealthReport.tsx` 的报告展示模式 +- API 参考 `backend/app/api/diagnosis.py` 但无需认证 +- 分享功能参考 `backend/app/api/reports.py` 的 PDF 生成 + +**Test scenarios:** +- 未登录用户输入品牌名,30 秒内看到健康分报告 +- 报告包含综合评分、3 维度、3 竞品对比 +- 点击"查看详细建议"弹出注册弹窗 +- 注册后诊断结果自动关联到账户 +- 生成可分享链接,他人打开看到相同报告 +- 24 小时内重复查询同一品牌返回缓存结果 +- 契约测试:GET /api/v1/public/health-score?brand=XXX 返回 200 + 非零评分 +- 契约测试:无品牌名参数返回 422 + +**Verification:** 未登录状态下访问 `/health-score`,输入品牌名后看到完整报告且可分享。契约测试通过。 + +--- + +### U3. Onboarding 重设计与转化层 + +**Goal:** 重设计 Onboarding 流程以"查看 GEO 健康分"为第一步,嵌入付费墙触发点。 + +**Requirements:** R5 + +**Dependencies:** U1, U2 + +**Files:** +- `frontend/app/(dashboard)/onboarding/page.tsx` — 重写 Onboarding 主页面 +- `frontend/app/(dashboard)/onboarding/Step0HealthScore.tsx` — 新建第一步:健康分 +- `frontend/app/(dashboard)/onboarding/Step1BrandName.tsx` — 修改为简化版 +- `frontend/app/(dashboard)/onboarding/Step4HealthReport.tsx` — 修改为含升级提示 +- `frontend/app/(dashboard)/onboarding/Step5ActionSuggestions.tsx` — 修改为含执行按钮+付费墙 +- `frontend/components/subscription/UpgradePrompt.tsx` — 新建升级提示组件 + +**Approach:** +1. 新流程:Step0(输入品牌名看健康分) → Step1(注册/登录) → Step2(补充竞品信息) → Step3(查看详细报告+升级提示) → Step4(行动建议+执行按钮) +2. Step0 直接复用 U2 的健康分页面组件,嵌入 Onboarding 流程 +3. 在以下场景嵌入升级提示: + - 查看详细 6 维度分析时 → "升级 Pro 版查看完整分析" + - 点击执行优化建议时 → "升级 Starter 版开始优化" + - 查看效果归因报告时 → "升级 Pro 版追踪优化效果" +4. 升级提示组件 `UpgradePrompt` 统一管理,显示当前套餐限制和升级收益 + +**Patterns to follow:** +- 现有 Onboarding 步骤组件模式(Step1-5 的 props 接口) +- `useOnboardingData` hook 的数据管理模式 + +**Test scenarios:** +- 新用户进入 Onboarding,第一步看到健康分输入框 +- 输入品牌名后看到免费健康分报告 +- 注册后继续 Onboarding,看到详细报告含升级提示 +- 免费用户点击"执行优化"时看到升级弹窗 +- 已完成 Onboarding 的用户不再看到引导 + +**Verification:** 新用户从健康分开始 Onboarding,在关键节点看到升级提示。 + +--- + +### U4. 真实支付集成与功能限制 + +**Goal:** 接入微信支付+支付宝,实现功能限制中间件,让付费真正生效。 + +**Requirements:** R3, R11 + +**Dependencies:** U3 + +**Files:** +- `backend/app/services/payment/wechat_pay.py` — 新建微信支付服务 +- `backend/app/services/payment/alipay.py` — 新建支付宝服务 +- `backend/app/services/payment/base.py` — 新建支付基类 +- `backend/app/api/payments.py` — 新建支付 API 端点 +- `backend/app/middleware/subscription_enforcement.py` — 新建功能限制中间件 +- `backend/app/services/subscription.py` — 修改添加真实支付逻辑 +- `backend/app/api/subscriptions.py` — 修改添加支付 Webhook +- `backend/app/models/subscription.py` — 修改添加支付相关字段 +- `backend/tests/test_api/test_payment_contract.py` — 新建支付 API 契约测试 +- `backend/tests/test_services/test_payment.py` — 新建支付服务测试 + +**Approach:** +1. 支付基类定义统一接口:`create_order`, `verify_callback`, `refund` +2. 微信支付:使用官方 SDK,实现 JSAPI 支付(H5 页面)和 Native 支付(扫码) +3. 支付宝:使用官方 SDK,实现手机网站支付 +4. 支付 Webhook:`POST /api/v1/payments/callback/wechat` 和 `/alipay`,验证签名后更新订阅状态 +5. 功能限制中间件:在 API 路由层检查用户套餐,超限返回 403+升级提示 +6. 限制维度:查询数/品牌数/诊断频率/内容生成数/知识库数/数据保留期 +7. **先写契约测试**:定义支付 API 的期望行为(创建订单、回调处理、功能限制) + +**Execution note:** 先写支付 API 契约测试(红色),再实现支付服务(绿色)。 + +**Patterns to follow:** +- 支付回调参考 `backend/app/api/auth.py` 的 Webhook 处理模式 +- 功能限制参考 `backend/app/middleware/rate_limit.py` 的中间件模式 +- 订阅逻辑复用 `backend/app/services/subscription.py` 的 PLANS 配置 + +**Test scenarios:** +- 创建支付订单,返回支付链接/二维码 +- 支付成功回调,订阅状态更新为 active +- 免费用户超出查询限制,返回 403+升级提示 +- Pro 用户在限制内正常使用 +- 支付失败回调,订阅状态不变 +- 订阅到期,自动降级为免费版 +- 契约测试:POST /api/v1/payments/orders 返回 201 + order_id + pay_url +- 契约测试:POST /api/v1/payments/callback/wechat 验证签名后更新订阅 + +**Verification:** 完整支付流程:创建订单→支付→回调→订阅激活→功能解锁。契约测试通过。 + +--- + +### U5. AI 内容生成与分发集成 + +**Goal:** 在 ContentPipeline 前端添加 AI 生成阶段,集成知乎/头条 API 发布和微信半自动发布。 + +**Requirements:** R2, R6, R10 + +**Dependencies:** U1, U4 + +**Files:** +- `backend/app/services/content/ai_generator.py` — 新建 AI 内容生成服务 +- `backend/app/services/content/content_pipeline.py` — 修改添加 Stage 0 +- `backend/app/services/distribution/publishers/zhihu_publisher.py` — 新建知乎发布器 +- `backend/app/services/distribution/publishers/toutiao_publisher.py` — 新建头条发布器 +- `backend/app/services/distribution/publishers/wechat_publisher.py` — 新建微信半自动发布器 +- `backend/app/services/distribution/publish_engine.py` — 新建发布引擎 +- `backend/app/api/distribution.py` — 修改添加发布 API +- `backend/tests/test_api/test_content_contract.py` — 新建内容生成 API 契约测试 +- `backend/tests/test_services/test_ai_generator.py` — 新建 AI 生成测试 +- `backend/tests/test_services/test_publishers.py` — 新建发布器测试 + +**Approach:** +1. AI 内容生成服务: + - 输入:诊断结果(问题维度)+ RAG 知识库上下文 + 目标关键词 + 目标平台 + - 调用 LLM 生成初稿,基于诊断结果针对性优化 + - 输出传入现有 Pipeline(RuleValidator → SensitiveFilter → SEOOptimizer → HTMLGenerator) +2. 发布引擎: + - 知乎:通过知乎创作中心 API 发布文章(需 OAuth 授权) + - 头条:通过头条号 API 发布文章(需 API Key) + - 微信:生成格式化内容+复制引导,不直接调用 API +3. 发布流程:AI 生成 → 人工预览/编辑 → 确认发布 → 调用对应 Publisher → 记录发布状态 +4. 发布记录关联到 Content 模型,用于后续归因 + +**Patterns to follow:** +- AI 生成参考 `backend/app/services/content/content_generation_service.py` 的 LLM 调用模式 +- RAG 集成参考 `backend/app/services/knowledge/rag_service.py` +- 平台规则复用 `backend/app/services/distribution/platform_rules.py` +- Publisher 模式参考 `backend/app/services/ai_engine/` 的适配器模式 + +**Test scenarios:** +- 基于诊断结果+知识库上下文生成针对性优化内容 +- 生成内容通过 Pipeline 校验(RuleValidator + SensitiveFilter) +- 知乎 API 发布成功,返回文章 URL +- 头条 API 发布成功,返回文章 URL +- 微信半自动模式生成格式化内容+复制引导 +- 发布记录保存到 Content 模型,状态为 published +- 契约测试:POST /api/v1/content/generate 返回 201 + content_id + 生成内容 + +**Verification:** 从诊断建议到内容生成到发布的完整流程可走通。契约测试通过。 + +--- + +### U6. 效果归因系统与 ROI 报告 + +**Goal:** 建设效果归因系统,追踪已发布内容的 AI 引用变化,生成 ROI 报告证明付费价值。 + +**Requirements:** R7 + +**Dependencies:** U5 + +**Files:** +- `backend/app/services/attribution/attribution_engine.py` — 新建归因引擎 +- `backend/app/services/attribution/roi_calculator.py` — 新建 ROI 计算器 +- `backend/app/api/attribution.py` — 新建归因 API +- `backend/app/models/attribution_record.py` — 新建归因记录模型 +- `backend/app/services/monitoring/monitor_service.py` — 修改集成归因 +- `frontend/app/(dashboard)/dashboard/roi/page.tsx` — 新建 ROI 报告页面 +- `frontend/lib/api/attribution.ts` — 新建归因 API 客户端 +- `backend/tests/test_services/test_attribution.py` — 新建归因测试 + +**Approach:** +1. 归因引擎核心逻辑: + - 内容发布时记录:发布时间、目标关键词、发布前基线引用数据 + - 定期(每 3 天)检查目标关键词的引用变化 + - 归因判定:引用率上升 + 新引用内容与已发布内容语义相关 → 归因为该内容贡献 + - 归因窗口:2-4 周,窗口内持续追踪 +2. ROI 计算器:输入订阅费用+归因引用提升+行业平均引用价值,输出 GEO ROI 百分比 +3. A/B 对比报告:优化前诊断分数 vs 当前诊断分数,按维度对比 +4. 效果保障:如果 2-4 周内无任何引用提升,提供额外优化建议或延长服务 + +**Patterns to follow:** +- 监控模式参考 `backend/app/services/monitoring/monitor_service.py` +- 定期任务参考 `backend/app/services/detection/detection_scheduler.py` +- 报告生成参考 `backend/app/api/reports.py` + +**Test scenarios:** +- 发布内容后自动创建归因追踪记录 +- 3 天后检查引用变化,记录 delta +- 引用率上升且语义相关,归因为该内容贡献 +- 归因窗口结束后生成 ROI 报告 +- 优化前后 A/B 对比报告展示各维度变化 +- 无引用提升时提供额外优化建议 + +**Verification:** 发布内容→2-4 周后→归因报告展示引用提升和 ROI。 + +--- + +### U7. 邮件集成与 Dashboard 变现 UI + +**Goal:** 将 EmailService 接入业务系统(周报、续费提醒),Dashboard 添加订阅状态、用量进度、ROI 面板。 + +**Requirements:** R7 (邮件部分) + +**Dependencies:** U4, U6 + +**Files:** +- `backend/app/services/email_service.py` — 修改添加新模板 +- `backend/app/services/email/email_scheduler.py` — 新建邮件调度器 +- `backend/app/templates/geo_weekly_report.html` — 新建周报模板 +- `backend/app/templates/renewal_reminder.html` — 新建续费提醒模板 +- `backend/app/templates/trial_expiring.html` — 新建试用期到期模板 +- `frontend/app/(dashboard)/dashboard/page.tsx` — 修改添加变现 UI +- `frontend/components/subscription/SubscriptionStatus.tsx` — 新建订阅状态组件 +- `frontend/components/subscription/UsageProgress.tsx` — 新建用量进度组件 +- `frontend/components/dashboard/ROICard.tsx` — 新建 ROI 卡片组件 + +**Approach:** +1. 邮件模板新增:GEO 变化周报、续费提醒(到期前 7 天/3 天/1 天)、试用期到期提醒、欢迎邮件 +2. 邮件调度器:APScheduler 定时任务,每周一发送周报,每日检查续费提醒 +3. Dashboard 变现 UI:顶部订阅状态栏、用量进度条、ROI 卡片、功能锁定 UI +4. 配置真实 SMTP(支持 SendGrid/阿里云邮件推送) + +**Patterns to follow:** +- 邮件模板参考 `backend/app/services/email_service.py` 现有模板格式 +- Dashboard 组件参考 `frontend/app/(dashboard)/dashboard/page.tsx` 的 KPI 卡片模式 +- 订阅状态参考 `backend/app/services/subscription.py` 的 PLANS 配置 + +**Test scenarios:** +- 每周一自动发送 GEO 变化周报给注册用户 +- 订阅到期前 7 天发送续费提醒 +- 试用期到期前 3 天发送提醒 +- Dashboard 显示当前套餐和用量进度 +- 超出用量时显示升级提示 +- ROI 卡片展示引用率提升和归因内容 + +**Verification:** 注册用户收到周报邮件,Dashboard 显示订阅状态和 ROI 数据。 + +--- + +### U8. API 契约集成测试与测试基础设施 + +**Goal:** 编写跨步骤集成测试验证完整链路,完成后端测试基础设施统一。 + +**Requirements:** R12, R15, R16 + +**Dependencies:** U1, U2, U4, U5 + +**Files:** +- `backend/tests/test_integration/test_monetization_flow.py` — 新建变现闭环集成测试 +- `backend/tests/conftest.py` — 增强 fixture 体系 +- `backend/tests/fixtures/auth.py` — 新建认证 fixture 模块 +- `backend/tests/fixtures/database.py` — 新建数据库 fixture 模块 +- `backend/tests/fixtures/client.py` — 新建 httpx 客户端 fixture 模块 +- `backend/tests/fixtures/brands.py` — 新建品牌测试数据 fixture 模块 + +**Approach:** +1. 编写跨步骤集成测试:品牌创建→诊断→方案→内容生成→效果追踪的完整链路 +2. 验证 U1-U5 的 API 契约测试在集成场景下仍然通过 +3. 完成后端测试目录统一验证(原 QA 计划 U1) +4. 建立共享 fixture 体系:按领域拆分到 `backend/tests/fixtures/` 下 + - auth.py:override_get_current_user、auth_token、auth_headers + - database.py:async_engine、async_session(含自动 rollback) + - client.py:async_client(httpx AsyncClient with app) + - brands.py:预创建的测试品牌和竞品数据 +5. conftest.py 通过 pytest plugin 机制自动加载 fixtures/ 下所有模块 + +**Patterns to follow:** +- `backend/tests/conftest.py` 现有 fixture 模式 +- `backend/tests/test_integration/test_business_flow.py` 现有集成测试模式 + +**Test scenarios:** +- 品牌创建→诊断→方案→内容生成→效果追踪全链路通过 +- 数据在步骤间正确传递(品牌 ID、诊断结果 ID、内容 ID) +- 共享 fixture 可被其他测试复用 +- 数据库 fixture 在测试结束后自动 rollback +- 多个测试并行运行时 fixture 互不干扰 + +**Verification:** `cd backend && pytest tests/test_integration/test_monetization_flow.py --tb=short` 通过。`cd backend && pytest tests/ -q` 全部测试被发现且无新增失败。 + +--- + +### U9. E2E 烟雾测试 + +**Goal:** 编写 2 个关键路径的 Playwright E2E 烟雾测试,验证最核心的用户路径端到端可用。 + +**Requirements:** R13, R14 + +**Dependencies:** U2, U3 + +**Files:** +- `frontend/e2e/tests/health-score-smoke.spec.ts` — 新建获客路径烟雾测试 +- `frontend/e2e/tests/core-flow-smoke.spec.ts` — 新建核心流程烟雾测试 + +**Approach:** +1. 获客路径烟雾测试:访问健康分页面→输入品牌名→看到报告→点击注册 +2. 核心流程烟雾测试:登录→创建品牌→触发诊断→查看诊断结果 +3. 使用独立 setup(直接调用 API 创建测试数据),不依赖共享 fixture +4. 确保截图和视频录制配置正确(screenshot:on-failure, video:retain-on-failure) + +**Patterns to follow:** +- `frontend/e2e/tests/login.spec.ts` 现有 E2E 测试模式 +- `frontend/e2e/pages/login.page.ts` Page Object 模式 + +**Test scenarios:** +- 获客路径:未登录→访问 /health-score→输入品牌名→30 秒内看到报告→点击注册→注册弹窗出现 +- 核心流程:登录→创建品牌→品牌出现在 Dashboard→触发诊断→诊断结果非零 +- 测试失败时自动截图保存 + +**Verification:** `cd frontend && npx playwright test health-score-smoke core-flow-smoke` 通过。 + +## Scope Boundaries + +**Deferred for later:** + +- 网站爬取通道(诊断数据采集 V2) +- 全面 E2E 测试覆盖(原 QA 计划 U3/U4/U5 → 上线后逐步扩展) +- 性能基线测试(原 QA 计划 U7 → 上线后有真实负载后建立) +- CI 安全扫描(原 QA 计划 U8 → 上线后集成) +- Alembic 迁移验证(原 QA 计划 U8 → 上线后添加) +- 前端组件测试(原 QA 计划 U10 → 上线后扩展) +- 行业 GEO 基准数据积累 +- API 市场开放和第三方集成 +- 白标报告和专属客户成功经理 +- 多语言和国际 AI 平台支持 +- 知识库自动构建(从品牌官网自动提取 GEO 相关信息结构) +- 批量内容生成(一次生成多平台内容) +- 内容日历和排期执行系统 + +**Outside this product's identity:** + +- 传统 SEO 工具功能(关键词排名、外链分析等) +- 广告投放和营销自动化 +- 社交媒体管理 +- 基础设施级别的渗透测试 + +**Deferred to follow-up work:** + +- E2E 测试迁移到共享 fixture(U9 完成后,将独立 setup 替换为 U8 的共享 fixture) +- 性能基线设定阈值和告警(上线后根据数据设定合理阈值) +- 跨浏览器 E2E 扩展(稳定 Chromium 后再扩展) +- 知乎/头条 API OAuth 授权完整流程 +- 批量内容生成 + +## Open Questions + +- 微信公众号 OAuth 授权流程的具体实现方案需在 U5 实施时确认 +- 知乎和头条 API 的申请和审核周期可能影响 U5 的交付时间 +- 诊断自动数据采集的品牌官网解析可能遇到反爬机制,需准备降级方案 +- 真实 SMTP 配置需要用户提供邮件服务商账号信息 +- 效果归因的具体技术实现方案(基于引用检测还是基于评分变化) +- 付费版定价是否需要调整(当前 ¥199/599/1999 是否匹配价值感知) + +## Risks & Dependencies + +| Risk | Impact | Mitigation | +|------|--------|------------| +| 诊断数据采集不准确 | 用户不信任健康分,获客失效 | 先用 AI 平台查询数据(已有适配器),品牌官网解析作为增强 | +| 微信/知乎/头条 API 审核不通过 | 分发集成延期 | 微信降级为半自动,知乎/头条备选手动发布 | +| 支付 SDK 集成复杂度超预期 | 付费闭环延期 | 先实现支付宝(文档更友好),微信支付后续迭代 | +| GEO 效果归因窗口内无可见提升 | 用户认为产品无效 | 提供效果保障机制,延长归因窗口或提供额外优化 | +| 免费诊断 API 调用成本过高 | 运营成本失控 | 24 小时缓存+频率限制+深度限制 | +| LLM API Key 不可用影响内容生成 | 核心付费功能不可用 | 支持多 LLM Provider 降级,至少保证一个可用 | + +## Sources & Research + +- Origin requirements: `docs/brainstorms/2026-05-31-geo-next-phase-core-flow-repair-requirements.md` +- Secondary origin: `docs/brainstorms/2026-05-31-geo-platform-monetization-closed-loop-requirements.md` +- QA plan (downscoped): `docs/plans/2026-05-31-002-test-quality-assurance-system-plan.md` +- Existing diagnosis system: `backend/app/services/diagnosis/geo_diagnosis.py` +- Diagnosis API root cause: `backend/app/api/diagnosis.py` line 75 (`GEODiagnosisInput()`) +- Existing content pipeline: `backend/app/services/content/content_pipeline.py` +- Existing distribution rules: `backend/app/services/distribution/platform_rules.py` +- Existing monitoring: `backend/app/services/monitoring/monitor_service.py` +- Existing subscription: `backend/app/services/subscription.py` +- Existing email: `backend/app/services/email_service.py` +- Existing E2E tests: `frontend/e2e/tests/` diff --git a/frontend/__tests__/hooks/use-compare-data.test.ts b/frontend/__tests__/hooks/use-compare-data.test.ts new file mode 100644 index 0000000..e7168b5 --- /dev/null +++ b/frontend/__tests__/hooks/use-compare-data.test.ts @@ -0,0 +1,192 @@ +/** + * useCompareData Hook 单元测试 + */ + +import { describe, it, expect, vi, beforeEach } from "vitest"; +import { renderHook, waitFor, act } from "@testing-library/react"; +import { SWRConfig } from "swr"; +import type { ReactNode } from "react"; + +vi.mock("@/lib/api/client", () => ({ + fetchWithAuth: vi.fn(), +})); + +import { fetchWithAuth } from "@/lib/api/client"; +const mockFetchWithAuth = vi.mocked(fetchWithAuth); + +vi.mock("next-auth/react", () => ({ + useSession: vi.fn(() => ({ + data: { accessToken: "test-token" }, + status: "authenticated", + })), +})); + +import { useCompareData } from "@/lib/hooks/use-compare-data"; +import type { BrandListResponse, CompareResponse } from "@/types/brand"; + +const noRetryOptions = { shouldRetryOnError: false }; + +/** 使用独立 SWR 缓存的 wrapper,避免测试间缓存冲突 */ +function createWrapper() { + return ({ children }: { children: ReactNode }) => + SWRConfig({ + value: { provider: () => new Map() }, + children, + }); +} + +describe("useCompareData", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + const mockBrandsResponse: BrandListResponse = { + items: [ + { + id: "brand-1", + name: "测试品牌", + aliases: [], + platforms: ["wenxin"], + frequency: "weekly", + status: "active", + score: 80, + last_queried_at: null, + next_query_at: null, + created_at: "2024-01-01", + }, + ], + total: 1, + }; + + const mockCompareResponse: CompareResponse = { + brand_id: "brand-1", + brand_name: "测试品牌", + items: [ + { + entity_id: "brand-1", + entity_name: "测试品牌", + entity_type: "brand", + mention_rate_score: 70, + sov_score: 65, + quality_score: 80, + overall_score: 75, + citation_count: 100, + dimensions: [], + overall_trend: "up", + overall_trend_value: 5, + }, + ], + radar_data: [], + }; + + it("应返回品牌列表数据", async () => { + mockFetchWithAuth.mockImplementation((url: string) => { + if (url.includes("/api/v1/brands")) return Promise.resolve(mockBrandsResponse); + return Promise.resolve({}); + }); + + const { result } = renderHook( + () => useCompareData({ swrOptions: noRetryOptions }), + { wrapper: createWrapper() } + ); + + await waitFor(() => { + expect(result.current.brands).toEqual(mockBrandsResponse.items); + }); + }); + + it("selectedBrandId 为空时应暂停对比数据请求", () => { + mockFetchWithAuth.mockResolvedValue(mockBrandsResponse); + + const { result } = renderHook( + () => useCompareData({ swrOptions: noRetryOptions }), + { wrapper: createWrapper() } + ); + + expect(result.current.compareData).toBeUndefined(); + }); + + it("selectedBrandId 存在时应获取对比数据", async () => { + mockFetchWithAuth.mockImplementation((url: string) => { + if (url.includes("/compare")) return Promise.resolve(mockCompareResponse); + return Promise.resolve(mockBrandsResponse); + }); + + const { result } = renderHook( + () => useCompareData({ initialBrandId: "brand-1", swrOptions: noRetryOptions }), + { wrapper: createWrapper() } + ); + + await waitFor(() => { + expect(result.current.compareData).toEqual(mockCompareResponse); + }); + }); + + it("应正确暴露 loading 状态", async () => { + mockFetchWithAuth.mockImplementation((url: string) => { + if (url.includes("/compare")) return Promise.resolve(mockCompareResponse); + return Promise.resolve(mockBrandsResponse); + }); + + const { result } = renderHook( + () => useCompareData({ initialBrandId: "brand-1", swrOptions: noRetryOptions }), + { wrapper: createWrapper() } + ); + + await waitFor(() => { + expect(result.current.isLoading).toBe(false); + }); + }); + + it("应正确暴露 error 状态", async () => { + mockFetchWithAuth.mockRejectedValue(new Error("网络错误")); + + const { result } = renderHook( + () => useCompareData({ swrOptions: noRetryOptions }), + { wrapper: createWrapper() } + ); + + await waitFor(() => { + expect(result.current.error).toBeDefined(); + expect(result.current.error?.message).toBe("网络错误"); + }); + }); + + it("setSelectedBrandId 应更新选中品牌", async () => { + mockFetchWithAuth.mockImplementation((url: string) => { + if (url.includes("/compare")) return Promise.resolve(mockCompareResponse); + return Promise.resolve(mockBrandsResponse); + }); + + const { result } = renderHook( + () => useCompareData({ swrOptions: noRetryOptions }), + { wrapper: createWrapper() } + ); + + await waitFor(() => { + expect(result.current.brands.length).toBeGreaterThan(0); + }); + + act(() => { + result.current.setSelectedBrandId("brand-1"); + }); + + expect(result.current.selectedBrandId).toBe("brand-1"); + }); + + it("应提供 refresh 方法刷新数据", async () => { + mockFetchWithAuth.mockResolvedValue(mockBrandsResponse); + + const { result } = renderHook( + () => useCompareData({ swrOptions: noRetryOptions }), + { wrapper: createWrapper() } + ); + + await waitFor(() => { + expect(result.current.isLoading).toBe(false); + }); + + expect(typeof result.current.refreshBrands).toBe("function"); + expect(typeof result.current.refreshCompare).toBe("function"); + }); +}); diff --git a/frontend/__tests__/hooks/use-content-data.test.ts b/frontend/__tests__/hooks/use-content-data.test.ts new file mode 100644 index 0000000..1ae6b1c --- /dev/null +++ b/frontend/__tests__/hooks/use-content-data.test.ts @@ -0,0 +1,141 @@ +/** + * useContentData Hook 单元测试 + */ + +import { describe, it, expect, vi, beforeEach } from "vitest"; +import { renderHook, waitFor } from "@testing-library/react"; +import { SWRConfig } from "swr"; +import type { ReactNode } from "react"; + +vi.mock("@/lib/api/client", () => ({ + fetchWithAuth: vi.fn(), +})); + +import { fetchWithAuth } from "@/lib/api/client"; +const mockFetchWithAuth = vi.mocked(fetchWithAuth); + +import { useContentData } from "@/lib/hooks/use-content-data"; +import type { Content } from "@/lib/api/contents"; +import type { KnowledgeBase } from "@/lib/api/knowledge"; + +const noRetryOptions = { shouldRetryOnError: false }; + +function createWrapper() { + return ({ children }: { children: ReactNode }) => + SWRConfig({ value: { provider: () => new Map() }, children }); +} + +describe("useContentData", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + const mockContents: Content[] = [ + { + id: "content-1", + title: "测试内容", + body: "内容正文", + content_type: "article", + status: "draft", + author_id: "user-1", + tags: ["test"], + created_at: "2024-01-01", + updated_at: "2024-01-01", + }, + ]; + + const mockKnowledgeBases: KnowledgeBase[] = [ + { + id: "kb-1", + name: "测试知识库", + type: "enterprise", + document_count: 5, + status: "ready", + created_at: "2024-01-01", + }, + ]; + + it("应同时获取内容列表和知识库列表", async () => { + mockFetchWithAuth.mockImplementation((url: string) => { + if (url.includes("/api/v1/contents/")) return Promise.resolve(mockContents); + if (url.includes("/api/v1/knowledge/bases")) return Promise.resolve(mockKnowledgeBases); + return Promise.resolve([]); + }); + + const { result } = renderHook( + () => useContentData({ swrOptions: noRetryOptions }), + { wrapper: createWrapper() } + ); + + await waitFor(() => { + expect(result.current.isLoading).toBe(false); + }); + + expect(result.current.contents).toEqual(mockContents); + expect(result.current.knowledgeBases).toEqual(mockKnowledgeBases); + }); + + it("应正确暴露 isLoading 状态", async () => { + mockFetchWithAuth.mockImplementation((url: string) => { + if (url.includes("/api/v1/contents/")) return Promise.resolve(mockContents); + if (url.includes("/api/v1/knowledge/bases")) return Promise.resolve(mockKnowledgeBases); + return Promise.resolve([]); + }); + + const { result } = renderHook( + () => useContentData({ swrOptions: noRetryOptions }), + { wrapper: createWrapper() } + ); + + await waitFor(() => { + expect(result.current.isLoading).toBe(false); + }); + }); + + it("应正确暴露 error 状态", async () => { + mockFetchWithAuth.mockRejectedValue(new Error("数据加载失败")); + + const { result } = renderHook( + () => useContentData({ swrOptions: noRetryOptions }), + { wrapper: createWrapper() } + ); + + await waitFor(() => { + expect(result.current.error).toBeDefined(); + expect(result.current.error?.message).toBe("数据加载失败"); + }); + }); + + it("应提供 refresh 方法", async () => { + mockFetchWithAuth.mockImplementation((url: string) => { + if (url.includes("/api/v1/contents/")) return Promise.resolve(mockContents); + if (url.includes("/api/v1/knowledge/bases")) return Promise.resolve(mockKnowledgeBases); + return Promise.resolve([]); + }); + + const { result } = renderHook( + () => useContentData({ swrOptions: noRetryOptions }), + { wrapper: createWrapper() } + ); + + await waitFor(() => { + expect(result.current.isLoading).toBe(false); + }); + + expect(typeof result.current.refreshContents).toBe("function"); + expect(typeof result.current.refreshKnowledgeBases).toBe("function"); + }); + + it("数据加载失败时 error 应包含错误信息", async () => { + mockFetchWithAuth.mockRejectedValue(new Error("网络异常")); + + const { result } = renderHook( + () => useContentData({ swrOptions: noRetryOptions }), + { wrapper: createWrapper() } + ); + + await waitFor(() => { + expect(result.current.error?.message).toBe("网络异常"); + }); + }); +}); diff --git a/frontend/__tests__/hooks/use-onboarding-data.test.ts b/frontend/__tests__/hooks/use-onboarding-data.test.ts new file mode 100644 index 0000000..138f146 --- /dev/null +++ b/frontend/__tests__/hooks/use-onboarding-data.test.ts @@ -0,0 +1,174 @@ +/** + * useOnboardingData Hook 单元测试 + */ + +import { describe, it, expect, vi, beforeEach } from "vitest"; +import { renderHook, waitFor, act } from "@testing-library/react"; +import { SWRConfig } from "swr"; +import type { ReactNode } from "react"; + +vi.mock("@/lib/api/client", () => ({ + fetchWithAuth: vi.fn(), +})); + +import { fetchWithAuth } from "@/lib/api/client"; +const mockFetchWithAuth = vi.mocked(fetchWithAuth); + +vi.mock("next-auth/react", () => ({ + useSession: vi.fn(() => ({ + data: { accessToken: "test-token" }, + status: "authenticated", + })), +})); + +import { useOnboardingData } from "@/lib/hooks/use-onboarding-data"; + +const noRetryOptions = { shouldRetryOnError: false }; + +function createWrapper() { + return ({ children }: { children: ReactNode }) => + SWRConfig({ value: { provider: () => new Map() }, children }); +} + +describe("useOnboardingData", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("应检查引导状态", async () => { + mockFetchWithAuth.mockResolvedValue({ completed: false }); + + const { result } = renderHook( + () => useOnboardingData({ swrOptions: noRetryOptions }), + { wrapper: createWrapper() } + ); + + await waitFor(() => { + expect(result.current.onboardingStatus).toEqual({ completed: false }); + }); + }); + + it("引导已完成时应标记 isCompleted", async () => { + mockFetchWithAuth.mockResolvedValue({ completed: true }); + + const { result } = renderHook( + () => useOnboardingData({ swrOptions: noRetryOptions }), + { wrapper: createWrapper() } + ); + + await waitFor(() => { + expect(result.current.onboardingStatus).toEqual({ completed: true }); + expect(result.current.isCompleted).toBe(true); + }); + }); + + it("应正确暴露 isLoading 状态", async () => { + mockFetchWithAuth.mockResolvedValue({ completed: false }); + + const { result } = renderHook( + () => useOnboardingData({ swrOptions: noRetryOptions }), + { wrapper: createWrapper() } + ); + + await waitFor(() => { + expect(result.current.isLoading).toBe(false); + }); + }); + + it("应正确暴露 error 状态", async () => { + mockFetchWithAuth.mockRejectedValue(new Error("检查引导状态失败")); + + const { result } = renderHook( + () => useOnboardingData({ swrOptions: noRetryOptions }), + { wrapper: createWrapper() } + ); + + await waitFor(() => { + expect(result.current.error).toBeDefined(); + expect(result.current.error?.message).toBe("检查引导状态失败"); + }); + }); + + it("应提供 createBrand mutation 方法", () => { + mockFetchWithAuth.mockResolvedValue({ completed: false }); + + const { result } = renderHook( + () => useOnboardingData({ swrOptions: noRetryOptions }), + { wrapper: createWrapper() } + ); + + expect(typeof result.current.createBrand).toBe("function"); + }); + + it("createBrand 成功应返回 brand_id", async () => { + mockFetchWithAuth.mockImplementation((url: string, options?: RequestInit) => { + if (options?.method === "POST") return Promise.resolve({ brand_id: "new-brand-1" }); + return Promise.resolve({ completed: false }); + }); + + const { result } = renderHook( + () => useOnboardingData({ swrOptions: noRetryOptions }), + { wrapper: createWrapper() } + ); + + await waitFor(() => { + expect(result.current.isLoading).toBe(false); + }); + + let brandId: string | null = null; + await act(async () => { + brandId = await result.current.createBrand({ + name: "新品牌", + competitors: [], + platforms: ["wenxin"], + frequency: "weekly", + }); + }); + + expect(brandId).toBe("new-brand-1"); + }); + + it("createBrand 失败应返回 null 并设置 mutationError", async () => { + mockFetchWithAuth.mockImplementation((url: string, options?: RequestInit) => { + if (options?.method === "POST") return Promise.reject(new Error("创建品牌失败")); + return Promise.resolve({ completed: false }); + }); + + const { result } = renderHook( + () => useOnboardingData({ swrOptions: noRetryOptions }), + { wrapper: createWrapper() } + ); + + await waitFor(() => { + expect(result.current.isLoading).toBe(false); + }); + + let brandId: string | null = null; + await act(async () => { + brandId = await result.current.createBrand({ + name: "新品牌", + competitors: [], + platforms: ["wenxin"], + frequency: "weekly", + }); + }); + + expect(brandId).toBeNull(); + expect(result.current.mutationError).toBeInstanceOf(Error); + }); + + it("应提供 refresh 方法", async () => { + mockFetchWithAuth.mockResolvedValue({ completed: false }); + + const { result } = renderHook( + () => useOnboardingData({ swrOptions: noRetryOptions }), + { wrapper: createWrapper() } + ); + + await waitFor(() => { + expect(result.current.isLoading).toBe(false); + }); + + expect(typeof result.current.refresh).toBe("function"); + }); +}); diff --git a/frontend/__tests__/hooks/use-suggestions-data.test.ts b/frontend/__tests__/hooks/use-suggestions-data.test.ts new file mode 100644 index 0000000..b589d26 --- /dev/null +++ b/frontend/__tests__/hooks/use-suggestions-data.test.ts @@ -0,0 +1,167 @@ +/** + * useSuggestionsData Hook 单元测试 + */ + +import { describe, it, expect, vi, beforeEach } from "vitest"; +import { renderHook, waitFor } from "@testing-library/react"; +import { SWRConfig } from "swr"; +import type { ReactNode } from "react"; + +vi.mock("@/lib/api/client", () => ({ + fetchWithAuth: vi.fn(), +})); + +import { fetchWithAuth } from "@/lib/api/client"; +const mockFetchWithAuth = vi.mocked(fetchWithAuth); + +vi.mock("next-auth/react", () => ({ + useSession: vi.fn(() => ({ + data: { accessToken: "test-token" }, + status: "authenticated", + })), +})); + +import { useSuggestionsData } from "@/lib/hooks/use-suggestions-data"; +import type { SuggestionListResponse } from "@/types/suggestion"; + +const noRetryOptions = { shouldRetryOnError: false }; + +function createWrapper() { + return ({ children }: { children: ReactNode }) => + SWRConfig({ value: { provider: () => new Map() }, children }); +} + +describe("useSuggestionsData", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + const mockSuggestionsResponse: SuggestionListResponse = { + suggestions: [ + { + id: "sug-1", + brand_id: "brand-1", + type: "content_optimization", + priority: "high", + title: "优化内容", + description: "建议优化内容", + action: null, + expected_impact: null, + difficulty: "easy", + status: "pending", + generated_at: "2024-01-01", + updated_at: "2024-01-01", + batch_id: "batch-1", + source: "rule", + }, + ], + total: 1, + }; + + it("brandId 为空时应暂停建议请求", () => { + mockFetchWithAuth.mockResolvedValue({}); + + const { result } = renderHook(() => useSuggestionsData(), { wrapper: createWrapper() }); + + expect(result.current.suggestions).toBeUndefined(); + }); + + it("brandId 存在时应获取建议列表", async () => { + mockFetchWithAuth.mockResolvedValue(mockSuggestionsResponse); + + const { result } = renderHook( + () => useSuggestionsData({ brandId: "brand-1", swrOptions: noRetryOptions }), + { wrapper: createWrapper() } + ); + + await waitFor(() => { + expect(result.current.suggestions).toEqual(mockSuggestionsResponse.suggestions); + }); + }); + + it("应支持筛选参数", async () => { + mockFetchWithAuth.mockResolvedValue(mockSuggestionsResponse); + + const { result } = renderHook( + () => useSuggestionsData({ + brandId: "brand-1", + filters: { type: "content_optimization", priority: "high", status: "pending" }, + swrOptions: noRetryOptions, + }), + { wrapper: createWrapper() } + ); + + await waitFor(() => { + expect(result.current.suggestions).toEqual(mockSuggestionsResponse.suggestions); + }); + + const suggestionCall = mockFetchWithAuth.mock.calls.find( + (call) => typeof call[0] === "string" && call[0].includes("/suggestions") + ); + expect(suggestionCall).toBeDefined(); + const calledUrl = suggestionCall![0] as string; + expect(calledUrl).toContain("type=content_optimization"); + expect(calledUrl).toContain("priority=high"); + expect(calledUrl).toContain("status=pending"); + }); + + it("应正确暴露 isLoading 状态", async () => { + mockFetchWithAuth.mockResolvedValue(mockSuggestionsResponse); + + const { result } = renderHook( + () => useSuggestionsData({ brandId: "brand-1", swrOptions: noRetryOptions }), + { wrapper: createWrapper() } + ); + + await waitFor(() => { + expect(result.current.isLoading).toBe(false); + }); + }); + + it("应正确暴露 error 状态", async () => { + mockFetchWithAuth.mockRejectedValue(new Error("加载建议失败")); + + const { result } = renderHook( + () => useSuggestionsData({ brandId: "brand-1", swrOptions: noRetryOptions }), + { wrapper: createWrapper() } + ); + + await waitFor(() => { + expect(result.current.error).toBeDefined(); + expect(result.current.error?.message).toBe("加载建议失败"); + }); + }); + + it("应提供 regenerate 方法", () => { + mockFetchWithAuth.mockResolvedValue(mockSuggestionsResponse); + + const { result } = renderHook( + () => useSuggestionsData({ brandId: "brand-1", swrOptions: noRetryOptions }), + { wrapper: createWrapper() } + ); + + expect(typeof result.current.regenerate).toBe("function"); + }); + + it("应提供 updateStatus 方法", () => { + mockFetchWithAuth.mockResolvedValue(mockSuggestionsResponse); + + const { result } = renderHook( + () => useSuggestionsData({ brandId: "brand-1", swrOptions: noRetryOptions }), + { wrapper: createWrapper() } + ); + + expect(typeof result.current.updateStatus).toBe("function"); + }); + + it("应提供 refresh 方法", () => { + mockFetchWithAuth.mockResolvedValue(mockSuggestionsResponse); + + const { result } = renderHook( + () => useSuggestionsData({ brandId: "brand-1", swrOptions: noRetryOptions }), + { wrapper: createWrapper() } + ); + + expect(typeof result.current.refresh).toBe("function"); + }); +}); diff --git a/frontend/__tests__/lib/action-suggestions.test.ts b/frontend/__tests__/lib/action-suggestions.test.ts new file mode 100644 index 0000000..04de5c3 --- /dev/null +++ b/frontend/__tests__/lib/action-suggestions.test.ts @@ -0,0 +1,157 @@ +/** + * 行动建议系统统一测试 + * 验证 ActionSuggestion 和 NextAction 使用统一的 ActionItem 类型 + */ + +import { describe, it, expect } from "vitest"; +import { generateActionSuggestions } from "@/lib/dashboard-health"; +import { generateNextActions } from "@/lib/next-action"; +import type { ActionItem, ActionPriority } from "@/types/suggestion"; +import type { ActionSuggestion } from "@/types/dashboard-health"; +import type { NextAction } from "@/types/next-action"; + +describe("ActionItem 统一类型", () => { + it("ActionItem 类型应包含所有必要字段", async () => { + const suggestionModule = await import("@/types/suggestion"); + // 验证类型可以被导入(编译时检查) + expect(suggestionModule).toBeDefined(); + }); + + it("ActionPriority 应为 'primary' | 'secondary' | 'optional'", async () => { + const suggestionModule = await import("@/types/suggestion"); + // ActionPriority 是类型,运行时不可直接检查,但我们可以验证模块导出 + expect(suggestionModule).toBeDefined(); + }); +}); + +describe("ActionSuggestion 兼容 ActionItem", () => { + it("generateActionSuggestions 返回的对象应兼容 ActionItem 结构", () => { + const stats = { + platformScores: [ + { + platform: "kimi", + score: 30, + competitor_score: 60, + competitor_name: "竞品A", + }, + ], + overallScore: 30, + hasQueries: true, + }; + + const suggestions = generateActionSuggestions(stats); + + expect(suggestions.length).toBeGreaterThan(0); + + for (const suggestion of suggestions) { + // ActionItem 必需字段 + expect(suggestion).toHaveProperty("id"); + expect(suggestion).toHaveProperty("title"); + expect(suggestion).toHaveProperty("description"); + expect(suggestion).toHaveProperty("icon"); + expect(suggestion).toHaveProperty("href"); + + // type 字段应兼容 ActionPriority + const validPriorities: ActionPriority[] = [ + "primary", + "secondary", + "optional", + ]; + expect(validPriorities).toContain(suggestion.type); + } + }); +}); + +describe("NextAction 兼容 ActionItem", () => { + it("generateNextActions 返回的对象应兼容 ActionItem 结构", () => { + const context = { + hasData: true, + hasBrands: true, + brandCount: 1, + overallScore: 50, + scoreChange: -5, + competitorCount: 2, + hasQueryHistory: true, + currentPage: "dashboard" as const, + }; + + const actions = generateNextActions(context); + + expect(actions.length).toBeGreaterThan(0); + + for (const action of actions) { + // ActionItem 必需字段 + expect(action).toHaveProperty("id"); + expect(action).toHaveProperty("title"); + expect(action).toHaveProperty("description"); + expect(action).toHaveProperty("icon"); + + // NextAction 应有 href 字段(映射自 actionUrl) + expect(action).toHaveProperty("href"); + + // priority 字段应兼容 ActionPriority + const validPriorities: ActionPriority[] = [ + "primary", + "secondary", + "optional", + ]; + expect(validPriorities).toContain(action.priority); + } + }); + + it("NextAction 的 href 应与 actionUrl 一致", () => { + const context = { + hasData: false, + hasBrands: false, + brandCount: 0, + overallScore: 0, + scoreChange: 0, + competitorCount: 0, + hasQueryHistory: false, + currentPage: "dashboard" as const, + }; + + const actions = generateNextActions(context); + + for (const action of actions) { + // href 应该存在且非空 + expect(action.href).toBeTruthy(); + expect(typeof action.href).toBe("string"); + } + }); +}); + +describe("两个系统生成结果的一致性", () => { + it("ActionSuggestion 和 NextAction 都能生成有效的行动建议", () => { + // dashboard-health 的 generateActionSuggestions + const dashboardSuggestions = generateActionSuggestions({ + platformScores: [ + { platform: "kimi", score: 30, competitor_score: 60 }, + ], + overallScore: 30, + hasQueries: true, + }); + + // next-action 的 generateNextActions + const nextActions = generateNextActions({ + hasData: true, + hasBrands: true, + brandCount: 1, + overallScore: 30, + scoreChange: -5, + competitorCount: 1, + hasQueryHistory: true, + currentPage: "dashboard", + }); + + // 两个系统都应该能生成建议 + expect(dashboardSuggestions.length).toBeGreaterThan(0); + expect(nextActions.length).toBeGreaterThan(0); + + // 所有建议都应有唯一 id + const dashboardIds = dashboardSuggestions.map((s) => s.id); + const nextActionIds = nextActions.map((a) => a.id); + expect(new Set(dashboardIds).size).toBe(dashboardIds.length); + expect(new Set(nextActionIds).size).toBe(nextActionIds.length); + }); +}); diff --git a/frontend/__tests__/lib/platforms.test.ts b/frontend/__tests__/lib/platforms.test.ts new file mode 100644 index 0000000..f97db4c --- /dev/null +++ b/frontend/__tests__/lib/platforms.test.ts @@ -0,0 +1,71 @@ +/** + * 平台映射统一测试 + * 验证 platforms.ts 作为唯一真实来源 + */ + +import { describe, it, expect } from "vitest"; +import { PLATFORM_MAP, PLATFORMS } from "@/lib/platforms"; + +describe("PLATFORM_MAP 统一平台映射", () => { + const EXPECTED_PLATFORMS = [ + "wenxin", + "kimi", + "tongyi", + "baidu_ai", + "yuanbao", + "qingyan", + "doubao", + "tiangong", + "xinghuo", + ] as const; + + it("PLATFORM_MAP 应包含全部9个平台", () => { + const keys = Object.keys(PLATFORM_MAP); + expect(keys).toHaveLength(9); + }); + + it("PLATFORM_MAP 应包含所有预期的平台键", () => { + for (const platform of EXPECTED_PLATFORMS) { + expect(PLATFORM_MAP).toHaveProperty(platform); + } + }); + + it("PLATFORM_MAP 必须包含 baidu_ai 和 yuanbao(dashboard-health.ts 的 PLATFORM_LABELS 缺失的)", () => { + expect(PLATFORM_MAP).toHaveProperty("baidu_ai"); + expect(PLATFORM_MAP).toHaveProperty("yuanbao"); + expect(PLATFORM_MAP["baidu_ai"]).toBe("百度AI搜索"); + expect(PLATFORM_MAP["yuanbao"]).toBe("腾讯元宝"); + }); + + it("每个平台的 label 应为非空字符串", () => { + for (const [key, label] of Object.entries(PLATFORM_MAP)) { + expect(label).toBeTruthy(); + expect(typeof label).toBe("string"); + expect(label.length).toBeGreaterThan(0); + } + }); + + it("PLATFORMS 数组应与 PLATFORM_MAP 键一致", () => { + const platformKeys = PLATFORMS.map((p) => p.key); + const mapKeys = Object.keys(PLATFORM_MAP); + expect(platformKeys.sort()).toEqual(mapKeys.sort()); + }); + + it("PLATFORMS 数组中每个条目的 label 应与 PLATFORM_MAP 一致", () => { + for (const platform of PLATFORMS) { + expect(platform.label).toBe(PLATFORM_MAP[platform.key]); + } + }); +}); + +describe("dashboard-health.ts 不再定义自己的 PLATFORM_LABELS", () => { + it("dashboard-health.ts 不应导出 PLATFORM_LABELS", async () => { + const dashboardHealthModule = await import("@/types/dashboard-health"); + expect("PLATFORM_LABELS" in dashboardHealthModule).toBe(false); + }); + + it("dashboard-health.ts 不应导出 PLATFORM_ICONS", async () => { + const dashboardHealthModule = await import("@/types/dashboard-health"); + expect("PLATFORM_ICONS" in dashboardHealthModule).toBe(false); + }); +}); diff --git a/frontend/__tests__/types/health-level.test.ts b/frontend/__tests__/types/health-level.test.ts new file mode 100644 index 0000000..d2fabed --- /dev/null +++ b/frontend/__tests__/types/health-level.test.ts @@ -0,0 +1,133 @@ +/** + * HealthLevel 类型和配置的统一测试 + * 验证 dashboard-health.ts 作为唯一真实来源 + */ + +import { describe, it, expect } from "vitest"; +import { + HealthLevel, + HEALTH_LEVEL_CONFIG, +} from "@/types/dashboard-health"; +import { getHealthLevel } from "@/lib/dashboard-health"; + +describe("HealthLevel 类型统一", () => { + it("HealthLevel 应包含正确的4个值:excellent, good, pass, danger", () => { + const expectedLevels: HealthLevel[] = [ + "excellent", + "good", + "pass", + "danger", + ]; + // 验证类型推断正确 - 如果 HealthLevel 包含 "fair" 则编译失败 + const actualLevels = Object.keys(HEALTH_LEVEL_CONFIG) as HealthLevel[]; + expect(actualLevels.sort()).toEqual(expectedLevels.sort()); + }); + + it("HEALTH_LEVEL_CONFIG 应包含所有4个等级的配置", () => { + const keys = Object.keys(HEALTH_LEVEL_CONFIG); + expect(keys).toContain("excellent"); + expect(keys).toContain("good"); + expect(keys).toContain("pass"); + expect(keys).toContain("danger"); + expect(keys).toHaveLength(4); + }); + + it("HEALTH_LEVEL_CONFIG 不应包含 'fair' 键", () => { + const keys = Object.keys(HEALTH_LEVEL_CONFIG); + expect(keys).not.toContain("fair"); + }); + + it("每个 HEALTH_LEVEL_CONFIG 条目应包含完整的配置字段", () => { + for (const [key, config] of Object.entries(HEALTH_LEVEL_CONFIG)) { + expect(config).toHaveProperty("level", key); + expect(config).toHaveProperty("label"); + expect(config).toHaveProperty("icon"); + expect(config).toHaveProperty("color"); + expect(config).toHaveProperty("minScore"); + expect(config).toHaveProperty("maxScore"); + expect(config.color).toHaveProperty("bg"); + expect(config.color).toHaveProperty("text"); + expect(config.color).toHaveProperty("border"); + expect(config.label.length).toBeGreaterThan(0); + } + }); +}); + +describe("getHealthLevel 函数", () => { + it("评分 85 应返回 'excellent'", () => { + expect(getHealthLevel(85)).toBe("excellent"); + }); + + it("评分 65 应返回 'good'", () => { + expect(getHealthLevel(65)).toBe("good"); + }); + + it("评分 45 应返回 'pass'", () => { + expect(getHealthLevel(45)).toBe("pass"); + }); + + it("评分 20 应返回 'danger'", () => { + expect(getHealthLevel(20)).toBe("danger"); + }); + + it("getHealthLevel 不应返回 'fair'", () => { + // 遍历所有可能的分数,确保永远不会返回 "fair" + for (let score = 0; score <= 100; score++) { + const result = getHealthLevel(score); + expect(result).not.toBe("fair"); + } + }); + + it("边界值测试:80 为 excellent 下界", () => { + expect(getHealthLevel(80)).toBe("excellent"); + }); + + it("边界值测试:79 为 good 上界", () => { + expect(getHealthLevel(79)).toBe("good"); + }); + + it("边界值测试:60 为 good 下界", () => { + expect(getHealthLevel(60)).toBe("good"); + }); + + it("边界值测试:59 为 pass 上界", () => { + expect(getHealthLevel(59)).toBe("pass"); + }); + + it("边界值测试:40 为 pass 下界", () => { + expect(getHealthLevel(40)).toBe("pass"); + }); + + it("边界值测试:39 为 danger 上界", () => { + expect(getHealthLevel(39)).toBe("danger"); + }); + + it("边界值测试:0 为 danger", () => { + expect(getHealthLevel(0)).toBe("danger"); + }); + + it("边界值测试:100 为 excellent", () => { + expect(getHealthLevel(100)).toBe("excellent"); + }); +}); + +describe("onboarding.ts 不再定义自己的 HealthLevel", () => { + it("onboarding.ts 应从 dashboard-health 导入 HealthLevel 而非自己定义", async () => { + // 读取 onboarding.ts 源码,验证不再有独立的 HealthLevel 定义 + const onboardingModule = await import("@/types/onboarding"); + // onboarding 导出的 HealthLevel 应该与 dashboard-health 的完全一致 + const dashboardModule = await import("@/types/dashboard-health"); + + // 验证 onboarding 导出的 getHealthLevel 返回 "pass" 而非 "fair" + if ("getHealthLevel" in onboardingModule) { + const onboardingGetHealthLevel = onboardingModule.getHealthLevel as ( + score: number, + ) => string; + expect(onboardingGetHealthLevel(45)).toBe("pass"); + expect(onboardingGetHealthLevel(45)).not.toBe("fair"); + } + + // 验证 onboarding 导出的 HEALTH_LEVELS 是 dashboard-health 的 HEALTH_LEVEL_CONFIG 的别名 + expect(onboardingModule.HEALTH_LEVELS).toBe(dashboardModule.HEALTH_LEVEL_CONFIG); + }); +}); diff --git a/frontend/app/(dashboard)/brands/[id]/page.tsx b/frontend/app/(dashboard)/brands/[id]/page.tsx index 079a5c5..cc36752 100644 --- a/frontend/app/(dashboard)/brands/[id]/page.tsx +++ b/frontend/app/(dashboard)/brands/[id]/page.tsx @@ -16,7 +16,13 @@ import type { CreateCompetitorRequest, CompetitorRecommendationItem, CompetitorRecommendationResponse, + BrandScoreV2Response, } from "@/types/brand"; +import { + HEALTH_LEVEL_CONFIG, + DIMENSION_COLORS, +} from "@/types/dashboard-health"; +import { getHealthLevelClassName } from "@/lib/dashboard-health"; import { ArrowLeft, Star, @@ -75,6 +81,9 @@ export default function BrandDetailPage() { const [loadingRecommendations, setLoadingRecommendations] = useState(false); const [showRecommendations, setShowRecommendations] = useState(false); + const [scoreV2, setScoreV2] = useState暂无评分数据,请先执行查询
+GEO和SEO是AI营销时代的共生体
++ GEO和SEO是AI营销时代的共生体 +
+ 解锁更多功能 +
++ {feature.label} +
+{feature.desc}
+功能开发中
-Agent状态监控即将上线
++ 功能开发中 +
++ Agent状态监控即将上线 +
+ 追踪内容发布效果,计算投资回报率 +
++ 追踪内容发布效果,计算投资回报率 +
++ 追踪内容发布效果,计算投资回报率 +
++ 追踪内容发布效果,计算投资回报率 +
+ROI
+= 0 ? "text-emerald-600" : "text-red-600" + )}> + {roi.roi_percentage >= 0 ? "+" : ""}{roi.roi_percentage}% +
+价值产出
++ ¥{roi.value_generated.toLocaleString()} +
+订阅成本
++ ¥{roi.subscription_cost.toLocaleString()} +
++ 当前套餐: {roi.current_plan} +
+分数提升
+= 0 ? "text-emerald-600" : "text-red-600" + )}> + {roi.total_score_delta >= 0 ? "+" : ""}{roi.total_score_delta} +
++ 盈亏平衡: {roi.break_even_delta} +
+制定GEO优化策略、关键词规划与目标设定
+{action.reason}
+ + {action.estimated_impact && ( +此功能正在开发中,敬请期待。
-{title}
+ {tasks.length > 0 && ( ++ 添加竞品后,AI将生成更精准的对比分析和优化建议,帮助您了解与竞品的差距并制定针对性策略。 +
+ +制定GEO优化策略、关键词规划与目标设定
+制定GEO优化策略、关键词规划与目标设定
++ 基于诊断数据,AI将为您制定个性化GEO优化方案 +
+ +制定GEO优化策略、关键词规划与目标设定
+{rec.title}
++ {rec.description} +
+{suggestion.description}
+ {actionButton && ( +- 基于您的品牌 “{brandName}” 的表现, 我们为您准备了以下优化建议 + 基于您的品牌 “{brandName}” 的表现,我们为您准备了以下优化建议
{error}
} +新用户引导
+ 系统正在分析品牌在AI搜索中的表现,请稍候 +
+{message}
+以下是该品牌在AI搜索中的综合表现
+数据来自缓存
+ )} +{rec.title}
+{rec.description}
++ 注册 GEO 平台账户,即可解锁完整健康报告、详细修复建议和持续监控功能 +
++ 免费检测您的品牌在AI搜索引擎中的表现,了解品牌可见度、推荐排名和情感倾向 +
++ 已有账户?{" "} + + 登录 + +
+投资回报率 (ROI)
++ {roiPercentage.toFixed(1)}% +
+ +创造价值
++ {formatCurrency(valueGenerated)} +
+订阅成本
++ {formatCurrency(subscriptionCost)} +
++ {description || `升级Pro版即可解锁${feature}功能`} +
+ +{feat.title}
++ {feat.description} +
++ {isCritical ? "额度即将用尽" : "额度使用已超过80%"} +
+ )} +