diff --git a/.cursor/rules/codegraph.mdc b/.cursor/rules/codegraph.mdc new file mode 100644 index 0000000..3f23cf6 --- /dev/null +++ b/.cursor/rules/codegraph.mdc @@ -0,0 +1,38 @@ +--- +description: CodeGraph MCP usage guide — when to use which tool +alwaysApply: true +--- + +## CodeGraph + +This project has a CodeGraph MCP server (`codegraph_*` tools) configured. CodeGraph is a tree-sitter-parsed knowledge graph of every symbol, edge, and file. Reads are sub-millisecond and return structural information grep cannot. + +### When to prefer codegraph over native search + +Use codegraph for **structural** questions — what calls what, what would break, where is X defined, what is X's signature. Use native grep/read only for **literal text** queries (string contents, comments, log messages) or after you already have a specific file open. + +| Question | Tool | +|---|---| +| "Where is X defined?" / "Find symbol named X" | `codegraph_search` | +| "What calls function Y?" | `codegraph_callers` | +| "What does Y call?" | `codegraph_callees` | +| "What would break if I changed Z?" | `codegraph_impact` | +| "Show me Y's signature / source / docstring" | `codegraph_node` | +| "Give me focused context for a task/area" | `codegraph_context` | +| "See several related symbols' source at once" | `codegraph_explore` | +| "What files exist under path/" | `codegraph_files` | +| "Is the index healthy?" | `codegraph_status` | + +### Rules of thumb + +- **Answer directly — don't delegate exploration.** For "how does X work" / architecture / trace questions, answer with 2-3 codegraph calls: `codegraph_context` first, then ONE `codegraph_explore` for the source of the symbols it surfaces. Codegraph IS the pre-built index, so spawning a separate file-reading sub-task/agent — or running a grep + read loop — repeats work codegraph already did and costs more for the same answer. +- **Trust codegraph results.** They come from a full AST parse. Do NOT re-verify them with grep — that's slower, less accurate, and wastes context. +- **Don't grep first** when looking up a symbol by name. `codegraph_search` is faster and returns kind + location + signature in one call. +- **Don't chain `codegraph_search` + `codegraph_node`** when you just want context — `codegraph_context` is one call. +- **Don't loop `codegraph_node` over many symbols** — one `codegraph_explore` call returns several symbols' source grouped in a single capped call, while each separate node/Read call re-reads the whole context and costs far more. +- **Index lag**: the file watcher debounces ~500ms behind writes; don't re-query immediately after editing a file in the same turn. + +### If `.codegraph/` doesn't exist + +The MCP server returns "not initialized." Ask the user: *"I notice this project doesn't have CodeGraph initialized. Want me to run `codegraph init -i` to build the index?"* + diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..c618c33 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,44 @@ +# GEO Platform Agents + +## Agent Registry + +| Agent | Name | Type | Supported Tasks | File | +|-------|------|------|----------------|------| +| CitationDetectorAgent | citation_detector | CITATION_DETECTOR | citation_detect, citation_batch | backend/app/agent_framework/agents/citation_detector.py | +| ContentGeneratorAgent | content_generator | CONTENT_GENERATOR | content_generate, content_regenerate | backend/app/agent_framework/agents/content_generator_agent.py | +| DeAIAgent | deai_agent | DEAI_AGENT | deai_process | backend/app/agent_framework/agents/deai_agent.py | +| GEOOptimizerAgent | geo_optimizer | GEO_OPTIMIZER | geo_optimize | backend/app/agent_framework/agents/geo_optimizer.py | +| MonitorAgent | monitor | PERFORMANCE_TRACKER | monitor_track, monitor_check_single | backend/app/agent_framework/agents/monitor_agent.py | +| SchemaAdvisorAgent | schema_advisor | SCHEMA_ADVISOR | schema_advise | backend/app/agent_framework/agents/schema_advisor.py | +| CompetitorAnalyzerAgent | competitor_analyzer | COMPETITOR_ANALYZER | competitor_analyze, competitor_gap_analysis | backend/app/agent_framework/agents/competitor_analyzer.py | +| TrendAgent | trend_agent | TREND_AGENT | trend_insight, trend_hotspot | backend/app/agent_framework/agents/trend_agent.py | + +## Running Agents + +### Standalone Mode (No Redis Required) +```bash +cd geo/backend +python3 -m app.agent_framework.standalone [agent_name|all] +``` + +### With Redis Queue +Agents auto-register via Registry when Redis is available. Tasks dispatched via TaskDispatcher. + +## Agent Framework Components +- BaseAgent: Abstract base class (backend/app/agent_framework/base_agent.py) +- Dispatcher: Task distribution (backend/app/agent_framework/dispatcher.py) +- Registry: Agent registration (backend/app/agent_framework/registry.py) +- Protocol: Message types and AgentType enum (backend/app/agent_framework/protocol.py) +- ConfigManager: Agent configuration (backend/app/agent_framework/config_manager.py) + +## New Agent Data Models +- MonitoringRecord + ContentBaseline (backend/app/models/monitoring.py) +- SchemaSuggestion (backend/app/models/schema_suggestion.py) +- CompetitorInsight (backend/app/models/competitor_insight.py) +- TrendInsight (backend/app/models/trend_insight.py) + +## New Agent API Endpoints +- /api/v1/monitoring - MonitorAgent API +- /api/v1/competitor - CompetitorAnalyzer API +- /api/v1/schema - SchemaAdvisor API +- /api/v1/trends - TrendAgent API diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..cb11f3e --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,58 @@ +# GEO Platform - AI Context + +## Project Overview +GEO (Generative Engine Optimization) SaaS platform that helps brands get cited by AI search engines (ChatGPT, Perplexity, DeepSeek, etc.). + +## Tech Stack +- Backend: FastAPI + SQLAlchemy 2.0 (async) + PostgreSQL 15 + Redis 7 +- Frontend: Next.js 14 + React 18 + TypeScript + Tailwind CSS + shadcn/ui +- AI: Multi-LLM provider (DeepSeek, OpenAI, etc.) + RAG knowledge base +- Infrastructure: Docker Compose + Prometheus metrics + +## Architecture +- Backend API: 33 route modules under backend/app/api/ +- Agent Framework: 8 agents (CitationDetector, ContentGenerator, DeAIAgent, GEOOptimizer, MonitorAgent, SchemaAdvisor, CompetitorAnalyzer, TrendAgent) +- Data Models: 35 models under backend/app/models/ +- Services: 80+ service files organized by domain under backend/app/services/ +- Repositories: 12 repository files under backend/app/repositories/ for data access +- Frontend: 25+ pages under frontend/app/(dashboard)/ +- Frontend API Clients: 27 modules under frontend/lib/api/ + +## Key Patterns +- Repository pattern for data access (backend/app/repositories/) +- Agent Framework: BaseAgent abstract class → concrete agents, Dispatcher + Registry + Redis Queue +- Shared BrandScoringDataService for V2 5-dimension scoring +- Content Pipeline: Generate → DeAI → GEO Optimize → HTML Output +- GEO Workflow: Diagnosis → Strategy → Plan → Content Generation → Monitoring + +## API Conventions +- All routes under /api/v1/ prefix +- Authentication via JWT Bearer token (get_current_user dependency) +- Pydantic schemas for request/response validation +- Async SQLAlchemy ORM with AsyncSession + +## Frontend Conventions +- API clients in frontend/lib/api/ using fetchWithAuth from client.ts +- Barrel export in frontend/lib/api/index.ts +- Types co-located with API modules +- Dashboard pages under frontend/app/(dashboard)/dashboard/ + +## Environment +- Backend runs on port 8000 +- Frontend runs on port 3001 (dev) +- PostgreSQL on port 5432 +- Redis on port 6379 +- .env.example has all required variables + +## Common Commands +- Start backend: cd geo/backend && python3 -m uvicorn app.main:app --reload --port 8000 +- Start frontend: cd geo/frontend && npm run dev +- Start all agents: cd geo/backend && python3 -m app.agent_framework.standalone all +- Docker: docker-compose up -d + +## Important Notes +- SEOOptimizer in code actually performs GEO optimization, not traditional SEO +- content.py (content generation) vs contents.py (content management) - different modules +- Analytics endpoints all require authentication via _get_org_id or get_current_user +- CORS: dev mode uses allow_origins=["*"], production must restrict +- Brand scoring uses V2 5-dimension system: mention_rate, recommendation_rank, sentiment_score, citation_quality, competitive_position diff --git a/backend/app/agent_framework/agents/__init__.py b/backend/app/agent_framework/agents/__init__.py index 87cc090..3275afd 100644 --- a/backend/app/agent_framework/agents/__init__.py +++ b/backend/app/agent_framework/agents/__init__.py @@ -4,10 +4,12 @@ from .citation_detector import CitationDetectorAgent from .content_generator_agent import ContentGeneratorAgent from .deai_agent import DeAIAgent from .geo_optimizer_agent import GEOOptimizerAgent +from .trend_agent import TrendAgent __all__ = [ "CitationDetectorAgent", "ContentGeneratorAgent", "DeAIAgent", "GEOOptimizerAgent", + "TrendAgent", ] diff --git a/backend/app/agent_framework/agents/citation_detector.py b/backend/app/agent_framework/agents/citation_detector.py index faaf13a..80c18b2 100644 --- a/backend/app/agent_framework/agents/citation_detector.py +++ b/backend/app/agent_framework/agents/citation_detector.py @@ -1,9 +1,9 @@ -"""CitationDetector Agent - 将现有 CitationEngine 封装为 Agent""" - import logging -import re import time -from datetime import datetime, timezone +import uuid +from datetime import datetime, timedelta, timezone + +from sqlalchemy import select from app.agent_framework.base import BaseAgent from app.agent_framework.protocol import ( @@ -16,19 +16,13 @@ from app.agent_framework.protocol import ( from app.database import AsyncSessionLocal from app.models.citation_record import CitationRecord from app.models.query import Query -from app.workers.citation_engine import CitationEngine +from app.models.query_task import QueryTask +from app.services.ai_engine.platform_bridge import execute_single_platform logger = logging.getLogger(__name__) class CitationDetectorAgent(BaseAgent): - """ - 引用检测 Agent:将现有 CitationEngine 封装为 BaseAgent 实现。 - - 支持的任务类型: - - citation_detect: 执行完整的引用检测(遍历 query 的所有平台) - - citation_detect_single: 执行单个平台的引用检测 - """ def __init__(self): super().__init__( @@ -36,7 +30,6 @@ class CitationDetectorAgent(BaseAgent): agent_type=AgentType.CITATION_DETECTOR, version="1.0.0", ) - self._engine = CitationEngine() def get_capabilities(self) -> AgentCapability: return AgentCapability( @@ -49,7 +42,6 @@ class CitationDetectorAgent(BaseAgent): ) async def execute(self, task: TaskMessage) -> TaskResult: - """执行引用检测任务""" started_at = datetime.now(timezone.utc) start_time = time.monotonic() @@ -94,17 +86,11 @@ class CitationDetectorAgent(BaseAgent): ) async def _execute_full_detect(self, task: TaskMessage) -> dict: - """ - 执行完整的引用检测(遍历 query 的所有平台)。 - input_data 需包含: query_id (str) - """ query_id = task.input_data.get("query_id") if not query_id: raise ValueError("input_data must contain 'query_id'") async with AsyncSessionLocal() as db: - from sqlalchemy import select - stmt = select(Query).where(Query.id == query_id) result = await db.execute(stmt) query = result.scalar_one_or_none() @@ -112,23 +98,75 @@ class CitationDetectorAgent(BaseAgent): if not query: raise ValueError(f"Query {query_id} not found") - # 上报进度:开始检测 await self.report_progress( task_id=task.task_id, progress=0.1, message=f"Starting citation detection for query '{query.keyword}'", ) - records = await self._engine.execute_query(query, db) + records: list[CitationRecord] = [] + platforms = query.platforms or ["wenxin", "kimi"] + brand_aliases = query.brand_aliases or [] + + for i, platform_name in enumerate(platforms): + progress = 0.1 + 0.8 * (i / len(platforms)) + await self.report_progress( + task_id=task.task_id, + progress=progress, + message=f"Detecting on platform '{platform_name}' ({i+1}/{len(platforms)})", + ) + + task_obj = await self._get_or_create_task(db, query.id, platform_name) + task_obj.status = "running" + task_obj.started_at = datetime.utcnow() + task_obj.error_message = None + await db.commit() + + try: + detect_result = await execute_single_platform( + keyword=query.keyword, + platform=platform_name, + target_brand=query.target_brand, + brand_aliases=brand_aliases, + ) + + record = CitationRecord.from_citation_result( + query_id=query.id, + platform=platform_name, + result=detect_result, + ) + db.add(record) + records.append(record) + + task_obj.status = "success" + task_obj.completed_at = datetime.utcnow() + await db.commit() + + except Exception as e: + logger.error(f"平台 {platform_name} 查询失败: {e}") + task_obj.status = "failed" + task_obj.error_message = str(e) + task_obj.completed_at = datetime.utcnow() + + record = CitationRecord.from_citation_result( + query_id=query.id, + platform=platform_name, + result={"cited": False, "raw_response": str(e)}, + ) + db.add(record) + records.append(record) + await db.commit() + + query.last_queried_at = datetime.utcnow() + query.next_query_at = self._calculate_next_query_at(query.frequency) + await db.commit() - # 上报进度:检测完成 await self.report_progress( task_id=task.task_id, progress=1.0, message=f"Detection completed: {len(records)} records found", ) - # 构建输出 record_summaries = [] for r in records: record_summaries.append({ @@ -148,10 +186,6 @@ class CitationDetectorAgent(BaseAgent): } async def _execute_single_detect(self, task: TaskMessage) -> dict: - """ - 执行单个平台的引用检测。 - input_data 需包含: keyword, platform, target_brand, brand_aliases(optional) - """ keyword = task.input_data.get("keyword") platform = task.input_data.get("platform") target_brand = task.input_data.get("target_brand") @@ -162,56 +196,118 @@ class CitationDetectorAgent(BaseAgent): "input_data must contain 'keyword', 'platform', 'target_brand'" ) - # 上报进度 await self.report_progress( task_id=task.task_id, progress=0.2, message=f"Querying platform '{platform}' with keyword '{keyword}'", ) - result = await self._engine.execute_single_platform( + result = await execute_single_platform( keyword=keyword, platform=platform, target_brand=target_brand, brand_aliases=brand_aliases, ) - # 上报进度:完成 await self.report_progress( task_id=task.task_id, progress=1.0, message=f"Single platform detection completed on '{platform}'", ) - # 清理 raw_response 以避免返回过大的数据 output = {k: v for k, v in result.items() if k != "raw_response"} return output - # ----------------------------------------------------------------------- - # 向后兼容:保留原有 CitationEngine 的同步调用接口 - # ----------------------------------------------------------------------- - async def execute_query_compat(self, query: Query, db) -> list[CitationRecord]: - """ - 向后兼容方法:供现有 scheduler 继续调用。 - 签名与 CitationEngine.execute_query 完全一致。 - """ - return await self._engine.execute_query(query, db) + records: list[CitationRecord] = [] + platforms = query.platforms or ["wenxin", "kimi"] + brand_aliases = query.brand_aliases or [] + + for platform_name in platforms: + task_obj = await self._get_or_create_task(db, query.id, platform_name) + task_obj.status = "running" + task_obj.started_at = datetime.utcnow() + task_obj.error_message = None + await db.commit() + + try: + detect_result = await execute_single_platform( + keyword=query.keyword, + platform=platform_name, + target_brand=query.target_brand, + brand_aliases=brand_aliases, + ) + + record = CitationRecord.from_citation_result( + query_id=query.id, + platform=platform_name, + result=detect_result, + ) + db.add(record) + records.append(record) + + task_obj.status = "success" + task_obj.completed_at = datetime.utcnow() + await db.commit() + + except Exception as e: + logger.error(f"平台 {platform_name} 查询失败: {e}") + task_obj.status = "failed" + task_obj.error_message = str(e) + task_obj.completed_at = datetime.utcnow() + + record = CitationRecord.from_citation_result( + query_id=query.id, + platform=platform_name, + result={"cited": False, "raw_response": str(e)}, + ) + db.add(record) + records.append(record) + await db.commit() + + query.last_queried_at = datetime.utcnow() + query.next_query_at = self._calculate_next_query_at(query.frequency) + await db.commit() + + return records async def execute_single_platform_compat( self, keyword: str, platform: str, target_brand: str, brand_aliases: list ) -> dict: - """ - 向后兼容方法:供现有 scheduler 继续调用。 - 签名与 CitationEngine.execute_single_platform 完全一致。 - """ - return await self._engine.execute_single_platform( + return await execute_single_platform( keyword=keyword, platform=platform, target_brand=target_brand, brand_aliases=brand_aliases, ) - async def close(self): - """关闭底层引擎""" - await self._engine.close() + async def _get_or_create_task(self, db, query_id: uuid.UUID, platform: str) -> QueryTask: + stmt = select(QueryTask).where( + QueryTask.query_id == query_id, + QueryTask.platform == platform, + ) + result = await db.execute(stmt) + task_obj = result.scalar_one_or_none() + + if not task_obj: + task_obj = QueryTask( + query_id=query_id, + platform=platform, + status="pending", + ) + db.add(task_obj) + await db.commit() + await db.refresh(task_obj) + + return task_obj + + @staticmethod + def _calculate_next_query_at(frequency: str | None) -> datetime: + now = datetime.utcnow() + freq_map = { + "daily": timedelta(days=1), + "weekly": timedelta(days=7), + "monthly": timedelta(days=30), + } + delta = freq_map.get(frequency or "weekly", timedelta(days=7)) + return now + delta diff --git a/backend/app/agent_framework/agents/competitor_analyzer.py b/backend/app/agent_framework/agents/competitor_analyzer.py new file mode 100644 index 0000000..1cfd023 --- /dev/null +++ b/backend/app/agent_framework/agents/competitor_analyzer.py @@ -0,0 +1,163 @@ +import logging +import time +from datetime import datetime, timezone + +from app.agent_framework.base import BaseAgent +from app.agent_framework.protocol import ( + AgentCapability, + AgentType, + TaskMessage, + TaskResult, + TaskStatus, +) +from app.services.llm import LLMFactory, LLMError + +logger = logging.getLogger(__name__) + + +class CompetitorAnalyzerAgent(BaseAgent): + + def __init__(self): + super().__init__( + name="competitor_analyzer", + agent_type=AgentType.COMPETITOR_ANALYZER, + version="1.0.0", + ) + + def get_capabilities(self) -> AgentCapability: + return AgentCapability( + agent_name=self.name, + agent_type=self.agent_type, + version=self.version, + supported_tasks=["competitor_analyze", "competitor_gap_analysis"], + max_concurrency=2, + description="竞品策略分析Agent:对比品牌与竞品的引用数据,识别差距领域,发现机会点,生成策略建议", + ) + + async def execute(self, task: TaskMessage) -> TaskResult: + started_at = datetime.now(timezone.utc) + start_time = time.monotonic() + + try: + if task.task_type == "competitor_gap_analysis": + output = await self._gap_analysis(task) + else: + output = await self._analyze(task) + + elapsed = time.monotonic() - start_time + return TaskResult( + task_id=task.task_id, + agent_name=self.name, + status=TaskStatus.COMPLETED, + output_data=output, + error_message=None, + started_at=started_at, + completed_at=datetime.now(timezone.utc), + metrics={ + "elapsed_seconds": round(elapsed, 2), + "task_type": task.task_type, + }, + ) + + except LLMError as e: + elapsed = time.monotonic() - start_time + logger.error(f"CompetitorAnalyzer LLM error on task {task.task_id}: {e}") + return TaskResult( + task_id=task.task_id, + agent_name=self.name, + status=TaskStatus.FAILED, + output_data=None, + error_message=f"LLM调用失败: {e}", + started_at=started_at, + completed_at=datetime.now(timezone.utc), + metrics={ + "elapsed_seconds": round(elapsed, 2), + "task_type": task.task_type, + }, + ) + + except Exception as e: + elapsed = time.monotonic() - start_time + logger.error(f"CompetitorAnalyzer task {task.task_id} failed: {e}") + return TaskResult( + task_id=task.task_id, + agent_name=self.name, + status=TaskStatus.FAILED, + output_data=None, + error_message=str(e), + started_at=started_at, + completed_at=datetime.now(timezone.utc), + metrics={ + "elapsed_seconds": round(elapsed, 2), + "task_type": task.task_type, + }, + ) + + async def _analyze(self, task: TaskMessage) -> dict: + from app.services.competitor.competitor_analyzer_service import CompetitorAnalyzerService + + input_data = task.input_data + brand_id = input_data.get("brand_id") + analysis_types = input_data.get("analysis_types") + period_days = input_data.get("period_days", 30) + + if not brand_id: + raise ValueError("input_data必须包含'brand_id'字段") + + await self.report_progress( + task_id=task.task_id, + progress=0.1, + message="开始竞品策略分析...", + ) + + service = CompetitorAnalyzerService() + result = await service.analyze_competitor( + brand_id=brand_id, + analysis_types=analysis_types, + period_days=period_days, + progress_callback=lambda p, m: self.report_progress( + task_id=task.task_id, progress=p, message=m, + ), + ) + + await self.report_progress( + task_id=task.task_id, + progress=1.0, + message="竞品策略分析完成", + ) + + return result + + async def _gap_analysis(self, task: TaskMessage) -> dict: + from app.services.competitor.competitor_analyzer_service import CompetitorAnalyzerService + + input_data = task.input_data + brand_id = input_data.get("brand_id") + period_days = input_data.get("period_days", 30) + + if not brand_id: + raise ValueError("input_data必须包含'brand_id'字段") + + await self.report_progress( + task_id=task.task_id, + progress=0.1, + message="开始竞品差距分析...", + ) + + service = CompetitorAnalyzerService() + result = await service.analyze_competitor( + brand_id=brand_id, + analysis_types=["citation_gap", "platform_coverage", "query_overlap"], + period_days=period_days, + progress_callback=lambda p, m: self.report_progress( + task_id=task.task_id, progress=p, message=m, + ), + ) + + await self.report_progress( + task_id=task.task_id, + progress=1.0, + message="竞品差距分析完成", + ) + + return result diff --git a/backend/app/agent_framework/agents/content_generator_agent.py b/backend/app/agent_framework/agents/content_generator_agent.py index 45d907a..7564fcb 100644 --- a/backend/app/agent_framework/agents/content_generator_agent.py +++ b/backend/app/agent_framework/agents/content_generator_agent.py @@ -2,7 +2,6 @@ import json import logging -import re import time from datetime import datetime, timezone @@ -16,6 +15,7 @@ from app.agent_framework.protocol import ( TaskStatus, ) from app.services.llm import LLMFactory, LLMError +from app.utils.json_extractor import extract_json logger = logging.getLogger(__name__) @@ -165,7 +165,7 @@ class ContentGeneratorAgent(BaseAgent): # 4. 解析JSON输出 try: - topics = json.loads(self._extract_json(response.content)) + topics = json.loads(extract_json(response.content)) except json.JSONDecodeError: topics = [{"title": response.content, "reason": "LLM输出解析失败,返回原始内容"}] @@ -277,22 +277,3 @@ class ContentGeneratorAgent(BaseAgent): logger.warning(f"RAG检索失败,跳过知识库上下文: {e}") return "暂无相关知识库内容" - - def _extract_json(self, text: str) -> str: - """从LLM响应中提取JSON(可能被markdown包裹)""" - # 尝试提取```json ... ```中的内容 - match = re.search(r'```(?:json)?\s*\n?(.*?)\n?```', text, re.DOTALL) - if match: - return match.group(1).strip() - # 尝试直接找 [ 或 { 开头的JSON - for i, c in enumerate(text): - if c in '[{': - depth = 0 - for j in range(i, len(text)): - if text[j] in '[{': - depth += 1 - elif text[j] in ']}': - depth -= 1 - if depth == 0: - return text[i : j + 1] - return text diff --git a/backend/app/agent_framework/agents/geo_optimizer_agent.py b/backend/app/agent_framework/agents/geo_optimizer_agent.py index dd05c97..586b972 100644 --- a/backend/app/agent_framework/agents/geo_optimizer_agent.py +++ b/backend/app/agent_framework/agents/geo_optimizer_agent.py @@ -2,7 +2,6 @@ import json import logging -import re import time from datetime import datetime, timezone @@ -16,6 +15,7 @@ from app.agent_framework.protocol import ( TaskStatus, ) from app.services.llm import LLMFactory, LLMError +from app.utils.json_extractor import extract_json logger = logging.getLogger(__name__) @@ -155,7 +155,7 @@ class GEOOptimizerAgent(BaseAgent): # 尝试解析JSON输出 try: - result = json.loads(self._extract_json(response.content)) + result = json.loads(extract_json(response.content)) result["usage"] = response.usage # 上报进度:完成 await self.report_progress( @@ -178,20 +178,3 @@ class GEOOptimizerAgent(BaseAgent): "changes": ["LLM输出非标准格式,已返回原始优化结果"], "usage": response.usage, } - - def _extract_json(self, text: str) -> str: - """从响应中提取JSON""" - match = re.search(r'```(?:json)?\s*\n?(.*?)\n?```', text, re.DOTALL) - if match: - return match.group(1).strip() - for i, c in enumerate(text): - if c == '{': - depth = 0 - for j in range(i, len(text)): - if text[j] == '{': - depth += 1 - elif text[j] == '}': - depth -= 1 - if depth == 0: - return text[i : j + 1] - return text diff --git a/backend/app/agent_framework/agents/monitor_agent.py b/backend/app/agent_framework/agents/monitor_agent.py new file mode 100644 index 0000000..c962c43 --- /dev/null +++ b/backend/app/agent_framework/agents/monitor_agent.py @@ -0,0 +1,231 @@ +import logging +import time +import uuid +from datetime import datetime, timezone + +from sqlalchemy import select + +from app.agent_framework.base import BaseAgent +from app.agent_framework.protocol import ( + AgentCapability, + AgentType, + TaskMessage, + TaskResult, + TaskStatus, +) +from app.database import AsyncSessionLocal +from app.models.monitoring import MonitoringRecord +from app.models.query import Query +from app.models.brand import Brand +from app.services.monitoring.monitor_service import MonitorService + +logger = logging.getLogger(__name__) + + +class MonitorAgent(BaseAgent): + + def __init__(self): + super().__init__( + name="monitor", + agent_type=AgentType.PERFORMANCE_TRACKER, + version="1.0.0", + ) + + def get_capabilities(self) -> AgentCapability: + return AgentCapability( + agent_name=self.name, + agent_type=self.agent_type, + version=self.version, + supported_tasks=["monitor_track", "monitor_check_single"], + max_concurrency=3, + description="效果追踪Agent:监测品牌引用量、情感、排名变化,生成变化报告", + ) + + async def execute(self, task: TaskMessage) -> TaskResult: + started_at = datetime.now(timezone.utc) + start_time = time.monotonic() + + try: + task_type = task.task_type + if task_type == "monitor_track": + output = await self._monitor_track(task) + elif task_type == "monitor_check_single": + output = await self._monitor_check_single(task) + else: + raise ValueError(f"不支持的任务类型: {task_type}") + + elapsed = time.monotonic() - start_time + return TaskResult( + task_id=task.task_id, + agent_name=self.name, + status=TaskStatus.COMPLETED, + output_data=output, + error_message=None, + started_at=started_at, + completed_at=datetime.now(timezone.utc), + metrics={ + "elapsed_seconds": round(elapsed, 2), + "task_type": task_type, + }, + ) + + except Exception as e: + elapsed = time.monotonic() - start_time + logger.error(f"MonitorAgent task {task.task_id} failed: {e}") + return TaskResult( + task_id=task.task_id, + agent_name=self.name, + status=TaskStatus.FAILED, + output_data=None, + error_message=str(e), + started_at=started_at, + completed_at=datetime.now(timezone.utc), + metrics={ + "elapsed_seconds": round(elapsed, 2), + "task_type": task.task_type, + }, + ) + + async def _monitor_track(self, task: TaskMessage) -> dict: + input_data = task.input_data + brand_id = input_data.get("brand_id") + if not brand_id: + raise ValueError("input_data必须包含'brand_id'字段") + + brand_id = uuid.UUID(brand_id) + + await self.report_progress( + task_id=task.task_id, + progress=0.1, + message="开始效果追踪...", + ) + + async with AsyncSessionLocal() as db: + stmt = select(Brand).where(Brand.id == brand_id) + result = await db.execute(stmt) + brand = result.scalar_one_or_none() + if not brand: + raise ValueError(f"品牌不存在: {brand_id}") + + queries_stmt = select(Query).where(Query.target_brand == brand.name) + queries_result = await db.execute(queries_stmt) + queries = list(queries_result.scalars().all()) + + if not queries: + await self.report_progress( + task_id=task.task_id, + progress=1.0, + message="品牌下无关联查询,效果追踪完成", + ) + return { + "brand_id": str(brand_id), + "brand_name": brand.name, + "total_queries": 0, + "reports": [], + } + + total_queries = len(queries) + reports = [] + service = MonitorService() + + monitoring_stmt = select(MonitoringRecord).where( + MonitoringRecord.brand_id == brand_id, + MonitoringRecord.status == "active", + ) + monitoring_result = await db.execute(monitoring_stmt) + monitoring_records = list(monitoring_result.scalars().all()) + + for idx, record in enumerate(monitoring_records): + progress = 0.2 + (0.7 * idx / max(len(monitoring_records), 1)) + await self.report_progress( + task_id=task.task_id, + progress=progress, + message=f"正在检测第 {idx + 1}/{len(monitoring_records)} 条监测记录...", + ) + + updated_record = await service.check_and_compare(db, record.id) + if updated_record: + report = await service.generate_change_report(db, updated_record.id) + if report: + reports.append(report) + + await self.report_progress( + task_id=task.task_id, + progress=1.0, + message="效果追踪完成", + ) + + return { + "brand_id": str(brand_id), + "brand_name": brand.name, + "total_queries": total_queries, + "checked_records": len(monitoring_records), + "reports": reports, + } + + async def _monitor_check_single(self, task: TaskMessage) -> dict: + input_data = task.input_data + brand_id = input_data.get("brand_id") + keyword = input_data.get("keyword") + platform = input_data.get("platform") + + if not brand_id: + raise ValueError("input_data必须包含'brand_id'字段") + + brand_id = uuid.UUID(brand_id) + + await self.report_progress( + task_id=task.task_id, + progress=0.1, + message="开始单关键词检测...", + ) + + async with AsyncSessionLocal() as db: + service = MonitorService() + + await self.report_progress( + task_id=task.task_id, + progress=0.3, + message="正在创建监测记录...", + ) + + record = await service.create_monitoring_record( + db=db, + brand_id=brand_id, + query_keywords=keyword, + platform=platform, + check_interval_hours=input_data.get("check_interval_hours", 24), + ) + + await self.report_progress( + task_id=task.task_id, + progress=0.6, + message="正在执行检测对比...", + ) + + updated_record = await service.check_and_compare(db, record.id) + + await self.report_progress( + task_id=task.task_id, + progress=0.8, + message="正在生成变化报告...", + ) + + report = None + if updated_record: + report = await service.generate_change_report(db, updated_record.id) + + await self.report_progress( + task_id=task.task_id, + progress=1.0, + message="单关键词检测完成", + ) + + return { + "record_id": str(record.id), + "brand_id": str(brand_id), + "keyword": keyword, + "platform": platform, + "change_type": updated_record.change_type if updated_record else None, + "report": report, + } diff --git a/backend/app/agent_framework/agents/schema_advisor.py b/backend/app/agent_framework/agents/schema_advisor.py new file mode 100644 index 0000000..518b281 --- /dev/null +++ b/backend/app/agent_framework/agents/schema_advisor.py @@ -0,0 +1,398 @@ +import copy +import json +import logging +import time +from datetime import datetime, timezone + +from app.agent_framework.base import BaseAgent +from app.agent_framework.prompts.schema_advisor import SCHEMA_ADVISOR_TEMPLATE +from app.agent_framework.protocol import ( + AgentCapability, + AgentType, + TaskMessage, + TaskResult, + TaskStatus, +) +from app.services.llm import LLMFactory, LLMError +from app.utils.json_extractor import extract_json + +logger = logging.getLogger(__name__) + +SCHEMA_TEMPLATES = { + "Organization": { + "@context": "https://schema.org", + "@type": "Organization", + "name": "", + "description": "", + "url": "", + "logo": "", + "sameAs": [], + "contactPoint": { + "@type": "ContactPoint", + "contactType": "customer service", + "telephone": "", + }, + }, + "Product": { + "@context": "https://schema.org", + "@type": "Product", + "name": "", + "description": "", + "brand": {"@type": "Brand", "name": ""}, + "offers": { + "@type": "Offer", + "priceCurrency": "CNY", + "availability": "https://schema.org/InStock", + }, + }, + "FAQPage": { + "@context": "https://schema.org", + "@type": "FAQPage", + "mainEntity": [ + { + "@type": "Question", + "name": "", + "acceptedAnswer": { + "@type": "Answer", + "text": "", + }, + } + ], + }, + "Article": { + "@context": "https://schema.org", + "@type": "Article", + "headline": "", + "description": "", + "author": {"@type": "Organization", "name": ""}, + "datePublished": "", + "image": "", + }, + "LocalBusiness": { + "@context": "https://schema.org", + "@type": "LocalBusiness", + "name": "", + "address": { + "@type": "PostalAddress", + "streetAddress": "", + "addressLocality": "", + "addressRegion": "", + "postalCode": "", + "addressCountry": "CN", + }, + "geo": { + "@type": "GeoCoordinates", + "latitude": "", + "longitude": "", + }, + "telephone": "", + "openingHours": "", + }, +} + +DIMENSION_SCHEMA_MAP = { + "schema_marketing": ["Organization", "LocalBusiness"], + "entity_clarity": ["Organization", "Product"], + "citation_readiness": ["FAQPage", "Article"], + "brand_visibility": ["Organization", "Product"], + "local_seo": ["LocalBusiness"], +} + +PRIORITY_THRESHOLD = { + "high": 30.0, + "medium": 60.0, +} + +DIFFICULTY_MAP = { + "Organization": "easy", + "Product": "medium", + "FAQPage": "medium", + "Article": "easy", + "LocalBusiness": "hard", +} + + +class SchemaAdvisorAgent(BaseAgent): + + def __init__(self): + super().__init__( + name="schema_advisor", + agent_type=AgentType.SCHEMA_ADVISOR, + version="1.0.0", + ) + + def get_capabilities(self) -> AgentCapability: + return AgentCapability( + agent_name=self.name, + agent_type=self.agent_type, + version=self.version, + supported_tasks=["schema_advise"], + max_concurrency=2, + description="Schema优化建议Agent:识别Schema缺失维度,生成JSON-LD结构化数据建议", + ) + + async def execute(self, task: TaskMessage) -> TaskResult: + started_at = datetime.now(timezone.utc) + start_time = time.monotonic() + + try: + output = await self._advise(task) + + elapsed = time.monotonic() - start_time + return TaskResult( + task_id=task.task_id, + agent_name=self.name, + status=TaskStatus.COMPLETED, + output_data=output, + error_message=None, + started_at=started_at, + completed_at=datetime.now(timezone.utc), + metrics={ + "elapsed_seconds": round(elapsed, 2), + "task_type": task.task_type, + }, + ) + + except LLMError as e: + elapsed = time.monotonic() - start_time + logger.error(f"SchemaAdvisor LLM error on task {task.task_id}: {e}") + return TaskResult( + task_id=task.task_id, + agent_name=self.name, + status=TaskStatus.FAILED, + output_data=None, + error_message=f"LLM调用失败: {e}", + started_at=started_at, + completed_at=datetime.now(timezone.utc), + metrics={ + "elapsed_seconds": round(elapsed, 2), + "task_type": task.task_type, + }, + ) + + except Exception as e: + elapsed = time.monotonic() - start_time + logger.error(f"SchemaAdvisor task {task.task_id} failed: {e}") + return TaskResult( + task_id=task.task_id, + agent_name=self.name, + status=TaskStatus.FAILED, + output_data=None, + error_message=str(e), + started_at=started_at, + completed_at=datetime.now(timezone.utc), + metrics={ + "elapsed_seconds": round(elapsed, 2), + "task_type": task.task_type, + }, + ) + + async def _advise(self, task: TaskMessage) -> dict: + input_data = task.input_data + brand_id = input_data.get("brand_id") + diagnosis_data = input_data.get("diagnosis_data", {}) + brand_info = input_data.get("brand_info", {}) + focus_dimensions = input_data.get("focus_dimensions") + + if not brand_id: + raise ValueError("input_data必须包含'brand_id'字段") + + await self.report_progress( + task_id=task.task_id, + progress=0.1, + message="开始Schema建议分析...", + ) + + missing_dimensions = self._identify_missing_dimensions(diagnosis_data, focus_dimensions) + + await self.report_progress( + task_id=task.task_id, + progress=0.3, + message=f"识别到{len(missing_dimensions)}个Schema缺失维度...", + ) + + matched = self._match_templates(missing_dimensions) + + await self.report_progress( + task_id=task.task_id, + progress=0.5, + message="匹配预定义模板完成,开始LLM填充...", + ) + + filled = await self._fill_with_llm(matched, brand_info) + + await self.report_progress( + task_id=task.task_id, + progress=0.8, + message="LLM填充完成,验证JSON-LD格式...", + ) + + validated = self._validate_and_sort(filled) + + await self.report_progress( + task_id=task.task_id, + progress=1.0, + message="Schema建议生成完成", + ) + + return { + "brand_id": brand_id, + "suggestions": validated, + "total": len(validated), + } + + def _identify_missing_dimensions( + self, + diagnosis_data: dict, + focus_dimensions: list[str] | None = None, + ) -> list[dict]: + dimensions = [] + dimension_scores = diagnosis_data.get("dimensions", {}) + for dim_name, dim_info in dimension_scores.items(): + if dim_name not in DIMENSION_SCHEMA_MAP: + continue + if focus_dimensions and dim_name not in focus_dimensions: + continue + score = dim_info.get("score", 0) if isinstance(dim_info, dict) else dim_info + max_score = dim_info.get("max_score", 100) if isinstance(dim_info, dict) else 100 + percentage = (score / max_score * 100) if max_score > 0 else 0 + if percentage < 80: + dimensions.append({ + "dimension": dim_name, + "current_score": round(score, 2), + "max_score": max_score, + "percentage": round(percentage, 2), + }) + if not dimensions and diagnosis_data: + overall = diagnosis_data.get("overall_score", 0) + if overall < 80: + for dim_name in DIMENSION_SCHEMA_MAP: + if focus_dimensions and dim_name not in focus_dimensions: + continue + dimensions.append({ + "dimension": dim_name, + "current_score": 0, + "max_score": 100, + "percentage": 0, + }) + return dimensions + + def _match_templates(self, missing_dimensions: list[dict]) -> list[dict]: + matched = [] + seen_types = set() + for dim in missing_dimensions: + schema_types = DIMENSION_SCHEMA_MAP.get(dim["dimension"], []) + for schema_type in schema_types: + if schema_type in seen_types: + continue + seen_types.add(schema_type) + template = SCHEMA_TEMPLATES.get(schema_type) + if template: + percentage = dim["percentage"] + if percentage < PRIORITY_THRESHOLD["high"]: + priority = "high" + elif percentage < PRIORITY_THRESHOLD["medium"]: + priority = "medium" + else: + priority = "low" + matched.append({ + "schema_type": schema_type, + "priority": priority, + "diagnosis_dimensions": { + "dimension": dim["dimension"], + "current_score": dim["current_score"], + "max_score": dim["max_score"], + "percentage": dim["percentage"], + }, + "json_ld_template": copy.deepcopy(template), + "implementation_difficulty": DIFFICULTY_MAP.get(schema_type, "medium"), + }) + return matched + + async def _fill_with_llm(self, matched: list[dict], brand_info: dict) -> list[dict]: + provider = LLMFactory.get_default() + results = [] + for item in matched: + schema_type = item["schema_type"] + try: + variables = { + "brand_name": brand_info.get("name", ""), + "brand_website": brand_info.get("website", ""), + "brand_industry": brand_info.get("industry", ""), + "schema_type": schema_type, + "diagnosis_data": json.dumps(item.get("diagnosis_dimensions", {}), ensure_ascii=False), + "existing_schemas": "无", + } + messages = SCHEMA_ADVISOR_TEMPLATE.render(variables) + response = await provider.chat( + messages, + temperature=0.3, + max_tokens=2048, + ) + filled = json.loads(extract_json(response.content)) + item["json_ld_filled"] = filled + item["estimated_impact"] = self._generate_impact_description( + schema_type, item.get("diagnosis_dimensions", {}).get("dimension", "") + ) + except (json.JSONDecodeError, LLMError, ValueError) as e: + logger.warning(f"LLM填充Schema {schema_type} 失败: {e}") + item["json_ld_filled"] = None + item["estimated_impact"] = self._generate_impact_description( + schema_type, item.get("diagnosis_dimensions", {}).get("dimension", "") + ) + results.append(item) + return results + + def _validate_json_ld(self, json_ld: dict) -> dict: + errors = [] + warnings = [] + + if not json_ld: + return {"is_valid": False, "errors": ["JSON-LD为空"], "warnings": []} + + if "@context" not in json_ld: + errors.append("缺少@context字段") + + if "@type" not in json_ld: + errors.append("缺少@type字段") + + if "@context" in json_ld and json_ld["@context"] != "https://schema.org": + warnings.append(f"@context值非标准: {json_ld.get('@context')}") + + if "@type" in json_ld and json_ld["@type"] not in SCHEMA_TEMPLATES: + warnings.append(f"@type非推荐类型: {json_ld.get('@type')}") + + try: + json.dumps(json_ld) + except (json.JSONDecodeError, TypeError) as e: + errors.append(f"JSON序列化失败: {e}") + + return { + "is_valid": len(errors) == 0, + "errors": errors, + "warnings": warnings, + } + + def _validate_and_sort(self, items: list[dict]) -> list[dict]: + validated = [] + for item in items: + json_ld_filled = item.get("json_ld_filled") + if json_ld_filled: + validation = self._validate_json_ld(json_ld_filled) + item["validation_errors"] = None if validation["is_valid"] else {"errors": validation["errors"], "warnings": validation["warnings"]} + else: + item["validation_errors"] = {"errors": ["JSON-LD填充失败"], "warnings": []} + validated.append(item) + priority_order = {"high": 0, "medium": 1, "low": 2} + validated.sort(key=lambda x: priority_order.get(x.get("priority", "medium"), 1)) + return validated + + def _generate_impact_description(self, schema_type: str, dimension: str) -> str: + impacts = { + "Organization": "增强品牌实体识别,提升AI搜索引擎对品牌的理解和引用概率", + "Product": "提升产品在搜索结果中的富摘要展示,增加点击率和引用率", + "FAQPage": "增加FAQ富摘要展示机会,提升在AI回答中的直接引用概率", + "Article": "优化文章内容的结构化表达,提升AI搜索引擎的内容理解和引用", + "LocalBusiness": "增强本地搜索可见性,提升地理位置相关查询的引用率", + } + return impacts.get(schema_type, f"提升{dimension}维度的得分和AI引用率") diff --git a/backend/app/agent_framework/agents/trend_agent.py b/backend/app/agent_framework/agents/trend_agent.py new file mode 100644 index 0000000..34b510a --- /dev/null +++ b/backend/app/agent_framework/agents/trend_agent.py @@ -0,0 +1,166 @@ +import logging +import time +from datetime import datetime, timezone + +from app.agent_framework.base import BaseAgent +from app.agent_framework.protocol import ( + AgentCapability, + AgentType, + TaskMessage, + TaskResult, + TaskStatus, +) + +logger = logging.getLogger(__name__) + + +class TrendAgent(BaseAgent): + + def __init__(self): + super().__init__( + name="trend_agent", + agent_type=AgentType.TREND_AGENT, + version="1.0.0", + ) + + def get_capabilities(self) -> AgentCapability: + return AgentCapability( + agent_name=self.name, + agent_type=self.agent_type, + version=self.version, + supported_tasks=["trend_insight", "trend_hotspot"], + max_concurrency=2, + description="趋势洞察Agent:分析品牌引用趋势、识别热点话题、推断变化原因并生成建议", + ) + + async def execute(self, task: TaskMessage) -> TaskResult: + started_at = datetime.now(timezone.utc) + start_time = time.monotonic() + + try: + if task.task_type == "trend_insight": + output = await self._trend_insight(task) + elif task.task_type == "trend_hotspot": + output = await self._trend_hotspot(task) + else: + raise ValueError(f"不支持的任务类型: {task.task_type}") + + elapsed = time.monotonic() - start_time + return TaskResult( + task_id=task.task_id, + agent_name=self.name, + status=TaskStatus.COMPLETED, + output_data=output, + error_message=None, + started_at=started_at, + completed_at=datetime.now(timezone.utc), + metrics={ + "elapsed_seconds": round(elapsed, 2), + "task_type": task.task_type, + }, + ) + + except Exception as e: + elapsed = time.monotonic() - start_time + logger.error(f"TrendAgent task {task.task_id} failed: {e}") + return TaskResult( + task_id=task.task_id, + agent_name=self.name, + status=TaskStatus.FAILED, + output_data=None, + error_message=str(e), + started_at=started_at, + completed_at=datetime.now(timezone.utc), + metrics={ + "elapsed_seconds": round(elapsed, 2), + "task_type": task.task_type, + }, + ) + + async def _trend_insight(self, task: TaskMessage) -> dict: + input_data = task.input_data + brand_id = input_data.get("brand_id") + days = input_data.get("days", 30) + platforms = input_data.get("platforms") + keywords = input_data.get("keywords") + + if not brand_id: + raise ValueError("input_data必须包含'brand_id'字段") + + await self.report_progress( + task_id=task.task_id, + progress=0.1, + message="开始趋势洞察分析...", + ) + + from app.database import AsyncSessionLocal + from app.services.trend.trend_analyzer_service import TrendAnalyzerService + + async with AsyncSessionLocal() as db: + service = TrendAnalyzerService(db) + + await self.report_progress( + task_id=task.task_id, + progress=0.3, + message="获取历史引用数据...", + ) + + await self.report_progress( + task_id=task.task_id, + progress=0.5, + message="执行时间序列分析...", + ) + + result = await service.analyze_trends( + brand_id=brand_id, + days=days, + platforms=platforms, + keywords=keywords, + ) + + await self.report_progress( + task_id=task.task_id, + progress=1.0, + message="趋势洞察分析完成", + ) + + return result + + async def _trend_hotspot(self, task: TaskMessage) -> dict: + input_data = task.input_data + brand_id = input_data.get("brand_id") + days = input_data.get("days", 30) + + if not brand_id: + raise ValueError("input_data必须包含'brand_id'字段") + + await self.report_progress( + task_id=task.task_id, + progress=0.1, + message="开始热点话题分析...", + ) + + from app.database import AsyncSessionLocal + from app.services.trend.trend_analyzer_service import TrendAnalyzerService + + async with AsyncSessionLocal() as db: + service = TrendAnalyzerService(db) + + await self.report_progress( + task_id=task.task_id, + progress=0.5, + message="检测引用量突增的关键词/话题...", + ) + + result = await service.get_hotspots( + brand_id=brand_id, + days=days, + ) + + await self.report_progress( + task_id=task.task_id, + progress=1.0, + message="热点话题分析完成", + ) + + return result diff --git a/backend/app/agent_framework/base.py b/backend/app/agent_framework/base.py index 93a07cf..8ca954d 100644 --- a/backend/app/agent_framework/base.py +++ b/backend/app/agent_framework/base.py @@ -21,6 +21,33 @@ from app.config import settings logger = logging.getLogger(__name__) +# Module-level lazy singleton for TaskDispatcher — avoids creating a new +# dispatcher on every method call while still deferring the import to +# prevent circular-dependency issues at module-load time. +_dispatcher_instance = None + + +def _get_dispatcher(): + """Return a cached TaskDispatcher instance (lazy singleton).""" + global _dispatcher_instance + if _dispatcher_instance is None: + from app.agent_framework.dispatcher import TaskDispatcher + _dispatcher_instance = TaskDispatcher(settings.REDIS_URL) + return _dispatcher_instance + + +# Module-level lazy singleton for AgentRegistry — same rationale. +_registry_instance = None + + +def _get_registry(): + """Return a cached AgentRegistry instance (lazy singleton).""" + global _registry_instance + if _registry_instance is None: + from app.agent_framework.registry import AgentRegistry + _registry_instance = AgentRegistry() + return _registry_instance + class BaseAgent(ABC): """所有 Agent 的基类,定义标准生命周期""" @@ -40,6 +67,10 @@ class BaseAgent(ABC): def status(self) -> AgentStatus: return self._status + @property + def is_distributed(self) -> bool: + return self._redis is not None + @abstractmethod async def execute(self, task: TaskMessage) -> TaskResult: """执行任务的核心逻辑,子类必须实现""" @@ -51,48 +82,47 @@ class BaseAgent(ABC): ... async def start(self): - """启动 Agent:注册到 Registry,开始监听任务队列""" logger.info(f"Starting agent '{self.name}' (type={self.agent_type}, version={self.version})") - # 初始化 Redis 连接 - self._redis = aioredis.from_url( - settings.REDIS_URL, - decode_responses=True, - ) + try: + self._redis = aioredis.from_url( + settings.REDIS_URL, + decode_responses=True, + ) + await self._redis.ping() - # 注册到 Registry - from app.agent_framework.registry import AgentRegistry + registry = _get_registry() + capability = self.get_capabilities() + await registry.register(capability, endpoint=f"agent:{self.name}") - registry = AgentRegistry() - capability = self.get_capabilities() - await registry.register(capability, endpoint=f"agent:{self.name}") + self._status = AgentStatus.ONLINE - # 更新状态 - self._status = AgentStatus.ONLINE + capability = self.get_capabilities() + max_concurrency = getattr(capability, 'max_concurrency', 1) or 1 + self._semaphore = asyncio.Semaphore(max_concurrency) + logger.info( + f"Agent '{self.name}' concurrency limit set to {max_concurrency}" + ) - # 根据 capabilities 的 max_concurrency 初始化 Semaphore - capability = self.get_capabilities() - max_concurrency = getattr(capability, 'max_concurrency', 1) or 1 - self._semaphore = asyncio.Semaphore(max_concurrency) - logger.info( - f"Agent '{self.name}' concurrency limit set to {max_concurrency}" - ) + self._heartbeat_task = asyncio.create_task(self._heartbeat_loop()) + self._listen_task = asyncio.create_task(self._listen_for_tasks()) - # 启动心跳 - self._heartbeat_task = asyncio.create_task(self._heartbeat_loop()) + logger.info(f"Agent '{self.name}' started in distributed mode") + except Exception as e: + self._redis = None + self._status = AgentStatus.ONLINE - # 开始监听任务队列 - self._listen_task = asyncio.create_task(self._listen_for_tasks()) + capability = self.get_capabilities() + max_concurrency = getattr(capability, 'max_concurrency', 1) or 1 + self._semaphore = asyncio.Semaphore(max_concurrency) - logger.info(f"Agent '{self.name}' started successfully") + logger.warning(f"Agent '{self.name}' started in local mode (Redis unavailable: {e})") async def stop(self): - """停止 Agent:注销,停止监听""" logger.info(f"Stopping agent '{self.name}'") self._status = AgentStatus.OFFLINE - # 取消监听任务 if self._listen_task and not self._listen_task.done(): self._listen_task.cancel() try: @@ -100,7 +130,6 @@ class BaseAgent(ABC): except asyncio.CancelledError: pass - # 取消心跳任务 if self._heartbeat_task and not self._heartbeat_task.done(): self._heartbeat_task.cancel() try: @@ -108,14 +137,10 @@ class BaseAgent(ABC): except asyncio.CancelledError: pass - # 注销 - from app.agent_framework.registry import AgentRegistry + if self._redis is not None: + registry = _get_registry() + await registry.unregister(self.name) - registry = AgentRegistry() - await registry.unregister(self.name) - - # 关闭 Redis 连接 - if self._redis: await self._redis.close() self._redis = None @@ -123,13 +148,10 @@ class BaseAgent(ABC): async def heartbeat(self): """定期心跳上报""" - from app.agent_framework.registry import AgentRegistry - - registry = AgentRegistry() + registry = _get_registry() await registry.update_heartbeat(self.name) async def report_progress(self, task_id: str, progress: float, message: str): - """上报任务进度""" progress_obj = TaskProgress( task_id=task_id, agent_name=self.name, @@ -138,7 +160,6 @@ class BaseAgent(ABC): updated_at=datetime.now(timezone.utc), ) - # 通过 Redis Pub/Sub 发布进度 if self._redis: try: await self._redis.publish( @@ -148,11 +169,11 @@ class BaseAgent(ABC): except Exception as e: logger.warning(f"Failed to publish progress for task {task_id}: {e}") - # 同时更新数据库 - from app.agent_framework.dispatcher import TaskDispatcher - - dispatcher = TaskDispatcher(settings.REDIS_URL) - await dispatcher.handle_progress(progress_obj) + try: + dispatcher = _get_dispatcher() + await dispatcher.handle_progress(progress_obj) + except Exception as e: + logger.warning(f"Failed to report progress to dispatcher for task {task_id}: {e}") async def _heartbeat_loop(self): """心跳循环""" @@ -198,7 +219,6 @@ class BaseAgent(ABC): await self._execute_task(task) async def _execute_task(self, task: TaskMessage): - """执行单个任务""" self._running_tasks.add(task.task_id) self._status = AgentStatus.BUSY @@ -209,15 +229,12 @@ class BaseAgent(ABC): result.started_at = started_at result.completed_at = datetime.now(timezone.utc) - # 处理结果 - from app.agent_framework.dispatcher import TaskDispatcher - - dispatcher = TaskDispatcher(settings.REDIS_URL) - await dispatcher.handle_result(result) + if self._redis is not None: + dispatcher = _get_dispatcher() + await dispatcher.handle_result(result) except Exception as e: logger.error(f"Agent '{self.name}' task {task.task_id} failed: {e}") - # 构建失败结果 error_result = TaskResult( task_id=task.task_id, agent_name=self.name, @@ -228,10 +245,9 @@ class BaseAgent(ABC): completed_at=datetime.now(timezone.utc), metrics=None, ) - from app.agent_framework.dispatcher import TaskDispatcher - - dispatcher = TaskDispatcher(settings.REDIS_URL) - await dispatcher.handle_result(error_result) + if self._redis is not None: + dispatcher = _get_dispatcher() + await dispatcher.handle_result(error_result) finally: self._running_tasks.discard(task.task_id) diff --git a/backend/app/agent_framework/prompts/__init__.py b/backend/app/agent_framework/prompts/__init__.py index 1c87d32..d6cc579 100644 --- a/backend/app/agent_framework/prompts/__init__.py +++ b/backend/app/agent_framework/prompts/__init__.py @@ -15,6 +15,7 @@ from .content_generator import CONTENT_GENERATOR_TEMPLATE from .deai_agent import DEAI_TEMPLATE from .geo_optimizer import GEO_OPTIMIZER_TEMPLATE from .rule_checker import RULE_CHECKER_TEMPLATE +from .schema_advisor import SCHEMA_ADVISOR_TEMPLATE from .topic_selector import TOPIC_SELECTOR_TEMPLATE __all__ = [ @@ -25,4 +26,5 @@ __all__ = [ "DEAI_TEMPLATE", "GEO_OPTIMIZER_TEMPLATE", "RULE_CHECKER_TEMPLATE", + "SCHEMA_ADVISOR_TEMPLATE", ] diff --git a/backend/app/agent_framework/prompts/schema_advisor.py b/backend/app/agent_framework/prompts/schema_advisor.py new file mode 100644 index 0000000..a8e1542 --- /dev/null +++ b/backend/app/agent_framework/prompts/schema_advisor.py @@ -0,0 +1,70 @@ +from .base_template import PromptSection, PromptTemplate + +SCHEMA_ADVISOR_TEMPLATE = PromptTemplate( + PromptSection( + identity="""你是一位精通Schema.org结构化数据和JSON-LD的技术专家。 +你深刻理解搜索引擎和AI模型(如ChatGPT、Perplexity、Kimi)如何解析和利用结构化数据, +知道如何通过精准的Schema标记提升品牌在AI搜索结果中的可见性和引用率。 +你生成的JSON-LD严格遵循Schema.org规范,确保可被搜索引擎正确解析。""", + + context="""## 品牌信息 +- 品牌名称:${brand_name} +- 网站:${brand_website} +- 行业:${brand_industry} + +## 诊断数据 +${diagnosis_data} + +## 已有Schema标记 +${existing_schemas} + +## 目标Schema类型 +${schema_type}""", + + instructions="""请根据以上品牌信息和诊断数据,为品牌生成完整的JSON-LD结构化数据。 + +生成要求: + +1. 内容填充: + - 所有字段必须填充真实、具体的内容,不得留空 + - 品牌名称、网站等基本信息必须与提供的数据一致 + - 描述性文本应当专业、准确,体现品牌特色 + +2. Schema类型特定要求: + - Organization: 包含name, description, url, logo, sameAs(社交媒体链接), contactPoint + - Product: 包含name, description, brand, offers, aggregateRating(如有) + - FAQPage: 生成3-5个与品牌行业相关的高质量FAQ,问题和答案需自然且信息丰富 + - Article: 包含headline, author, datePublished, description, image + - LocalBusiness: 包含name, address(完整地址结构), geo, telephone, openingHours + +3. 语言要求: + - 所有自然语言内容使用与品牌名称相同的语言 + - 技术字段(如@type, @context)保持英文 + +4. 结构完整性: + - 必须包含@context和@type + - 嵌套对象必须完整,不得省略必要子属性""", + + constraints="""## 约束条件 +- 严格遵循Schema.org规范,不得使用非标准属性 +- @context必须为"https://schema.org" +- @type必须是Schema.org定义的有效类型 +- 不得编造不存在的品牌信息(如无实际地址,LocalBusiness的address可使用占位结构) +- FAQ的问题必须是用户真实可能搜索的问题 +- 所有URL字段如无实际值,留空字符串 +- 不得在JSON-LD中包含HTML标签""", + + output_format="""## 输出格式 +请以JSON格式输出填充后的JSON-LD: + +```json +{ + "@context": "https://schema.org", + "@type": "...", + "...": "..." +} +``` + +仅输出JSON-LD对象,不要包含任何解释文字。""", + ) +) diff --git a/backend/app/agent_framework/protocol.py b/backend/app/agent_framework/protocol.py index 624dd0c..82c5e09 100644 --- a/backend/app/agent_framework/protocol.py +++ b/backend/app/agent_framework/protocol.py @@ -6,7 +6,6 @@ from enum import Enum class AgentType(str, Enum): - """Agent 类型枚举""" CITATION_DETECTOR = "citation_detector" CONTENT_GENERATOR = "content_generator" DEAI_AGENT = "deai_agent" @@ -14,6 +13,9 @@ class AgentType(str, Enum): RULE_CHECKER = "rule_checker" COMPETITOR_ANALYZER = "competitor_analyzer" PERFORMANCE_TRACKER = "performance_tracker" + SCHEMA_ADVISOR = "schema_advisor" + MONITOR_AGENT = "monitor_agent" + TREND_AGENT = "trend_agent" class TaskStatus(str, Enum): diff --git a/backend/app/agent_framework/standalone.py b/backend/app/agent_framework/standalone.py new file mode 100644 index 0000000..330b2e3 --- /dev/null +++ b/backend/app/agent_framework/standalone.py @@ -0,0 +1,62 @@ +import asyncio +import argparse +import logging +import sys + +from app.agent_framework.agents.citation_detector import CitationDetectorAgent +from app.agent_framework.agents.content_generator_agent import ContentGeneratorAgent +from app.agent_framework.agents.deai_agent import DeAIAgent +from app.agent_framework.agents.geo_optimizer_agent import GEOOptimizerAgent +from app.agent_framework.agents.monitor_agent import MonitorAgent +from app.agent_framework.agents.schema_advisor import SchemaAdvisorAgent +from app.agent_framework.agents.competitor_analyzer import CompetitorAnalyzerAgent +from app.agent_framework.agents.trend_agent import TrendAgent + +AGENTS = { + "citation_detector": CitationDetectorAgent, + "content_generator": ContentGeneratorAgent, + "deai_agent": DeAIAgent, + "geo_optimizer": GEOOptimizerAgent, + "monitor": MonitorAgent, + "schema_advisor": SchemaAdvisorAgent, + "competitor_analyzer": CompetitorAnalyzerAgent, + "trend_agent": TrendAgent, + "all": None, +} + + +async def run_agent(name: str): + if name == "all": + agents = [cls() for cls in AGENTS.values() if cls is not None] + else: + cls = AGENTS.get(name) + if cls is None: + print(f"Unknown agent: {name}") + sys.exit(1) + agents = [cls()] + + for agent in agents: + await agent.start() + + print(f"Agent(s) running: {[a.name for a in agents]}") + try: + await asyncio.Future() + except asyncio.CancelledError: + pass + finally: + for agent in agents: + await agent.stop() + + +def main(): + parser = argparse.ArgumentParser(description="Run GEO Agent(s)") + parser.add_argument("agent", choices=list(AGENTS.keys()), help="Agent name or 'all'") + parser.add_argument("--log-level", default="INFO", help="Logging level") + args = parser.parse_args() + + logging.basicConfig(level=getattr(logging, args.log_level.upper())) + asyncio.run(run_agent(args.agent)) + + +if __name__ == "__main__": + main() diff --git a/backend/app/api/ai_engines.py b/backend/app/api/ai_engines.py index 4a8f356..a6b8228 100644 --- a/backend/app/api/ai_engines.py +++ b/backend/app/api/ai_engines.py @@ -46,6 +46,7 @@ class QueryResultResponse(BaseModel): brand_context: str | None competitor_contexts: list[str] response_time_ms: int + timestamp: str model_config = {"from_attributes": True} @@ -61,6 +62,7 @@ class CitationRateResponse(BaseModel): class BatchQueryResponse(BaseModel): results: list[QueryResultResponse] citation_rate: CitationRateResponse + avg_response_time_ms: float = 0.0 def _result_to_response(r: AIQueryResult) -> QueryResultResponse: @@ -73,6 +75,7 @@ def _result_to_response(r: AIQueryResult) -> QueryResultResponse: brand_context=r.brand_context, competitor_contexts=r.competitor_contexts, response_time_ms=r.response_time_ms, + timestamp=r.timestamp.isoformat(), ) @@ -97,9 +100,16 @@ async def _execute_batch( engine_types, query, brand_name, competitor_names ) citation_rate = service.calculate_citation_rate(results) + response_results = [_result_to_response(r) for r in results] + avg_response_time = ( + sum(r.response_time_ms for r in results) / len(results) + if results + else 0.0 + ) return BatchQueryResponse( - results=[_result_to_response(r) for r in results], + results=response_results, citation_rate=CitationRateResponse(**citation_rate), + avg_response_time_ms=avg_response_time, ) diff --git a/backend/app/api/alerts.py b/backend/app/api/alerts.py index c5fa5a1..f1cf5a7 100644 --- a/backend/app/api/alerts.py +++ b/backend/app/api/alerts.py @@ -24,7 +24,7 @@ from app.schemas.alert import ( AlertSettingListResponse, AlertSettingBulkUpdate, ) -from app.services.alert_engine import AlertEngine +from app.services.alert.alert_engine import AlertEngine logger = logging.getLogger(__name__) diff --git a/backend/app/api/api_keys.py b/backend/app/api/api_keys.py index 667ca82..78b56ac 100644 --- a/backend/app/api/api_keys.py +++ b/backend/app/api/api_keys.py @@ -6,7 +6,7 @@ from pydantic import BaseModel from app.api.deps import get_current_user from app.models.user import User from app.services.api_key_manager import APIKeyManager, KeySource, KeyStatus -from app.services.smart_router import ENGINE_COST_PROFILES +from app.services.llm.smart_router import ENGINE_COST_PROFILES logger = logging.getLogger(__name__) diff --git a/backend/app/api/attribution.py b/backend/app/api/attribution.py new file mode 100644 index 0000000..31a1d6d --- /dev/null +++ b/backend/app/api/attribution.py @@ -0,0 +1,300 @@ +import logging +import uuid +from datetime import datetime + +from fastapi import APIRouter, Depends, HTTPException, status +from pydantic import BaseModel +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.api.deps import get_current_user +from app.database import get_db +from app.models.attribution_record import AttributionRecord +from app.models.brand import Brand +from app.models.diagnosis_record import DiagnosisRecord +from app.models.user import User +from app.services.attribution.attribution_engine import AttributionEngine +from app.services.attribution.roi_calculator import ROICalculator + +logger = logging.getLogger(__name__) + +router = APIRouter() + +PLAN_COSTS = { + "free": 0.0, + "starter": 99.0, + "pro": 299.0, + "enterprise": 999.0, +} + + +def _to_uuid(value: str | uuid.UUID) -> uuid.UUID: + if isinstance(value, uuid.UUID): + return value + return uuid.UUID(str(value)) + + +class StartTrackingRequest(BaseModel): + brand_id: str + content_id: str | None = None + + +class AttributionResponse(BaseModel): + id: str + brand_id: str + content_id: str | None + baseline_score: float + current_score: float | None + score_delta: float | None + status: str + roi_percentage: float | None + created_at: datetime + + model_config = {"from_attributes": True} + + +class ROIReport(BaseModel): + brand_id: str + brand_name: str + subscription_cost: float + current_plan: str + total_score_delta: float + value_generated: float + roi_percentage: float + break_even_delta: float + tracking_records: list[AttributionResponse] + ab_comparison: dict | None + + +class ABComparisonResponse(BaseModel): + brand_id: str + brand_name: str + overall_before: float + overall_after: float + overall_delta: float + dimensions: list[dict] + + +def _record_to_response(record: AttributionRecord) -> AttributionResponse: + return AttributionResponse( + id=str(record.id), + brand_id=str(record.brand_id), + content_id=str(record.content_id) if record.content_id else None, + baseline_score=record.baseline_score, + current_score=record.current_score, + score_delta=record.score_delta, + status=record.status, + roi_percentage=record.roi_percentage, + created_at=record.created_at, + ) + + +async def _get_brand_or_404( + brand_id: uuid.UUID, + current_user: User, + db: AsyncSession, +) -> Brand: + user_uuid = _to_uuid(current_user.id) + stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == user_uuid) + result = await db.execute(stmt) + brand = result.scalar_one_or_none() + if not brand: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="品牌不存在", + ) + return brand + + +@router.post("/start", response_model=AttributionResponse) +async def start_tracking( + body: StartTrackingRequest, + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + brand_id = _to_uuid(body.brand_id) + brand = await _get_brand_or_404(brand_id, current_user, db) + + content_id = _to_uuid(body.content_id) if body.content_id else None + + engine = AttributionEngine() + record = await engine.start_tracking(db, brand.id, content_id, current_user.id) + return _record_to_response(record) + + +@router.get("/brand/{brand_id}") +async def get_brand_attribution( + brand_id: uuid.UUID, + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + brand = await _get_brand_or_404(brand_id, current_user, db) + + engine = AttributionEngine() + summary = await engine.get_brand_attribution_summary(db, brand.id) + summary["records"] = [_record_to_response(r) for r in summary["records"]] + return summary + + +@router.get("/{record_id}", response_model=AttributionResponse) +async def get_attribution_record( + record_id: uuid.UUID, + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + stmt = select(AttributionRecord).where(AttributionRecord.id == record_id) + result = await db.execute(stmt) + record = result.scalar_one_or_none() + if not record: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="归因记录不存在", + ) + return _record_to_response(record) + + +@router.post("/{record_id}/check", response_model=AttributionResponse) +async def check_attribution( + record_id: uuid.UUID, + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + engine = AttributionEngine() + try: + record = await engine.check_attribution(db, record_id) + except ValueError: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="归因记录不存在", + ) + return _record_to_response(record) + + +@router.get("/roi/{brand_id}", response_model=ROIReport) +async def get_roi_report( + brand_id: uuid.UUID, + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + brand = await _get_brand_or_404(brand_id, current_user, db) + + engine = AttributionEngine() + summary = await engine.get_brand_attribution_summary(db, brand.id) + + user_plan = getattr(current_user, "plan", "free") or "free" + subscription_cost = PLAN_COSTS.get(user_plan, 0.0) + + calculator = ROICalculator() + roi_data = calculator.calculate_roi( + subscription_cost=subscription_cost, + score_delta=summary["total_score_delta"], + attribution_records=summary["records"], + ) + + ab_comparison = None + baseline_record = ( + await db.execute( + select(DiagnosisRecord) + .where( + DiagnosisRecord.brand_id == brand.id, + DiagnosisRecord.status == "completed", + ) + .order_by(DiagnosisRecord.completed_at.asc()) + .limit(1) + ) + ).scalar_one_or_none() + latest_record = ( + await db.execute( + select(DiagnosisRecord) + .where( + DiagnosisRecord.brand_id == brand.id, + DiagnosisRecord.status == "completed", + ) + .order_by(DiagnosisRecord.completed_at.desc()) + .limit(1) + ) + ).scalar_one_or_none() + + if baseline_record and latest_record and baseline_record.id != latest_record.id: + before_dims = baseline_record.result_json.get("dimensions", []) if baseline_record.result_json else [] + after_dims = latest_record.result_json.get("dimensions", []) if latest_record.result_json else [] + before_map = {d.get("name"): d for d in before_dims} + after_map = {d.get("name"): d for d in after_dims} + ab_comparison = calculator.generate_ab_comparison( + before_score=baseline_record.overall_score or 0, + after_score=latest_record.overall_score or 0, + before_dimensions=before_map, + after_dimensions=after_map, + ) + + return ROIReport( + brand_id=str(brand.id), + brand_name=brand.name, + subscription_cost=subscription_cost, + current_plan=user_plan, + total_score_delta=summary["total_score_delta"], + value_generated=roi_data["value_generated"], + roi_percentage=roi_data["roi_percentage"], + break_even_delta=roi_data["break_even_delta"], + tracking_records=[_record_to_response(r) for r in summary["records"]], + ab_comparison=ab_comparison, + ) + + +@router.get("/ab-comparison/{brand_id}", response_model=ABComparisonResponse) +async def get_ab_comparison( + brand_id: uuid.UUID, + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + brand = await _get_brand_or_404(brand_id, current_user, db) + + baseline_record = ( + await db.execute( + select(DiagnosisRecord) + .where( + DiagnosisRecord.brand_id == brand.id, + DiagnosisRecord.status == "completed", + ) + .order_by(DiagnosisRecord.completed_at.asc()) + .limit(1) + ) + ).scalar_one_or_none() + latest_record = ( + await db.execute( + select(DiagnosisRecord) + .where( + DiagnosisRecord.brand_id == brand.id, + DiagnosisRecord.status == "completed", + ) + .order_by(DiagnosisRecord.completed_at.desc()) + .limit(1) + ) + ).scalar_one_or_none() + + if not baseline_record or not latest_record: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="暂无诊断数据,无法生成A/B对比", + ) + + calculator = ROICalculator() + before_dims = baseline_record.result_json.get("dimensions", []) if baseline_record.result_json else [] + after_dims = latest_record.result_json.get("dimensions", []) if latest_record.result_json else [] + before_map = {d.get("name"): d for d in before_dims} + after_map = {d.get("name"): d for d in after_dims} + comparison = calculator.generate_ab_comparison( + before_score=baseline_record.overall_score or 0, + after_score=latest_record.overall_score or 0, + before_dimensions=before_map, + after_dimensions=after_map, + ) + + return ABComparisonResponse( + brand_id=str(brand.id), + brand_name=brand.name, + overall_before=comparison["overall_before"], + overall_after=comparison["overall_after"], + overall_delta=comparison["overall_delta"], + dimensions=comparison["dimensions"], + ) diff --git a/backend/app/api/brands.py b/backend/app/api/brands.py index fa704c0..7a2deab 100644 --- a/backend/app/api/brands.py +++ b/backend/app/api/brands.py @@ -1,5 +1,6 @@ """Brands API endpoints.""" import json +import logging import uuid from typing import Annotated @@ -14,9 +15,12 @@ from app.api.scoring import router as scoring_router from app.database import get_db from app.models.user import User from app.models.brand import Brand +from app.models.query import Query as QueryModel from app.schemas.brand import BrandCreate, BrandUpdate, BrandResponse, BrandListResponse from app.services.cache import get_cache_service, TTL_BRANDS +logger = logging.getLogger(__name__) + router = APIRouter() # Include competitors router under brands @@ -127,6 +131,28 @@ async def update_brand( # Update only provided fields update_data = brand_data.model_dump(exclude_unset=True) + + # 检测品牌名称变更,同步更新关联 Query 的 target_brand 和 brand_aliases + old_name = brand.name + new_name = update_data.get("name") + + if new_name and new_name != old_name: + queries_stmt = select(QueryModel).where( + QueryModel.user_id == current_user.id, + QueryModel.target_brand == old_name, + ) + queries_result = await db.execute(queries_stmt) + related_queries = queries_result.scalars().all() + + for query in related_queries: + query.target_brand = new_name + # 如果 brand_aliases 中包含旧名称,也同步更新 + if query.brand_aliases and old_name in query.brand_aliases: + query.brand_aliases = [new_name if a == old_name else a for a in query.brand_aliases] + + if related_queries: + logger.info(f"Brand renamed from '{old_name}' to '{new_name}', synced {len(related_queries)} queries") + for field, value in update_data.items(): setattr(brand, field, value) diff --git a/backend/app/api/citations.py b/backend/app/api/citations.py index b2fb949..beba568 100644 --- a/backend/app/api/citations.py +++ b/backend/app/api/citations.py @@ -13,7 +13,7 @@ from app.schemas.citation import ( CitationListResponse, CitationStatsResponse, ) -from app.services.citation import ( +from app.services.citation.citation import ( get_citation_stats, get_citations, ) diff --git a/backend/app/api/competitor_analysis.py b/backend/app/api/competitor_analysis.py new file mode 100644 index 0000000..6534927 --- /dev/null +++ b/backend/app/api/competitor_analysis.py @@ -0,0 +1,135 @@ +import logging +import uuid + +from fastapi import APIRouter, Depends, HTTPException, Query, status +from sqlalchemy import select, func +from sqlalchemy.ext.asyncio import AsyncSession + +from app.api.deps import get_current_user +from app.database import get_db +from app.models.user import User +from app.models.brand import Brand +from app.models.competitor_insight import CompetitorInsight +from app.schemas.competitor_insight import ( + CompetitorAnalysisRequest, + CompetitorInsightResponse, + CompetitorInsightList, + CompetitorGapSummary, +) +from app.services.competitor.competitor_analyzer_service import CompetitorAnalyzerService + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +async def _get_brand_if_owned( + brand_id: uuid.UUID, + current_user: User, + db: AsyncSession, +) -> Brand: + stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == current_user.id) + result = await db.execute(stmt) + brand = result.scalar_one_or_none() + if not brand: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="品牌不存在", + ) + return brand + + +@router.post("/analyze", response_model=CompetitorInsightList) +async def analyze_competitor( + request: CompetitorAnalysisRequest, + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + await _get_brand_if_owned(request.brand_id, current_user, db) + + service = CompetitorAnalyzerService() + try: + result = await service.analyze_competitor( + brand_id=request.brand_id, + analysis_types=request.analysis_types, + period_days=request.period_days or 30, + ) + except ValueError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e), + ) + + stmt = ( + select(CompetitorInsight) + .where(CompetitorInsight.brand_id == request.brand_id) + .order_by(CompetitorInsight.created_at.desc()) + ) + db_result = await db.execute(stmt) + insights = list(db_result.scalars().all()) + + return {"items": insights, "total": len(insights)} + + +@router.get("/brand/{brand_id}", response_model=CompetitorInsightList) +async def get_brand_insights( + brand_id: uuid.UUID, + skip: int = Query(0, ge=0), + limit: int = Query(20, ge=1, le=100), + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + await _get_brand_if_owned(brand_id, current_user, db) + + count_stmt = select(func.count()).select_from(CompetitorInsight).where( + CompetitorInsight.brand_id == brand_id, + ) + count_result = await db.execute(count_stmt) + total = count_result.scalar_one() + + stmt = ( + select(CompetitorInsight) + .where(CompetitorInsight.brand_id == brand_id) + .order_by(CompetitorInsight.created_at.desc()) + .offset(skip) + .limit(limit) + ) + result = await db.execute(stmt) + insights = list(result.scalars().all()) + + return {"items": insights, "total": total} + + +@router.get("/{insight_id}", response_model=CompetitorInsightResponse) +async def get_insight_detail( + insight_id: uuid.UUID, + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + stmt = select(CompetitorInsight).where(CompetitorInsight.id == insight_id) + result = await db.execute(stmt) + insight = result.scalar_one_or_none() + + if not insight: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="洞察不存在", + ) + + await _get_brand_if_owned(insight.brand_id, current_user, db) + + return insight + + +@router.get("/brand/{brand_id}/gap-summary", response_model=list[CompetitorGapSummary]) +async def get_gap_summary( + brand_id: uuid.UUID, + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + brand = await _get_brand_if_owned(brand_id, current_user, db) + + service = CompetitorAnalyzerService() + gap_summaries = await service.calculate_gap_score(db, brand_id, brand.name) + + return gap_summaries diff --git a/backend/app/api/competitors.py b/backend/app/api/competitors.py index 31b3d22..38eaba9 100644 --- a/backend/app/api/competitors.py +++ b/backend/app/api/competitors.py @@ -21,6 +21,7 @@ from app.schemas.competitor import ( CompetitorRecommendationItem, CompetitorRecommendationResponse, ) +from app.utils.json_extractor import extract_json logger = logging.getLogger(__name__) @@ -363,7 +364,7 @@ async def _get_llm_recommendations( raise ValueError("API返回空响应") # 提取JSON - json_str = _extract_json_from_text(content) + json_str = extract_json(content) data = json.loads(json_str) recommendations = [] @@ -393,32 +394,6 @@ async def _get_llm_recommendations( raise -def _extract_json_from_text(text: str) -> str: - """从文本中提取JSON字符串。""" - import re - - # 尝试直接解析 - try: - json.loads(text) - return text - except json.JSONDecodeError: - pass - - # 尝试从代码块中提取 - json_pattern = r'```(?:json)?\s*([\s\S]*?)\s*```' - match = re.search(json_pattern, text) - if match: - return match.group(1).strip() - - # 尝试找到第一个{到最后一个}之间的内容 - first_brace = text.find('{') - last_brace = text.rfind('}') - if first_brace != -1 and last_brace != -1 and last_brace > first_brace: - return text[first_brace:last_brace + 1] - - raise ValueError(f"无法从响应中提取JSON: {text[:200]}") - - def _get_industry_label(industry: str | None) -> str: """获取行业中文标签。""" labels = { diff --git a/backend/app/api/content.py b/backend/app/api/content.py index afda162..5549b15 100644 --- a/backend/app/api/content.py +++ b/backend/app/api/content.py @@ -1,18 +1,28 @@ -"""内容生产API - 串联Agent Pipeline""" +"""内容生产API - 串联Agent Pipeline + +业务逻辑已委托给 ContentGenerationService,API 层仅负责: +1. 请求解析与参数校验 +2. 调用服务层 +3. 格式化响应 +""" import json import logging import re +import uuid from typing import Optional from fastapi import APIRouter, Depends, HTTPException from pydantic import BaseModel -from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.api.deps import get_current_user from app.database import get_db +from app.models.brand import Brand from app.models.content import Content, ContentVersion +from app.models.diagnosis_record import DiagnosisRecord from app.models.user import User +from app.services.content.content_generation_service import ContentGenerationService +from app.services.llm import LLMError logger = logging.getLogger(__name__) @@ -29,6 +39,7 @@ class ContentGenerateRequest(BaseModel): brand_description: str = "" run_deai: bool = True run_geo: bool = True + use_agent_framework: bool = False class ContentGenerateResponse(BaseModel): @@ -41,44 +52,6 @@ class ContentGenerateResponse(BaseModel): pipeline_stages: list[dict] = [] # 每个阶段的执行结果摘要 -async def _get_knowledge_context( - db: AsyncSession, - brand_name: str, - knowledge_base_ids: list[str], - target_keyword: str, -) -> str: - """ - 从知识库检索与查询相关的上下文。 - - 如果有知识库ID,则调用 RAGService.search 获取相关内容; - 否则返回空字符串,不影响后续流程。 - """ - if not knowledge_base_ids: - return "" - - try: - from app.services.knowledge.rag_service import RAGService - rag_service = RAGService() - results = await rag_service.search( - session=db, - query=f"{brand_name} {target_keyword}" if brand_name else target_keyword, - knowledge_base_ids=knowledge_base_ids, - top_k=3, - ) - if results: - context_parts = [] - for r in results: - content = r.get("content", "") - title = r.get("document_title", "") - if content: - context_parts.append(f"[{title}] {content}") - return "\n".join(context_parts) - return "" - except Exception as e: - logger.warning(f"知识库检索失败,将不使用知识库上下文: {e}") - return "" - - @router.post("/generate", response_model=ContentGenerateResponse) async def generate_content( req: ContentGenerateRequest, @@ -88,115 +61,132 @@ async def generate_content( """ 一键生成内容(同步执行Pipeline),结果存入数据库 - 流程:ContentGenerator → DeAI → GEOOptimizer + 流程:ContentGenerator -> DeAI -> GEOOptimizer + 业务逻辑委托给 ContentGenerationService """ - from app.services.llm import LLMError, LLMFactory - from app.agent_framework.prompts import ( - CONTENT_GENERATOR_TEMPLATE, - DEAI_TEMPLATE, - GEO_OPTIMIZER_TEMPLATE, - ) - org_id = getattr(current_user, "organization_id", None) if not org_id: raise HTTPException(status_code=403, detail="用户未关联组织") - stages = [] - try: - provider = LLMFactory.get_default() - - # 获取知识库上下文 - knowledge_context = await _get_knowledge_context( - db, req.brand_name, req.knowledge_base_ids, req.target_keyword + service = ContentGenerationService() + result = await service.generate_content( + keyword=req.target_keyword, + brand_name=req.brand_name, + platform=req.target_platform, + content_style=req.content_style, + word_count=req.word_count, + knowledge_base_ids=req.knowledge_base_ids, + db=db, + user_id=current_user.id, + org_id=org_id, + run_deai=req.run_deai, + run_geo=req.run_geo, + use_agent_framework=req.use_agent_framework, ) - # Stage 1: 内容生成 - gen_variables = { - "topic_title": req.target_keyword, - "target_keyword": req.target_keyword, - "target_platform": req.target_platform, - "content_angle": "综合分析", - "content_style": req.content_style, - "word_count": str(req.word_count), - "brand_name": req.brand_name, - "knowledge_context": knowledge_context, - } - messages = CONTENT_GENERATOR_TEMPLATE.render(gen_variables) - response = await provider.chat(messages, temperature=0.7, max_tokens=req.word_count * 2) - content = response.content - stages.append({"stage": "content_generation", "status": "success", "word_count": len(content)}) - - # Stage 2: 去AI化(可选) - if req.run_deai: - deai_variables = { - "original_content": content, - "target_style": "自然流畅", - "preserve_structure": "是", - } - messages = DEAI_TEMPLATE.render(deai_variables) - response = await provider.chat(messages, temperature=0.9, max_tokens=len(content) * 2) - content = response.content - stages.append({"stage": "deai", "status": "success"}) - - # Stage 3: GEO优化(可选) - optimized = content - seo_score = None - if req.run_geo: - geo_variables = { - "original_content": content, - "target_keywords": req.target_keyword, - "target_platform": req.target_platform, - "optimization_level": "moderate", - } - messages = GEO_OPTIMIZER_TEMPLATE.render(geo_variables) - response = await provider.chat(messages, temperature=0.5, max_tokens=len(content) * 2) - optimized = response.content - stages.append({"stage": "geo_optimization", "status": "success"}) - - # ---- 存入数据库 ---- - content_obj = Content( - organization_id=org_id, - title=req.target_keyword, - content_type="article", - body=optimized, - status="draft", - target_platforms=[req.target_platform] if req.target_platform else [], - keywords=[req.target_keyword], - extra_metadata={ - "original_content": content if content != optimized else None, - "pipeline_stages": stages, - "seo_score": seo_score, - "brand_name": req.brand_name, - "content_style": req.content_style, - "word_count_target": req.word_count, - }, - created_by=current_user.id, - current_version=1, - ) - db.add(content_obj) - await db.flush() # get content_obj.id - - # 创建版本记录(初始版本) - version = ContentVersion( - content_id=content_obj.id, - version_number=1, - title=req.target_keyword, - body=optimized, - change_summary="Pipeline自动生成", - created_by=current_user.id, - ) - db.add(version) - await db.commit() - await db.refresh(content_obj) - return ContentGenerateResponse( status="success", - content=content, - optimized_content=optimized, - seo_score=seo_score, - content_id=str(content_obj.id), - pipeline_stages=stages, + content=result["content"], + optimized_content=result["optimized_content"], + seo_score=result["seo_score"], + content_id=result["content_id"], + pipeline_stages=result["pipeline_stages"], + ) + + except LLMError as e: + raise HTTPException(status_code=502, detail=f"LLM调用失败: {str(e)}") + except Exception as e: + raise HTTPException(status_code=500, detail=f"内容生成异常: {str(e)}") + + +class GEOContentGenerateRequest(BaseModel): + brand_id: str + target_keywords: list[str] + platform: str = "通用" + content_style: str = "专业严谨" + word_count: int = 2000 + knowledge_base_ids: list[str] = [] + run_deai: bool = True + run_geo: bool = True + + +class GEOContentGenerateResponse(BaseModel): + content_id: Optional[str] = None + content: str = "" + optimized_content: str = "" + seo_score: Optional[int] = None + pipeline_stages: list[dict] = [] + + +@router.post("/generate-geo", response_model=GEOContentGenerateResponse, status_code=201) +async def generate_geo_content( + req: GEOContentGenerateRequest, + db: AsyncSession = Depends(get_db), + current_user: User = Depends(get_current_user), +): + org_id = getattr(current_user, "organization_id", None) + if not org_id: + raise HTTPException(status_code=403, detail="用户未关联组织") + + from sqlalchemy import select + + try: + brand_uuid = uuid.UUID(req.brand_id) + except ValueError: + raise HTTPException(status_code=400, detail=f"Invalid brand_id format: {req.brand_id}") + + brand_stmt = select(Brand).where(Brand.id == brand_uuid) + brand_result = await db.execute(brand_stmt) + brand = brand_result.scalar_one_or_none() + if not brand: + raise HTTPException(status_code=404, detail=f"Brand not found: {req.brand_id}") + + diagnosis_context = "" + diag_stmt = ( + select(DiagnosisRecord) + .where(DiagnosisRecord.brand_id == brand_uuid, DiagnosisRecord.status == "completed") + .order_by(DiagnosisRecord.created_at.desc()) + ) + diag_result = await db.execute(diag_stmt) + diagnosis = diag_result.scalar_one_or_none() + if diagnosis and diagnosis.result_json: + result_json = diagnosis.result_json + weak_dimensions = [] + if isinstance(result_json, dict): + dimensions = result_json.get("dimensions", {}) + for dim_name, dim_data in dimensions.items(): + if isinstance(dim_data, dict) and dim_data.get("score", 100) < 60: + weak_dimensions.append(dim_name) + if weak_dimensions: + diagnosis_context = f"基于诊断结果,以下维度需要重点优化:{', '.join(weak_dimensions)}。请围绕这些维度生成针对性内容。" + + keyword = "、".join(req.target_keywords) + if diagnosis_context: + keyword = f"{keyword}({diagnosis_context})" + + try: + service = ContentGenerationService() + result = await service.generate_content( + keyword=keyword, + brand_name=brand.name, + platform=req.platform, + content_style=req.content_style, + word_count=req.word_count, + knowledge_base_ids=req.knowledge_base_ids, + db=db, + user_id=current_user.id, + org_id=org_id, + run_deai=req.run_deai, + run_geo=req.run_geo, + ) + + return GEOContentGenerateResponse( + content_id=result["content_id"], + content=result["content"], + optimized_content=result["optimized_content"], + seo_score=result["seo_score"], + pipeline_stages=result["pipeline_stages"], ) except LLMError as e: diff --git a/backend/app/api/dashboard.py b/backend/app/api/dashboard.py index 70760c9..954bedf 100644 --- a/backend/app/api/dashboard.py +++ b/backend/app/api/dashboard.py @@ -1,6 +1,5 @@ """Dashboard API endpoints.""" import uuid -from datetime import datetime, timedelta from fastapi import APIRouter, Depends, Query from sqlalchemy import select, func, Integer @@ -19,325 +18,15 @@ from app.schemas.dashboard import ( PlatformScoreItem, RecentQueryItem, ) -from app.services.scoring_service import ScoringService, get_health_level -from app.services.sentiment_service import get_sentiment_service -from app.schemas.scoring import CitationResult +from app.services.scoring.scoring_service import get_health_level +from app.services.scoring.brand_scoring_data_service import ( + get_brand_scoring_data_service, + REQUIRED_PLATFORMS, +) from app.services.cache import get_cache_service, TTL_DASHBOARD router = APIRouter() -# Required platforms for the dashboard -REQUIRED_PLATFORMS = [ - "wenxin", # 文心一言 - "kimi", # Kimi - "tongyi", # 通义千问 - "doubao", # 豆包 - "xinghuo", # 讯飞星火 - "tiangong", # 天工AI - "qingyan", # 智谱清言 -] - - -async def _get_brand_score_by_platform( - db: AsyncSession, - user_id: uuid.UUID, - brand_id: uuid.UUID, -) -> dict[str, float]: - """ - Calculate brand score by platform. - - Returns a dict mapping platform key to score (0-100). - """ - brand_stmt = select(Brand).where( - Brand.id == brand_id, - Brand.user_id == user_id, - ) - brand_result = await db.execute(brand_stmt) - brand = brand_result.scalar_one_or_none() - - if not brand: - return {platform: 0.0 for platform in REQUIRED_PLATFORMS} - - queries_stmt = select(QueryModel).where( - QueryModel.user_id == user_id, - QueryModel.target_brand == brand.name, - ) - queries_result = await db.execute(queries_stmt) - queries = list(queries_result.scalars().all()) - - if not queries: - return {platform: 0.0 for platform in REQUIRED_PLATFORMS} - - query_ids = [q.id for q in queries] - - platform_scores = {platform: 0.0 for platform in REQUIRED_PLATFORMS} - - for platform in REQUIRED_PLATFORMS: - citation_stmt = select( - func.count().label("total"), - func.sum( - func.cast( - func.case((CitationRecord.cited == True, 1), else_=0), - Integer - ) - ).label("cited") - ).where( - CitationRecord.query_id.in_(query_ids), - CitationRecord.platform == platform, - ) - result = await db.execute(citation_stmt) - row = result.one() - - total = row.total or 0 - cited = row.cited or 0 - - if total > 0: - citation_rate = cited / total - platform_scores[platform] = round(citation_rate * 100, 2) - else: - platform_scores[platform] = 0.0 - - return platform_scores - - -async def _get_competitor_scores_by_platform( - db: AsyncSession, - user_id: uuid.UUID, - brand_id: uuid.UUID, -) -> dict[str, float]: - """ - Calculate average competitor score by platform. - - Returns a dict mapping platform key to average competitor score (0-100). - """ - competitor_stmt = select(Competitor).where(Competitor.brand_id == brand_id) - competitor_result = await db.execute(competitor_stmt) - competitors = list(competitor_result.scalars().all()) - - if not competitors: - return {platform: 0.0 for platform in REQUIRED_PLATFORMS} - - competitor_names = [c.name for c in competitors] - - competitor_queries_stmt = select(QueryModel).where( - QueryModel.user_id == user_id, - QueryModel.target_brand.in_(competitor_names), - ) - competitor_queries_result = await db.execute(competitor_queries_stmt) - competitor_queries = list(competitor_queries_result.scalars().all()) - - if not competitor_queries: - return {platform: 0.0 for platform in REQUIRED_PLATFORMS} - - competitor_query_ids = [q.id for q in competitor_queries] - - platform_scores = {platform: 0.0 for platform in REQUIRED_PLATFORMS} - - for platform in REQUIRED_PLATFORMS: - citation_stmt = select( - func.count().label("total"), - func.sum( - func.cast( - func.case((CitationRecord.cited == True, 1), else_=0), - Integer - ) - ).label("cited") - ).where( - CitationRecord.query_id.in_(competitor_query_ids), - CitationRecord.platform == platform, - ) - result = await db.execute(citation_stmt) - row = result.one() - - total = row.total or 0 - cited = row.cited or 0 - - if total > 0: - citation_rate = cited / total - platform_scores[platform] = round(citation_rate * 100, 2) - else: - platform_scores[platform] = 0.0 - - return platform_scores - - -async def _calculate_overall_score_v2( - db: AsyncSession, - user_id: uuid.UUID, - brand_id: uuid.UUID, -) -> tuple[float, float, list[DimensionScoreItem], int, int]: - """ - Calculate V2 overall score, change from yesterday, dimension scores, - and competitor position. - - Returns: (overall_score, change_from_yesterday, dimensions, ahead_count, behind_count) - """ - brand_stmt = select(Brand).where( - Brand.id == brand_id, - Brand.user_id == user_id, - ) - brand_result = await db.execute(brand_stmt) - brand = brand_result.scalar_one_or_none() - - if not brand: - return 0.0, 0.0, [], 0, 0 - - # Get queries for this brand - queries_stmt = select(QueryModel).where( - QueryModel.user_id == user_id, - QueryModel.target_brand == brand.name, - ) - queries_result = await db.execute(queries_stmt) - queries = list(queries_result.scalars().all()) - - if not queries: - return 0.0, 0.0, [], 0, 0 - - query_ids = [q.id for q in queries] - - # Get all citations - citations_stmt = select(CitationRecord).where( - CitationRecord.query_id.in_(query_ids), - ) - citations_result = await db.execute(citations_stmt) - all_citations = list(citations_result.scalars().all()) - - total_queries = len(all_citations) - brand_citations = [c for c in all_citations if c.cited] - - # Get competitor data - competitor_stmt = select(Competitor).where(Competitor.brand_id == brand_id) - competitor_result = await db.execute(competitor_stmt) - competitors = list(competitor_result.scalars().all()) - competitor_names = [c.name for c in competitors] - - # Calculate competitor mentions - competitor_mentions: dict[str, int] = {} - for comp_name in competitor_names: - count = sum( - 1 for c in all_citations - if c.cited and c.competitor_brands - and comp_name in c.competitor_brands - ) - if count > 0: - competitor_mentions[comp_name] = count - - # Sentiment analysis - prefer persisted sentiment field - sentiment_service = get_sentiment_service() - sentiment_counts = {"positive": 0, "neutral": 0, "negative": 0} - for citation in brand_citations: - if citation.sentiment and citation.sentiment in ("positive", "neutral", "negative"): - sentiment_counts[citation.sentiment] += 1 - else: - content = citation.raw_response or citation.citation_text or "" - if content.strip(): - try: - result = await sentiment_service.analyze( - brand_name=brand.name, - content=content, - ) - sentiment_counts[result.sentiment] += 1 - except Exception: - sentiment_counts["neutral"] += 1 - else: - sentiment_counts["neutral"] += 1 - - # Build citation results - citation_results = [ - CitationResult( - cited=c.cited, - position=c.citation_position, - citation_text=c.citation_text, - sentiment="neutral", # simplified for dashboard - confidence=c.confidence or 0.0, - ) - for c in brand_citations - ] - - # Extract positions - positions = [c.citation_position for c in brand_citations if c.cited] - - # Calculate V2 score - scoring_service = ScoringService() - v2_result = scoring_service.calculate_v2( - mentioned_count=len(brand_citations), - total_queries=total_queries, - positions=positions, - sentiment_counts=sentiment_counts, - citations=citation_results, - brand_mentions=len(brand_citations), - competitor_mentions=competitor_mentions, - ) - - # Build dimension items - dimensions = [ - DimensionScoreItem( - name=v2_result.mention_rate.name, - score=round(v2_result.mention_rate.score, 2), - max_score=v2_result.mention_rate.max_score, - percentage=round(v2_result.mention_rate.percentage, 2), - ), - DimensionScoreItem( - name=v2_result.recommendation_rank.name, - score=round(v2_result.recommendation_rank.score, 2), - max_score=v2_result.recommendation_rank.max_score, - percentage=round(v2_result.recommendation_rank.percentage, 2), - ), - DimensionScoreItem( - name=v2_result.sentiment_score.name, - score=round(v2_result.sentiment_score.score, 2), - max_score=v2_result.sentiment_score.max_score, - percentage=round(v2_result.sentiment_score.percentage, 2), - ), - DimensionScoreItem( - name=v2_result.citation_quality.name, - score=round(v2_result.citation_quality.score, 2), - max_score=v2_result.citation_quality.max_score, - percentage=round(v2_result.citation_quality.percentage, 2), - ), - DimensionScoreItem( - name=v2_result.competitive_position.name, - score=round(v2_result.competitive_position.score, 2), - max_score=v2_result.competitive_position.max_score, - percentage=round(v2_result.competitive_position.percentage, 2), - ), - ] - - # Calculate competitor ahead/behind - ahead_count = sum( - 1 for count in competitor_mentions.values() - if len(brand_citations) > count - ) - behind_count = sum( - 1 for count in competitor_mentions.values() - if len(brand_citations) <= count - ) - - # Calculate change from yesterday - today = datetime.now().date() - yesterday = today - timedelta(days=1) - - today_citations = [ - c for c in all_citations - if c.queried_at.date() == today - ] - yesterday_citations = [ - c for c in all_citations - if c.queried_at.date() == yesterday - ] - - today_cited = sum(1 for c in today_citations if c.cited) - today_total = len(today_citations) - today_score = (today_cited / today_total * 100) if today_total > 0 else 0.0 - - yesterday_cited = sum(1 for c in yesterday_citations if c.cited) - yesterday_total = len(yesterday_citations) - yesterday_score = (yesterday_cited / yesterday_total * 100) if yesterday_total > 0 else 0.0 - - change = round(today_score - yesterday_score, 2) - - return v2_result.overall_score, change, dimensions, ahead_count, behind_count - @router.get("/stats", response_model=DashboardStatsResponse) async def get_dashboard_stats( @@ -357,7 +46,6 @@ async def get_dashboard_stats( - 最近查询记录 """ cache = get_cache_service() - # 如果 brand_id 尚未确定,先查库取第一个品牌 if brand_id is None: brand_stmt = select(Brand).where(Brand.user_id == current_user.id).limit(1) brand_result = await db.execute(brand_stmt) @@ -379,34 +67,67 @@ async def get_dashboard_stats( total_platforms=7, ) - # 尝试从缓存读取(TTL: 2 分钟) cache_key = f"dashboard:stats:{current_user.id}:{brand_id}" cached = await cache.get_json(cache_key) if cached is not None: return cached - # Get brand name brand_stmt = select(Brand).where(Brand.id == brand_id) brand_result = await db.execute(brand_stmt) brand = brand_result.scalar_one_or_none() brand_name = brand.name if brand else None - # Calculate V2 overall score and dimensions - overall_score, score_change, dimensions, ahead_count, behind_count = ( - await _calculate_overall_score_v2(db, current_user.id, brand_id) + scoring_data_service = get_brand_scoring_data_service() + + scoring_data = await scoring_data_service.get_brand_scoring_data( + db, current_user.id, brand ) - # Get platform scores - platform_scores_dict = await _get_brand_score_by_platform( + overall_score = scoring_data.v2_result.overall_score + score_change = scoring_data.change_from_yesterday + + dimensions = [ + DimensionScoreItem( + name=scoring_data.v2_result.mention_rate.name, + score=round(scoring_data.v2_result.mention_rate.score, 2), + max_score=scoring_data.v2_result.mention_rate.max_score, + percentage=round(scoring_data.v2_result.mention_rate.percentage, 2), + ), + DimensionScoreItem( + name=scoring_data.v2_result.recommendation_rank.name, + score=round(scoring_data.v2_result.recommendation_rank.score, 2), + max_score=scoring_data.v2_result.recommendation_rank.max_score, + percentage=round(scoring_data.v2_result.recommendation_rank.percentage, 2), + ), + DimensionScoreItem( + name=scoring_data.v2_result.sentiment_score.name, + score=round(scoring_data.v2_result.sentiment_score.score, 2), + max_score=scoring_data.v2_result.sentiment_score.max_score, + percentage=round(scoring_data.v2_result.sentiment_score.percentage, 2), + ), + DimensionScoreItem( + name=scoring_data.v2_result.citation_quality.name, + score=round(scoring_data.v2_result.citation_quality.score, 2), + max_score=scoring_data.v2_result.citation_quality.max_score, + percentage=round(scoring_data.v2_result.citation_quality.percentage, 2), + ), + DimensionScoreItem( + name=scoring_data.v2_result.competitive_position.name, + score=round(scoring_data.v2_result.competitive_position.score, 2), + max_score=scoring_data.v2_result.competitive_position.max_score, + percentage=round(scoring_data.v2_result.competitive_position.percentage, 2), + ), + ] + + ahead_count = scoring_data.competitor_data.get("ahead_count", 0) + behind_count = scoring_data.competitor_data.get("behind_count", 0) + + platform_scores_dict = scoring_data.platform_scores + + competitor_scores_dict = await scoring_data_service.get_competitor_platform_scores( db, current_user.id, brand_id ) - # Get competitor platform scores - competitor_scores_dict = await _get_competitor_scores_by_platform( - db, current_user.id, brand_id - ) - - # Get first competitor name for display competitor_stmt = select(Competitor).where( Competitor.brand_id == brand_id ).limit(1) @@ -424,10 +145,8 @@ async def get_dashboard_stats( for platform, score in platform_scores_dict.items() ] - # Count monitored platforms monitored = sum(1 for s in platform_scores_dict.values() if s > 0) - # Get recent queries (last 10) recent_queries_stmt = ( select(QueryModel) .where(QueryModel.user_id == current_user.id) @@ -437,7 +156,6 @@ async def get_dashboard_stats( recent_queries_result = await db.execute(recent_queries_stmt) recent_queries_list = list(recent_queries_result.scalars().all()) - # Get citation count for each query recent_queries = [] for query in recent_queries_list: citation_count_stmt = select( @@ -460,7 +178,6 @@ async def get_dashboard_stats( queried_at=query.last_queried_at or query.created_at, )) - # Health level health_level = get_health_level(overall_score) response = DashboardStatsResponse( @@ -477,7 +194,6 @@ async def get_dashboard_stats( brand_name=brand_name, ) - # 将结果写入缓存(TTL: 2 分钟) await cache.set_json( cache_key, response.model_dump(mode="json"), diff --git a/backend/app/api/detection.py b/backend/app/api/detection.py index 168e1a0..83c32bf 100644 --- a/backend/app/api/detection.py +++ b/backend/app/api/detection.py @@ -16,7 +16,7 @@ from app.schemas.detection_task import ( DetectionTaskUpdate, DetectionTriggerResponse, ) -from app.services.detection_scheduler import DetectionSchedulerService, TaskNotFoundError +from app.services.detection.detection_scheduler import DetectionSchedulerService, TaskNotFoundError logger = logging.getLogger(__name__) diff --git a/backend/app/api/diagnosis.py b/backend/app/api/diagnosis.py index 2a6af11..a00a0fa 100644 --- a/backend/app/api/diagnosis.py +++ b/backend/app/api/diagnosis.py @@ -1,6 +1,8 @@ """诊断API端点 - 提供SEO和GEO诊断功能""" +import asyncio import logging import uuid +from datetime import UTC, datetime from typing import Annotated from fastapi import APIRouter, Depends, HTTPException, status @@ -11,13 +13,32 @@ from app.api.deps import get_current_user from app.database import get_db from app.models.user import User from app.models.brand import Brand -from app.services.seo_diagnosis import SEODiagnosisService -from app.services.geo_diagnosis import GEODiagnosisService, GEODiagnosisInput +from app.models.diagnosis_record import DiagnosisRecord +from app.schemas.diagnosis import ( + GEODiagnosisHistoryItem, + GEODiagnosisHistoryResponse, + GEODiagnosisResponse, + GEODiagnosisResultResponse, + GEODiagnosisTaskResponse, + GEODiagnosisTriggerRequest, +) +from app.services.diagnosis.data_collector import DataCollectorService +from app.services.diagnosis.seo_diagnosis import SEODiagnosisService +from app.services.diagnosis.geo_diagnosis import GEODiagnosisService, GEODiagnosisInput +from app.utils.health import get_health_level_label logger = logging.getLogger(__name__) router = APIRouter() +_FREE_TIER_DIMENSIONS = {"内容可提取性", "E-E-A-T信号", "引用就绪度"} + + +def _to_uuid(value: str | uuid.UUID) -> uuid.UUID: + if isinstance(value, uuid.UUID): + return value + return uuid.UUID(str(value)) + @router.get("/seo/{brand_id}") async def get_seo_diagnosis( @@ -25,24 +46,14 @@ async def get_seo_diagnosis( current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): - """ - 获取品牌的SEO诊断结果 - - 返回5维度SEO诊断: - - 技术SEO (25分) - - 页面SEO (20分) - - 内容质量 (20分) - - 外链分析 (15分) - - 用户体验 (20分) - """ brand = await _get_brand_or_404(brand_id, current_user, db) - + try: service = SEODiagnosisService() result = service.diagnose() - + logger.info(f"SEO诊断完成: brand_id={brand_id}, brand={brand.name}, score={result.overall_score}") - + return result.to_dict() except Exception as e: logger.error(f"SEO诊断失败: brand_id={brand_id}, error={str(e)}", exc_info=True) @@ -52,39 +63,158 @@ async def get_seo_diagnosis( ) -@router.get("/geo/{brand_id}") -async def get_geo_diagnosis( +@router.post("/geo/{brand_id}", status_code=status.HTTP_202_ACCEPTED) +async def trigger_geo_diagnosis( brand_id: uuid.UUID, + body: GEODiagnosisTriggerRequest = GEODiagnosisTriggerRequest(), current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): - """ - 获取品牌的GEO诊断结果 - - 返回6维度GEO诊断: - - 内容可提取性 (20分) - - 实体清晰度 (15分) - - E-E-A-T信号 (20分) - - Schema标记 (15分) - - 主题权威 (15分) - - 引用就绪度 (15分) - """ brand = await _get_brand_or_404(brand_id, current_user, db) - - try: - input_data = GEODiagnosisInput() - service = GEODiagnosisService() - result = service.diagnose(input_data) - - logger.info(f"GEO诊断完成: brand_id={brand_id}, brand={brand.name}, score={result.overall_score}") - - return result.to_dict() - except Exception as e: - logger.error(f"GEO诊断失败: brand_id={brand_id}, error={str(e)}", exc_info=True) - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="GEO诊断服务异常,请稍后重试", + + if not body.force_refresh: + existing = await _find_recent_completed(db, brand_id, hours=24) + if existing: + return GEODiagnosisTaskResponse( + task_id=str(existing.id), + brand_id=str(brand_id), + status="completed", + ) + + record = DiagnosisRecord( + brand_id=brand_id, + user_id=_to_uuid(current_user.id), + diagnosis_type="geo", + status="pending", + ) + db.add(record) + await db.commit() + await db.refresh(record) + + asyncio.create_task( + _run_geo_diagnosis( + record_id=record.id, + brand_id=brand_id, + brand_name=brand.name, + brand_aliases=brand.aliases or [], + website=brand.website, + industry=brand.industry, + user_id=_to_uuid(current_user.id), ) + ) + + return GEODiagnosisTaskResponse( + task_id=str(record.id), + brand_id=str(brand_id), + status="pending", + ) + + +@router.get("/geo/{brand_id}/result") +async def get_geo_diagnosis_result( + brand_id: uuid.UUID, + task_id: str | None = None, + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + await _get_brand_or_404(brand_id, current_user, db) + + if task_id: + try: + tid = uuid.UUID(task_id) + except ValueError: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="无效的task_id", + ) + stmt = select(DiagnosisRecord).where( + DiagnosisRecord.id == tid, + DiagnosisRecord.brand_id == brand_id, + ) + else: + stmt = ( + select(DiagnosisRecord) + .where( + DiagnosisRecord.brand_id == brand_id, + DiagnosisRecord.diagnosis_type == "geo", + ) + .order_by(DiagnosisRecord.created_at.desc()) + .limit(1) + ) + + result = await db.execute(stmt) + record = result.scalar_one_or_none() + + if not record: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="诊断记录不存在", + ) + + if record.status == "pending" or record.status == "running": + return GEODiagnosisResultResponse( + task_id=str(record.id), + brand_id=str(brand_id), + status=record.status, + ) + + if record.status == "failed": + return GEODiagnosisResultResponse( + task_id=str(record.id), + brand_id=str(brand_id), + status="failed", + error=record.error_message, + ) + + user_plan = getattr(current_user, "plan", None) or "free" + is_paid = user_plan not in ("free", None) + diagnosis_resp = _build_diagnosis_response(record.result_json, is_paid) + + return GEODiagnosisResultResponse( + task_id=str(record.id), + brand_id=str(brand_id), + status="completed", + result=diagnosis_resp, + ) + + +@router.get("/geo/{brand_id}/history") +async def get_geo_diagnosis_history( + brand_id: uuid.UUID, + limit: int = 10, + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + await _get_brand_or_404(brand_id, current_user, db) + + stmt = ( + select(DiagnosisRecord) + .where( + DiagnosisRecord.brand_id == brand_id, + DiagnosisRecord.diagnosis_type == "geo", + DiagnosisRecord.status == "completed", + ) + .order_by(DiagnosisRecord.created_at.desc()) + .limit(limit) + ) + result = await db.execute(stmt) + records = result.scalars().all() + + items = [ + GEODiagnosisHistoryItem( + task_id=str(r.id), + overall_score=r.overall_score or 0, + health_level=r.result_json.get("health_level", "danger") if r.result_json else "danger", + created_at=r.created_at.isoformat(), + completed_at=r.completed_at.isoformat() if r.completed_at else None, + ) + for r in records + ] + + return GEODiagnosisHistoryResponse( + brand_id=str(brand_id), + history=items, + ) @router.get("/combined/{brand_id}") @@ -93,29 +223,32 @@ async def get_combined_diagnosis( current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): - """ - 获取品牌的综合诊断结果 - - 结合SEO和GEO诊断,返回综合评分和详细诊断结果 - """ brand = await _get_brand_or_404(brand_id, current_user, db) - + try: seo_service = SEODiagnosisService() seo_result = seo_service.diagnose() - + + collector = DataCollectorService(db) + collection = await collector.collect( + brand_name=brand.name, + brand_aliases=brand.aliases or [], + website=brand.website, + industry=brand.industry, + ) + geo_service = GEODiagnosisService() - geo_result = geo_service.diagnose(GEODiagnosisInput()) - + geo_result = geo_service.diagnose(collection.diagnosis_input) + combined_score = round((seo_result.overall_score + geo_result.overall_score) / 2, 2) - + logger.info( f"综合诊断完成: brand_id={brand_id}, brand={brand.name}, " f"seo_score={seo_result.overall_score}, " f"geo_score={geo_result.overall_score}, " f"combined_score={combined_score}" ) - + return { "seo_score": seo_result.overall_score, "geo_score": geo_result.overall_score, @@ -131,20 +264,131 @@ async def get_combined_diagnosis( ) +async def _run_geo_diagnosis( + record_id: uuid.UUID, + brand_id: uuid.UUID, + brand_name: str, + brand_aliases: list[str], + website: str | None, + industry: str | None, + user_id: uuid.UUID, +) -> None: + from app.database import AsyncSessionLocal + + async with AsyncSessionLocal() as db: + try: + stmt = select(DiagnosisRecord).where(DiagnosisRecord.id == record_id) + result = await db.execute(stmt) + record = result.scalar_one_or_none() + if not record: + return + + record.status = "running" + await db.commit() + + collector = DataCollectorService(db) + collection = await collector.collect( + brand_name=brand_name, + brand_aliases=brand_aliases, + website=website, + industry=industry, + ) + + geo_service = GEODiagnosisService() + diagnosis_result = geo_service.diagnose(collection.diagnosis_input) + + record.status = "completed" + record.overall_score = diagnosis_result.overall_score + record.result_json = diagnosis_result.to_dict() + record.completed_at = datetime.now(UTC) + record.collection_metadata = collection.metadata + if collection.errors: + record.collection_metadata["errors"] = collection.errors + + await db.commit() + + logger.info( + f"GEO诊断完成: brand_id={brand_id}, brand={brand_name}, " + f"score={diagnosis_result.overall_score}" + ) + + except Exception as e: + logger.error(f"GEO诊断任务失败: record_id={record_id}, error={e}", exc_info=True) + try: + stmt = select(DiagnosisRecord).where(DiagnosisRecord.id == record_id) + result = await db.execute(stmt) + record = result.scalar_one_or_none() + if record: + record.status = "failed" + record.error_message = str(e) + await db.commit() + except Exception: + logger.error(f"更新失败状态也失败: record_id={record_id}") + + +def _build_diagnosis_response(result_json: dict | None, is_paid: bool) -> GEODiagnosisResponse: + if not result_json: + return GEODiagnosisResponse( + overall_score=0, + health_level="danger", + health_level_label=get_health_level_label("danger"), + dimensions=[], + recommendations=[], + is_full_report=is_paid, + ) + + dimensions = result_json.get("dimensions", []) + if not is_paid: + dimensions = [d for d in dimensions if d.get("name") in _FREE_TIER_DIMENSIONS] + + recommendations = result_json.get("recommendations", []) + if not is_paid: + recommendations = [r for r in recommendations if r.get("priority") == "P0"] + + return GEODiagnosisResponse( + overall_score=result_json.get("overall_score", 0), + health_level=result_json.get("health_level", "danger"), + health_level_label=result_json.get("health_level_label", get_health_level_label("danger")), + dimensions=dimensions, + recommendations=recommendations, + is_full_report=is_paid, + ) + + +async def _find_recent_completed( + db: AsyncSession, brand_id: uuid.UUID, hours: int = 24 +) -> DiagnosisRecord | None: + from datetime import timedelta + + cutoff = datetime.now(UTC) - timedelta(hours=hours) + stmt = ( + select(DiagnosisRecord) + .where( + DiagnosisRecord.brand_id == brand_id, + DiagnosisRecord.status == "completed", + DiagnosisRecord.completed_at >= cutoff, + ) + .order_by(DiagnosisRecord.completed_at.desc()) + .limit(1) + ) + result = await db.execute(stmt) + return result.scalar_one_or_none() + + async def _get_brand_or_404( brand_id: uuid.UUID, current_user: User, db: AsyncSession, ) -> Brand: - """获取品牌或抛出404异常""" - stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == current_user.id) + user_uuid = _to_uuid(current_user.id) + stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == user_uuid) result = await db.execute(stmt) brand = result.scalar_one_or_none() - + if not brand: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="品牌不存在", ) - + return brand diff --git a/backend/app/api/distribution.py b/backend/app/api/distribution.py index f01d92c..4361e68 100644 --- a/backend/app/api/distribution.py +++ b/backend/app/api/distribution.py @@ -4,6 +4,7 @@ import uuid from datetime import datetime from fastapi import APIRouter, Depends, HTTPException, status +from pydantic import BaseModel, Field from sqlalchemy.ext.asyncio import AsyncSession from app.api.deps import get_current_user @@ -25,7 +26,9 @@ from app.schemas.distribution import ( ) from app.services.distribution.formatter import ContentFormatter from app.services.distribution.platform_rules import PLATFORM_RULES, PlatformRuleEngine +from app.services.distribution.publish_engine import PublishEngine from app.services.distribution.publish_strategy import PublishStrategyService +from app.services.distribution.publishers.base import PublishResult router = APIRouter() @@ -235,3 +238,72 @@ async def create_schedule( tips=tips, created_at=schedule.created_at.strftime("%Y-%m-%d %H:%M:%S"), ) + + +class PublishRequest(BaseModel): + content_id: str = Field(min_length=1) + platforms: list[str] = Field(min_length=1) + + +class PublishStatusItem(BaseModel): + platform: str + status: str + article_url: str | None = None + published_at: str | None = None + + +class PublishStatusResponse(BaseModel): + platforms: list[PublishStatusItem] + + +class PublishResponse(BaseModel): + results: list[PublishResult] + + +_publish_engine = PublishEngine() + + +@router.post("/publish", response_model=PublishResponse) +async def publish_content( + req: PublishRequest, + db: AsyncSession = Depends(get_db), + current_user: User = Depends(get_current_user), +): + org_id = current_user.organization_id + if not org_id: + raise HTTPException(status_code=403, detail="用户未关联组织") + + try: + results = await _publish_engine.publish_content( + content_id=req.content_id, + platforms=req.platforms, + db=db, + user_id=str(current_user.id), + org_id=str(org_id), + ) + except ValueError as e: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) + + return PublishResponse(results=results) + + +@router.get("/publish/{content_id}/status", response_model=PublishStatusResponse) +async def get_publish_status( + content_id: str, + db: AsyncSession = Depends(get_db), + current_user: User = Depends(get_current_user), +): + try: + status_list = await _publish_engine.get_publish_status( + content_id=content_id, + db=db, + ) + except ValueError: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Content not found: {content_id}", + ) + + return PublishStatusResponse( + platforms=[PublishStatusItem(**s) for s in status_list] + ) diff --git a/backend/app/api/health_score.py b/backend/app/api/health_score.py new file mode 100644 index 0000000..0b9d6c1 --- /dev/null +++ b/backend/app/api/health_score.py @@ -0,0 +1,144 @@ +import hashlib +import logging + +from fastapi import APIRouter, Depends, Query +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.database import get_db +from app.models.brand import Brand +from app.schemas.health_score import ( + HealthScoreDimension, + HealthScoreRecommendation, + HealthScoreResponse, +) +from app.services.cache import get_cache_service +from app.services.diagnosis.data_collector import DataCollectorService +from app.services.diagnosis.geo_diagnosis import GEODiagnosisService +from app.utils.health import get_health_level, get_health_level_label + +logger = logging.getLogger(__name__) + +router = APIRouter() + +_FREE_TIER_DIMENSIONS = {"内容可提取性", "E-E-A-T信号", "引用就绪度"} + +_CACHE_TTL = 86400 + + +def _build_default_response(brand_name: str) -> HealthScoreResponse: + return HealthScoreResponse( + brand_name=brand_name, + overall_score=0.0, + health_level="danger", + health_level_label=get_health_level_label("danger"), + dimensions=[ + HealthScoreDimension( + name=d, + score=0.0, + max_score=0.0, + percentage=0.0, + status="fail", + ) + for d in sorted(_FREE_TIER_DIMENSIONS) + ], + recommendations=[], + is_full_report=False, + cached=False, + ) + + +@router.get("/health-score", response_model=HealthScoreResponse) +async def get_public_health_score( + brand: str = Query(..., min_length=1), + competitors: str = Query(default=""), + db: AsyncSession = Depends(get_db), +): + cache = get_cache_service() + cache_key = f"health_score:{hashlib.md5(brand.lower().encode()).hexdigest()}" + + cached_data = await cache.get_json(cache_key) + if cached_data is not None: + cached_data["cached"] = True + return cached_data + + brand_name = brand.strip() + competitor_list = [c.strip() for c in competitors.split(",") if c.strip()][:3] + + brand_aliases: list[str] = [] + website: str | None = None + industry: str | None = None + + stmt = select(Brand).where(Brand.name == brand_name).limit(1) + result = await db.execute(stmt) + brand_record = result.scalar_one_or_none() + + if brand_record: + brand_aliases = brand_record.aliases or [] + website = brand_record.website + industry = brand_record.industry + + try: + collector = DataCollectorService(db) + collection = await collector.collect( + brand_name=brand_name, + brand_aliases=brand_aliases, + website=website, + industry=industry, + ) + + geo_service = GEODiagnosisService() + diagnosis_result = geo_service.diagnose(collection.diagnosis_input) + except Exception as e: + logger.error(f"健康分数据采集失败: brand={brand_name}, error={e}", exc_info=True) + default_resp = _build_default_response(brand_name) + await cache.set_json(cache_key, default_resp.model_dump(mode="json"), expire=_CACHE_TTL) + return default_resp + + all_dimensions = diagnosis_result.to_dict().get("dimensions", []) + free_dimensions = [d for d in all_dimensions if d.get("name") in _FREE_TIER_DIMENSIONS] + + overall_score = round(sum(d.get("score", 0.0) for d in free_dimensions), 2) + health_level = get_health_level(overall_score) + + dimension_responses = [ + HealthScoreDimension( + name=d.get("name", ""), + score=round(d.get("score", 0.0), 2), + max_score=d.get("max_score", 0.0), + percentage=round(d.get("percentage", 0.0), 2), + status=d.get("status", "fail"), + ) + for d in free_dimensions + ] + + all_recommendations = diagnosis_result.to_dict().get("recommendations", []) + free_recommendations = [ + r for r in all_recommendations + if r.get("priority") == "P0" and r.get("dimension") in _FREE_TIER_DIMENSIONS + ] + + recommendation_responses = [ + HealthScoreRecommendation( + priority=r.get("priority", "P0"), + dimension=r.get("dimension", ""), + title=r.get("title", ""), + description=r.get("description", ""), + ) + for r in free_recommendations + ] + + response = HealthScoreResponse( + brand_name=brand_name, + overall_score=overall_score, + health_level=health_level, + health_level_label=get_health_level_label(health_level), + dimensions=dimension_responses, + recommendations=recommendation_responses, + is_full_report=False, + cached=False, + ) + + await cache.set_json(cache_key, response.model_dump(mode="json"), expire=_CACHE_TTL) + + return response diff --git a/backend/app/api/monitoring.py b/backend/app/api/monitoring.py new file mode 100644 index 0000000..69f411c --- /dev/null +++ b/backend/app/api/monitoring.py @@ -0,0 +1,207 @@ +import uuid + +from fastapi import APIRouter, Depends, HTTPException, Query, status +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.api.deps import get_current_user +from app.database import get_db +from app.models.user import User +from app.models.brand import Brand +from app.models.monitoring import MonitoringRecord +from app.schemas.monitoring import ( + MonitoringRecordCreate, + MonitoringRecordResponse, + MonitoringRecordList, + MonitoringChangeReport, + MonitoringStatusUpdate, +) +from app.services.monitoring.monitor_service import MonitorService + +router = APIRouter() + + +async def _get_brand_with_access( + brand_id: uuid.UUID, + db: AsyncSession, + current_user: User, +) -> Brand: + stmt = select(Brand).where( + Brand.id == brand_id, + Brand.user_id == current_user.id, + ) + result = await db.execute(stmt) + brand = result.scalar_one_or_none() + + if not brand: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="品牌不存在", + ) + return brand + + +def _record_to_response(record: MonitoringRecord) -> MonitoringRecordResponse: + return MonitoringRecordResponse( + id=record.id, + brand_id=record.brand_id, + content_id=record.content_id, + query_keywords=record.query_keywords, + platform=record.platform, + baseline_citation_count=record.baseline_citation_count, + baseline_sentiment=record.baseline_sentiment, + baseline_rank=record.baseline_rank, + current_citation_count=record.current_citation_count, + current_sentiment=record.current_sentiment, + current_rank=record.current_rank, + change_type=record.change_type, + change_details=record.change_details, + check_interval_hours=record.check_interval_hours, + last_checked_at=record.last_checked_at, + next_check_at=record.next_check_at, + status=record.status, + created_at=record.created_at, + updated_at=record.updated_at, + ) + + +@router.post("/tasks", response_model=MonitoringRecordResponse) +async def create_monitoring_task( + request: MonitoringRecordCreate, + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + await _get_brand_with_access(request.brand_id, db, current_user) + + service = MonitorService() + record = await service.create_monitoring_record( + db=db, + brand_id=request.brand_id, + content_id=request.content_id, + query_keywords=request.query_keywords, + platform=request.platform, + check_interval_hours=request.check_interval_hours, + ) + return _record_to_response(record) + + +@router.get("/brand/{brand_id}", response_model=MonitoringRecordList) +async def get_brand_monitoring( + brand_id: uuid.UUID, + skip: int = Query(0, ge=0), + limit: int = Query(20, ge=1, le=100), + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + await _get_brand_with_access(brand_id, db, current_user) + + service = MonitorService() + records, total = await service.get_brand_monitoring( + db=db, + brand_id=brand_id, + skip=skip, + limit=limit, + ) + return MonitoringRecordList( + records=[_record_to_response(r) for r in records], + total=total, + ) + + +@router.get("/{record_id}/report", response_model=MonitoringChangeReport) +async def get_change_report( + record_id: uuid.UUID, + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + stmt = select(MonitoringRecord).where(MonitoringRecord.id == record_id) + result = await db.execute(stmt) + record = result.scalar_one_or_none() + + if not record: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="监测记录不存在", + ) + + await _get_brand_with_access(record.brand_id, db, current_user) + + service = MonitorService() + report = await service.generate_change_report(db, record_id) + if not report: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="变化报告不存在", + ) + + return MonitoringChangeReport( + monitoring_record_id=uuid.UUID(report["monitoring_record_id"]), + brand_id=uuid.UUID(report["brand_id"]), + change_type=report["change_type"], + change_details=report["change_details"], + baseline=report["baseline"], + current=report["current"], + recommendations=report["recommendations"], + ) + + +@router.put("/{record_id}/status", response_model=MonitoringRecordResponse) +async def update_monitoring_status( + record_id: uuid.UUID, + status_update: MonitoringStatusUpdate, + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + valid_statuses = {"active", "paused", "completed"} + if status_update.status not in valid_statuses: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"无效的状态值,支持: {', '.join(valid_statuses)}", + ) + + stmt = select(MonitoringRecord).where(MonitoringRecord.id == record_id) + result = await db.execute(stmt) + record = result.scalar_one_or_none() + + if not record: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="监测记录不存在", + ) + + await _get_brand_with_access(record.brand_id, db, current_user) + + record.status = status_update.status + await db.commit() + await db.refresh(record) + + return _record_to_response(record) + + +@router.post("/{record_id}/check", response_model=MonitoringRecordResponse) +async def trigger_manual_check( + record_id: uuid.UUID, + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + stmt = select(MonitoringRecord).where(MonitoringRecord.id == record_id) + result = await db.execute(stmt) + record = result.scalar_one_or_none() + + if not record: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="监测记录不存在", + ) + + await _get_brand_with_access(record.brand_id, db, current_user) + + service = MonitorService() + updated_record = await service.check_and_compare(db, record_id) + if not updated_record: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="检测执行失败", + ) + + return _record_to_response(updated_record) diff --git a/backend/app/api/onboarding.py b/backend/app/api/onboarding.py index 471e9cf..181e87a 100644 --- a/backend/app/api/onboarding.py +++ b/backend/app/api/onboarding.py @@ -20,15 +20,24 @@ from app.database import get_db from app.models.user import User from app.models.brand import Brand from app.models.competitor import Competitor -from app.models.citation_record import CitationRecord from app.models.query import Query as QueryModel -from app.services.scoring_service import ScoringService from app.schemas.brand import BrandCreate, BrandResponse +from app.services.diagnosis.data_collector import DataCollectorService +from app.services.diagnosis.geo_diagnosis import GEODiagnosisService +from app.utils.health import get_health_level, get_health_level_label logger = logging.getLogger(__name__) router = APIRouter(prefix="/onboarding", tags=["onboarding"]) +_FREE_TIER_DIMENSIONS = {"内容可提取性", "E-E-A-T信号", "引用就绪度"} + + +def _to_uuid(value: str | uuid.UUID) -> uuid.UUID: + if isinstance(value, uuid.UUID): + return value + return uuid.UUID(str(value)) + # ------------------------------------------------------------------ # Request / Response schemas @@ -60,23 +69,43 @@ class CompetitorRecommendationSimpleResponse(BaseModel): recommendations: list[CompetitorRecommendationSimple] +class HealthDimensionItem(BaseModel): + name: str + score: float + max_score: float + percentage: float + status: str + + +class HealthRecommendationItem(BaseModel): + priority: str + dimension: str + title: str + description: str + + class HealthReportResponse(BaseModel): - """初始健康评分报告""" brand_id: str brand_name: str overall_score: float + health_level: str + health_level_label: str platform_scores: dict strengths: list[str] weaknesses: list[str] competitor_scores: list[dict] + dimensions: list[HealthDimensionItem] = [] + recommendations: list[HealthRecommendationItem] = [] + is_full_report: bool = False class ActionSuggestion(BaseModel): - """行动建议项""" title: str description: str - priority: str # high / medium / low - action_type: str # e.g. coverage, keyword, sentiment, platform + priority: str + action_type: str + is_paid_action: bool = False + action_button_text: str = "" class ActionSuggestionsResponse(BaseModel): @@ -105,7 +134,7 @@ async def get_onboarding_status( - completed=True 且 brand_id 有值 → 已完成 - completed=False, current_step=1 → 需要创建品牌 """ - stmt = select(Brand).where(Brand.user_id == current_user.id) + stmt = select(Brand).where(Brand.user_id == _to_uuid(current_user.id)) result = await db.execute(stmt) brand = result.scalar_one_or_none() @@ -144,7 +173,7 @@ async def create_onboarding_brand( ) brand = Brand( - user_id=current_user.id, + user_id=_to_uuid(current_user.id), name=full_brand_data.name, aliases=full_brand_data.aliases, website=full_brand_data.website, @@ -156,6 +185,38 @@ async def create_onboarding_brand( await db.commit() await db.refresh(brand) + # 自动创建默认查询词(检查 max_queries 限制) + try: + current_query_count_stmt = select(func.count()).select_from(QueryModel).where( + QueryModel.user_id == current_user.id + ) + current_query_count_result = await db.execute(current_query_count_stmt) + current_query_count = current_query_count_result.scalar_one() + + max_queries = getattr(current_user, "max_queries", 3) # 默认免费版 3 个 + + if current_query_count < max_queries: + default_query = QueryModel( + user_id=current_user.id, + keyword=f"{brand.name} 推荐", + target_brand=brand.name, + brand_aliases=brand.aliases or [], + platforms=brand.platforms or ["wenxin", "kimi"], + frequency=brand.frequency or "weekly", + status="active", + ) + db.add(default_query) + await db.commit() + await db.refresh(default_query) + logger.info(f"Auto-created default query for brand '{brand.name}'") + else: + logger.info( + f"Skipped auto-creating default query for brand '{brand.name}': " + f"query limit reached ({current_query_count}/{max_queries})" + ) + except Exception as e: + logger.warning(f"Failed to auto-create default query for brand '{brand.name}': {e}") + return brand @@ -171,8 +232,7 @@ async def get_onboarding_competitor_recommendations( 复用 brands/competitors 中的推荐逻辑, 支持 LLM 智能推荐和规则推荐两种模式。 """ - # 验证品牌归属 - stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == current_user.id) + stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == _to_uuid(current_user.id)) result = await db.execute(stmt) brand = result.scalar_one_or_none() @@ -182,7 +242,6 @@ async def get_onboarding_competitor_recommendations( detail="品牌不存在", ) - # 获取已有竞品名称(排除) existing_stmt = select(Competitor.name).where(Competitor.brand_id == brand_id) existing_result = await db.execute(existing_stmt) existing_names = [row[0] for row in existing_result.all()] @@ -228,14 +287,8 @@ async def get_onboarding_health_report( current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): - """ - 获取品牌初始健康评分报告。 - - 基于 citation_records 表统计品牌的引用数据, - 如果没有引用数据则返回初始化状态(overall_score: 0)。 - """ - # 验证品牌归属 - stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == current_user.id) + user_uuid = _to_uuid(current_user.id) + stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == user_uuid) result = await db.execute(stmt) brand = result.scalar_one_or_none() @@ -245,7 +298,55 @@ async def get_onboarding_health_report( detail="品牌不存在", ) - # 查询与品牌关联的 queries + user_plan = getattr(current_user, "plan", None) or "free" + is_paid = user_plan not in ("free", None) + + try: + collector = DataCollectorService(db) + collection = await collector.collect( + brand_name=brand.name, + brand_aliases=brand.aliases or [], + website=brand.website, + industry=brand.industry, + ) + + geo_service = GEODiagnosisService() + diagnosis_result = geo_service.diagnose(collection.diagnosis_input) + except Exception as e: + logger.error(f"Onboarding健康报告采集失败: brand_id={brand_id}, error={e}", exc_info=True) + return HealthReportResponse( + brand_id=str(brand.id), + brand_name=brand.name, + overall_score=0.0, + health_level="danger", + health_level_label=get_health_level_label("danger"), + platform_scores={}, + strengths=["品牌已创建,等待数据采集"], + weaknesses=["数据采集失败,请稍后重试"], + competitor_scores=[], + dimensions=[ + HealthDimensionItem(name=d, score=0.0, max_score=0.0, percentage=0.0, status="fail") + for d in sorted(_FREE_TIER_DIMENSIONS) + ], + recommendations=[], + is_full_report=is_paid, + ) + + result_dict = diagnosis_result.to_dict() + all_dimensions = result_dict.get("dimensions", []) + all_recommendations = result_dict.get("recommendations", []) + + if is_paid: + filtered_dimensions = all_dimensions + filtered_recommendations = all_recommendations + else: + filtered_dimensions = [d for d in all_dimensions if d.get("name") in _FREE_TIER_DIMENSIONS] + filtered_recommendations = [r for r in all_recommendations if r.get("priority") == "P0"] + + overall_score = round(diagnosis_result.overall_score, 2) + health_level = get_health_level(overall_score) + + platform_scores: dict[str, float] = {} queries_stmt = select(QueryModel).where( QueryModel.user_id == current_user.id, QueryModel.target_brand == brand.name, @@ -253,138 +354,89 @@ async def get_onboarding_health_report( queries_result = await db.execute(queries_stmt) queries = list(queries_result.scalars().all()) - # 没有查询数据 → 返回初始化状态 - if not queries: - return HealthReportResponse( - brand_id=str(brand.id), - brand_name=brand.name, - overall_score=0.0, - platform_scores={}, - strengths=["品牌已创建,等待数据采集"], - weaknesses=["尚无AI平台引用数据,需等待查询执行"], - competitor_scores=[], + if queries: + from app.models.citation_record import CitationRecord + query_ids = [q.id for q in queries] + citations_stmt = select(CitationRecord).where( + CitationRecord.query_id.in_(query_ids), ) + citations_result = await db.execute(citations_stmt) + citations = list(citations_result.scalars().all()) - query_ids = [q.id for q in queries] + platforms_seen: dict[str, dict] = {} + for c in citations: + p = c.platform or "unknown" + if p not in platforms_seen: + platforms_seen[p] = {"total": 0, "cited": 0} + platforms_seen[p]["total"] += 1 + if c.cited: + platforms_seen[p]["cited"] += 1 - # 获取引用记录 - citations_stmt = select(CitationRecord).where( - CitationRecord.query_id.in_(query_ids), - ) - citations_result = await db.execute(citations_stmt) - citations = list(citations_result.scalars().all()) + for p, data in platforms_seen.items(): + rate = (data["cited"] / data["total"] * 100) if data["total"] > 0 else 0.0 + platform_scores[p] = round(rate, 2) - total = len(citations) - cited = [c for c in citations if c.cited] - - # 计算各平台评分 - platform_scores: dict[str, float] = {} - platforms_seen: dict[str, dict] = {} # {platform: {total, cited}} - - for c in citations: - p = c.platform or "unknown" - if p not in platforms_seen: - platforms_seen[p] = {"total": 0, "cited": 0} - platforms_seen[p]["total"] += 1 - if c.cited: - platforms_seen[p]["cited"] += 1 - - for p, data in platforms_seen.items(): - rate = (data["cited"] / data["total"] * 100) if data["total"] > 0 else 0.0 - platform_scores[p] = round(rate, 2) - - # 使用 ScoringService 计算 overall_score - scoring_service = ScoringService() - sentiment_counts = {"positive": 0, "neutral": 0, "negative": 0} - for c in cited: - sentiment = c.sentiment or "neutral" - if sentiment in sentiment_counts: - sentiment_counts[sentiment] += 1 - - from app.schemas.scoring import CitationResult - citation_results = [ - CitationResult( - cited=c.cited, - position=c.citation_position, - citation_text=c.citation_text, - sentiment=c.sentiment or "neutral", - confidence=c.confidence or 0.0, - ) - for c in cited - ] - positions = [c.citation_position for c in cited if c.cited] - - # 获取竞品信息 - competitor_stmt = select(Competitor).where(Competitor.brand_id == brand_id) - competitor_result = await db.execute(competitor_stmt) - competitors = list(competitor_result.scalars().all()) - competitor_names = [c.name for c in competitors] - competitor_mentions: dict[str, int] = {} - for comp_name in competitor_names: - count = sum( - 1 for c in citations - if c.cited and c.competitor_brands and comp_name in c.competitor_brands - ) - if count > 0: - competitor_mentions[comp_name] = count - - v2_result = scoring_service.calculate_v2( - mentioned_count=len(cited), - total_queries=total, - positions=positions, - sentiment_counts=sentiment_counts, - citations=citation_results, - brand_mentions=len(cited), - competitor_mentions=competitor_mentions, - ) - - # 生成 strengths/weaknesses strengths = [] weaknesses = [] - if total == 0: - strengths.append("品牌已创建") - weaknesses.append("尚无引用数据") - else: - mention_rate = len(cited) / total * 100 if total > 0 else 0 - if mention_rate >= 50: - strengths.append(f"提及率较高 ({round(mention_rate, 1)}%)") - else: - weaknesses.append(f"提及率偏低 ({round(mention_rate, 1)}%)") + for d in filtered_dimensions: + pct = d.get("percentage", 0) + if pct >= 60: + strengths.append(f"{d.get('name', '')}表现良好 ({round(pct, 1)}%)") + elif pct > 0: + weaknesses.append(f"{d.get('name', '')}有待提升 ({round(pct, 1)}%)") - for p, score in platform_scores.items(): - if score >= 60: - strengths.append(f"{p} 平台表现良好 ({score}%)") - elif score > 0: - weaknesses.append(f"{p} 平台覆盖率不足 ({score}%)") + if not strengths: + strengths.append("已有初步诊断数据") + if not weaknesses: + weaknesses.append("暂无明显短板") - if sentiment_counts["positive"] > sentiment_counts["negative"]: - strengths.append("情感倾向正面") - elif sentiment_counts["negative"] > sentiment_counts["positive"]: - weaknesses.append("情感倾向偏负面") + competitor_stmt = select(Competitor).where(Competitor.brand_id == brand_id) + competitor_result = await db.execute(competitor_stmt) + competitors = list(competitor_result.scalars().all()) - if not strengths: - strengths.append("已有初步引用数据") - if not weaknesses: - weaknesses.append("暂无明显短板") - - # 竞品评分 competitor_scores = [] - for comp_name, mentions in competitor_mentions.items(): - comp_score = round(mentions / total * 100, 2) if total > 0 else 0.0 + for comp in competitors: competitor_scores.append({ - "name": comp_name, - "score": comp_score, + "name": comp.name, + "score": 0.0, + "is_leading": False, }) + dimension_items = [ + HealthDimensionItem( + name=d.get("name", ""), + score=round(d.get("score", 0.0), 2), + max_score=d.get("max_score", 0.0), + percentage=round(d.get("percentage", 0.0), 2), + status=d.get("status", "fail"), + ) + for d in filtered_dimensions + ] + + recommendation_items = [ + HealthRecommendationItem( + priority=r.get("priority", "P0"), + dimension=r.get("dimension", ""), + title=r.get("title", ""), + description=r.get("description", ""), + ) + for r in filtered_recommendations + ] + return HealthReportResponse( brand_id=str(brand.id), brand_name=brand.name, - overall_score=round(v2_result.overall_score, 2), + overall_score=overall_score, + health_level=health_level, + health_level_label=get_health_level_label(health_level), platform_scores=platform_scores, strengths=strengths, weaknesses=weaknesses, competitor_scores=competitor_scores, + dimensions=dimension_items, + recommendations=recommendation_items, + is_full_report=is_paid, ) @@ -394,21 +446,21 @@ async def get_onboarding_action_suggestions( current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): - """ - 根据健康报告生成行动建议(基于规则引擎,不需要 LLM)。 - """ - # 先获取健康报告数据(复用逻辑) report = await get_onboarding_health_report(brand_id, current_user, db) + user_plan = getattr(current_user, "plan", None) or "free" + is_paid = user_plan not in ("free", None) + suggestions = [] - # 规则引擎:基于评分和平台数据生成建议 if report.overall_score < 20: suggestions.append(ActionSuggestion( title="提升 AI 平台覆盖率", description=f"当前综合评分仅 {report.overall_score},品牌在AI搜索中几乎未被提及。建议增加查询词覆盖面,让AI平台更频繁地引用品牌。", priority="high", action_type="coverage", + is_paid_action=False, + action_button_text="设置查询词", )) if report.overall_score < 50: @@ -417,9 +469,10 @@ async def get_onboarding_action_suggestions( description="品牌在关键查询词下的提及率偏低,建议调整查询关键词策略,聚焦行业核心术语。", priority="high", action_type="keyword", + is_paid_action=False, + action_button_text="优化关键词", )) - # 平台维度建议 for platform, score in report.platform_scores.items(): if score < 30: suggestions.append(ActionSuggestion( @@ -427,28 +480,32 @@ async def get_onboarding_action_suggestions( description=f"品牌在 {platform} 平台的引用率仅为 {score}%,需要针对性优化该平台的内容策略。", priority="medium", action_type="platform", + is_paid_action=False, + action_button_text=f"优化{platform}", )) - # 情感维度建议 - if "情感倾向偏负面" in report.weaknesses: - suggestions.append(ActionSuggestion( - title="改善品牌情感倾向", - description="AI平台对品牌的情感评价偏负面,建议发布正面品牌内容、优化品牌描述以改善情感得分。", - priority="medium", - action_type="sentiment", - )) + for dim in report.dimensions: + if dim.percentage < 40 and dim.name not in _FREE_TIER_DIMENSIONS: + suggestions.append(ActionSuggestion( + title=f"提升{dim.name}得分", + description=f"{dim.name}当前得分仅 {round(dim.percentage, 1)}%,解锁详细诊断和优化方案。", + priority="medium", + action_type="dimension", + is_paid_action=True, + action_button_text="升级解锁", + )) - # 竞品对比建议 for comp in report.competitor_scores: - if comp["score"] > report.overall_score: + if comp.get("score", 0) > report.overall_score: suggestions.append(ActionSuggestion( title=f"应对竞品 {comp['name']} 威胁", description=f"竞品 {comp['name']} 评分 ({comp['score']}) 高于本品牌 ({report.overall_score}),建议分析竞品优势领域并制定差异化策略。", priority="high", action_type="competitive", + is_paid_action=True, + action_button_text="查看竞品分析", )) - # 如果没有引用数据,给出基础建议 if report.overall_score == 0: suggestions = [ ActionSuggestion( @@ -456,28 +513,45 @@ async def get_onboarding_action_suggestions( description="品牌尚无查询数据,建议首先设置与品牌最相关的核心查询词,让系统开始数据采集。", priority="high", action_type="keyword", + is_paid_action=False, + action_button_text="设置查询词", ), ActionSuggestion( title="添加竞品对比", description="添加主要竞品以便进行对比分析,了解品牌在市场中的定位。", priority="medium", action_type="coverage", + is_paid_action=False, + action_button_text="添加竞品", ), ActionSuggestion( - title="完善品牌信息", - description="补充品牌别名、网站、行业等详细信息,有助于提升AI平台识别率。", + title="解锁完整6维度诊断", + description="免费版仅展示3个核心维度,升级后可查看完整6维度诊断报告和深度优化方案。", priority="medium", - action_type="brand_info", + action_type="upgrade", + is_paid_action=True, + action_button_text="升级Pro", ), ] - # 确保至少有1条建议 if not suggestions: suggestions.append(ActionSuggestion( title="持续监测品牌表现", description="品牌表现良好,建议持续监测并保持当前策略。", priority="low", action_type="monitor", + is_paid_action=False, + action_button_text="查看Dashboard", + )) + + if not is_paid: + suggestions.append(ActionSuggestion( + title="升级Pro解锁完整诊断", + description="免费版仅展示3个核心维度和P0建议。升级Pro可获取完整6维度诊断、深度竞品分析和AI优化方案。", + priority="low", + action_type="upgrade", + is_paid_action=True, + action_button_text="升级Pro", )) return ActionSuggestionsResponse(suggestions=suggestions) @@ -496,8 +570,7 @@ async def complete_onboarding( User 模型当前没有 onboarding_completed 专用字段, 品牌的创建即代表 onboarding 完成。 """ - # 验证品牌归属 - stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == current_user.id) + stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == _to_uuid(current_user.id)) result = await db.execute(stmt) brand = result.scalar_one_or_none() @@ -507,8 +580,6 @@ async def complete_onboarding( detail="品牌不存在", ) - # 品牌已创建即代表 onboarding 完成,无需额外字段更新 - # 后续如需专用字段,可通过 alembic 迁移添加 user.onboarding_completed logger.info(f"User {current_user.id} completed onboarding with brand {brand_id}") return OnboardingCompleteResponse(success=True) \ No newline at end of file diff --git a/backend/app/api/payments.py b/backend/app/api/payments.py new file mode 100644 index 0000000..94005ac --- /dev/null +++ b/backend/app/api/payments.py @@ -0,0 +1,274 @@ +import logging +import uuid +from datetime import datetime, timezone + +from fastapi import APIRouter, Depends, HTTPException, Request, status +from pydantic import BaseModel +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.api.deps import get_current_user +from app.database import get_db +from app.models.payment_order import PaymentOrder as PaymentOrderModel +from app.models.user import User +from app.services.payment import get_payment_gateway +from app.services.subscription import PLANS, subscribe + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/v1/payments", tags=["支付"]) + + +class CreateOrderRequest(BaseModel): + plan: str + payment_provider: str = "wechat" + + +class CreateOrderResponse(BaseModel): + order_id: str + pay_url: str + amount: float + currency: str = "CNY" + status: str = "pending" + + +class OrderStatusResponse(BaseModel): + order_id: str + status: str + plan: str + amount: float + payment_provider: str + payment_id: str | None = None + created_at: str | None = None + paid_at: str | None = None + + +class RefundRequest(BaseModel): + reason: str = "" + + +@router.post("/orders", response_model=CreateOrderResponse, status_code=status.HTTP_201_CREATED) +async def create_payment_order( + request: CreateOrderRequest, + db: AsyncSession = Depends(get_db), + current_user: User = Depends(get_current_user), +): + plan_data = PLANS.get(request.plan) + if plan_data is None: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=f"无效的套餐: {request.plan}", + ) + + if plan_data["price"] == 0: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="免费套餐无需支付", + ) + + order_id = uuid.uuid4() + amount = plan_data["price"] + + gateway = get_payment_gateway(request.payment_provider) + payment_order = await gateway.create_order( + order_id=str(order_id), + amount=amount, + description=f"GEO平台{plan_data['name']}订阅", + user_id=current_user.id, + plan=request.plan, + ) + + db_order = PaymentOrderModel( + id=order_id, + user_id=current_user.id, + plan=request.plan, + amount=amount, + payment_provider=request.payment_provider, + status="pending", + pay_url=payment_order.pay_url, + ) + db.add(db_order) + await db.commit() + + return CreateOrderResponse( + order_id=str(order_id), + pay_url=payment_order.pay_url, + amount=amount, + status="pending", + ) + + +@router.post("/callback/wechat") +async def wechat_pay_callback(request: Request): + body = await request.form() + request_data = dict(body) + + gateway = get_payment_gateway("wechat") + callback = await gateway.verify_callback(request_data) + + return await _handle_payment_callback(request_data, callback, "wechat") + + +@router.post("/callback/alipay") +async def alipay_callback(request: Request): + body = await request.form() + request_data = dict(body) + + gateway = get_payment_gateway("alipay") + callback = await gateway.verify_callback(request_data) + + return await _handle_payment_callback(request_data, callback, "alipay") + + +async def _handle_payment_callback(request_data: dict, callback, provider: str): + from app.database import AsyncSessionLocal + + async with AsyncSessionLocal() as db: + try: + result = await _process_callback(db, callback, provider) + await db.commit() + return result + except Exception as e: + logger.error(f"[PaymentCallback] 处理回调异常: {e}", exc_info=True) + await db.rollback() + if provider == "wechat": + return _wechat_fail_response() + return "fail" + + +async def _process_callback(db: AsyncSession, callback, provider: str): + stmt = select(PaymentOrderModel).where( + PaymentOrderModel.id == uuid.UUID(callback.order_id) + ) + result = await db.execute(stmt) + order = result.scalar_one_or_none() + + if order is None: + logger.warning(f"[PaymentCallback] 订单不存在: order_id={callback.order_id}") + if provider == "wechat": + return _wechat_fail_response() + return "fail" + + if callback.status == "success": + order.status = "paid" + order.payment_id = callback.payment_id + order.callback_data = callback.raw_data + order.paid_at = datetime.now(timezone.utc) + + await subscribe(db, order.user_id, order.plan) + + logger.info( + f"[PaymentCallback] 支付成功: order_id={callback.order_id}, " + f"plan={order.plan}, provider={provider}" + ) + else: + order.status = "failed" + order.callback_data = callback.raw_data + + if provider == "wechat": + return _wechat_success_response() + return "success" + + +def _wechat_success_response(): + from fastapi.responses import Response + return Response(content="", media_type="application/xml") + + +def _wechat_fail_response(): + from fastapi.responses import Response + return Response(content="", media_type="application/xml") + + +@router.get("/orders/{order_id}", response_model=OrderStatusResponse) +async def query_order_status( + order_id: str, + db: AsyncSession = Depends(get_db), + current_user: User = Depends(get_current_user), +): + try: + oid = uuid.UUID(order_id) + except ValueError: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="无效的订单ID", + ) + + stmt = select(PaymentOrderModel).where(PaymentOrderModel.id == oid) + result = await db.execute(stmt) + order = result.scalar_one_or_none() + + if order is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="订单不存在", + ) + + if order.user_id != current_user.id: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="无权查看此订单", + ) + + return OrderStatusResponse( + order_id=str(order.id), + status=order.status, + plan=order.plan, + amount=order.amount, + payment_provider=order.payment_provider, + payment_id=order.payment_id, + created_at=order.created_at.isoformat() if order.created_at else None, + paid_at=order.paid_at.isoformat() if order.paid_at else None, + ) + + +@router.post("/refund/{order_id}") +async def refund_order( + order_id: str, + body: RefundRequest = RefundRequest(), + db: AsyncSession = Depends(get_db), + current_user: User = Depends(get_current_user), +): + user_plan = getattr(current_user, "plan", "free") or "free" + if user_plan != "enterprise": + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="仅企业管理员可执行退款操作", + ) + + try: + oid = uuid.UUID(order_id) + except ValueError: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="无效的订单ID", + ) + + stmt = select(PaymentOrderModel).where(PaymentOrderModel.id == oid) + result = await db.execute(stmt) + order = result.scalar_one_or_none() + + if order is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="订单不存在", + ) + + if order.status != "paid": + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="仅已支付订单可退款", + ) + + gateway = get_payment_gateway(order.payment_provider) + success = await gateway.refund(order_id, order.amount, body.reason) + + if success: + order.status = "refunded" + await db.commit() + return {"message": "退款成功", "order_id": order_id} + + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="退款失败", + ) diff --git a/backend/app/api/platforms.py b/backend/app/api/platforms.py index b69a8bc..79910ba 100644 --- a/backend/app/api/platforms.py +++ b/backend/app/api/platforms.py @@ -1,11 +1,5 @@ -""" -平台健康检查API - 验证各AI平台适配器状态 - -端点: GET /api/platforms/health -返回: 各平台适配器配置状态和健康信息 -""" - import logging +import os from typing import Annotated from fastapi import APIRouter, Depends @@ -18,7 +12,6 @@ router = APIRouter(prefix="/platforms", tags=["platforms"]) class PlatformHealthStatus: - """平台健康状态""" def __init__( self, @@ -37,22 +30,38 @@ class PlatformHealthStatus: self.message = message +_PLATFORM_URLS = { + "kimi": "https://kimi.moonshot.cn", + "wenxin": "https://yiyan.baidu.com", + "doubao": "https://www.doubao.com/", +} + + +def _check_api_key_health( + platform_name: str, + env_key_name: str, + url: str, +) -> PlatformHealthStatus: + api_key = os.getenv(env_key_name, "") + api_key_set = bool(api_key and api_key.strip()) + configured = api_key_set + + return PlatformHealthStatus( + name=platform_name, + url=url, + configured=configured, + api_key_set=api_key_set, + status="configured" if configured else "not_configured", + message="API Key已配置" if configured else "API Key未配置", + ) + + def check_kimi_health() -> PlatformHealthStatus: - """检查Kimi平台健康状态""" try: - from app.workers.platforms.kimi import KimiAdapter - - adapter = KimiAdapter() - api_key_set = bool(adapter.api_key and adapter.api_key.strip()) - configured = adapter.is_configured - - return PlatformHealthStatus( - name="kimi", - url=adapter.platform_url, - configured=configured, - api_key_set=api_key_set, - status="configured" if configured else "not_configured", - message="API Key已配置" if configured else "API Key未配置", + return _check_api_key_health( + platform_name="kimi", + env_key_name="MOONSHOT_API_KEY", + url=_PLATFORM_URLS["kimi"], ) except Exception as e: logger.error(f"Kimi健康检查失败: {e}") @@ -65,18 +74,16 @@ def check_kimi_health() -> PlatformHealthStatus: def check_wenxin_health() -> PlatformHealthStatus: - """检查文心平台健康状态""" try: - from app.workers.platforms.wenxin import WenxinAdapter - - adapter = WenxinAdapter() - api_key_set = bool(adapter.api_key and adapter.api_key.strip()) - secret_key_set = bool(adapter.secret_key and adapter.secret_key.strip()) - configured = adapter.is_configured + api_key = os.getenv("BAIDU_QIANFAN_API_KEY", "") + secret_key = os.getenv("BAIDU_QIANFAN_SECRET_KEY", "") + api_key_set = bool(api_key and api_key.strip()) + secret_key_set = bool(secret_key and secret_key.strip()) + configured = api_key_set and secret_key_set return PlatformHealthStatus( name="wenxin", - url=adapter.platform_url, + url=_PLATFORM_URLS["wenxin"], configured=configured, api_key_set=api_key_set, status="configured" if configured else "not_configured", @@ -93,21 +100,11 @@ def check_wenxin_health() -> PlatformHealthStatus: def check_doubao_health() -> PlatformHealthStatus: - """检查豆包平台健康状态""" try: - from app.workers.platforms.doubao import DoubaoAdapter - - adapter = DoubaoAdapter() - api_key_set = bool(adapter.api_key and adapter.api_key.strip()) - configured = adapter.is_configured - - return PlatformHealthStatus( - name="doubao", - url=adapter.platform_url, - configured=configured, - api_key_set=api_key_set, - status="configured" if configured else "not_configured", - message="API Key已配置" if configured else "API Key未配置", + return _check_api_key_health( + platform_name="doubao", + env_key_name="DOUBAO_API_KEY", + url=_PLATFORM_URLS["doubao"], ) except Exception as e: logger.error(f"豆包健康检查失败: {e}") @@ -120,7 +117,6 @@ def check_doubao_health() -> PlatformHealthStatus: def check_all_platforms() -> dict: - """检查所有平台健康状态""" platforms = [ check_kimi_health(), check_wenxin_health(), @@ -136,29 +132,12 @@ def check_all_platforms() -> dict: @router.get("/health") async def get_platform_health(): - """ - 获取所有AI平台适配器的健康状态 - - 返回每个平台的: - - name: 平台名称 - - configured: 是否已配置 - - url: 平台URL - - api_key_set: API Key是否已设置 - - status: 健康状态 (configured / not_configured / error) - - message: 状态消息 - """ health_info = check_all_platforms() return health_info @router.get("/health/{platform_name}") async def get_platform_health_by_name(platform_name: str): - """ - 获取指定平台适配器的健康状态 - - Args: - platform_name: 平台名称 (kimi / wenxin / doubao) - """ if platform_name == "kimi": result = vars(check_kimi_health()) elif platform_name == "wenxin": diff --git a/backend/app/api/queries.py b/backend/app/api/queries.py index 1ade055..ad0c122 100644 --- a/backend/app/api/queries.py +++ b/backend/app/api/queries.py @@ -9,7 +9,7 @@ from app.database import get_db from app.models.user import User from app.schemas.citation import RunNowResponse from app.schemas.query import QueryCreate, QueryListResponse, QueryResponse, QueryUpdate -from app.services.citation import trigger_query_now +from app.services.citation.citation import trigger_query_now from app.services.query import create_query, delete_query, get_queries, get_query, update_query router = APIRouter() diff --git a/backend/app/api/reports.py b/backend/app/api/reports.py index fc84ca0..4949657 100644 --- a/backend/app/api/reports.py +++ b/backend/app/api/reports.py @@ -1,23 +1,90 @@ +import logging import uuid from datetime import datetime from typing import Optional from fastapi import APIRouter, Depends, HTTPException, Query, status from fastapi.responses import StreamingResponse +from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from starlette.responses import Response from app.api.deps import get_current_user from app.database import get_db +from app.models.brand import Brand from app.models.user import User -from app.services.citation import export_citations_csv, export_citations_pdf +from app.schemas.scoring import CitationResult +from app.services.citation.citation import export_citations_csv, export_citations_pdf +from app.services.scoring.scoring_service import ScoringService, ScoringResultV2 + +logger = logging.getLogger(__name__) router = APIRouter() +async def _compute_v2_scores( + db: AsyncSession, + user_id: uuid.UUID, + brand_id: uuid.UUID, +) -> ScoringResultV2 | None: + try: + from app.api.scoring import ( + _get_citations_for_brand, + _analyze_sentiments_for_citations, + ) + + total_queries, brand_citations, _, competitor_mentions = ( + await _get_citations_for_brand(db, user_id, brand_id) + ) + + if total_queries == 0: + return None + + brand_stmt = select(Brand).where( + Brand.id == brand_id, Brand.user_id == user_id + ) + brand_result = await db.execute(brand_stmt) + brand = brand_result.scalar_one_or_none() + if not brand: + return None + + sentiment_counts = await _analyze_sentiments_for_citations( + brand_name=brand.name, + brand_citations=brand_citations, + ) + + citation_results = [ + CitationResult( + cited=c.cited, + position=c.citation_position, + citation_text=c.citation_text, + sentiment=c.sentiment or "neutral", + confidence=c.confidence or 0.0, + ) + for c in brand_citations + ] + + positions = [c.citation_position for c in brand_citations if c.cited] + + scoring_service = ScoringService() + return scoring_service.calculate_v2( + mentioned_count=len(brand_citations), + total_queries=total_queries, + positions=positions, + sentiment_counts=sentiment_counts, + citations=citation_results, + brand_mentions=len(brand_citations), + competitor_mentions=competitor_mentions, + ) + except Exception: + logger.warning("V2 scoring failed for brand %s", brand_id, exc_info=True) + return None + + @router.get("/export/csv") async def export_report( query_id: uuid.UUID = Query(...), + brand_id: Optional[uuid.UUID] = Query(None), format: str = Query("csv"), db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_user), @@ -29,7 +96,13 @@ async def export_report( ) try: - csv_content = await export_citations_csv(db, current_user.id, query_id) + v2_result = None + if brand_id is not None: + v2_result = await _compute_v2_scores(db, current_user.id, brand_id) + + csv_content = await export_citations_csv( + db, current_user.id, query_id, v2_result=v2_result + ) except ValueError as e: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -51,11 +124,18 @@ async def export_report( @router.get("/export/pdf") async def export_pdf( query_id: Optional[uuid.UUID] = None, + brand_id: Optional[uuid.UUID] = Query(None), db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_user), ): try: - pdf_bytes = await export_citations_pdf(db, current_user.id, query_id) + v2_result = None + if brand_id is not None: + v2_result = await _compute_v2_scores(db, current_user.id, brand_id) + + pdf_bytes = await export_citations_pdf( + db, current_user.id, query_id, v2_result=v2_result + ) except ValueError as e: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, diff --git a/backend/app/api/schema_advisor.py b/backend/app/api/schema_advisor.py new file mode 100644 index 0000000..dfb3005 --- /dev/null +++ b/backend/app/api/schema_advisor.py @@ -0,0 +1,248 @@ +import uuid + +from fastapi import APIRouter, Depends, HTTPException, Query, status +from sqlalchemy import select, func +from sqlalchemy.ext.asyncio import AsyncSession + +from app.api.deps import get_current_user +from app.database import get_db +from app.models.user import User +from app.models.brand import Brand +from app.models.schema_suggestion import SchemaSuggestion +from app.schemas.schema_suggestion import ( + SchemaAdviseRequest, + SchemaSuggestionResponse, + SchemaSuggestionList, + SchemaValidationResult, + SchemaStatusUpdateRequest, +) +from app.services.schema.schema_advisor_service import SchemaAdvisorService +from app.services.scoring.scoring_service import ScoringService + +router = APIRouter() + + +async def _get_brand_with_access( + brand_id: uuid.UUID, + db: AsyncSession, + current_user: User, +) -> Brand: + stmt = select(Brand).where( + Brand.id == brand_id, + Brand.user_id == current_user.id, + ) + result = await db.execute(stmt) + brand = result.scalar_one_or_none() + if not brand: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="品牌不存在", + ) + return brand + + +async def _get_brand_diagnosis_data( + db: AsyncSession, + user_id: uuid.UUID, + brand: Brand, +) -> dict: + from app.models.query import Query as QueryModel + from app.models.citation_record import CitationRecord + from app.models.competitor import Competitor + from app.schemas.scoring import CitationResult + from app.services.analysis.sentiment_service import get_sentiment_service + + queries_stmt = select(QueryModel).where( + QueryModel.user_id == user_id, + QueryModel.target_brand == brand.name, + ) + queries_result = await db.execute(queries_stmt) + queries = list(queries_result.scalars().all()) + + if not queries: + scoring_service = ScoringService() + empty_result = scoring_service.calculate_v2( + mentioned_count=0, + total_queries=0, + positions=[], + sentiment_counts={"positive": 0, "neutral": 0, "negative": 0}, + citations=[], + brand_mentions=0, + competitor_mentions={}, + ) + return empty_result.to_dict() + + query_ids = [q.id for q in queries] + + citations_stmt = select(CitationRecord).where( + CitationRecord.query_id.in_(query_ids), + ) + citations_result = await db.execute(citations_stmt) + all_citations = list(citations_result.scalars().all()) + + total_queries = len(all_citations) + brand_citations = [c for c in all_citations if c.cited] + + competitor_stmt = select(Competitor).where(Competitor.brand_id == brand.id) + competitor_result = await db.execute(competitor_stmt) + competitors = list(competitor_result.scalars().all()) + competitor_names = [c.name for c in competitors] + + competitor_mentions: dict[str, int] = {} + for comp_name in competitor_names: + count = sum( + 1 for c in all_citations + if c.cited and c.competitor_brands + and comp_name in c.competitor_brands + ) + if count > 0: + competitor_mentions[comp_name] = count + + sentiment_service = get_sentiment_service() + sentiment_counts = {"positive": 0, "neutral": 0, "negative": 0} + for citation in brand_citations: + if citation.sentiment and citation.sentiment in ("positive", "neutral", "negative"): + sentiment_counts[citation.sentiment] += 1 + else: + content = citation.raw_response or citation.citation_text or "" + if content.strip(): + try: + result = await sentiment_service.analyze( + brand_name=brand.name, + content=content, + ) + sentiment_counts[result.sentiment] += 1 + except Exception: + sentiment_counts["neutral"] += 1 + else: + sentiment_counts["neutral"] += 1 + + citation_results = [ + CitationResult( + cited=c.cited, + position=c.citation_position, + citation_text=c.citation_text, + sentiment="neutral", + confidence=c.confidence or 0.0, + ) + for c in brand_citations + ] + + positions = [c.citation_position for c in brand_citations if c.cited] + + scoring_service = ScoringService() + v2_result = scoring_service.calculate_v2( + mentioned_count=len(brand_citations), + total_queries=total_queries, + positions=positions, + sentiment_counts=sentiment_counts, + citations=citation_results, + brand_mentions=len(brand_citations), + competitor_mentions=competitor_mentions, + ) + + return v2_result.to_dict() + + +@router.post("/advise", response_model=SchemaSuggestionList) +async def generate_schema_advise( + request: SchemaAdviseRequest, + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + brand = await _get_brand_with_access(request.brand_id, db, current_user) + + diagnosis_data = await _get_brand_diagnosis_data(db, current_user.id, brand) + + brand_info = { + "name": brand.name, + "website": brand.website or "", + "industry": brand.industry or "", + } + + service = SchemaAdvisorService() + suggestions = await service.generate_suggestions( + db=db, + brand_id=brand.id, + diagnosis_data=diagnosis_data, + brand_info=brand_info, + target_url=request.target_url, + focus_dimensions=request.focus_dimensions, + ) + + return SchemaSuggestionList( + suggestions=[SchemaSuggestionResponse.model_validate(s) for s in suggestions], + total=len(suggestions), + ) + + +@router.get("/brand/{brand_id}", response_model=SchemaSuggestionList) +async def get_brand_schema_suggestions( + brand_id: uuid.UUID, + status_filter: str | None = Query(None, alias="status", description="按状态筛选"), + schema_type: str | None = Query(None, description="按Schema类型筛选"), + skip: int = Query(0, ge=0), + limit: int = Query(20, ge=1, le=100), + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + await _get_brand_with_access(brand_id, db, current_user) + + service = SchemaAdvisorService() + suggestions, total = await service.get_suggestions( + db=db, + brand_id=brand_id, + status_filter=status_filter, + schema_type=schema_type, + skip=skip, + limit=limit, + ) + + return SchemaSuggestionList( + suggestions=[SchemaSuggestionResponse.model_validate(s) for s in suggestions], + total=total, + ) + + +@router.get("/{suggestion_id}", response_model=SchemaSuggestionResponse) +async def get_schema_suggestion_detail( + suggestion_id: uuid.UUID, + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + service = SchemaAdvisorService() + suggestion = await service.get_suggestion_by_id(db, suggestion_id) + if not suggestion: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="建议不存在", + ) + brand = await _get_brand_with_access(suggestion.brand_id, db, current_user) + return SchemaSuggestionResponse.model_validate(suggestion) + + +@router.put("/{suggestion_id}/status", response_model=SchemaSuggestionResponse) +async def update_schema_suggestion_status( + suggestion_id: uuid.UUID, + status_update: SchemaStatusUpdateRequest, + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + valid_statuses = {"pending", "applied", "dismissed"} + if status_update.status not in valid_statuses: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"无效的状态值,支持: {', '.join(valid_statuses)}", + ) + + service = SchemaAdvisorService() + suggestion = await service.get_suggestion_by_id(db, suggestion_id) + if not suggestion: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="建议不存在", + ) + await _get_brand_with_access(suggestion.brand_id, db, current_user) + + updated = await service.update_status(db, suggestion_id, status_update.status) + return SchemaSuggestionResponse.model_validate(updated) diff --git a/backend/app/api/scoring.py b/backend/app/api/scoring.py index b61cf3e..50af769 100644 --- a/backend/app/api/scoring.py +++ b/backend/app/api/scoring.py @@ -27,8 +27,8 @@ from app.schemas.scoring import ( DimensionCompareItem, CitationResult, ) -from app.services.scoring_service import ScoringService, get_health_level -from app.services.sentiment_service import get_sentiment_service +from app.services.scoring.scoring_service import ScoringService, get_health_level +from app.services.analysis.sentiment_service import get_sentiment_service logger = logging.getLogger(__name__) @@ -446,7 +446,7 @@ async def get_brand_score( # 异步触发告警检测(不影响主流程) try: - from app.services.alert_engine import AlertEngine + from app.services.alert.alert_engine import AlertEngine alert_engine = AlertEngine(db) # 获取当前已有提及的平台集合 diff --git a/backend/app/api/strategy.py b/backend/app/api/strategy.py new file mode 100644 index 0000000..08d96ed --- /dev/null +++ b/backend/app/api/strategy.py @@ -0,0 +1,388 @@ +import uuid +from datetime import datetime + +from fastapi import APIRouter, Depends, HTTPException, Query, status +from sqlalchemy import select, func +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + +from app.api.deps import get_current_user +from app.database import get_db +from app.models.user import User +from app.models.brand import Brand +from app.models.geo_plan import GeoPlan, GeoPlanAction +from app.schemas.geo_plan import ( + GeoPlanGenerateRequest, + GeoPlanResponse, + GeoPlanListResponse, + GeoPlanActionResponse, + GeoPlanActionUpdateStatus, + GeoPlanActionExecuteResponse, +) +from app.services.scoring.brand_scoring_data_service import get_brand_scoring_data_service +from app.services.strategy.geo_plan_generator import generate_geo_plan +from app.services.content.content_generation_service import ContentGenerationService + +router = APIRouter() + + +async def _get_brand_with_access( + brand_id: uuid.UUID, + db: AsyncSession, + current_user: User, +) -> Brand: + stmt = select(Brand).where( + Brand.id == brand_id, + Brand.user_id == current_user.id, + ) + result = await db.execute(stmt) + brand = result.scalar_one_or_none() + + if not brand: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="品牌不存在", + ) + return brand + + +async def _get_brand_scoring_data( + db: AsyncSession, + user_id: uuid.UUID, + brand: Brand, +) -> tuple: + scoring_data_service = get_brand_scoring_data_service() + scoring_data = await scoring_data_service.get_brand_scoring_data(db, user_id, brand) + return ( + scoring_data.v2_result, + scoring_data.competitor_data, + scoring_data.sentiment_counts, + scoring_data.platform_scores, + scoring_data.total_queries, + scoring_data.mentioned_count, + ) + + +def _plan_to_response(plan: GeoPlan) -> GeoPlanResponse: + actions = [ + GeoPlanActionResponse( + id=action.id, + plan_id=action.plan_id, + action_type=action.action_type, + title=action.title, + description=action.description, + reason=action.reason, + priority=action.priority, + status=action.status, + target_keyword=action.target_keyword, + target_platform=action.target_platform, + content_style=action.content_style, + estimated_impact=action.estimated_impact, + difficulty=action.difficulty, + execution_params=action.execution_params, + sort_order=action.sort_order, + completed_at=action.completed_at, + created_at=action.created_at, + ) + for action in sorted(plan.actions, key=lambda a: a.sort_order) + ] + return GeoPlanResponse( + id=plan.id, + brand_id=plan.brand_id, + title=plan.title, + status=plan.status, + diagnosis_score=plan.diagnosis_score, + target_score=plan.target_score, + estimated_weeks=plan.estimated_weeks, + plan_data=plan.plan_data, + source=plan.source, + actions=actions, + created_at=plan.created_at, + updated_at=plan.updated_at, + ) + + +@router.post("/generate", response_model=GeoPlanResponse) +async def generate_geo_plan_endpoint( + request: GeoPlanGenerateRequest, + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + brand = await _get_brand_with_access(request.brand_id, db, current_user) + + ( + v2_result, + competitor_data, + sentiment_data, + platform_scores, + total_queries, + mentioned_count, + ) = await _get_brand_scoring_data(db, current_user.id, brand) + + target_score = request.target_score or 75 + + plan_data = await generate_geo_plan( + brand_name=brand.name, + scoring_result=v2_result, + target_score=target_score, + total_queries=total_queries, + platform_scores=platform_scores, + competitor_data=competitor_data, + ) + + from app.config import settings + source = "llm" if (settings.ENABLE_LLM and settings.DEEPSEEK_API_KEY) else "rule" + + organization_id = current_user.id + org_stmt = select(func.count()).select_from( + select(1).where(True).subquery() + ) + + db_plan = GeoPlan( + organization_id=organization_id, + brand_id=brand.id, + title=plan_data.title, + status="draft", + diagnosis_score=int(round(v2_result.overall_score)), + target_score=target_score, + estimated_weeks=plan_data.estimated_weeks, + plan_data={ + "weekly_plan": plan_data.weekly_plan, + }, + source=source, + created_by=current_user.id, + ) + db.add(db_plan) + await db.flush() + + for idx, action_item in enumerate(plan_data.actions): + db_action = GeoPlanAction( + plan_id=db_plan.id, + action_type=action_item.action_type, + title=action_item.title, + description=action_item.description, + reason=action_item.reason, + priority=action_item.priority, + status="pending", + target_keyword=action_item.target_keyword, + target_platform=action_item.target_platform, + content_style=action_item.content_style, + estimated_impact=action_item.estimated_impact, + difficulty=action_item.difficulty, + execution_params=action_item.execution_params, + sort_order=idx, + ) + db.add(db_action) + + await db.commit() + await db.refresh(db_plan) + + stmt = ( + select(GeoPlan) + .options(selectinload(GeoPlanAction.plan)) + .where(GeoPlan.id == db_plan.id) + ) + result = await db.execute(stmt) + db_plan = result.scalar_one() + + action_stmt = select(GeoPlanAction).where( + GeoPlanAction.plan_id == db_plan.id + ).order_by(GeoPlanAction.sort_order) + action_result = await db.execute(action_stmt) + db_plan.actions = list(action_result.scalars().all()) + + return _plan_to_response(db_plan) + + +@router.get("/brand/{brand_id}", response_model=GeoPlanListResponse) +async def get_brand_plans( + brand_id: uuid.UUID, + skip: int = Query(0, ge=0), + limit: int = Query(20, ge=1, le=100), + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + await _get_brand_with_access(brand_id, db, current_user) + + count_stmt = select(func.count()).select_from(GeoPlan).where( + GeoPlan.brand_id == brand_id, + ) + count_result = await db.execute(count_stmt) + total = count_result.scalar_one() + + stmt = ( + select(GeoPlan) + .where(GeoPlan.brand_id == brand_id) + .order_by(GeoPlan.created_at.desc()) + .offset(skip) + .limit(limit) + ) + result = await db.execute(stmt) + plans = list(result.scalars().all()) + + plan_responses = [] + for plan in plans: + action_stmt = select(GeoPlanAction).where( + GeoPlanAction.plan_id == plan.id + ).order_by(GeoPlanAction.sort_order) + action_result = await db.execute(action_stmt) + plan.actions = list(action_result.scalars().all()) + plan_responses.append(_plan_to_response(plan)) + + return GeoPlanListResponse(plans=plan_responses, total=total) + + +@router.get("/{plan_id}", response_model=GeoPlanResponse) +async def get_plan_detail( + plan_id: uuid.UUID, + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + stmt = select(GeoPlan).where(GeoPlan.id == plan_id) + result = await db.execute(stmt) + plan = result.scalar_one_or_none() + + if not plan: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="方案不存在", + ) + + brand = await _get_brand_with_access(plan.brand_id, db, current_user) + + action_stmt = select(GeoPlanAction).where( + GeoPlanAction.plan_id == plan.id + ).order_by(GeoPlanAction.sort_order) + action_result = await db.execute(action_stmt) + plan.actions = list(action_result.scalars().all()) + + return _plan_to_response(plan) + + +@router.put("/actions/{action_id}/status", response_model=GeoPlanActionResponse) +async def update_action_status( + action_id: uuid.UUID, + status_update: GeoPlanActionUpdateStatus, + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + valid_statuses = {"pending", "in_progress", "completed", "skipped"} + if status_update.status not in valid_statuses: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"无效的状态值,支持: {', '.join(valid_statuses)}", + ) + + stmt = select(GeoPlanAction).where(GeoPlanAction.id == action_id) + result = await db.execute(stmt) + action = result.scalar_one_or_none() + + if not action: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="行动项不存在", + ) + + plan_stmt = select(GeoPlan).where(GeoPlan.id == action.plan_id) + plan_result = await db.execute(plan_stmt) + plan = plan_result.scalar_one() + + await _get_brand_with_access(plan.brand_id, db, current_user) + + action.status = status_update.status + if status_update.status == "completed": + action.completed_at = datetime.now() + + await db.commit() + await db.refresh(action) + + return GeoPlanActionResponse( + id=action.id, + plan_id=action.plan_id, + action_type=action.action_type, + title=action.title, + description=action.description, + reason=action.reason, + priority=action.priority, + status=action.status, + target_keyword=action.target_keyword, + target_platform=action.target_platform, + content_style=action.content_style, + estimated_impact=action.estimated_impact, + difficulty=action.difficulty, + execution_params=action.execution_params, + sort_order=action.sort_order, + completed_at=action.completed_at, + created_at=action.created_at, + ) + + +@router.post("/actions/{action_id}/execute", response_model=GeoPlanActionExecuteResponse) +async def execute_action( + action_id: uuid.UUID, + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + stmt = select(GeoPlanAction).where(GeoPlanAction.id == action_id) + result = await db.execute(stmt) + action = result.scalar_one_or_none() + + if not action: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="行动项不存在", + ) + + plan_stmt = select(GeoPlan).where(GeoPlan.id == action.plan_id) + plan_result = await db.execute(plan_stmt) + plan = plan_result.scalar_one() + + brand = await _get_brand_with_access(plan.brand_id, db, current_user) + + if action.action_type not in ("content_creation", "content_optimization"): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"行动类型 '{action.action_type}' 不支持一键执行,仅支持 content_creation 和 content_optimization", + ) + + params = action.execution_params or {} + keyword = params.get("keyword", action.target_keyword or brand.name) + platform = params.get("platform", action.target_platform or "通用") + style = params.get("style", action.content_style or "专业严谨") + word_count = params.get("word_count", 2000) + knowledge_base_ids = params.get("knowledge_base_ids") + + content_service = ContentGenerationService() + + try: + gen_result = await content_service.generate_content( + keyword=keyword, + brand_name=brand.name, + platform=platform, + content_style=style, + word_count=word_count, + knowledge_base_ids=knowledge_base_ids, + db=db, + user_id=current_user.id, + org_id=str(plan.organization_id), + run_deai=True, + run_geo=True, + ) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"内容生成失败: {str(e)}", + ) + + action.status = "completed" + action.completed_at = datetime.now() + await db.commit() + await db.refresh(action) + + content_id = gen_result.get("content_id") + + return GeoPlanActionExecuteResponse( + action_id=action.id, + content_id=content_id, + message="内容生成成功" if content_id else "内容生成完成(未持久化)", + ) diff --git a/backend/app/api/suggestions.py b/backend/app/api/suggestions.py index d039b84..bc2c4db 100644 --- a/backend/app/api/suggestions.py +++ b/backend/app/api/suggestions.py @@ -22,9 +22,9 @@ from app.schemas.suggestion import ( SuggestionHistoryResponse, ) from app.schemas.scoring import CitationResult -from app.services.scoring_service import ScoringService -from app.services.sentiment_service import get_sentiment_service -from app.services.optimization_advisor import ( +from app.services.scoring.scoring_service import ScoringService +from app.services.analysis.sentiment_service import get_sentiment_service +from app.services.advisor.optimization_advisor import ( generate_suggestions, build_context_from_scoring_result, ) @@ -163,7 +163,7 @@ async def _get_brand_scoring_data( ) # 计算平台评分 - from app.api.dashboard import REQUIRED_PLATFORMS + from app.services.scoring.brand_scoring_data_service import REQUIRED_PLATFORMS platform_scores: dict[str, float] = {} for platform in REQUIRED_PLATFORMS: platform_citations = [c for c in all_citations if c.platform == platform] diff --git a/backend/app/api/trends.py b/backend/app/api/trends.py new file mode 100644 index 0000000..91b7fc7 --- /dev/null +++ b/backend/app/api/trends.py @@ -0,0 +1,124 @@ +import uuid + +from fastapi import APIRouter, Depends, HTTPException, Query, status +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.api.deps import get_current_user +from app.database import get_db +from app.models.user import User +from app.models.brand import Brand +from app.models.trend_insight import TrendInsight +from app.schemas.trend_insight import ( + TrendInsightRequest, + TrendInsightResponse, + TrendInsightList, + TrendSummary, +) +from app.services.trend.trend_analyzer_service import TrendAnalyzerService + +router = APIRouter() + + +async def _get_brand_with_access( + brand_id: uuid.UUID, + db: AsyncSession, + current_user: User, +) -> Brand: + stmt = select(Brand).where( + Brand.id == brand_id, + Brand.user_id == current_user.id, + ) + result = await db.execute(stmt) + brand = result.scalar_one_or_none() + + if not brand: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="品牌不存在", + ) + return brand + + +@router.post("/insight", response_model=TrendInsightResponse) +async def create_trend_insight( + request: TrendInsightRequest, + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + await _get_brand_with_access(request.brand_id, db, current_user) + + service = TrendAnalyzerService(db) + result = await service.analyze_trends( + brand_id=request.brand_id, + days=request.period_days, + platforms=request.platforms, + keywords=request.keywords, + ) + + if result.get("status") == "insufficient_data": + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=result.get("message", "数据不足"), + ) + + insight_id = uuid.UUID(result["insight_id"]) + insight = await service.get_insight_by_id(insight_id) + if insight is None: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="洞察创建失败", + ) + return insight + + +@router.get("/brand/{brand_id}", response_model=TrendInsightList) +async def list_trend_insights( + brand_id: uuid.UUID, + skip: int = Query(0, ge=0), + limit: int = Query(20, ge=1, le=100), + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + await _get_brand_with_access(brand_id, db, current_user) + + service = TrendAnalyzerService(db) + items, total = await service.get_insights( + brand_id=brand_id, + skip=skip, + limit=limit, + ) + return TrendInsightList(items=items, total=total) + + +@router.get("/brand/{brand_id}/summary", response_model=TrendSummary) +async def get_trend_summary( + brand_id: uuid.UUID, + period_days: int = Query(30, ge=7, le=365), + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + await _get_brand_with_access(brand_id, db, current_user) + + service = TrendAnalyzerService(db) + summary = await service.get_summary( + brand_id=brand_id, + days=period_days, + ) + return TrendSummary(**summary) + + +@router.get("/{insight_id}", response_model=TrendInsightResponse) +async def get_trend_insight_detail( + insight_id: uuid.UUID, + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + service = TrendAnalyzerService(db) + insight = await service.get_insight_by_id(insight_id) + if insight is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="洞察不存在", + ) + return insight diff --git a/backend/app/config.py b/backend/app/config.py index 2c47257..ff2a1d4 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -64,6 +64,40 @@ class Settings(BaseSettings): # AI平台API调用频率限制(每分钟请求数) API_RATE_LIMIT_RPM: int = 10 + # Payment Gateway Configuration + WECHAT_PAY_MCH_ID: str = "" + WECHAT_PAY_API_KEY: str = "" + WECHAT_PAY_APP_ID: str = "" + WECHAT_PAY_CERT_PATH: str = "" + WECHAT_PAY_NOTIFY_URL: str = "" + ALIPAY_APP_ID: str = "" + ALIPAY_PRIVATE_KEY_PATH: str = "" + ALIPAY_PUBLIC_KEY_PATH: str = "" + ALIPAY_NOTIFY_URL: str = "" + PAYMENT_MODE: str = "mock" + + ZHIHU_CLIENT_ID: str = "" + ZHIHU_CLIENT_SECRET: str = "" + ZHIHU_ACCESS_TOKEN: str = "" + TOUTIAO_APP_ID: str = "" + TOUTIAO_APP_SECRET: str = "" + TOUTIAO_ACCESS_TOKEN: str = "" + WECHAT_OFFICIAL_APP_ID: str = "" + WECHAT_OFFICIAL_APP_SECRET: str = "" + SMTP_HOST: str = "" + SMTP_PORT: int = 587 + SMTP_USER: str = "" + SMTP_PASSWORD: str = "" + SMTP_FROM_EMAIL: str = "noreply@geo-platform.com" + SMTP_FROM_NAME: str = "GEO平台" + EMAIL_MODE: str = "mock" + SENDGRID_API_KEY: str = "" + ALIYUN_MAIL_ACCESS_KEY: str = "" + ALIYUN_MAIL_ACCESS_SECRET: str = "" + ALIYUN_MAIL_REGION: str = "cn-hangzhou" + + DISTRIBUTION_MODE: str = "mock" + @field_validator("JWT_SECRET") @classmethod def validate_jwt_secret(cls, v: str) -> str: diff --git a/backend/app/main.py b/backend/app/main.py index 919468b..4e8fdbf 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -43,14 +43,21 @@ from app.api.ai_engines import router as ai_engines_router from app.api.detection import router as detection_router from app.api.api_keys import router as api_keys_router from app.api.usage import router as usage_router +from app.api.strategy import router as strategy_router +from app.api.competitor_analysis import router as competitor_analysis_router +from app.api.trends import router as trends_router +from app.api.schema_advisor import router as schema_advisor_router +from app.api.monitoring import router as monitoring_router +from app.api.health_score import router as health_score_router +from app.api.payments import router as payments_router +from app.api.attribution import router as attribution_router from app.config import settings from app.database import engine, Base from app.schemas.common import ErrorResponse, ErrorCode from app.middleware.rate_limit import RateLimitMiddleware from app.middleware.logging_middleware import RequestLoggingMiddleware from app.middleware.request_id import RequestIdMiddleware -from app.middleware.metrics import MetricsMiddleware -from app.monitoring.middleware import MonitoringMiddleware +from app.middleware.metrics import MetricsMiddleware, MonitoringMiddleware from app.database import get_db from app.workers.scheduler import query_scheduler @@ -59,7 +66,13 @@ from app.workers.scheduler import query_scheduler async def lifespan(app: FastAPI): import app.models - import app.monitoring + import app.middleware.prometheus_metrics + from app.middleware.prometheus_metrics import SERVICE_INFO + import os + SERVICE_INFO.info({ + "version": "1.0.0", + "environment": os.getenv("ENVIRONMENT", "development"), + }) async with engine.begin() as conn: await conn.execute(text("SELECT 1")) @@ -120,10 +133,14 @@ _allow_origins = [origin.strip() for origin in settings.CORS_ORIGINS.split(",") if not _allow_origins: _allow_origins = ["http://localhost:3000"] +import os + +_is_dev = os.getenv("ENVIRONMENT", "development") == "development" + app.add_middleware( CORSMiddleware, - allow_origins=_allow_origins, - allow_credentials=True, + allow_origins=_allow_origins if not _is_dev else ["*"], + allow_credentials=not _is_dev, allow_methods=["*"], allow_headers=["*"], ) @@ -174,6 +191,14 @@ app.include_router(ai_engines_router, prefix="/api/v1/ai-engines", tags=["AI引 app.include_router(detection_router, prefix="/api/v1/detection", tags=["定时检测任务"]) app.include_router(api_keys_router, prefix="/api/v1/api-keys", tags=["API Key管理"]) app.include_router(usage_router, prefix="/api/v1/usage", tags=["用量追踪"]) +app.include_router(strategy_router, prefix="/api/v1/strategy", tags=["GEO方案"]) +app.include_router(competitor_analysis_router, prefix="/api/v1/competitor", tags=["竞品分析"]) +app.include_router(schema_advisor_router, prefix="/api/v1/schema", tags=["Schema建议"]) +app.include_router(trends_router, prefix="/api/v1/trends", tags=["趋势洞察"]) +app.include_router(monitoring_router, prefix="/api/v1/monitoring", tags=["效果追踪"]) +app.include_router(health_score_router, prefix="/api/v1/public", tags=["公开API"]) +app.include_router(payments_router) +app.include_router(attribution_router, prefix="/api/v1/attribution", tags=["效果归因"]) @app.get("/health", tags=["可观测性"]) diff --git a/backend/app/monitoring/agent_hooks.py b/backend/app/middleware/agent_hooks.py similarity index 96% rename from backend/app/monitoring/agent_hooks.py rename to backend/app/middleware/agent_hooks.py index b86302d..9df929c 100644 --- a/backend/app/monitoring/agent_hooks.py +++ b/backend/app/middleware/agent_hooks.py @@ -3,7 +3,7 @@ import time from contextlib import asynccontextmanager from typing import Optional -from app.monitoring.metrics import ( +from app.middleware.prometheus_metrics import ( AGENT_EXECUTIONS_TOTAL, AGENT_EXECUTION_DURATION_SECONDS, AGENT_RUNNING_TASKS, diff --git a/backend/app/monitoring/llm_metrics.py b/backend/app/middleware/llm_metrics.py similarity index 98% rename from backend/app/monitoring/llm_metrics.py rename to backend/app/middleware/llm_metrics.py index 85de516..5bf0813 100644 --- a/backend/app/monitoring/llm_metrics.py +++ b/backend/app/middleware/llm_metrics.py @@ -2,7 +2,7 @@ import time from typing import Optional -from app.monitoring.metrics import ( +from app.middleware.prometheus_metrics import ( LLM_REQUESTS_TOTAL, LLM_REQUEST_DURATION_SECONDS, LLM_TOKENS_TOTAL, diff --git a/backend/app/middleware/metrics.py b/backend/app/middleware/metrics.py index c87b97a..bd1a5dd 100644 --- a/backend/app/middleware/metrics.py +++ b/backend/app/middleware/metrics.py @@ -1,9 +1,19 @@ -"""请求指标收集中间件:计时、慢请求告警、响应时间响应头。""" +"""请求指标收集中间件:计时、慢请求告警、响应时间响应头、Prometheus指标收集。 + +合并自原 middleware/metrics.py(MetricsMiddleware)和 monitoring/middleware.py(MonitoringMiddleware)。 +""" import time import logging +from typing import Callable + +from fastapi import Request, Response from starlette.middleware.base import BaseHTTPMiddleware -from starlette.requests import Request -from starlette.responses import Response + +from app.middleware.prometheus_metrics import ( + API_REQUESTS_TOTAL, + API_REQUEST_DURATION_SECONDS, + API_REQUESTS_IN_PROGRESS, +) logger = logging.getLogger("geo.metrics") @@ -11,14 +21,14 @@ logger = logging.getLogger("geo.metrics") SLOW_REQUEST_THRESHOLD = 1.0 # 跳过指标收集的路径前缀(健康检查等高频低价值路径) -_SKIP_PATHS = {"/health", "/ready", "/docs", "/openapi.json", "/favicon.ico"} +_SKIP_PATHS = {"/health", "/ready", "/docs", "/openapi.json", "/favicon.ico", "/metrics"} class MetricsMiddleware(BaseHTTPMiddleware): """记录每个 HTTP 请求的耗时,并: - 在响应头写入 X-Response-Time - 对超过阈值的慢请求输出 WARNING 日志(携带结构化字段) - - 预留 Sentry / Prometheus 集成点(TODO 注释标注) + - 预留 Sentry 集成点(TODO 注释标注) """ async def dispatch(self, request: Request, call_next) -> Response: @@ -51,10 +61,82 @@ class MetricsMiddleware(BaseHTTPMiddleware): else: logger.debug("Request completed", extra=log_extra) - # TODO: 集成 Prometheus Counter/Histogram - # metrics_registry.http_request_duration.observe(duration, labels={...}) - # TODO: 集成 Sentry 性能监控 # if sentry_sdk: sentry_sdk.set_measurement("response_time_ms", duration_ms) return response + + +class MonitoringMiddleware(BaseHTTPMiddleware): + """API监控中间件 — 收集 Prometheus 指标。 + + - 记录请求总数、耗时分布、活跃请求数 + - 自动规范化端点标签(替换路径中的ID参数) + """ + + async def dispatch(self, request: Request, call_next: Callable) -> Response: + # 跳过排除路径 + if request.url.path in _SKIP_PATHS: + return await call_next(request) + + # 提取端点标识(用于指标标签) + endpoint = self._get_endpoint_label(request) + + # 增加活跃请求计数 + API_REQUESTS_IN_PROGRESS.labels( + method=request.method, + endpoint=endpoint + ).inc() + + # 记录开始时间 + start_time = time.perf_counter() + + try: + # 执行请求 + response = await call_next(request) + status_code = response.status_code + except Exception as e: + status_code = 500 + raise + finally: + # 计算耗时 + duration = time.perf_counter() - start_time + + # 记录指标 + API_REQUESTS_TOTAL.labels( + method=request.method, + endpoint=endpoint, + status=str(status_code) + ).inc() + + API_REQUEST_DURATION_SECONDS.labels( + method=request.method, + endpoint=endpoint + ).observe(duration) + + # 减少活跃请求计数 + API_REQUESTS_IN_PROGRESS.labels( + method=request.method, + endpoint=endpoint + ).dec() + + return response + + def _get_endpoint_label(self, request: Request) -> str: + """提取端点标签""" + path = request.url.path + + # 规范化路径(替换ID等参数) + parts = path.strip("/").split("/") + + # 处理常见模式:/api/v1/resources/{id} + if len(parts) >= 4 and parts[0] == "api": + resource = parts[2] if len(parts) > 2 else "unknown" + action = parts[3] if len(parts) > 3 else "list" + + # 映射到规范标签 + if action.isdigit(): + return f"{resource}_detail" + return f"{resource}_{action}" + + return "other" diff --git a/backend/app/monitoring/metrics.py b/backend/app/middleware/prometheus_metrics.py similarity index 100% rename from backend/app/monitoring/metrics.py rename to backend/app/middleware/prometheus_metrics.py diff --git a/backend/app/middleware/subscription_enforcement.py b/backend/app/middleware/subscription_enforcement.py new file mode 100644 index 0000000..af0fee6 --- /dev/null +++ b/backend/app/middleware/subscription_enforcement.py @@ -0,0 +1,57 @@ +from fastapi import Depends, HTTPException +from sqlalchemy.ext.asyncio import AsyncSession + +from app.api.deps import get_current_user +from app.database import get_db +from app.models.user import User +from app.services.subscription import PLANS + + +class SubscriptionEnforcement: + @staticmethod + def require_plan(*allowed_plans: str): + async def _check(current_user: User = Depends(get_current_user)): + user_plan = getattr(current_user, "plan", "free") or "free" + if user_plan not in allowed_plans: + raise HTTPException( + status_code=403, + detail={ + "message": f"此功能需要 {allowed_plans[0]} 及以上套餐", + "required_plan": allowed_plans[0], + "current_plan": user_plan, + "upgrade_url": "/api/v1/subscriptions/plans", + }, + ) + return current_user + return _check + + @staticmethod + def check_quota(resource: str): + async def _check( + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), + ): + user_plan = getattr(current_user, "plan", "free") or "free" + plan_config = PLANS.get(user_plan, PLANS["free"]) + + if resource == "queries": + limit = plan_config.get("max_queries", 3) + current_usage = getattr(current_user, "max_queries", limit) or limit + remaining = max(0, limit - current_usage) + elif resource == "brands": + limit = plan_config.get("max_brands", 1) + remaining = limit if limit == -1 else max(0, limit) + elif resource == "alerts": + limit = plan_config.get("max_alerts_per_month", 0) + remaining = limit if limit == -1 else max(0, limit) + else: + remaining = 0 + + return { + "user_id": current_user.id, + "plan": user_plan, + "resource": resource, + "remaining": remaining, + "unlimited": remaining == -1, + } + return _check diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index e4a5550..c41ec6e 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -33,6 +33,14 @@ from app.models.alert import Alert from app.models.alert_setting import AlertSetting from app.models.detection_task import DetectionTask from app.models.usage_record import UsageRecord +from app.models.geo_plan import GeoPlan, GeoPlanAction +from app.models.trend_insight import TrendInsight +from app.models.competitor_insight import CompetitorInsight +from app.models.schema_suggestion import SchemaSuggestion +from app.models.monitoring import MonitoringRecord, ContentBaseline +from app.models.diagnosis_record import DiagnosisRecord +from app.models.payment_order import PaymentOrder +from app.models.attribution_record import AttributionRecord __all__ = [ "User", @@ -76,4 +84,14 @@ __all__ = [ "AlertSetting", "DetectionTask", "UsageRecord", + "GeoPlan", + "GeoPlanAction", + "CompetitorInsight", + "SchemaSuggestion", + "TrendInsight", + "MonitoringRecord", + "ContentBaseline", + "DiagnosisRecord", + "PaymentOrder", + "AttributionRecord", ] diff --git a/backend/app/models/attribution_record.py b/backend/app/models/attribution_record.py new file mode 100644 index 0000000..5e8f9dc --- /dev/null +++ b/backend/app/models/attribution_record.py @@ -0,0 +1,65 @@ +import uuid +from datetime import datetime + +from sqlalchemy import String, Float, Integer, ForeignKey, Index, func, Text +from sqlalchemy import Uuid +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from app.database import Base, JSONType + + +class AttributionRecord(Base): + __tablename__ = "attribution_records" + + id: Mapped[uuid.UUID] = mapped_column( + Uuid(as_uuid=True), + primary_key=True, + default=uuid.uuid4, + ) + user_id: Mapped[str] = mapped_column( + Text, + ForeignKey("users.id", ondelete="CASCADE"), + nullable=False, + ) + brand_id: Mapped[uuid.UUID] = mapped_column( + Uuid(as_uuid=True), + ForeignKey("brands.id", ondelete="CASCADE"), + nullable=False, + ) + content_id: Mapped[uuid.UUID | None] = mapped_column( + Uuid(as_uuid=True), + ForeignKey("contents.id", ondelete="SET NULL"), + nullable=True, + ) + baseline_score: Mapped[float] = mapped_column(Float, nullable=False) + current_score: Mapped[float | None] = mapped_column(Float, nullable=True) + score_delta: Mapped[float | None] = mapped_column(Float, nullable=True) + attribution_window_days: Mapped[int] = mapped_column( + Integer, server_default="28", nullable=False, + ) + published_at: Mapped[datetime | None] = mapped_column(nullable=True) + window_end_at: Mapped[datetime | None] = mapped_column(nullable=True) + status: Mapped[str] = mapped_column( + String(20), server_default="tracking", nullable=False, + ) + attributed_dimensions: Mapped[dict | None] = mapped_column(JSONType, nullable=True) + roi_percentage: Mapped[float | None] = mapped_column(Float, nullable=True) + created_at: Mapped[datetime] = mapped_column( + server_default=func.now(), + nullable=False, + ) + updated_at: Mapped[datetime] = mapped_column( + server_default=func.now(), + onupdate=func.now(), + nullable=False, + ) + + brand: Mapped["Brand"] = relationship("Brand") + content: Mapped["Content | None"] = relationship("Content") + + __table_args__ = ( + Index("idx_attribution_records_brand_id", "brand_id"), + Index("idx_attribution_records_user_id", "user_id"), + Index("idx_attribution_records_status", "status"), + Index("idx_attribution_records_content_id", "content_id"), + ) diff --git a/backend/app/models/brand.py b/backend/app/models/brand.py index ef9420d..bd56c30 100644 --- a/backend/app/models/brand.py +++ b/backend/app/models/brand.py @@ -62,7 +62,12 @@ class Brand(Base): "Suggestion", back_populates="brand", cascade="all, delete-orphan" ) + schema_suggestions: Mapped[list["SchemaSuggestion"]] = relationship( + "SchemaSuggestion", back_populates="brand", cascade="all, delete-orphan" + ) + # Import at bottom to avoid circular import from app.models.competitor import Competitor # noqa: E402, F401 from app.models.suggestion import Suggestion # noqa: E402, F401 +from app.models.schema_suggestion import SchemaSuggestion # noqa: E402, F401 diff --git a/backend/app/models/citation_record.py b/backend/app/models/citation_record.py index a8fc947..2bfca19 100644 --- a/backend/app/models/citation_record.py +++ b/backend/app/models/citation_record.py @@ -6,6 +6,7 @@ from sqlalchemy import Uuid, JSON from sqlalchemy.orm import Mapped, mapped_column, relationship from app.database import Base +from app.utils.text import sanitize_raw_response class CitationRecord(Base): @@ -75,3 +76,40 @@ class CitationRecord(Base): Index("idx_citation_records_queried_at", "queried_at"), Index("idx_citation_records_platform", "platform"), ) + + @classmethod + def from_citation_result( + cls, + query_id: uuid.UUID, + platform: str, + result: dict, + ) -> "CitationRecord": + """从引用检测结果字典创建 CitationRecord 实例 + + 统一处理字段映射、默认值和 raw_response / ai_response_text 的清理。 + + Args: + query_id: 关联的查询 ID + platform: 平台名称 + result: 引用检测结果字典 + + Returns: + CitationRecord 实例(未持久化) + """ + return cls( + query_id=query_id, + platform=platform, + cited=result.get("cited", False), + citation_position=result.get("position"), + citation_text=result.get("citation_text"), + competitor_brands=result.get("competitor_brands", []), + raw_response=sanitize_raw_response(result.get("raw_response", "")), + confidence=result.get("confidence"), + match_type=result.get("match_type"), + # 引用源分析字段 + data_source=result.get("data_source"), + source_urls=result.get("source_urls"), + source_titles=result.get("source_titles"), + citation_contexts=result.get("citation_contexts"), + ai_response_text=sanitize_raw_response(result.get("ai_response_text", "")), + ) diff --git a/backend/app/models/competitor_insight.py b/backend/app/models/competitor_insight.py new file mode 100644 index 0000000..b5b1582 --- /dev/null +++ b/backend/app/models/competitor_insight.py @@ -0,0 +1,64 @@ +import uuid +from datetime import datetime + +from sqlalchemy import String, Float, Integer, ForeignKey, Index, func +from sqlalchemy import Uuid +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.orm import Mapped, mapped_column +from sqlalchemy.types import TypeDecorator, JSON + +from app.database import Base + + +class JSONType(TypeDecorator): + impl = JSON + cache_ok = True + + def load_dialect_impl(self, dialect): + if dialect.name == "postgresql": + return dialect.type_descriptor(JSONB()) + return dialect.type_descriptor(JSON()) + + +class CompetitorInsight(Base): + __tablename__ = "competitor_insights" + + id: Mapped[uuid.UUID] = mapped_column( + Uuid(as_uuid=True), + primary_key=True, + default=uuid.uuid4, + ) + brand_id: Mapped[uuid.UUID] = mapped_column( + Uuid(as_uuid=True), + ForeignKey("brands.id", ondelete="CASCADE"), + nullable=False, + ) + competitor_name: Mapped[str] = mapped_column(String(100), nullable=False) + analysis_type: Mapped[str] = mapped_column( + String(50), nullable=False, + ) + insight_data: Mapped[dict | None] = mapped_column(JSONType, nullable=True) + citation_count_brand: Mapped[int] = mapped_column(Integer, default=0, nullable=False) + citation_count_competitor: Mapped[int] = mapped_column(Integer, default=0, nullable=False) + sentiment_brand: Mapped[float | None] = mapped_column(Float, nullable=True) + sentiment_competitor: Mapped[float | None] = mapped_column(Float, nullable=True) + platform_breakdown: Mapped[dict | None] = mapped_column(JSONType, nullable=True) + gap_analysis: Mapped[dict | None] = mapped_column(JSONType, nullable=True) + opportunity_areas: Mapped[dict | None] = mapped_column(JSONType, nullable=True) + recommendations: Mapped[dict | None] = mapped_column(JSONType, nullable=True) + confidence: Mapped[str] = mapped_column(String(20), default="medium", nullable=False) + period_days: Mapped[int] = mapped_column(Integer, default=30, nullable=False) + created_at: Mapped[datetime] = mapped_column( + server_default=func.now(), + nullable=False, + ) + updated_at: Mapped[datetime] = mapped_column( + server_default=func.now(), + onupdate=func.now(), + nullable=False, + ) + + __table_args__ = ( + Index("idx_competitor_insights_brand_id", "brand_id"), + Index("idx_competitor_insights_analysis_type", "analysis_type"), + ) diff --git a/backend/app/models/diagnosis_record.py b/backend/app/models/diagnosis_record.py new file mode 100644 index 0000000..e83270a --- /dev/null +++ b/backend/app/models/diagnosis_record.py @@ -0,0 +1,40 @@ +import uuid +from datetime import datetime + +from sqlalchemy import String, Uuid, JSON, Float, Text, ForeignKey, Index, func +from sqlalchemy.orm import Mapped, mapped_column + +from app.database import Base + + +class DiagnosisRecord(Base): + __tablename__ = "diagnosis_records" + + id: Mapped[uuid.UUID] = mapped_column( + Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4 + ) + brand_id: Mapped[uuid.UUID] = mapped_column( + Uuid(as_uuid=True), + ForeignKey("brands.id", ondelete="CASCADE"), + nullable=False, + ) + user_id: Mapped[uuid.UUID] = mapped_column(Uuid(as_uuid=True), nullable=False) + diagnosis_type: Mapped[str] = mapped_column( + String(20), default="geo", nullable=False + ) + status: Mapped[str] = mapped_column(String(20), default="pending", nullable=False) + overall_score: Mapped[float | None] = mapped_column(Float, nullable=True) + result_json: Mapped[dict | None] = mapped_column(JSON, nullable=True) + error_message: Mapped[str | None] = mapped_column(Text, nullable=True) + collection_metadata: Mapped[dict | None] = mapped_column(JSON, nullable=True) + created_at: Mapped[datetime] = mapped_column( + server_default=func.now(), nullable=False + ) + completed_at: Mapped[datetime | None] = mapped_column(nullable=True) + + __table_args__ = ( + Index("idx_diagnosis_records_brand_id", "brand_id"), + Index("idx_diagnosis_records_user_id", "user_id"), + Index("idx_diagnosis_records_status", "status"), + Index("idx_diagnosis_records_created_at", "created_at"), + ) diff --git a/backend/app/models/geo_plan.py b/backend/app/models/geo_plan.py new file mode 100644 index 0000000..23c5f04 --- /dev/null +++ b/backend/app/models/geo_plan.py @@ -0,0 +1,112 @@ +import uuid +from datetime import datetime + +from sqlalchemy import String, Integer, Text, ForeignKey, Index, func +from sqlalchemy import Uuid +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from app.database import Base, JSONType + + +class GeoPlan(Base): + __tablename__ = "geo_plans" + + id: Mapped[uuid.UUID] = mapped_column( + Uuid(as_uuid=True), + primary_key=True, + default=uuid.uuid4, + ) + organization_id: Mapped[uuid.UUID] = mapped_column( + Uuid(as_uuid=True), + ForeignKey("organizations.id", ondelete="CASCADE"), + nullable=False, + ) + brand_id: Mapped[uuid.UUID] = mapped_column( + Uuid(as_uuid=True), + ForeignKey("brands.id", ondelete="CASCADE"), + nullable=False, + ) + title: Mapped[str] = mapped_column(String(500), nullable=False) + status: Mapped[str] = mapped_column( + String(20), server_default="draft", nullable=False, + ) + diagnosis_score: Mapped[int] = mapped_column(Integer, nullable=False) + target_score: Mapped[int] = mapped_column(Integer, nullable=False) + estimated_weeks: Mapped[int] = mapped_column(Integer, nullable=False) + plan_data: Mapped[dict | None] = mapped_column(JSONType, nullable=True) + source: Mapped[str] = mapped_column(String(20), nullable=False, default="rule") + created_by: Mapped[str | None] = mapped_column( + String(36), + ForeignKey("users.id", ondelete="SET NULL"), + nullable=True, + ) + created_at: Mapped[datetime] = mapped_column( + server_default=func.now(), + nullable=False, + ) + updated_at: Mapped[datetime] = mapped_column( + server_default=func.now(), + onupdate=func.now(), + nullable=False, + ) + + brand: Mapped["Brand"] = relationship("Brand") + creator: Mapped["User | None"] = relationship( + "User", foreign_keys=[created_by] + ) + + __table_args__ = ( + Index("idx_geo_plans_brand_id", "brand_id"), + Index("idx_geo_plans_status", "status"), + Index("idx_geo_plans_organization_id", "organization_id"), + ) + + +class GeoPlanAction(Base): + __tablename__ = "geo_plan_actions" + + id: Mapped[uuid.UUID] = mapped_column( + Uuid(as_uuid=True), + primary_key=True, + default=uuid.uuid4, + ) + plan_id: Mapped[uuid.UUID] = mapped_column( + Uuid(as_uuid=True), + ForeignKey("geo_plans.id", ondelete="CASCADE"), + nullable=False, + ) + action_type: Mapped[str] = mapped_column(String(50), nullable=False) + title: Mapped[str] = mapped_column(String(500), nullable=False) + description: Mapped[str] = mapped_column(Text, nullable=False) + reason: Mapped[str] = mapped_column(Text, nullable=False) + priority: Mapped[str] = mapped_column(String(10), nullable=False) + status: Mapped[str] = mapped_column( + String(20), server_default="pending", nullable=False, + ) + target_keyword: Mapped[str | None] = mapped_column(String(200), nullable=True) + target_platform: Mapped[str | None] = mapped_column(String(50), nullable=True) + content_style: Mapped[str | None] = mapped_column(String(50), nullable=True) + estimated_impact: Mapped[str | None] = mapped_column(String(500), nullable=True) + difficulty: Mapped[str] = mapped_column(String(10), nullable=False) + execution_params: Mapped[dict | None] = mapped_column(JSONType, nullable=True) + sort_order: Mapped[int] = mapped_column( + Integer, server_default="0", nullable=False, + ) + completed_at: Mapped[datetime | None] = mapped_column(nullable=True) + created_at: Mapped[datetime] = mapped_column( + server_default=func.now(), + nullable=False, + ) + updated_at: Mapped[datetime] = mapped_column( + server_default=func.now(), + onupdate=func.now(), + nullable=False, + ) + + plan: Mapped["GeoPlan"] = relationship("GeoPlan") + + __table_args__ = ( + Index("idx_geo_plan_actions_plan_id", "plan_id"), + Index("idx_geo_plan_actions_status", "status"), + Index("idx_geo_plan_actions_priority", "priority"), + ) diff --git a/backend/app/models/monitoring.py b/backend/app/models/monitoring.py new file mode 100644 index 0000000..93f193e --- /dev/null +++ b/backend/app/models/monitoring.py @@ -0,0 +1,100 @@ +import uuid +from datetime import datetime + +from sqlalchemy import String, Integer, Float, ForeignKey, Index, func +from sqlalchemy import Uuid +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from app.database import Base, JSONType + + +class MonitoringRecord(Base): + __tablename__ = "monitoring_records" + + id: Mapped[uuid.UUID] = mapped_column( + Uuid(as_uuid=True), + primary_key=True, + default=uuid.uuid4, + ) + brand_id: Mapped[uuid.UUID] = mapped_column( + Uuid(as_uuid=True), + ForeignKey("brands.id", ondelete="CASCADE"), + nullable=False, + ) + content_id: Mapped[str | None] = mapped_column(String(36), nullable=True) + query_keywords: Mapped[str | None] = mapped_column(String(500), nullable=True) + platform: Mapped[str | None] = mapped_column(String(50), nullable=True) + baseline_citation_count: Mapped[int] = mapped_column( + Integer, server_default="0", nullable=False, + ) + baseline_sentiment: Mapped[float | None] = mapped_column(Float, nullable=True) + baseline_rank: Mapped[int | None] = mapped_column(Integer, nullable=True) + current_citation_count: Mapped[int] = mapped_column( + Integer, server_default="0", nullable=False, + ) + current_sentiment: Mapped[float | None] = mapped_column(Float, nullable=True) + current_rank: Mapped[int | None] = mapped_column(Integer, nullable=True) + change_type: Mapped[str | None] = mapped_column(String(20), nullable=True) + change_details: Mapped[dict | None] = mapped_column(JSONType, nullable=True) + check_interval_hours: Mapped[int] = mapped_column( + Integer, server_default="24", nullable=False, + ) + last_checked_at: Mapped[datetime | None] = mapped_column(nullable=True) + next_check_at: Mapped[datetime | None] = mapped_column(nullable=True) + status: Mapped[str] = mapped_column( + String(20), server_default="active", nullable=False, + ) + created_at: Mapped[datetime] = mapped_column( + server_default=func.now(), + nullable=False, + ) + updated_at: Mapped[datetime] = mapped_column( + server_default=func.now(), + onupdate=func.now(), + nullable=False, + ) + + brand: Mapped["Brand"] = relationship("Brand") + baselines: Mapped[list["ContentBaseline"]] = relationship( + "ContentBaseline", back_populates="monitoring_record", cascade="all, delete-orphan", + ) + + __table_args__ = ( + Index("idx_monitoring_records_brand_id", "brand_id"), + Index("idx_monitoring_records_status", "status"), + Index("idx_monitoring_records_next_check_at", "next_check_at"), + ) + + +class ContentBaseline(Base): + __tablename__ = "content_baselines" + + id: Mapped[uuid.UUID] = mapped_column( + Uuid(as_uuid=True), + primary_key=True, + default=uuid.uuid4, + ) + monitoring_record_id: Mapped[uuid.UUID] = mapped_column( + Uuid(as_uuid=True), + ForeignKey("monitoring_records.id", ondelete="CASCADE"), + nullable=False, + ) + brand_name: Mapped[str] = mapped_column(String(100), nullable=False) + keyword: Mapped[str] = mapped_column(String(200), nullable=False) + platform: Mapped[str] = mapped_column(String(50), nullable=False) + citation_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + sentiment_score: Mapped[float | None] = mapped_column(Float, nullable=True) + rank_position: Mapped[int | None] = mapped_column(Integer, nullable=True) + snapshot_data: Mapped[dict | None] = mapped_column(JSONType, nullable=True) + recorded_at: Mapped[datetime] = mapped_column( + server_default=func.now(), + nullable=False, + ) + + monitoring_record: Mapped["MonitoringRecord"] = relationship( + "MonitoringRecord", back_populates="baselines", + ) + + __table_args__ = ( + Index("idx_content_baselines_monitoring_record_id", "monitoring_record_id"), + ) diff --git a/backend/app/models/monitoring_record.py b/backend/app/models/monitoring_record.py new file mode 100644 index 0000000..6e99f31 --- /dev/null +++ b/backend/app/models/monitoring_record.py @@ -0,0 +1,108 @@ +import uuid +from datetime import datetime + +from sqlalchemy import String, Integer, Float, ForeignKey, Index, func +from sqlalchemy import Uuid +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from app.database import Base, JSONType + + +class MonitoringRecord(Base): + __tablename__ = "monitoring_records" + + id: Mapped[uuid.UUID] = mapped_column( + Uuid(as_uuid=True), + primary_key=True, + default=uuid.uuid4, + ) + user_id: Mapped[uuid.UUID] = mapped_column( + Uuid(as_uuid=True), + ForeignKey("users.id", ondelete="CASCADE"), + nullable=False, + ) + brand_id: Mapped[uuid.UUID] = mapped_column( + Uuid(as_uuid=True), + ForeignKey("brands.id", ondelete="CASCADE"), + nullable=False, + ) + content_id: Mapped[uuid.UUID | None] = mapped_column( + Uuid(as_uuid=True), + ForeignKey("contents.id", ondelete="SET NULL"), + nullable=True, + ) + query_id: Mapped[uuid.UUID | None] = mapped_column( + Uuid(as_uuid=True), + ForeignKey("queries.id", ondelete="SET NULL"), + nullable=True, + ) + task_type: Mapped[str] = mapped_column(String(50), nullable=False) + status: Mapped[str] = mapped_column( + String(20), server_default="pending", nullable=False, + ) + baseline_data: Mapped[dict | None] = mapped_column(JSONType, nullable=True) + current_data: Mapped[dict | None] = mapped_column(JSONType, nullable=True) + change_report: Mapped[dict | None] = mapped_column(JSONType, nullable=True) + interval_hours: Mapped[int] = mapped_column( + Integer, server_default="24", nullable=False, + ) + last_checked_at: Mapped[datetime | None] = mapped_column(nullable=True) + next_check_at: Mapped[datetime | None] = mapped_column(nullable=True) + created_at: Mapped[datetime] = mapped_column( + server_default=func.now(), + nullable=False, + ) + updated_at: Mapped[datetime] = mapped_column( + server_default=func.now(), + onupdate=func.now(), + nullable=False, + ) + + brand: Mapped["Brand"] = relationship("Brand") + user: Mapped["User"] = relationship("User") + + __table_args__ = ( + Index("idx_monitoring_records_user_id", "user_id"), + Index("idx_monitoring_records_brand_id", "brand_id"), + Index("idx_monitoring_records_status", "status"), + Index("idx_monitoring_records_next_check_at", "next_check_at"), + ) + + +class ContentBaseline(Base): + __tablename__ = "content_baselines" + + id: Mapped[uuid.UUID] = mapped_column( + Uuid(as_uuid=True), + primary_key=True, + default=uuid.uuid4, + ) + brand_id: Mapped[uuid.UUID] = mapped_column( + Uuid(as_uuid=True), + ForeignKey("brands.id", ondelete="CASCADE"), + nullable=False, + ) + content_id: Mapped[uuid.UUID | None] = mapped_column( + Uuid(as_uuid=True), + ForeignKey("contents.id", ondelete="SET NULL"), + nullable=True, + ) + citation_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + positive_ratio: Mapped[float] = mapped_column(Float, nullable=False, default=0.0) + avg_rank: Mapped[float] = mapped_column(Float, nullable=False, default=0.0) + platform_data: Mapped[dict | None] = mapped_column(JSONType, nullable=True) + recorded_at: Mapped[datetime] = mapped_column( + server_default=func.now(), + nullable=False, + ) + created_at: Mapped[datetime] = mapped_column( + server_default=func.now(), + nullable=False, + ) + + brand: Mapped["Brand"] = relationship("Brand") + + __table_args__ = ( + Index("idx_content_baselines_brand_id", "brand_id"), + Index("idx_content_baselines_content_id", "content_id"), + ) diff --git a/backend/app/models/payment_order.py b/backend/app/models/payment_order.py new file mode 100644 index 0000000..22e6905 --- /dev/null +++ b/backend/app/models/payment_order.py @@ -0,0 +1,40 @@ +import uuid +from datetime import datetime + +from sqlalchemy import String, Float, DateTime, ForeignKey, func, Uuid +from sqlalchemy.orm import Mapped, mapped_column + +from app.database import Base, JSONType + + +class PaymentOrder(Base): + __tablename__ = "payment_orders" + + id: Mapped[uuid.UUID] = mapped_column( + Uuid(as_uuid=True), + primary_key=True, + default=uuid.uuid4, + ) + user_id: Mapped[str] = mapped_column( + String(36), + ForeignKey("users.id", ondelete="CASCADE"), + nullable=False, + ) + plan: Mapped[str] = mapped_column(String(20), nullable=False) + amount: Mapped[float] = mapped_column(Float, nullable=False) + currency: Mapped[str] = mapped_column(String(10), default="CNY") + payment_provider: Mapped[str] = mapped_column(String(20), nullable=False) + payment_id: Mapped[str | None] = mapped_column(String(255), nullable=True) + status: Mapped[str] = mapped_column(String(20), default="pending") + pay_url: Mapped[str | None] = mapped_column(String(1024), nullable=True) + callback_data: Mapped[dict | None] = mapped_column(JSONType, nullable=True) + created_at: Mapped[datetime] = mapped_column( + server_default=func.now(), + nullable=False, + ) + updated_at: Mapped[datetime] = mapped_column( + server_default=func.now(), + onupdate=func.now(), + nullable=False, + ) + paid_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) diff --git a/backend/app/models/schema_suggestion.py b/backend/app/models/schema_suggestion.py new file mode 100644 index 0000000..57bebff --- /dev/null +++ b/backend/app/models/schema_suggestion.py @@ -0,0 +1,74 @@ +import uuid +from datetime import datetime + +from sqlalchemy import String, Text, DateTime, ForeignKey, Index, func, Float +from sqlalchemy import Uuid +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from app.database import Base, JSONType + + +class SchemaSuggestion(Base): + __tablename__ = "schema_suggestions" + + id: Mapped[uuid.UUID] = mapped_column( + Uuid(as_uuid=True), + primary_key=True, + default=uuid.uuid4, + ) + brand_id: Mapped[uuid.UUID] = mapped_column( + Uuid(as_uuid=True), + ForeignKey("brands.id", ondelete="CASCADE"), + nullable=False, + ) + schema_type: Mapped[str] = mapped_column( + String(50), nullable=False, + ) + target_url: Mapped[str | None] = mapped_column( + String(500), nullable=True, + ) + json_ld_template: Mapped[dict] = mapped_column( + JSONType, nullable=False, default=dict, + ) + json_ld_filled: Mapped[dict | None] = mapped_column( + JSONType, nullable=True, + ) + priority: Mapped[str] = mapped_column( + String(20), nullable=False, default="medium", + ) + status: Mapped[str] = mapped_column( + String(20), nullable=False, default="pending", + ) + diagnosis_dimensions: Mapped[dict | None] = mapped_column( + JSONType, nullable=True, + ) + implementation_difficulty: Mapped[str] = mapped_column( + String(20), nullable=False, default="medium", + ) + estimated_impact: Mapped[str | None] = mapped_column( + Text, nullable=True, + ) + validation_errors: Mapped[dict | None] = mapped_column( + JSONType, nullable=True, + ) + created_at: Mapped[datetime] = mapped_column( + server_default=func.now(), + nullable=False, + ) + updated_at: Mapped[datetime] = mapped_column( + server_default=func.now(), + onupdate=func.now(), + nullable=False, + ) + + brand: Mapped["Brand"] = relationship("Brand", back_populates="schema_suggestions") + + __table_args__ = ( + Index("idx_schema_suggestions_brand_id", "brand_id"), + Index("idx_schema_suggestions_status", "status"), + Index("idx_schema_suggestions_schema_type", "schema_type"), + Index("idx_schema_suggestions_brand_status", "brand_id", "status"), + ) + + +from app.models.brand import Brand # noqa: E402, F401 diff --git a/backend/app/models/trend_insight.py b/backend/app/models/trend_insight.py new file mode 100644 index 0000000..bec9fa8 --- /dev/null +++ b/backend/app/models/trend_insight.py @@ -0,0 +1,54 @@ +import uuid +from datetime import datetime + +from sqlalchemy import String, Integer, Float, ForeignKey, Index, func, Text +from sqlalchemy import Uuid, JSON +from sqlalchemy.orm import Mapped, mapped_column + +from app.database import Base + + +class TrendInsight(Base): + __tablename__ = "trend_insights" + + id: Mapped[uuid.UUID] = mapped_column( + Uuid(as_uuid=True), + primary_key=True, + default=uuid.uuid4, + ) + brand_id: Mapped[str] = mapped_column( + String(36), + ForeignKey("brands.id", ondelete="CASCADE"), + nullable=False, + ) + trend_type: Mapped[str] = mapped_column(String(20), nullable=False) + keyword: Mapped[str | None] = mapped_column(String(200), nullable=True) + platform: Mapped[str | None] = mapped_column(String(50), nullable=True) + period_start: Mapped[datetime] = mapped_column(nullable=False) + period_end: Mapped[datetime] = mapped_column(nullable=False) + data_points: Mapped[list | None] = mapped_column(JSON, nullable=True) + change_rate: Mapped[float | None] = mapped_column(Float, nullable=True) + absolute_change: Mapped[int | None] = mapped_column(Integer, nullable=True) + sentiment_trend: Mapped[dict | None] = mapped_column(JSON, nullable=True) + cause_analysis: Mapped[str | None] = mapped_column(Text, nullable=True) + recommendations: Mapped[list | None] = mapped_column(JSON, nullable=True) + confidence: Mapped[float] = mapped_column(Float, nullable=False, default=0.5) + severity: Mapped[str] = mapped_column( + String(20), nullable=False, server_default="info", + ) + created_at: Mapped[datetime] = mapped_column( + server_default=func.now(), + nullable=False, + ) + updated_at: Mapped[datetime] = mapped_column( + server_default=func.now(), + onupdate=func.now(), + nullable=False, + ) + + __table_args__ = ( + Index("idx_trend_insights_brand_id", "brand_id"), + Index("idx_trend_insights_trend_type", "trend_type"), + Index("idx_trend_insights_created_at", "created_at"), + Index("idx_trend_insights_period_start", "period_start"), + ) diff --git a/backend/app/monitoring/__init__.py b/backend/app/monitoring/__init__.py deleted file mode 100644 index 79b529b..0000000 --- a/backend/app/monitoring/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -"""监控模块""" -import os - -from app.monitoring.metrics import * -from app.monitoring.middleware import MonitoringMiddleware -from app.monitoring.agent_hooks import agent_execution_context, record_agent_execution -from app.monitoring.llm_metrics import get_llm_metrics, LLMMetricsWrapper - -# 设置服务信息 -SERVICE_INFO.info({ - "version": "1.0.0", - "environment": os.getenv("ENVIRONMENT", "development"), -}) diff --git a/backend/app/monitoring/middleware.py b/backend/app/monitoring/middleware.py deleted file mode 100644 index eec355b..0000000 --- a/backend/app/monitoring/middleware.py +++ /dev/null @@ -1,86 +0,0 @@ -"""监控中间件""" -import time -from typing import Callable - -from fastapi import Request, Response -from starlette.middleware.base import BaseHTTPMiddleware - -from app.monitoring.metrics import ( - API_REQUESTS_TOTAL, - API_REQUEST_DURATION_SECONDS, - API_REQUESTS_IN_PROGRESS, -) - -# 需要排除的路径(不记录指标) -EXCLUDED_PATHS = {"/health", "/ready", "/metrics", "/docs", "/openapi.json"} - - -class MonitoringMiddleware(BaseHTTPMiddleware): - """API监控中间件""" - - async def dispatch(self, request: Request, call_next: Callable) -> Response: - # 跳过排除路径 - if request.url.path in EXCLUDED_PATHS: - return await call_next(request) - - # 提取端点标识(用于指标标签) - endpoint = self._get_endpoint_label(request) - - # 增加活跃请求计数 - API_REQUESTS_IN_PROGRESS.labels( - method=request.method, - endpoint=endpoint - ).inc() - - # 记录开始时间 - start_time = time.perf_counter() - - try: - # 执行请求 - response = await call_next(request) - status_code = response.status_code - except Exception as e: - status_code = 500 - raise - finally: - # 计算耗时 - duration = time.perf_counter() - start_time - - # 记录指标 - API_REQUESTS_TOTAL.labels( - method=request.method, - endpoint=endpoint, - status=str(status_code) - ).inc() - - API_REQUEST_DURATION_SECONDS.labels( - method=request.method, - endpoint=endpoint - ).observe(duration) - - # 减少活跃请求计数 - API_REQUESTS_IN_PROGRESS.labels( - method=request.method, - endpoint=endpoint - ).dec() - - return response - - def _get_endpoint_label(self, request: Request) -> str: - """提取端点标签""" - path = request.url.path - - # 规范化路径(替换ID等参数) - parts = path.strip("/").split("/") - - # 处理常见模式:/api/v1/resources/{id} - if len(parts) >= 4 and parts[0] == "api": - resource = parts[2] if len(parts) > 2 else "unknown" - action = parts[3] if len(parts) > 3 else "list" - - # 映射到规范标签 - if action.isdigit(): - return f"{resource}_detail" - return f"{resource}_{action}" - - return "other" diff --git a/backend/app/repositories/__init__.py b/backend/app/repositories/__init__.py index 8354432..42f05dc 100644 --- a/backend/app/repositories/__init__.py +++ b/backend/app/repositories/__init__.py @@ -1,7 +1,31 @@ from app.repositories.api_key_repository import APIKeyRepository from app.repositories.usage_repository import UsageRepository +from app.repositories.brand_repository import BrandRepository +from app.repositories.query_repository import QueryRepository +from app.repositories.citation_repository import CitationRepository +from app.repositories.content_repository import ContentRepository +from app.repositories.knowledge_repository import KnowledgeRepository +from app.repositories.alert_repository import AlertRepository +from app.repositories.subscription_repository import SubscriptionRepository +from app.repositories.organization_repository import OrganizationRepository +from app.repositories.user_repository import UserRepository +from app.repositories.detection_task_repository import DetectionTaskRepository +from app.repositories.suggestion_repository import SuggestionRepository +from app.repositories.competitor_repository import CompetitorRepository __all__ = [ "APIKeyRepository", "UsageRepository", + "BrandRepository", + "QueryRepository", + "CitationRepository", + "ContentRepository", + "KnowledgeRepository", + "AlertRepository", + "SubscriptionRepository", + "OrganizationRepository", + "UserRepository", + "DetectionTaskRepository", + "SuggestionRepository", + "CompetitorRepository", ] diff --git a/backend/app/repositories/alert_repository.py b/backend/app/repositories/alert_repository.py new file mode 100644 index 0000000..f5eaf64 --- /dev/null +++ b/backend/app/repositories/alert_repository.py @@ -0,0 +1,77 @@ +import uuid +from typing import Optional + +from sqlalchemy import select, func, update +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.alert import Alert + + +class AlertRepository: + def __init__(self, session: AsyncSession): + self.session = session + + async def get_by_id(self, id: uuid.UUID) -> Optional[Alert]: + result = await self.session.execute( + select(Alert).where(Alert.id == id) + ) + return result.scalar_one_or_none() + + async def list_by_user( + self, user_id: uuid.UUID, *, skip: int = 0, limit: int = 100 + ) -> list[Alert]: + result = await self.session.execute( + select(Alert) + .where(Alert.user_id == user_id) + .order_by(Alert.created_at.desc()) + .offset(skip) + .limit(limit) + ) + return list(result.scalars().all()) + + async def count_by_user(self, user_id: uuid.UUID) -> int: + result = await self.session.execute( + select(func.count()).select_from(Alert).where(Alert.user_id == user_id) + ) + return result.scalar_one() + + async def get_unread_by_user(self, user_id: uuid.UUID) -> list[Alert]: + result = await self.session.execute( + select(Alert).where( + Alert.user_id == user_id, + Alert.is_read == False, + ).order_by(Alert.created_at.desc()) + ) + return list(result.scalars().all()) + + async def mark_as_read(self, alert_id: uuid.UUID) -> bool: + instance = await self.get_by_id(alert_id) + if instance is None: + return False + instance.is_read = True + await self.session.flush() + return True + + async def create(self, **kwargs) -> Alert: + instance = Alert(**kwargs) + self.session.add(instance) + await self.session.flush() + return instance + + async def update(self, id: uuid.UUID, **kwargs) -> Optional[Alert]: + instance = await self.get_by_id(id) + if instance is None: + return None + for key, value in kwargs.items(): + if hasattr(instance, key): + setattr(instance, key, value) + await self.session.flush() + return instance + + async def delete(self, id: uuid.UUID) -> bool: + instance = await self.get_by_id(id) + if instance is None: + return False + await self.session.delete(instance) + await self.session.flush() + return True diff --git a/backend/app/repositories/brand_repository.py b/backend/app/repositories/brand_repository.py new file mode 100644 index 0000000..455c203 --- /dev/null +++ b/backend/app/repositories/brand_repository.py @@ -0,0 +1,71 @@ +import uuid +from typing import Optional + +from sqlalchemy import select, func +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.brand import Brand + + +class BrandRepository: + def __init__(self, session: AsyncSession): + self.session = session + + async def get_by_id(self, id: uuid.UUID) -> Optional[Brand]: + result = await self.session.execute( + select(Brand).where(Brand.id == id) + ) + return result.scalar_one_or_none() + + async def list_by_user( + self, user_id: uuid.UUID, *, skip: int = 0, limit: int = 100 + ) -> list[Brand]: + result = await self.session.execute( + select(Brand) + .where(Brand.user_id == user_id) + .order_by(Brand.created_at.desc()) + .offset(skip) + .limit(limit) + ) + return list(result.scalars().all()) + + async def count_by_user(self, user_id: uuid.UUID) -> int: + result = await self.session.execute( + select(func.count()).select_from(Brand).where(Brand.user_id == user_id) + ) + return result.scalar_one() + + async def get_by_name_and_organization( + self, name: str, organization_id: uuid.UUID + ) -> Optional[Brand]: + result = await self.session.execute( + select(Brand).where( + Brand.name == name, + Brand.user_id == organization_id, + ) + ) + return result.scalar_one_or_none() + + async def create(self, **kwargs) -> Brand: + instance = Brand(**kwargs) + self.session.add(instance) + await self.session.flush() + return instance + + async def update(self, id: uuid.UUID, **kwargs) -> Optional[Brand]: + instance = await self.get_by_id(id) + if instance is None: + return None + for key, value in kwargs.items(): + if hasattr(instance, key): + setattr(instance, key, value) + await self.session.flush() + return instance + + async def delete(self, id: uuid.UUID) -> bool: + instance = await self.get_by_id(id) + if instance is None: + return False + await self.session.delete(instance) + await self.session.flush() + return True diff --git a/backend/app/repositories/citation_repository.py b/backend/app/repositories/citation_repository.py new file mode 100644 index 0000000..cecc66f --- /dev/null +++ b/backend/app/repositories/citation_repository.py @@ -0,0 +1,84 @@ +import uuid +from typing import Optional + +from sqlalchemy import select, func +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.citation_record import CitationRecord + + +class CitationRepository: + def __init__(self, session: AsyncSession): + self.session = session + + async def get_by_id(self, id: uuid.UUID) -> Optional[CitationRecord]: + result = await self.session.execute( + select(CitationRecord).where(CitationRecord.id == id) + ) + return result.scalar_one_or_none() + + async def list_by_query( + self, query_id: uuid.UUID, *, skip: int = 0, limit: int = 100 + ) -> list[CitationRecord]: + result = await self.session.execute( + select(CitationRecord) + .where(CitationRecord.query_id == query_id) + .order_by(CitationRecord.queried_at.desc()) + .offset(skip) + .limit(limit) + ) + return list(result.scalars().all()) + + async def count_by_query(self, query_id: uuid.UUID) -> int: + result = await self.session.execute( + select(func.count()).select_from(CitationRecord).where( + CitationRecord.query_id == query_id + ) + ) + return result.scalar_one() + + async def get_by_query_and_platform( + self, query_id: uuid.UUID, platform: str + ) -> Optional[CitationRecord]: + result = await self.session.execute( + select(CitationRecord).where( + CitationRecord.query_id == query_id, + CitationRecord.platform == platform, + ) + ) + return result.scalar_one_or_none() + + async def count_cited_by_brand(self, brand_name: str) -> int: + result = await self.session.execute( + select(func.count()) + .select_from(CitationRecord) + .join(CitationRecord.query) + .where( + CitationRecord.cited == True, + ) + ) + return result.scalar_one() + + async def create(self, **kwargs) -> CitationRecord: + instance = CitationRecord(**kwargs) + self.session.add(instance) + await self.session.flush() + return instance + + async def update(self, id: uuid.UUID, **kwargs) -> Optional[CitationRecord]: + instance = await self.get_by_id(id) + if instance is None: + return None + for key, value in kwargs.items(): + if hasattr(instance, key): + setattr(instance, key, value) + await self.session.flush() + return instance + + async def delete(self, id: uuid.UUID) -> bool: + instance = await self.get_by_id(id) + if instance is None: + return False + await self.session.delete(instance) + await self.session.flush() + return True diff --git a/backend/app/repositories/competitor_repository.py b/backend/app/repositories/competitor_repository.py new file mode 100644 index 0000000..3436569 --- /dev/null +++ b/backend/app/repositories/competitor_repository.py @@ -0,0 +1,71 @@ +import uuid +from typing import Optional + +from sqlalchemy import select, func +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.competitor import Competitor + + +class CompetitorRepository: + def __init__(self, session: AsyncSession): + self.session = session + + async def get_by_id(self, id: uuid.UUID) -> Optional[Competitor]: + result = await self.session.execute( + select(Competitor).where(Competitor.id == id) + ) + return result.scalar_one_or_none() + + async def list_by_brand( + self, brand_id: uuid.UUID, *, skip: int = 0, limit: int = 100 + ) -> list[Competitor]: + result = await self.session.execute( + select(Competitor) + .where(Competitor.brand_id == brand_id) + .order_by(Competitor.created_at.desc()) + .offset(skip) + .limit(limit) + ) + return list(result.scalars().all()) + + async def count_by_brand(self, brand_id: uuid.UUID) -> int: + result = await self.session.execute( + select(func.count()).select_from(Competitor).where( + Competitor.brand_id == brand_id + ) + ) + return result.scalar_one() + + async def get_by_brand(self, brand_name: str) -> list[Competitor]: + from app.models.brand import Brand + result = await self.session.execute( + select(Competitor) + .join(Brand, Competitor.brand_id == Brand.id) + .where(Brand.name == brand_name) + ) + return list(result.scalars().all()) + + async def create(self, **kwargs) -> Competitor: + instance = Competitor(**kwargs) + self.session.add(instance) + await self.session.flush() + return instance + + async def update(self, id: uuid.UUID, **kwargs) -> Optional[Competitor]: + instance = await self.get_by_id(id) + if instance is None: + return None + for key, value in kwargs.items(): + if hasattr(instance, key): + setattr(instance, key, value) + await self.session.flush() + return instance + + async def delete(self, id: uuid.UUID) -> bool: + instance = await self.get_by_id(id) + if instance is None: + return False + await self.session.delete(instance) + await self.session.flush() + return True diff --git a/backend/app/repositories/content_repository.py b/backend/app/repositories/content_repository.py new file mode 100644 index 0000000..369ddc5 --- /dev/null +++ b/backend/app/repositories/content_repository.py @@ -0,0 +1,68 @@ +import uuid +from typing import Optional + +from sqlalchemy import select, func +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.content import Content + + +class ContentRepository: + def __init__(self, session: AsyncSession): + self.session = session + + async def get_by_id(self, id: uuid.UUID) -> Optional[Content]: + result = await self.session.execute( + select(Content).where(Content.id == id) + ) + return result.scalar_one_or_none() + + async def list_by_organization( + self, organization_id: uuid.UUID, *, skip: int = 0, limit: int = 100 + ) -> list[Content]: + result = await self.session.execute( + select(Content) + .where(Content.organization_id == organization_id) + .order_by(Content.created_at.desc()) + .offset(skip) + .limit(limit) + ) + return list(result.scalars().all()) + + async def count_by_organization(self, organization_id: uuid.UUID) -> int: + result = await self.session.execute( + select(func.count()).select_from(Content).where( + Content.organization_id == organization_id + ) + ) + return result.scalar_one() + + async def get_by_status(self, status: str) -> list[Content]: + result = await self.session.execute( + select(Content).where(Content.status == status) + ) + return list(result.scalars().all()) + + async def create(self, **kwargs) -> Content: + instance = Content(**kwargs) + self.session.add(instance) + await self.session.flush() + return instance + + async def update(self, id: uuid.UUID, **kwargs) -> Optional[Content]: + instance = await self.get_by_id(id) + if instance is None: + return None + for key, value in kwargs.items(): + if hasattr(instance, key): + setattr(instance, key, value) + await self.session.flush() + return instance + + async def delete(self, id: uuid.UUID) -> bool: + instance = await self.get_by_id(id) + if instance is None: + return False + await self.session.delete(instance) + await self.session.flush() + return True diff --git a/backend/app/repositories/detection_task_repository.py b/backend/app/repositories/detection_task_repository.py new file mode 100644 index 0000000..4f89293 --- /dev/null +++ b/backend/app/repositories/detection_task_repository.py @@ -0,0 +1,68 @@ +import uuid +from typing import Optional + +from sqlalchemy import select, func +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.detection_task import DetectionTask + + +class DetectionTaskRepository: + def __init__(self, session: AsyncSession): + self.session = session + + async def get_by_id(self, id: uuid.UUID) -> Optional[DetectionTask]: + result = await self.session.execute( + select(DetectionTask).where(DetectionTask.id == id) + ) + return result.scalar_one_or_none() + + async def list_by_user( + self, user_id: uuid.UUID, *, skip: int = 0, limit: int = 100 + ) -> list[DetectionTask]: + result = await self.session.execute( + select(DetectionTask) + .where(DetectionTask.user_id == user_id) + .order_by(DetectionTask.created_at.desc()) + .offset(skip) + .limit(limit) + ) + return list(result.scalars().all()) + + async def count_by_user(self, user_id: uuid.UUID) -> int: + result = await self.session.execute( + select(func.count()).select_from(DetectionTask).where( + DetectionTask.user_id == user_id + ) + ) + return result.scalar_one() + + async def get_active_tasks(self) -> list[DetectionTask]: + result = await self.session.execute( + select(DetectionTask).where(DetectionTask.is_active == True) + ) + return list(result.scalars().all()) + + async def create(self, **kwargs) -> DetectionTask: + instance = DetectionTask(**kwargs) + self.session.add(instance) + await self.session.flush() + return instance + + async def update(self, id: uuid.UUID, **kwargs) -> Optional[DetectionTask]: + instance = await self.get_by_id(id) + if instance is None: + return None + for key, value in kwargs.items(): + if hasattr(instance, key): + setattr(instance, key, value) + await self.session.flush() + return instance + + async def delete(self, id: uuid.UUID) -> bool: + instance = await self.get_by_id(id) + if instance is None: + return False + await self.session.delete(instance) + await self.session.flush() + return True diff --git a/backend/app/repositories/knowledge_repository.py b/backend/app/repositories/knowledge_repository.py new file mode 100644 index 0000000..7eea12f --- /dev/null +++ b/backend/app/repositories/knowledge_repository.py @@ -0,0 +1,70 @@ +import uuid +from typing import Optional + +from sqlalchemy import select, func +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.knowledge import KnowledgeBase + + +class KnowledgeRepository: + def __init__(self, session: AsyncSession): + self.session = session + + async def get_by_id(self, id: uuid.UUID) -> Optional[KnowledgeBase]: + result = await self.session.execute( + select(KnowledgeBase).where(KnowledgeBase.id == id) + ) + return result.scalar_one_or_none() + + async def list_by_organization( + self, organization_id: uuid.UUID, *, skip: int = 0, limit: int = 100 + ) -> list[KnowledgeBase]: + result = await self.session.execute( + select(KnowledgeBase) + .where(KnowledgeBase.organization_id == organization_id) + .order_by(KnowledgeBase.created_at.desc()) + .offset(skip) + .limit(limit) + ) + return list(result.scalars().all()) + + async def count_by_organization(self, organization_id: uuid.UUID) -> int: + result = await self.session.execute( + select(func.count()).select_from(KnowledgeBase).where( + KnowledgeBase.organization_id == organization_id + ) + ) + return result.scalar_one() + + async def get_by_organization(self, organization_id: uuid.UUID) -> list[KnowledgeBase]: + result = await self.session.execute( + select(KnowledgeBase).where( + KnowledgeBase.organization_id == organization_id + ) + ) + return list(result.scalars().all()) + + async def create(self, **kwargs) -> KnowledgeBase: + instance = KnowledgeBase(**kwargs) + self.session.add(instance) + await self.session.flush() + return instance + + async def update(self, id: uuid.UUID, **kwargs) -> Optional[KnowledgeBase]: + instance = await self.get_by_id(id) + if instance is None: + return None + for key, value in kwargs.items(): + if hasattr(instance, key): + setattr(instance, key, value) + await self.session.flush() + return instance + + async def delete(self, id: uuid.UUID) -> bool: + instance = await self.get_by_id(id) + if instance is None: + return False + await self.session.delete(instance) + await self.session.flush() + return True diff --git a/backend/app/repositories/organization_repository.py b/backend/app/repositories/organization_repository.py new file mode 100644 index 0000000..22f81ab --- /dev/null +++ b/backend/app/repositories/organization_repository.py @@ -0,0 +1,75 @@ +import uuid +from typing import Optional + +from sqlalchemy import select, func +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.organization import Organization + + +class OrganizationRepository: + def __init__(self, session: AsyncSession): + self.session = session + + async def get_by_id(self, id: uuid.UUID) -> Optional[Organization]: + result = await self.session.execute( + select(Organization).where(Organization.id == id) + ) + return result.scalar_one_or_none() + + async def get_by_slug(self, slug: str) -> Optional[Organization]: + result = await self.session.execute( + select(Organization).where(Organization.slug == slug) + ) + return result.scalar_one_or_none() + + async def list_all(self, *, skip: int = 0, limit: int = 100) -> list[Organization]: + result = await self.session.execute( + select(Organization) + .order_by(Organization.created_at.desc()) + .offset(skip) + .limit(limit) + ) + return list(result.scalars().all()) + + async def count_all(self) -> int: + result = await self.session.execute( + select(func.count()).select_from(Organization) + ) + return result.scalar_one() + + async def get_by_owner(self, user_id: str) -> list[Organization]: + from app.models.organization import OrgMember + result = await self.session.execute( + select(Organization) + .join(OrgMember, OrgMember.organization_id == Organization.id) + .where( + OrgMember.user_id == user_id, + OrgMember.role == "owner", + ) + ) + return list(result.scalars().all()) + + async def create(self, **kwargs) -> Organization: + instance = Organization(**kwargs) + self.session.add(instance) + await self.session.flush() + return instance + + async def update(self, id: uuid.UUID, **kwargs) -> Optional[Organization]: + instance = await self.get_by_id(id) + if instance is None: + return None + for key, value in kwargs.items(): + if hasattr(instance, key): + setattr(instance, key, value) + await self.session.flush() + return instance + + async def delete(self, id: uuid.UUID) -> bool: + instance = await self.get_by_id(id) + if instance is None: + return False + await self.session.delete(instance) + await self.session.flush() + return True diff --git a/backend/app/repositories/query_repository.py b/backend/app/repositories/query_repository.py new file mode 100644 index 0000000..aca59b1 --- /dev/null +++ b/backend/app/repositories/query_repository.py @@ -0,0 +1,72 @@ +import uuid +from typing import Optional + +from sqlalchemy import select, func +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.query import Query + + +class QueryRepository: + def __init__(self, session: AsyncSession): + self.session = session + + async def get_by_id(self, id: uuid.UUID) -> Optional[Query]: + result = await self.session.execute( + select(Query).where(Query.id == id) + ) + return result.scalar_one_or_none() + + async def list_by_user( + self, user_id: str, *, skip: int = 0, limit: int = 100 + ) -> list[Query]: + result = await self.session.execute( + select(Query) + .where(Query.user_id == user_id) + .order_by(Query.created_at.desc()) + .offset(skip) + .limit(limit) + ) + return list(result.scalars().all()) + + async def count_by_user(self, user_id: str) -> int: + result = await self.session.execute( + select(func.count()).select_from(Query).where(Query.user_id == user_id) + ) + return result.scalar_one() + + async def get_by_brand(self, brand_name: str) -> list[Query]: + result = await self.session.execute( + select(Query).where(Query.target_brand == brand_name) + ) + return list(result.scalars().all()) + + async def get_active_queries(self) -> list[Query]: + result = await self.session.execute( + select(Query).where(Query.status == "active") + ) + return list(result.scalars().all()) + + async def create(self, **kwargs) -> Query: + instance = Query(**kwargs) + self.session.add(instance) + await self.session.flush() + return instance + + async def update(self, id: uuid.UUID, **kwargs) -> Optional[Query]: + instance = await self.get_by_id(id) + if instance is None: + return None + for key, value in kwargs.items(): + if hasattr(instance, key): + setattr(instance, key, value) + await self.session.flush() + return instance + + async def delete(self, id: uuid.UUID) -> bool: + instance = await self.get_by_id(id) + if instance is None: + return False + await self.session.delete(instance) + await self.session.flush() + return True diff --git a/backend/app/repositories/subscription_repository.py b/backend/app/repositories/subscription_repository.py new file mode 100644 index 0000000..540dcfd --- /dev/null +++ b/backend/app/repositories/subscription_repository.py @@ -0,0 +1,70 @@ +import uuid +from typing import Optional + +from sqlalchemy import select, func +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.subscription import Subscription + + +class SubscriptionRepository: + def __init__(self, session: AsyncSession): + self.session = session + + async def get_by_id(self, id: uuid.UUID) -> Optional[Subscription]: + result = await self.session.execute( + select(Subscription).where(Subscription.id == id) + ) + return result.scalar_one_or_none() + + async def list_by_user( + self, user_id: str, *, skip: int = 0, limit: int = 100 + ) -> list[Subscription]: + result = await self.session.execute( + select(Subscription) + .where(Subscription.user_id == user_id) + .order_by(Subscription.created_at.desc()) + .offset(skip) + .limit(limit) + ) + return list(result.scalars().all()) + + async def count_by_user(self, user_id: str) -> int: + result = await self.session.execute( + select(func.count()).select_from(Subscription).where( + Subscription.user_id == user_id + ) + ) + return result.scalar_one() + + async def get_by_organization(self, organization_id: uuid.UUID) -> list[Subscription]: + result = await self.session.execute( + select(Subscription).where( + Subscription.user_id == organization_id + ) + ) + return list(result.scalars().all()) + + async def create(self, **kwargs) -> Subscription: + instance = Subscription(**kwargs) + self.session.add(instance) + await self.session.flush() + return instance + + async def update(self, id: uuid.UUID, **kwargs) -> Optional[Subscription]: + instance = await self.get_by_id(id) + if instance is None: + return None + for key, value in kwargs.items(): + if hasattr(instance, key): + setattr(instance, key, value) + await self.session.flush() + return instance + + async def delete(self, id: uuid.UUID) -> bool: + instance = await self.get_by_id(id) + if instance is None: + return False + await self.session.delete(instance) + await self.session.flush() + return True diff --git a/backend/app/repositories/suggestion_repository.py b/backend/app/repositories/suggestion_repository.py new file mode 100644 index 0000000..1e9e4e8 --- /dev/null +++ b/backend/app/repositories/suggestion_repository.py @@ -0,0 +1,71 @@ +import uuid +from typing import Optional + +from sqlalchemy import select, func +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.suggestion import Suggestion + + +class SuggestionRepository: + def __init__(self, session: AsyncSession): + self.session = session + + async def get_by_id(self, id: uuid.UUID) -> Optional[Suggestion]: + result = await self.session.execute( + select(Suggestion).where(Suggestion.id == id) + ) + return result.scalar_one_or_none() + + async def list_by_brand( + self, brand_id: uuid.UUID, *, skip: int = 0, limit: int = 100 + ) -> list[Suggestion]: + result = await self.session.execute( + select(Suggestion) + .where(Suggestion.brand_id == brand_id) + .order_by(Suggestion.generated_at.desc()) + .offset(skip) + .limit(limit) + ) + return list(result.scalars().all()) + + async def count_by_brand(self, brand_id: uuid.UUID) -> int: + result = await self.session.execute( + select(func.count()).select_from(Suggestion).where( + Suggestion.brand_id == brand_id + ) + ) + return result.scalar_one() + + async def get_by_brand(self, brand_name: str) -> list[Suggestion]: + from app.models.brand import Brand + result = await self.session.execute( + select(Suggestion) + .join(Brand, Suggestion.brand_id == Brand.id) + .where(Brand.name == brand_name) + ) + return list(result.scalars().all()) + + async def create(self, **kwargs) -> Suggestion: + instance = Suggestion(**kwargs) + self.session.add(instance) + await self.session.flush() + return instance + + async def update(self, id: uuid.UUID, **kwargs) -> Optional[Suggestion]: + instance = await self.get_by_id(id) + if instance is None: + return None + for key, value in kwargs.items(): + if hasattr(instance, key): + setattr(instance, key, value) + await self.session.flush() + return instance + + async def delete(self, id: uuid.UUID) -> bool: + instance = await self.get_by_id(id) + if instance is None: + return False + await self.session.delete(instance) + await self.session.flush() + return True diff --git a/backend/app/repositories/user_repository.py b/backend/app/repositories/user_repository.py new file mode 100644 index 0000000..cdefc93 --- /dev/null +++ b/backend/app/repositories/user_repository.py @@ -0,0 +1,62 @@ +from typing import Optional + +from sqlalchemy import select, func +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.user import User + + +class UserRepository: + def __init__(self, session: AsyncSession): + self.session = session + + async def get_by_id(self, id: str) -> Optional[User]: + result = await self.session.execute( + select(User).where(User.id == id) + ) + return result.scalar_one_or_none() + + async def get_by_email(self, email: str) -> Optional[User]: + result = await self.session.execute( + select(User).where(User.email == email) + ) + return result.scalar_one_or_none() + + async def list_all(self, *, skip: int = 0, limit: int = 100) -> list[User]: + result = await self.session.execute( + select(User) + .order_by(User.createdAt.desc()) + .offset(skip) + .limit(limit) + ) + return list(result.scalars().all()) + + async def count_all(self) -> int: + result = await self.session.execute( + select(func.count()).select_from(User) + ) + return result.scalar_one() + + async def create(self, **kwargs) -> User: + instance = User(**kwargs) + self.session.add(instance) + await self.session.flush() + return instance + + async def update(self, id: str, **kwargs) -> Optional[User]: + instance = await self.get_by_id(id) + if instance is None: + return None + for key, value in kwargs.items(): + if hasattr(instance, key): + setattr(instance, key, value) + await self.session.flush() + return instance + + async def delete(self, id: str) -> bool: + instance = await self.get_by_id(id) + if instance is None: + return False + await self.session.delete(instance) + await self.session.flush() + return True diff --git a/backend/app/schemas/competitor_insight.py b/backend/app/schemas/competitor_insight.py new file mode 100644 index 0000000..80c8308 --- /dev/null +++ b/backend/app/schemas/competitor_insight.py @@ -0,0 +1,48 @@ +import uuid +from datetime import datetime +from typing import Optional + +from pydantic import BaseModel, Field + + +class CompetitorAnalysisRequest(BaseModel): + brand_id: uuid.UUID = Field(..., description="品牌ID") + analysis_types: list[str] | None = Field( + None, + description="分析类型列表: citation_gap/content_strategy/platform_coverage/query_overlap/differentiation", + ) + period_days: int | None = Field(30, description="分析周期天数") + + +class CompetitorInsightResponse(BaseModel): + id: uuid.UUID = Field(..., description="洞察ID") + brand_id: uuid.UUID = Field(..., description="品牌ID") + competitor_name: str = Field(..., description="竞品名称") + analysis_type: str = Field(..., description="分析类型") + insight_data: dict | None = Field(None, description="洞察数据") + citation_count_brand: int = Field(0, description="品牌引用次数") + citation_count_competitor: int = Field(0, description="竞品引用次数") + sentiment_brand: float | None = Field(None, description="品牌情感分数") + sentiment_competitor: float | None = Field(None, description="竞品情感分数") + platform_breakdown: dict | None = Field(None, description="平台分布") + gap_analysis: dict | None = Field(None, description="差距分析") + opportunity_areas: dict | None = Field(None, description="机会领域") + recommendations: dict | None = Field(None, description="策略建议") + confidence: str = Field("medium", description="置信度: high/medium/low") + period_days: int = Field(30, description="分析周期天数") + created_at: datetime = Field(..., description="创建时间") + updated_at: datetime = Field(..., description="更新时间") + + model_config = {"from_attributes": True} + + +class CompetitorInsightList(BaseModel): + items: list[CompetitorInsightResponse] = Field(default_factory=list, description="洞察列表") + total: int = Field(0, description="总数") + + +class CompetitorGapSummary(BaseModel): + brand_name: str = Field(..., description="品牌名称") + competitor_name: str = Field(..., description="竞品名称") + gap_dimensions: list[dict] = Field(default_factory=list, description="差距维度列表") + overall_gap_score: float = Field(0.0, description="综合差距分数(0-100)") diff --git a/backend/app/schemas/diagnosis.py b/backend/app/schemas/diagnosis.py new file mode 100644 index 0000000..05ec826 --- /dev/null +++ b/backend/app/schemas/diagnosis.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +from pydantic import BaseModel, Field + + +class GEODiagnosisTriggerRequest(BaseModel): + force_refresh: bool = Field(default=False) + + +class GEODiagnosisTaskResponse(BaseModel): + task_id: str + brand_id: str + status: str + + +class GEODimensionItemResponse(BaseModel): + name: str + status: str + description: str + suggestion: str + score: float + max_score: float + + +class GEODimensionResponse(BaseModel): + name: str + score: float + max_score: float + percentage: float + status: str + items: list[GEODimensionItemResponse] + detail: dict + + +class GEORecommendationResponse(BaseModel): + priority: str + dimension: str + title: str + description: str + impact: str + effort: str + + +class GEODiagnosisResponse(BaseModel): + overall_score: float + health_level: str + health_level_label: str + dimensions: list[GEODimensionResponse] + recommendations: list[GEORecommendationResponse] + is_full_report: bool = False + + +class GEODiagnosisResultResponse(BaseModel): + task_id: str + brand_id: str + status: str + result: GEODiagnosisResponse | None = None + error: str | None = None + + +class GEODiagnosisHistoryItem(BaseModel): + task_id: str + overall_score: float + health_level: str + created_at: str + completed_at: str | None = None + + +class GEODiagnosisHistoryResponse(BaseModel): + brand_id: str + history: list[GEODiagnosisHistoryItem] diff --git a/backend/app/schemas/geo_plan.py b/backend/app/schemas/geo_plan.py new file mode 100644 index 0000000..f53dea0 --- /dev/null +++ b/backend/app/schemas/geo_plan.py @@ -0,0 +1,69 @@ +import uuid +from datetime import datetime + +from pydantic import BaseModel, Field + + +class GeoPlanGenerateRequest(BaseModel): + brand_id: uuid.UUID = Field(..., description="品牌ID") + target_score: int | None = Field(75, description="目标评分") + + +class GeoPlanActionResponse(BaseModel): + id: uuid.UUID = Field(..., description="行动项ID") + plan_id: uuid.UUID = Field(..., description="方案ID") + action_type: str = Field(..., description="行动类型") + title: str = Field(..., description="行动标题") + description: str = Field(..., description="详细描述") + reason: str = Field(..., description="基于诊断数据的原因") + priority: str = Field(..., description="优先级: high/medium/low") + status: str = Field(..., description="状态: pending/in_progress/completed/skipped") + target_keyword: str | None = Field(None, description="预填关键词") + target_platform: str | None = Field(None, description="预填平台") + content_style: str | None = Field(None, description="预填风格") + estimated_impact: str | None = Field(None, description="预期效果") + difficulty: str = Field(..., description="难度: easy/medium/hard") + execution_params: dict | None = Field(None, description="一键执行参数") + sort_order: int = Field(0, description="排序序号") + completed_at: datetime | None = Field(None, description="完成时间") + created_at: datetime = Field(..., description="创建时间") + + class Config: + from_attributes = True + + +class GeoPlanResponse(BaseModel): + id: uuid.UUID = Field(..., description="方案ID") + brand_id: uuid.UUID = Field(..., description="品牌ID") + title: str = Field(..., description="方案标题") + status: str = Field(..., description="状态: draft/active/completed/archived") + diagnosis_score: int = Field(..., description="诊断评分") + target_score: int = Field(..., description="目标评分") + estimated_weeks: int = Field(..., description="预计周数") + plan_data: dict | None = Field(None, description="方案详细数据") + source: str = Field(..., description="生成来源: rule/llm") + actions: list[GeoPlanActionResponse] = Field( + default_factory=list, description="行动项列表" + ) + created_at: datetime = Field(..., description="创建时间") + updated_at: datetime = Field(..., description="更新时间") + + class Config: + from_attributes = True + + +class GeoPlanListResponse(BaseModel): + plans: list[GeoPlanResponse] = Field(default_factory=list, description="方案列表") + total: int = Field(..., description="总数") + + +class GeoPlanActionUpdateStatus(BaseModel): + status: str = Field( + ..., description="新状态: pending/in_progress/completed/skipped" + ) + + +class GeoPlanActionExecuteResponse(BaseModel): + action_id: uuid.UUID = Field(..., description="行动项ID") + content_id: str | None = Field(None, description="生成的内容ID") + message: str = Field(..., description="执行结果消息") diff --git a/backend/app/schemas/health_score.py b/backend/app/schemas/health_score.py new file mode 100644 index 0000000..f9d17b1 --- /dev/null +++ b/backend/app/schemas/health_score.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from pydantic import BaseModel, Field + + +class HealthScoreDimension(BaseModel): + name: str + score: float + max_score: float + percentage: float + status: str + + +class HealthScoreRecommendation(BaseModel): + priority: str + dimension: str + title: str + description: str + + +class HealthScoreResponse(BaseModel): + brand_name: str + overall_score: float + health_level: str + health_level_label: str + dimensions: list[HealthScoreDimension] + recommendations: list[HealthScoreRecommendation] + is_full_report: bool = False + cached: bool = False + + +class HealthScoreRequest(BaseModel): + brand: str + competitors: list[str] = Field(default_factory=list) diff --git a/backend/app/schemas/monitoring.py b/backend/app/schemas/monitoring.py new file mode 100644 index 0000000..316ccfb --- /dev/null +++ b/backend/app/schemas/monitoring.py @@ -0,0 +1,72 @@ +import uuid +from datetime import datetime + +from pydantic import BaseModel, Field + + +class MonitoringRecordCreate(BaseModel): + brand_id: uuid.UUID = Field(..., description="品牌ID") + content_id: str | None = Field(None, description="内容ID") + query_keywords: str | None = Field(None, description="查询关键词") + platform: str | None = Field(None, description="平台") + check_interval_hours: int = Field(24, description="检测间隔(小时)") + + +class MonitoringRecordResponse(BaseModel): + id: uuid.UUID = Field(..., description="记录ID") + brand_id: uuid.UUID = Field(..., description="品牌ID") + content_id: str | None = Field(None, description="内容ID") + query_keywords: str | None = Field(None, description="查询关键词") + platform: str | None = Field(None, description="平台") + baseline_citation_count: int = Field(0, description="基线引用量") + baseline_sentiment: float | None = Field(None, description="基线情感分数") + baseline_rank: int | None = Field(None, description="基线排名") + current_citation_count: int = Field(0, description="当前引用量") + current_sentiment: float | None = Field(None, description="当前情感分数") + current_rank: int | None = Field(None, description="当前排名") + change_type: str | None = Field(None, description="变化类型: positive/negative/neutral") + change_details: dict | None = Field(None, description="变化详情") + check_interval_hours: int = Field(24, description="检测间隔(小时)") + last_checked_at: datetime | None = Field(None, description="上次检测时间") + next_check_at: datetime | None = Field(None, description="下次检测时间") + status: str = Field("active", description="状态: active/paused/completed") + created_at: datetime = Field(..., description="创建时间") + updated_at: datetime = Field(..., description="更新时间") + + class Config: + from_attributes = True + + +class MonitoringRecordList(BaseModel): + records: list[MonitoringRecordResponse] = Field(default_factory=list, description="监测记录列表") + total: int = Field(..., description="总数") + + +class MonitoringChangeReport(BaseModel): + monitoring_record_id: uuid.UUID = Field(..., description="监测记录ID") + brand_id: uuid.UUID = Field(..., description="品牌ID") + change_type: str = Field(..., description="变化类型: positive/negative/neutral") + change_details: dict | None = Field(None, description="变化详情") + baseline: dict = Field(default_factory=dict, description="基线数据") + current: dict = Field(default_factory=dict, description="当前数据") + recommendations: list[str] = Field(default_factory=list, description="建议") + + +class ContentBaselineResponse(BaseModel): + id: uuid.UUID = Field(..., description="基线ID") + monitoring_record_id: uuid.UUID = Field(..., description="监测记录ID") + brand_name: str = Field(..., description="品牌名称") + keyword: str = Field(..., description="关键词") + platform: str = Field(..., description="平台") + citation_count: int = Field(0, description="引用量") + sentiment_score: float | None = Field(None, description="情感分数") + rank_position: int | None = Field(None, description="排名位置") + snapshot_data: dict | None = Field(None, description="快照数据") + recorded_at: datetime = Field(..., description="记录时间") + + class Config: + from_attributes = True + + +class MonitoringStatusUpdate(BaseModel): + status: str = Field(..., description="新状态: active/paused/completed") diff --git a/backend/app/schemas/schema_suggestion.py b/backend/app/schemas/schema_suggestion.py new file mode 100644 index 0000000..711697f --- /dev/null +++ b/backend/app/schemas/schema_suggestion.py @@ -0,0 +1,45 @@ +import uuid +from datetime import datetime + +from pydantic import BaseModel, Field + + +class SchemaAdviseRequest(BaseModel): + brand_id: uuid.UUID = Field(..., description="品牌ID") + target_url: str | None = Field(None, description="目标页面URL") + focus_dimensions: list[str] | None = Field(None, description="聚焦的诊断维度") + + +class SchemaSuggestionResponse(BaseModel): + id: uuid.UUID = Field(..., description="建议ID") + brand_id: uuid.UUID = Field(..., description="品牌ID") + schema_type: str = Field(..., description="Schema类型: Organization/Product/FAQPage/Article/LocalBusiness") + target_url: str | None = Field(None, description="目标页面URL") + json_ld_template: dict = Field(..., description="JSON-LD模板") + json_ld_filled: dict | None = Field(None, description="填充后的JSON-LD") + priority: str = Field(default="medium", description="优先级: high/medium/low") + status: str = Field(default="pending", description="状态: pending/applied/dismissed") + diagnosis_dimensions: dict | None = Field(None, description="诊断维度数据") + implementation_difficulty: str = Field(default="medium", description="实施难度: easy/medium/hard") + estimated_impact: str | None = Field(None, description="预期影响描述") + validation_errors: dict | None = Field(None, description="验证错误信息") + created_at: datetime = Field(..., description="创建时间") + updated_at: datetime = Field(..., description="更新时间") + + class Config: + from_attributes = True + + +class SchemaSuggestionList(BaseModel): + suggestions: list[SchemaSuggestionResponse] = Field(default_factory=list, description="建议列表") + total: int = Field(..., description="总数") + + +class SchemaValidationResult(BaseModel): + is_valid: bool = Field(..., description="是否有效") + errors: list[str] = Field(default_factory=list, description="错误列表") + warnings: list[str] = Field(default_factory=list, description="警告列表") + + +class SchemaStatusUpdateRequest(BaseModel): + status: str = Field(..., description="新状态: pending/applied/dismissed") diff --git a/backend/app/schemas/trend_insight.py b/backend/app/schemas/trend_insight.py new file mode 100644 index 0000000..52d9814 --- /dev/null +++ b/backend/app/schemas/trend_insight.py @@ -0,0 +1,48 @@ +import uuid +from datetime import datetime + +from pydantic import BaseModel, Field + + +class TrendInsightRequest(BaseModel): + brand_id: uuid.UUID = Field(..., description="品牌ID") + period_days: int = Field(30, ge=7, le=365, description="分析周期天数") + platforms: list[str] | None = Field(None, description="筛选平台列表") + keywords: list[str] | None = Field(None, description="筛选关键词列表") + + +class TrendInsightResponse(BaseModel): + id: uuid.UUID + brand_id: uuid.UUID + trend_type: str + keyword: str | None + platform: str | None + period_start: datetime + period_end: datetime + data_points: list | None + change_rate: float | None + absolute_change: int | None + sentiment_trend: dict | None + cause_analysis: str | None + recommendations: list | None + confidence: float + severity: str + created_at: datetime + updated_at: datetime + + model_config = {"from_attributes": True} + + +class TrendInsightList(BaseModel): + items: list[TrendInsightResponse] + total: int + + +class TrendSummary(BaseModel): + brand_id: uuid.UUID + period_days: int + rising_count: int = 0 + declining_count: int = 0 + hotspot_count: int = 0 + top_keywords: list[str] = Field(default_factory=list) + platform_highlights: dict = Field(default_factory=dict) diff --git a/backend/app/services/advisor/__init__.py b/backend/app/services/advisor/__init__.py new file mode 100644 index 0000000..de56ef1 --- /dev/null +++ b/backend/app/services/advisor/__init__.py @@ -0,0 +1,17 @@ +from .optimization_advisor import ( + generate_suggestions, + generate_rule_based_suggestions, + generate_llm_suggestions, + build_context_from_scoring_result, + SuggestionItem, + BrandAnalysisContext, +) + +__all__ = [ + "generate_suggestions", + "generate_rule_based_suggestions", + "generate_llm_suggestions", + "build_context_from_scoring_result", + "SuggestionItem", + "BrandAnalysisContext", +] diff --git a/backend/app/services/optimization_advisor.py b/backend/app/services/advisor/optimization_advisor.py similarity index 96% rename from backend/app/services/optimization_advisor.py rename to backend/app/services/advisor/optimization_advisor.py index ed36168..bdd600d 100644 --- a/backend/app/services/optimization_advisor.py +++ b/backend/app/services/advisor/optimization_advisor.py @@ -25,7 +25,8 @@ from dataclasses import dataclass, field from typing import Any from app.config import settings -from app.services.scoring_service import ScoringResultV2 +from app.services.scoring.scoring_service import ScoringResultV2 +from app.utils.json_extractor import extract_json logger = logging.getLogger(__name__) @@ -525,7 +526,7 @@ async def generate_llm_suggestions( raise ValueError("LLM返回空响应") # 提取JSON - json_str = _extract_json(content) + json_str = extract_json(content) result = json.loads(json_str) # 解析建议 @@ -573,32 +574,6 @@ async def generate_llm_suggestions( return generate_rule_based_suggestions(ctx) -def _extract_json(text: str) -> str: - """从文本中提取JSON""" - import re - - # 尝试直接解析 - try: - json.loads(text) - return text - except json.JSONDecodeError: - pass - - # 尝试从代码块中提取 - json_pattern = r'```(?:json)?\s*([\s\S]*?)\s*```' - match = re.search(json_pattern, text) - if match: - return match.group(1).strip() - - # 尝试找到第一个{到最后一个}之间的内容 - first_brace = text.find('{') - last_brace = text.rfind('}') - if first_brace != -1 and last_brace != -1 and last_brace > first_brace: - return text[first_brace:last_brace + 1] - - raise ValueError(f"无法从响应中提取JSON: {text[:200]}") - - # ============================================================ # 主入口:生成优化建议 # ============================================================ diff --git a/backend/app/services/ai_engine/__init__.py b/backend/app/services/ai_engine/__init__.py index 73a5f30..f5b78ea 100644 --- a/backend/app/services/ai_engine/__init__.py +++ b/backend/app/services/ai_engine/__init__.py @@ -6,6 +6,10 @@ from .doubao import DoubaoAdapter from .gemini import GeminiAdapter from .kimi import KimiAdapter from .perplexity import PerplexityAdapter +from .platform_bridge import ( + execute_single_platform, + query_platform_raw, +) from .qwen import QwenAdapter from .wenxin import WenxinAdapter from .yuanbao import YuanbaoAdapter @@ -25,4 +29,6 @@ __all__ = [ "QwenAdapter", "GeminiAdapter", "BatchQueryService", + "execute_single_platform", + "query_platform_raw", ] diff --git a/backend/app/services/ai_engine/platform_bridge.py b/backend/app/services/ai_engine/platform_bridge.py new file mode 100644 index 0000000..7be8779 --- /dev/null +++ b/backend/app/services/ai_engine/platform_bridge.py @@ -0,0 +1,295 @@ +import logging +import os +import re +import time +from urllib.parse import quote +from datetime import UTC, datetime + +import httpx + +from .base import AIEngineAdapter, AIQueryResult, EngineType + +logger = logging.getLogger(__name__) + +_PLATFORM_NAME_MAP: dict[str, EngineType] = { + "wenxin": EngineType.WENXIN, + "kimi": EngineType.KIMI, + "doubao": EngineType.DOUBAO, + "tongyi": EngineType.QWEN, + "deepseek": EngineType.DEEPSEEK, + "chatgpt": EngineType.CHATGPT, + "perplexity": EngineType.PERPLEXITY, + "gemini": EngineType.GEMINI, + "yuanbao": EngineType.YUANBAO, +} + +_SEARCH_ONLY_PLATFORMS = {"qingyan", "tiangong", "xinghuo"} + + +def get_engine_type_for_platform(platform_name: str) -> EngineType | None: + return _PLATFORM_NAME_MAP.get(platform_name) + + +def is_search_only_platform(platform_name: str) -> bool: + return platform_name in _SEARCH_ONLY_PLATFORMS + + +async def search_wikipedia(keyword: str, max_chars: int = 2000) -> str: + search_url = "https://zh.wikipedia.org/w/api.php" + headers = { + "User-Agent": "GEO-Citation-Bot/1.0 (contact@example.com)", + } + + async with httpx.AsyncClient(timeout=30) as client: + search_resp = await client.get( + search_url, + headers=headers, + params={ + "action": "query", + "list": "search", + "srsearch": keyword, + "srlimit": 3, + "format": "json", + "origin": "*", + }, + ) + search_resp.raise_for_status() + search_data = search_resp.json() + + search_results = search_data.get("query", {}).get("search", []) + if not search_results: + return "" + + title = search_results[0]["title"] + async with httpx.AsyncClient(timeout=30) as client: + extract_resp = await client.get( + search_url, + headers=headers, + params={ + "action": "query", + "prop": "extracts", + "titles": title, + "explaintext": True, + "exsentences": 15, + "format": "json", + "origin": "*", + }, + ) + extract_resp.raise_for_status() + extract_data = extract_resp.json() + + pages = extract_data.get("query", {}).get("pages", {}) + for page in pages.values(): + extract = page.get("extract", "") + if extract: + extract = re.sub(r'\[\d+\]', '', extract) + extract = re.sub(r'\s+', ' ', extract).strip() + return extract[:max_chars] + + return "" + + +def _strip_html(raw: str) -> str: + raw = raw.replace(" ", " ") + raw = raw.replace(""", '"') + raw = raw.replace("&", "&") + raw = raw.replace("<", "<") + raw = raw.replace(">", ">") + raw = raw.replace("'", "'") + text = re.sub(r"<[^>]+>", "", raw) + text = re.sub(r"\s+", " ", text).strip() + return text + + +async def search_duckduckgo(query: str, max_results: int = 5) -> str: + url = f"https://html.duckduckgo.com/html/?q={quote(query)}" + headers = { + "User-Agent": ( + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) " + "AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36" + ), + "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8", + "Accept-Language": "zh-CN,zh;q=0.9,en-US;q=0.8,en;q=0.7", + } + + try: + async with httpx.AsyncClient(timeout=30, follow_redirects=True) as client: + resp = await client.get(url, headers=headers) + resp.raise_for_status() + html = resp.text + + if "web-result" not in html and "result__snippet" not in html and "result__title" not in html: + raise RuntimeError("DuckDuckGo 返回了非结果页面") + + results: list[str] = [] + + result_blocks = re.findall( + r'
]*>.*?]*class="result__title"[^>]*>.*?]*>(.*?).*?]*>.*?]*class="result__snippet"[^>]*>(.*?).*?
', + html, + re.DOTALL | re.IGNORECASE, + ) + if result_blocks: + for title_raw, snippet_raw in result_blocks[:max_results]: + title = _strip_html(title_raw) + snippet = _strip_html(snippet_raw) + if title or snippet: + results.append(f"{title}\n{snippet}") + + if not results: + snippets = re.findall( + r']*class="result__snippet"[^>]*>(.*?)', html, re.DOTALL | re.IGNORECASE + ) + titles = re.findall( + r']*class="result__title"[^>]*>.*?]*>(.*?).*?]*>', + html, + re.DOTALL | re.IGNORECASE, + ) + for i in range(min(len(titles), len(snippets), max_results)): + title = _strip_html(titles[i]) + snippet = _strip_html(snippets[i]) + if title or snippet: + results.append(f"{title}\n{snippet}") + + if results: + return "\n\n".join(results) + + raise RuntimeError("DuckDuckGo 未解析到结果") + + except Exception as e: + logger.warning(f"DuckDuckGo 搜索失败: {e},回退到 Wikipedia") + wiki_text = await search_wikipedia(query, max_chars=2000) + if wiki_text: + return wiki_text + raise RuntimeError(f"所有搜索源均失败: {e}") + + +async def fetch_search_content(platform_name: str, keyword: str) -> str: + logger.info(f"[{platform_name}] 搜索查询: {keyword}") + return await search_duckduckgo(keyword, max_results=5) + + +class SearchOnlyAdapter(AIEngineAdapter): + def __init__(self, platform_name: str, **kwargs): + self._platform_name = platform_name + super().__init__(**kwargs) + + def get_engine_type(self) -> EngineType: + return EngineType.DEEPSEEK + + def _get_env_key(self) -> str | None: + return "" + + async def query( + self, + query: str, + brand_name: str, + competitor_names: list[str] | None = None, + ) -> AIQueryResult: + start_time = time.perf_counter() + content = await fetch_search_content(self._platform_name, query) + elapsed_ms = int((time.perf_counter() - start_time) * 1000) + + has_brand, has_comp, brand_ctx, comp_ctx = self._detect_citations( + content, brand_name, competitor_names + ) + + return AIQueryResult( + engine_type=self.get_engine_type(), + query=query, + raw_response=content, + citations=[], + has_brand_citation=has_brand, + has_competitor_citation=has_comp, + brand_context=brand_ctx, + competitor_contexts=comp_ctx, + response_time_ms=elapsed_ms, + timestamp=datetime.now(UTC), + metadata={"platform_name": self._platform_name, "mode": "search_only"}, + ) + + +async def query_platform_raw( + platform_name: str, + keyword: str, + brand_name: str = "", + competitor_names: list[str] | None = None, +) -> str: + from .batch_query import _build_adapters + + if is_search_only_platform(platform_name): + content = await fetch_search_content(platform_name, keyword) + return f"[data_source: search_engine]\n{content}" + + engine_type = get_engine_type_for_platform(platform_name) + if engine_type is None: + raise ValueError(f"不支持的平台: {platform_name}") + + adapters = _build_adapters() + adapter = adapters.get(engine_type.value) + if adapter is None: + raise ValueError(f"平台 {platform_name} 适配器未注册") + + result = await adapter.query(keyword, brand_name, competitor_names) + return f"[data_source: ai_platform]\n{result.raw_response}" + + +_SUPPORTED_PLATFORMS = { + "wenxin", "kimi", "doubao", "tongyi", + "qingyan", "tiangong", "xinghuo", +} + + +async def execute_single_platform( + keyword: str, + platform: str, + target_brand: str, + brand_aliases: list, +) -> dict: + if platform not in _SUPPORTED_PLATFORMS: + raise ValueError(f"不支持的平台: {platform}") + + from app.workers.citation_extractor import analyze_citations + + search_keyword = f"{keyword} {target_brand}" + raw_response = await query_platform_raw( + platform_name=platform, + keyword=search_keyword, + brand_name=target_brand, + ) + + citation_analysis = analyze_citations(raw_response) + + from app.workers.citation_engine import BrandMatcher, CompetitorDetector + + matcher = BrandMatcher(target_brand=target_brand, brand_aliases=brand_aliases) + match_result = matcher.match(citation_analysis.clean_response) + + competitor_detector = CompetitorDetector() + competitor_brands = competitor_detector.detect( + citation_analysis.clean_response, target_brand + ) + + source_urls = [ + c.source_url for c in citation_analysis.citations if c.source_url + ] + source_titles = [ + c.source_title for c in citation_analysis.citations if c.source_title + ] + citation_contexts = [ + c.citation_context for c in citation_analysis.citations if c.citation_context + ] + + return { + "cited": match_result["cited"], + "confidence": match_result["confidence"], + "match_type": match_result["match_type"], + "position": match_result["position"], + "citation_text": match_result["citation_text"], + "competitor_brands": competitor_brands, + "raw_response": raw_response, + "data_source": citation_analysis.data_source, + "source_urls": source_urls, + "source_titles": source_titles, + "citation_contexts": citation_contexts, + "ai_response_text": citation_analysis.clean_response, + } diff --git a/backend/app/services/alert/__init__.py b/backend/app/services/alert/__init__.py new file mode 100644 index 0000000..0c996b6 --- /dev/null +++ b/backend/app/services/alert/__init__.py @@ -0,0 +1,11 @@ +from .alert_engine import ( + AlertEngine, + AlertContext, + DEFAULT_ALERT_CONFIGS, +) + +__all__ = [ + "AlertEngine", + "AlertContext", + "DEFAULT_ALERT_CONFIGS", +] diff --git a/backend/app/services/alert_engine.py b/backend/app/services/alert/alert_engine.py similarity index 100% rename from backend/app/services/alert_engine.py rename to backend/app/services/alert/alert_engine.py diff --git a/backend/app/services/analysis/__init__.py b/backend/app/services/analysis/__init__.py new file mode 100644 index 0000000..8548db8 --- /dev/null +++ b/backend/app/services/analysis/__init__.py @@ -0,0 +1,13 @@ +from .sentiment_service import ( + SentimentAnalysisService, + SentimentResult, + SentimentCache, + get_sentiment_service, +) + +__all__ = [ + "SentimentAnalysisService", + "SentimentResult", + "SentimentCache", + "get_sentiment_service", +] diff --git a/backend/app/services/sentiment_service.py b/backend/app/services/analysis/sentiment_service.py similarity index 94% rename from backend/app/services/sentiment_service.py rename to backend/app/services/analysis/sentiment_service.py index 02e9e8b..ccee0fe 100644 --- a/backend/app/services/sentiment_service.py +++ b/backend/app/services/analysis/sentiment_service.py @@ -5,11 +5,11 @@ import asyncio import hashlib import json import logging -import re import time from typing import Optional from app.config import settings +from app.utils.json_extractor import extract_json logger = logging.getLogger(__name__) @@ -276,31 +276,11 @@ class SentimentAnalysisService: raise RuntimeError("API返回空响应") # 提取JSON - json_str = self._extract_json(content) - return json.loads(json_str) - - def _extract_json(self, text: str) -> str: - """从文本中提取JSON""" - # 尝试直接解析 try: - json.loads(text) - return text - except json.JSONDecodeError: - pass - - # 尝试从代码块中提取 - json_pattern = r"```(?:json)?\s*([\s\S]*?)\s*```" - match = re.search(json_pattern, text) - if match: - return match.group(1).strip() - - # 尝试找到第一个{到最后一个}之间的内容 - first_brace = text.find("{") - last_brace = text.rfind("}") - if first_brace != -1 and last_brace != -1 and last_brace > first_brace: - return text[first_brace : last_brace + 1] - - raise RuntimeError(f"无法从响应中提取JSON: {text[:200]}") + json_str = extract_json(content) + except ValueError as e: + raise RuntimeError(str(e)) from e + return json.loads(json_str) def _parse_response(self, response: dict) -> SentimentResult: """解析API响应""" diff --git a/backend/app/services/attribution/__init__.py b/backend/app/services/attribution/__init__.py new file mode 100644 index 0000000..ba3f522 --- /dev/null +++ b/backend/app/services/attribution/__init__.py @@ -0,0 +1,4 @@ +from app.services.attribution.attribution_engine import AttributionEngine +from app.services.attribution.roi_calculator import ROICalculator + +__all__ = ["AttributionEngine", "ROICalculator"] diff --git a/backend/app/services/attribution/attribution_engine.py b/backend/app/services/attribution/attribution_engine.py new file mode 100644 index 0000000..94b0b38 --- /dev/null +++ b/backend/app/services/attribution/attribution_engine.py @@ -0,0 +1,150 @@ +import logging +import uuid +from datetime import UTC, datetime, timedelta + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.attribution_record import AttributionRecord +from app.models.diagnosis_record import DiagnosisRecord + +logger = logging.getLogger(__name__) + + +class AttributionEngine: + async def start_tracking( + self, + db: AsyncSession, + brand_id: uuid.UUID, + content_id: uuid.UUID | None, + user_id: str, + ) -> AttributionRecord: + baseline_score = await self._get_latest_score(db, brand_id) + + now = datetime.now(UTC) + record = AttributionRecord( + user_id=user_id, + brand_id=brand_id, + content_id=content_id, + baseline_score=baseline_score, + published_at=now, + window_end_at=now + timedelta(days=28), + status="tracking", + ) + db.add(record) + await db.commit() + await db.refresh(record) + return record + + async def check_attribution( + self, + db: AsyncSession, + record_id: uuid.UUID, + ) -> AttributionRecord: + stmt = select(AttributionRecord).where(AttributionRecord.id == record_id) + result = await db.execute(stmt) + record = result.scalar_one_or_none() + if not record: + raise ValueError(f"AttributionRecord {record_id} not found") + + current_score = await self._get_latest_score(db, record.brand_id) + record.current_score = current_score + record.score_delta = round(current_score - record.baseline_score, 2) + + baseline_dims = await self._get_latest_dimensions(db, record.brand_id, record.published_at) + current_dims = await self._get_latest_dimensions(db, record.brand_id, None) + if baseline_dims and current_dims: + record.attributed_dimensions = self._compute_dimension_deltas( + baseline_dims, current_dims + ) + + now = datetime.now(UTC) + if record.window_end_at: + window_end = record.window_end_at + if window_end.tzinfo is None: + window_end = window_end.replace(tzinfo=UTC) + if now >= window_end: + record.status = "completed" + elif record.score_delta and record.score_delta > 0: + record.status = "tracking" + + await db.commit() + await db.refresh(record) + return record + + async def get_brand_attribution_summary( + self, + db: AsyncSession, + brand_id: uuid.UUID, + ) -> dict: + stmt = ( + select(AttributionRecord) + .where(AttributionRecord.brand_id == brand_id) + .order_by(AttributionRecord.created_at.desc()) + ) + result = await db.execute(stmt) + records = result.scalars().all() + + total_delta = sum(r.score_delta or 0 for r in records) + tracking_count = sum(1 for r in records if r.status == "tracking") + completed_count = sum(1 for r in records if r.status == "completed") + positive_count = sum(1 for r in records if (r.score_delta or 0) > 0) + + return { + "brand_id": str(brand_id), + "records": records, + "total_score_delta": round(total_delta, 2), + "tracking_count": tracking_count, + "completed_count": completed_count, + "positive_count": positive_count, + } + + async def _get_latest_score(self, db: AsyncSession, brand_id: uuid.UUID) -> float: + stmt = ( + select(DiagnosisRecord) + .where( + DiagnosisRecord.brand_id == brand_id, + DiagnosisRecord.status == "completed", + ) + .order_by(DiagnosisRecord.completed_at.desc()) + .limit(1) + ) + result = await db.execute(stmt) + record = result.scalar_one_or_none() + if record and record.overall_score is not None: + return float(record.overall_score) + logger.warning("No completed DiagnosisRecord for brand %s, using 0 as baseline", brand_id) + return 0.0 + + async def _get_latest_dimensions( + self, + db: AsyncSession, + brand_id: uuid.UUID, + before: datetime | None, + ) -> dict | None: + stmt = ( + select(DiagnosisRecord) + .where( + DiagnosisRecord.brand_id == brand_id, + DiagnosisRecord.status == "completed", + ) + .order_by(DiagnosisRecord.completed_at.desc()) + .limit(1) + ) + if before: + stmt = stmt.where(DiagnosisRecord.completed_at < before) + result = await db.execute(stmt) + record = result.scalar_one_or_none() + if record and record.result_json: + return record.result_json.get("dimensions") + return None + + def _compute_dimension_deltas(self, before_dims: list, after_dims: list) -> dict: + before_map = {d.get("name"): d.get("score", 0) for d in before_dims} + after_map = {d.get("name"): d.get("score", 0) for d in after_dims} + deltas = {} + for name in after_map: + b = before_map.get(name, 0) + a = after_map[name] + deltas[name] = {"before": b, "after": a, "delta": round(a - b, 2)} + return deltas diff --git a/backend/app/services/attribution/roi_calculator.py b/backend/app/services/attribution/roi_calculator.py new file mode 100644 index 0000000..17f12c6 --- /dev/null +++ b/backend/app/services/attribution/roi_calculator.py @@ -0,0 +1,59 @@ +from app.models.attribution_record import AttributionRecord + + +class ROICalculator: + INDUSTRY_AVG_CITATION_VALUE = 50.0 + + def calculate_roi( + self, + subscription_cost: float, + score_delta: float, + attribution_records: list[AttributionRecord], + ) -> dict: + value_generated = score_delta * self.INDUSTRY_AVG_CITATION_VALUE + if subscription_cost > 0: + roi_percentage = round( + (value_generated - subscription_cost) / subscription_cost * 100, 2 + ) + else: + roi_percentage = 0.0 + break_even_delta = self.estimate_break_even(subscription_cost) + return { + "roi_percentage": roi_percentage, + "value_generated": round(value_generated, 2), + "cost": subscription_cost, + "break_even_delta": round(break_even_delta, 2), + } + + def generate_ab_comparison( + self, + before_score: float, + after_score: float, + before_dimensions: dict, + after_dimensions: dict, + ) -> dict: + overall_delta = round(after_score - before_score, 2) + dimensions = [] + all_names = set(list(before_dimensions.keys()) + list(after_dimensions.keys())) + for name in all_names: + b = before_dimensions.get(name, {}).get("score", 0) + a = after_dimensions.get(name, {}).get("score", 0) + delta = round(a - b, 2) + dimensions.append({ + "name": name, + "before": b, + "after": a, + "delta": delta, + "improved": delta > 0, + }) + return { + "overall_before": before_score, + "overall_after": after_score, + "overall_delta": overall_delta, + "dimensions": dimensions, + } + + def estimate_break_even(self, subscription_cost: float) -> float: + if self.INDUSTRY_AVG_CITATION_VALUE == 0: + return 0.0 + return subscription_cost / self.INDUSTRY_AVG_CITATION_VALUE diff --git a/backend/app/services/citation/__init__.py b/backend/app/services/citation/__init__.py new file mode 100644 index 0000000..e700070 --- /dev/null +++ b/backend/app/services/citation/__init__.py @@ -0,0 +1,34 @@ +from .citation import ( + get_citations, + get_citation_stats, + trigger_query_now, + export_citations_pdf, + export_citations_csv, + PLATFORM_NAMES, +) + +from .citation_pattern import ( + CitationPatternEngine, + CitationPattern, + PatternAnalysisReport, + ContentStructureAnalyzer, + AuthoritySignalAnalyzer, + CitationFormatAnalyzer, + EnginePreferenceAnalyzer, +) + +__all__ = [ + "get_citations", + "get_citation_stats", + "trigger_query_now", + "export_citations_pdf", + "export_citations_csv", + "PLATFORM_NAMES", + "CitationPatternEngine", + "CitationPattern", + "PatternAnalysisReport", + "ContentStructureAnalyzer", + "AuthoritySignalAnalyzer", + "CitationFormatAnalyzer", + "EnginePreferenceAnalyzer", +] diff --git a/backend/app/services/citation.py b/backend/app/services/citation/citation.py similarity index 83% rename from backend/app/services/citation.py rename to backend/app/services/citation/citation.py index fa27d48..9dafb7c 100644 --- a/backend/app/services/citation.py +++ b/backend/app/services/citation/citation.py @@ -1,9 +1,15 @@ +from __future__ import annotations + import asyncio import csv import io import logging import uuid from datetime import datetime, timedelta, timezone +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from app.services.scoring.scoring_service import ScoringResultV2 from sqlalchemy import func, select, and_, cast, Integer from sqlalchemy.ext.asyncio import AsyncSession @@ -13,7 +19,7 @@ from app.database import AsyncSessionLocal from app.models.citation_record import CitationRecord from app.models.query import Query from app.models.query_task import QueryTask -from app.workers.citation_engine import CitationEngine +from app.services.ai_engine.platform_bridge import execute_single_platform as _execute_single_platform_bridge logger = logging.getLogger(__name__) @@ -304,8 +310,10 @@ async def _execute_query_tasks( brand_aliases: list, user_id: uuid.UUID | None = None, ): - """后台执行查询任务""" - engine = CitationEngine() + """后台执行查询任务 — 通过 Agent 框架执行,失败时回退到直接引擎""" + from app.agent_framework.agents.citation_detector import CitationDetectorAgent + + agent = CitationDetectorAgent() try: async with AsyncSessionLocal() as db: # 验证 query 归属该用户 @@ -330,7 +338,8 @@ async def _execute_query_tasks( task.error_message = None await db.commit() - citation_result = await engine.execute_single_platform( + citation_result = await _execute_single_platform_via_agent( + agent=agent, keyword=keyword, platform=task.platform, target_brand=target_brand, @@ -338,16 +347,10 @@ async def _execute_query_tasks( ) if citation_result: - record = CitationRecord( + record = CitationRecord.from_citation_result( query_id=query_id, platform=task.platform, - cited=citation_result.get("cited", False), - citation_position=citation_result.get("position"), - citation_text=citation_result.get("citation_text"), - competitor_brands=citation_result.get("competitor_brands", []), - raw_response=citation_result.get("raw_response", ""), - confidence=citation_result.get("confidence"), - match_type=citation_result.get("match_type"), + result=citation_result, ) db.add(record) @@ -366,7 +369,34 @@ async def _execute_query_tasks( except Exception as e: logger.error(f"查询引擎执行失败: {e}") finally: - await engine.close() + await agent.close() + + +async def _execute_single_platform_via_agent( + agent, + keyword: str, + platform: str, + target_brand: str, + brand_aliases: list, +) -> dict: + try: + return await agent.execute_single_platform_compat( + keyword=keyword, + platform=platform, + target_brand=target_brand, + brand_aliases=brand_aliases, + ) + except Exception as agent_err: + logger.warning( + f"Agent 框架执行单平台检测失败 ({platform}): {agent_err}," + "回退到直接引擎" + ) + return await _execute_single_platform_bridge( + keyword=keyword, + platform=platform, + target_brand=target_brand, + brand_aliases=brand_aliases, + ) PLATFORM_NAMES = { @@ -386,6 +416,7 @@ async def export_citations_pdf( db: AsyncSession, user_id: uuid.UUID, query_id: uuid.UUID | None = None, + v2_result: ScoringResultV2 | None = None, ) -> bytes: """生成PDF格式报告""" import os @@ -505,6 +536,20 @@ async def export_citations_pdf( pdf.cell(col_widths[i], 7, d, border=1) pdf.ln() + if v2_result is not None: + pdf.add_page() + pdf.set_font_size(16) + pdf.cell(0, 12, "四、V2 品牌可见性评分", new_x="LMARGIN", new_y="NEXT") + pdf.set_font_size(11) + pdf.cell(0, 8, f"综合评分: {v2_result.overall_score:.2f}/100", new_x="LMARGIN", new_y="NEXT") + pdf.cell(0, 8, f"健康等级: {v2_result.health_level}", new_x="LMARGIN", new_y="NEXT") + pdf.ln(5) + pdf.cell(0, 8, f"提及率: {v2_result.mention_rate.score:.2f}/{v2_result.mention_rate.max_score:.0f} ({v2_result.mention_rate.percentage:.1f}%)", new_x="LMARGIN", new_y="NEXT") + pdf.cell(0, 8, f"推荐排名: {v2_result.recommendation_rank.score:.2f}/{v2_result.recommendation_rank.max_score:.0f} ({v2_result.recommendation_rank.percentage:.1f}%)", new_x="LMARGIN", new_y="NEXT") + pdf.cell(0, 8, f"情感倾向: {v2_result.sentiment_score.score:.2f}/{v2_result.sentiment_score.max_score:.0f} ({v2_result.sentiment_score.percentage:.1f}%)", new_x="LMARGIN", new_y="NEXT") + pdf.cell(0, 8, f"引用质量: {v2_result.citation_quality.score:.2f}/{v2_result.citation_quality.max_score:.0f} ({v2_result.citation_quality.percentage:.1f}%)", new_x="LMARGIN", new_y="NEXT") + pdf.cell(0, 8, f"竞品对比: {v2_result.competitive_position.score:.2f}/{v2_result.competitive_position.max_score:.0f} ({v2_result.competitive_position.percentage:.1f}%)", new_x="LMARGIN", new_y="NEXT") + return pdf.output() @@ -512,6 +557,7 @@ async def export_citations_csv( db: AsyncSession, user_id: uuid.UUID, query_id: uuid.UUID, + v2_result: ScoringResultV2 | None = None, ) -> str: query = await _verify_query_ownership(db, query_id, user_id) if query is None: @@ -527,7 +573,7 @@ async def export_citations_csv( output = io.StringIO() writer = csv.writer(output) - writer.writerow([ + headers = [ "查询关键词", "目标品牌", "查询日期", @@ -538,7 +584,18 @@ async def export_citations_csv( "匹配置信度", "匹配类型", "竞争品牌", - ]) + ] + if v2_result is not None: + headers.extend([ + "overall_score", + "health_level", + "mention_rate", + "recommendation_rank", + "sentiment_score", + "citation_quality", + "competitive_position", + ]) + writer.writerow(headers) total_queries = len(records) total_citations = 0 @@ -570,7 +627,7 @@ async def export_citations_csv( if record.confidence is not None: confidence_str = f"{record.confidence:.2f}" - writer.writerow([ + row = [ query.keyword, query.target_brand, date_str, @@ -581,7 +638,18 @@ async def export_citations_csv( confidence_str, match_type_display, ", ".join(record.competitor_brands) if record.competitor_brands else "", - ]) + ] + if v2_result is not None: + row.extend([ + round(v2_result.overall_score, 2), + v2_result.health_level, + round(v2_result.mention_rate.score, 2), + round(v2_result.recommendation_rank.score, 2), + round(v2_result.sentiment_score.score, 2), + round(v2_result.citation_quality.score, 2), + round(v2_result.competitive_position.score, 2), + ]) + writer.writerow(row) # 汇总统计 writer.writerow([]) diff --git a/backend/app/services/citation_pattern.py b/backend/app/services/citation/citation_pattern.py similarity index 100% rename from backend/app/services/citation_pattern.py rename to backend/app/services/citation/citation_pattern.py diff --git a/tests/__init__.py b/backend/app/services/competitor/__init__.py similarity index 100% rename from tests/__init__.py rename to backend/app/services/competitor/__init__.py diff --git a/backend/app/services/competitor/competitor_analyzer_service.py b/backend/app/services/competitor/competitor_analyzer_service.py new file mode 100644 index 0000000..afd2a1c --- /dev/null +++ b/backend/app/services/competitor/competitor_analyzer_service.py @@ -0,0 +1,749 @@ +import json +import logging +import uuid +from collections import defaultdict +from datetime import datetime, timedelta +from typing import Callable + +from sqlalchemy import select, func, and_ +from sqlalchemy.ext.asyncio import AsyncSession + +from app.database import AsyncSessionLocal +from app.models.brand import Brand +from app.models.competitor import Competitor +from app.models.competitor_insight import CompetitorInsight +from app.models.citation_record import CitationRecord +from app.models.query import Query +from app.services.llm import LLMFactory, LLMError +from app.utils.json_extractor import extract_json + +logger = logging.getLogger(__name__) + +VALID_ANALYSIS_TYPES = [ + "citation_gap", + "content_strategy", + "platform_coverage", + "query_overlap", + "differentiation", +] + + +class CompetitorAnalyzerService: + + async def analyze_competitor( + self, + brand_id: uuid.UUID, + analysis_types: list[str] | None = None, + period_days: int = 30, + progress_callback: Callable[[float, str], None] | None = None, + ) -> dict: + if analysis_types is None: + analysis_types = VALID_ANALYSIS_TYPES + + invalid = set(analysis_types) - set(VALID_ANALYSIS_TYPES) + if invalid: + raise ValueError(f"不支持的分析类型: {', '.join(invalid)}") + + async with AsyncSessionLocal() as session: + brand = await session.get(Brand, brand_id) + if not brand: + raise ValueError(f"品牌不存在: {brand_id}") + + if progress_callback: + await progress_callback(0.1, "获取竞品列表...") + + competitors = await self._get_competitors(session, brand_id) + if not competitors: + raise ValueError("未找到竞品数据") + + if progress_callback: + await progress_callback(0.2, "聚合品牌引用数据...") + + brand_citation_data = await self._aggregate_citation_data( + session, brand_id, brand.name, period_days, + ) + + results = [] + total = len(competitors) + for i, competitor in enumerate(competitors): + if progress_callback: + progress = 0.25 + (0.5 * i / total) + await progress_callback(progress, f"分析竞品 {competitor.name}...") + + competitor_citation_data = await self._aggregate_citation_data( + session, brand_id, competitor.name, period_days, + ) + + for analysis_type in analysis_types: + insight = await self._build_insight( + session=session, + brand_id=brand_id, + brand_name=brand.name, + competitor=competitor, + analysis_type=analysis_type, + brand_data=brand_citation_data, + competitor_data=competitor_citation_data, + period_days=period_days, + ) + session.add(insight) + results.append(insight) + + await session.commit() + for r in results: + await session.refresh(r) + + return { + "brand_id": str(brand_id), + "brand_name": brand.name, + "insights": [ + { + "id": str(r.id), + "competitor_name": r.competitor_name, + "analysis_type": r.analysis_type, + "citation_count_brand": r.citation_count_brand, + "citation_count_competitor": r.citation_count_competitor, + "sentiment_brand": r.sentiment_brand, + "sentiment_competitor": r.sentiment_competitor, + "platform_breakdown": r.platform_breakdown, + "gap_analysis": r.gap_analysis, + "opportunity_areas": r.opportunity_areas, + "recommendations": r.recommendations, + "confidence": r.confidence, + "period_days": r.period_days, + "created_at": r.created_at.isoformat() if r.created_at else None, + } + for r in results + ], + "total": len(results), + } + + async def compare_citation_volume( + self, + brand_data: dict, + competitor_data: dict, + ) -> dict: + brand_count = brand_data["citation_count"] + competitor_count = competitor_data["citation_count"] + total = brand_count + competitor_count + + return { + "brand": brand_count, + "competitor": competitor_count, + "diff": brand_count - competitor_count, + "brand_share": round(brand_count / total, 4) if total > 0 else 0.0, + "competitor_share": round(competitor_count / total, 4) if total > 0 else 0.0, + "by_platform": self._compare_platform_citations( + brand_data, competitor_data, + ), + } + + async def compare_citation_quality( + self, + brand_data: dict, + competitor_data: dict, + ) -> dict: + brand_positive = brand_data.get("positive_ratio", 0.0) + competitor_positive = competitor_data.get("positive_ratio", 0.0) + brand_rank = brand_data.get("avg_rank", 0.0) + competitor_rank = competitor_data.get("avg_rank", 0.0) + + return { + "sentiment": { + "brand_positive_ratio": brand_positive, + "competitor_positive_ratio": competitor_positive, + "diff": round(brand_positive - competitor_positive, 4), + }, + "ranking": { + "brand_avg_rank": brand_rank, + "competitor_avg_rank": competitor_rank, + "diff": round(brand_rank - competitor_rank, 2), + }, + "brand_sentiment_breakdown": brand_data.get("sentiment_breakdown", {}), + "competitor_sentiment_breakdown": competitor_data.get("sentiment_breakdown", {}), + } + + async def analyze_content_strategy( + self, + brand_data: dict, + competitor_data: dict, + ) -> dict: + brand_types = brand_data.get("content_types", {}) + competitor_types = competitor_data.get("content_types", {}) + + all_types = set(brand_types.keys()) | set(competitor_types.keys()) + type_comparison = {} + for ct in all_types: + type_comparison[ct] = { + "brand": brand_types.get(ct, 0), + "competitor": competitor_types.get(ct, 0), + } + + competitor_only_types = set(competitor_types.keys()) - set(brand_types.keys()) + brand_only_types = set(brand_types.keys()) - set(competitor_types.keys()) + + return { + "type_comparison": type_comparison, + "competitor_unique_types": list(competitor_only_types), + "brand_unique_types": list(brand_only_types), + "competitor_top_types": sorted( + competitor_types.items(), key=lambda x: x[1], reverse=True, + )[:5], + } + + async def identify_opportunities( + self, + brand_data: dict, + competitor_data: dict, + comparison: dict, + ) -> dict: + opportunities = [] + + brand_platforms = set(brand_data["by_platform"].keys()) + competitor_platforms = set(competitor_data["by_platform"].keys()) + + brand_only = brand_platforms - competitor_platforms + if brand_only: + for platform in brand_only: + opportunities.append({ + "area": f"platform_{platform}", + "description": f"品牌在{platform}平台有引用而竞品没有,可加大投入建立差异化优势", + "potential": "high", + "action": f"增加在{platform}平台的内容投放和优化", + }) + + competitor_only = competitor_platforms - brand_platforms + if competitor_only: + for platform in competitor_only: + opportunities.append({ + "area": f"platform_{platform}", + "description": f"竞品在{platform}平台有引用而品牌没有,存在进入机会", + "potential": "medium", + "action": f"研究{platform}平台的内容偏好,制定进入策略", + }) + + citation_volume = comparison.get("citation_volume", {}) + if citation_volume.get("diff", 0) > 0: + opportunities.append({ + "area": "citation_volume_advantage", + "description": "品牌引用量高于竞品,可强化品牌权威性传播", + "potential": "high", + "action": "收集高引用内容案例,扩大品牌影响力", + }) + + quality = comparison.get("quality", {}) + sentiment = quality.get("sentiment", {}) + if sentiment.get("diff", 0) > 0.1: + opportunities.append({ + "area": "sentiment_advantage", + "description": "品牌正面引用比例显著高于竞品,可强化正面形象传播", + "potential": "high", + "action": "收集正面引用案例,制作品牌优势内容", + }) + + if not opportunities: + opportunities.append({ + "area": "general", + "description": "当前数据未发现明显差异化机会,建议持续监测并积累更多数据", + "potential": "low", + "action": "增加查询频率和覆盖平台,积累更多引用数据", + }) + + return { + "opportunities": opportunities, + "total_opportunities": len(opportunities), + "high_potential_count": sum(1 for o in opportunities if o["potential"] == "high"), + } + + async def generate_recommendations( + self, + brand_name: str, + competitor_name: str, + comparison: dict, + gaps: dict, + opportunities: dict, + data_sufficiency: str, + ) -> dict: + prompt = f"""你是一个专业的GEO(Generative Engine Optimization)策略分析师。 +请基于以下品牌与竞品的引用对比数据,生成策略建议。 + +品牌: {brand_name} +竞品: {competitor_name} +数据充分性: {data_sufficiency} + +对比数据: +{json.dumps(comparison, ensure_ascii=False, indent=2)} + +差距分析: +{json.dumps(gaps, ensure_ascii=False, indent=2)} + +机会发现: +{json.dumps(opportunities, ensure_ascii=False, indent=2)} + +请返回JSON格式(不要包含其他文字): +{{ + "gap_closing_strategies": [ + {{"strategy": "策略描述", "priority": "high/medium/low", "expected_impact": "预期效果"}} + ], + "differentiation_strategies": [ + {{"strategy": "策略描述", "priority": "high/medium/low", "expected_impact": "预期效果"}} + ], + "quick_wins": [ + {{"action": "行动描述", "effort": "low/medium/high", "timeline": "预计时间"}} + ], + "long_term_recommendations": [ + {{"recommendation": "建议描述", "rationale": "理由"}} + ] +}}""" + + try: + provider = LLMFactory.get_default() + response = await provider.chat( + [{"role": "user", "content": prompt}], + temperature=0.3, + max_tokens=2000, + ) + result = json.loads(extract_json(response.content)) + result["usage"] = response.usage + return result + except (LLMError, json.JSONDecodeError, ValueError) as e: + logger.warning(f"LLM策略生成失败,使用默认策略: {e}") + return self._default_strategy(gaps, opportunities) + + async def calculate_gap_score( + self, + db: AsyncSession, + brand_id: uuid.UUID, + brand_name: str, + ) -> list[dict]: + stmt = ( + select(CompetitorInsight) + .where(CompetitorInsight.brand_id == brand_id) + .order_by(CompetitorInsight.created_at.desc()) + ) + result = await db.execute(stmt) + insights = list(result.scalars().all()) + + competitor_map: dict[str, list[CompetitorInsight]] = defaultdict(list) + for insight in insights: + competitor_map[insight.competitor_name].append(insight) + + summaries = [] + for comp_name, comp_insights in competitor_map.items(): + gap_dimensions = [] + score_components = [] + + for insight in comp_insights: + gap = insight.gap_analysis or {} + if not gap: + continue + + for g in gap.get("gaps", []): + dimension = g.get("dimension", "unknown") + severity = g.get("severity", "low") + gap_value = g.get("gap", 0) + + severity_score = {"high": 30, "medium": 15, "low": 5}.get(severity, 5) + score_components.append(severity_score) + + gap_dimensions.append({ + "dimension": dimension, + "severity": severity, + "gap": gap_value, + "analysis_type": insight.analysis_type, + }) + + overall_score = min(sum(score_components), 100) if score_components else 0.0 + + summaries.append({ + "brand_name": brand_name, + "competitor_name": comp_name, + "gap_dimensions": gap_dimensions, + "overall_gap_score": round(overall_score, 2), + }) + + return summaries + + async def _build_insight( + self, + session: AsyncSession, + brand_id: uuid.UUID, + brand_name: str, + competitor: Competitor, + analysis_type: str, + brand_data: dict, + competitor_data: dict, + period_days: int, + ) -> CompetitorInsight: + comparison = {} + comparison["citation_volume"] = await self.compare_citation_volume( + brand_data, competitor_data, + ) + comparison["quality"] = await self.compare_citation_quality( + brand_data, competitor_data, + ) + + gap = self._identify_gaps(comparison, brand_name, competitor.name) + opportunities = await self.identify_opportunities( + brand_data, competitor_data, comparison, + ) + + data_sufficiency = self._assess_data_sufficiency(brand_data, competitor_data) + + insight_data = {} + if analysis_type == "content_strategy": + insight_data = await self.analyze_content_strategy( + brand_data, competitor_data, + ) + elif analysis_type == "platform_coverage": + insight_data = comparison["citation_volume"]["by_platform"] + elif analysis_type == "query_overlap": + insight_data = await self._analyze_query_overlap( + session, brand_id, brand_name, competitor.name, period_days, + ) + elif analysis_type == "differentiation": + insight_data = { + "brand_unique_platforms": list( + set(brand_data["by_platform"].keys()) - set(competitor_data["by_platform"].keys()) + ), + "competitor_unique_platforms": list( + set(competitor_data["by_platform"].keys()) - set(brand_data["by_platform"].keys()) + ), + "sentiment_diff": comparison["quality"].get("sentiment", {}), + } + + recommendations = await self.generate_recommendations( + brand_name=brand_name, + competitor_name=competitor.name, + comparison=comparison, + gaps=gap, + opportunities=opportunities, + data_sufficiency=data_sufficiency, + ) + + confidence = self._determine_confidence(brand_data, competitor_data) + + return CompetitorInsight( + brand_id=brand_id, + competitor_name=competitor.name, + analysis_type=analysis_type, + insight_data=insight_data if insight_data else None, + citation_count_brand=brand_data["citation_count"], + citation_count_competitor=competitor_data["citation_count"], + sentiment_brand=brand_data.get("positive_ratio"), + sentiment_competitor=competitor_data.get("positive_ratio"), + platform_breakdown=comparison["citation_volume"]["by_platform"], + gap_analysis=gap, + opportunity_areas=opportunities, + recommendations=recommendations, + confidence=confidence, + period_days=period_days, + ) + + async def _get_competitors( + self, + db: AsyncSession, + brand_id: uuid.UUID, + ) -> list[Competitor]: + stmt = select(Competitor).where(Competitor.brand_id == brand_id) + result = await db.execute(stmt) + return list(result.scalars().all()) + + async def _aggregate_citation_data( + self, + db: AsyncSession, + brand_id: uuid.UUID, + target_name: str, + period_days: int = 30, + ) -> dict: + since = datetime.utcnow() - timedelta(days=period_days) + + query_stmt = select(Query).where(Query.brand_id == brand_id) + query_result = await db.execute(query_stmt) + queries = list(query_result.scalars().all()) + + if not queries: + return { + "citation_count": 0, + "positive_ratio": 0.0, + "avg_rank": 0.0, + "by_platform": {}, + "content_types": {}, + "sentiment_breakdown": {"positive": 0, "neutral": 0, "negative": 0}, + "total_records": 0, + } + + query_ids = [q.id for q in queries] + query_aliases = set() + for q in queries: + query_aliases.add(q.target_brand.lower()) + if q.brand_aliases: + for alias in q.brand_aliases: + query_aliases.add(alias.lower()) + + conditions = [CitationRecord.query_id.in_(query_ids)] + if since: + conditions.append(CitationRecord.queried_at >= since) + + stmt = select(CitationRecord).where(and_(*conditions)) + result = await db.execute(stmt) + records = list(result.scalars().all()) + + target_lower = target_name.lower() + matching_records = [] + for record in records: + if record.cited and record.competitor_brands: + is_target = False + for cb in record.competitor_brands: + if isinstance(cb, str) and cb.lower() == target_lower: + is_target = True + break + elif isinstance(cb, str) and cb.lower() in query_aliases: + is_target = True + break + if is_target: + matching_records.append(record) + elif record.cited and not record.competitor_brands: + matching_records.append(record) + + total_citations = len(matching_records) + if total_citations == 0: + return { + "citation_count": 0, + "positive_ratio": 0.0, + "avg_rank": 0.0, + "by_platform": {}, + "content_types": {}, + "sentiment_breakdown": {"positive": 0, "neutral": 0, "negative": 0}, + "total_records": len(records), + } + + sentiment_breakdown = {"positive": 0, "neutral": 0, "negative": 0} + for r in matching_records: + s = r.sentiment or "neutral" + if s in sentiment_breakdown: + sentiment_breakdown[s] += 1 + else: + sentiment_breakdown["neutral"] += 1 + + positive_count = sentiment_breakdown["positive"] + positive_ratio = positive_count / total_citations if total_citations > 0 else 0.0 + + ranks = [ + r.citation_position for r in matching_records + if r.citation_position is not None and r.citation_position > 0 + ] + avg_rank = sum(ranks) / len(ranks) if ranks else 0.0 + + by_platform = defaultdict(lambda: {"citations": 0, "positive": 0, "ranks": []}) + for r in matching_records: + platform = r.platform + by_platform[platform]["citations"] += 1 + if r.sentiment == "positive": + by_platform[platform]["positive"] += 1 + if r.citation_position is not None and r.citation_position > 0: + by_platform[platform]["ranks"].append(r.citation_position) + + platform_stats = {} + for platform, data in by_platform.items(): + platform_stats[platform] = { + "citations": data["citations"], + "positive_ratio": data["positive"] / data["citations"] if data["citations"] > 0 else 0.0, + "avg_rank": sum(data["ranks"]) / len(data["ranks"]) if data["ranks"] else 0.0, + } + + content_types = defaultdict(int) + for r in matching_records: + match_type = r.match_type or "unknown" + content_types[match_type] += 1 + + return { + "citation_count": total_citations, + "positive_ratio": round(positive_ratio, 4), + "avg_rank": round(avg_rank, 2), + "by_platform": platform_stats, + "content_types": dict(content_types), + "sentiment_breakdown": sentiment_breakdown, + "total_records": len(records), + } + + def _compare_platform_citations( + self, + brand_data: dict, + competitor_data: dict, + ) -> dict: + all_platforms = set(brand_data["by_platform"].keys()) | set(competitor_data["by_platform"].keys()) + result = {} + for platform in all_platforms: + bp = brand_data["by_platform"].get(platform, {"citations": 0, "positive_ratio": 0.0, "avg_rank": 0.0}) + cp = competitor_data["by_platform"].get(platform, {"citations": 0, "positive_ratio": 0.0, "avg_rank": 0.0}) + result[platform] = { + "brand": bp, + "competitor": cp, + } + return result + + def _identify_gaps( + self, + comparison: dict, + brand_name: str, + competitor_name: str, + ) -> dict: + gaps = [] + + volume = comparison.get("citation_volume", {}) + citation_diff = volume.get("diff", 0) + if citation_diff < 0: + gaps.append({ + "dimension": "citation_count", + "brand_value": volume.get("brand", 0), + "competitor_value": volume.get("competitor", 0), + "gap": abs(citation_diff), + "severity": "high" if abs(citation_diff) >= 5 else "medium" if abs(citation_diff) >= 2 else "low", + }) + + quality = comparison.get("quality", {}) + sentiment = quality.get("sentiment", {}) + positive_diff = sentiment.get("diff", 0) + if positive_diff < -0.1: + gaps.append({ + "dimension": "positive_ratio", + "brand_value": sentiment.get("brand_positive_ratio", 0), + "competitor_value": sentiment.get("competitor_positive_ratio", 0), + "gap": abs(positive_diff), + "severity": "high" if abs(positive_diff) >= 0.3 else "medium" if abs(positive_diff) >= 0.15 else "low", + }) + + ranking = quality.get("ranking", {}) + rank_diff = ranking.get("diff", 0) + if rank_diff > 1.0: + gaps.append({ + "dimension": "avg_rank", + "brand_value": ranking.get("brand_avg_rank", 0), + "competitor_value": ranking.get("competitor_avg_rank", 0), + "gap": abs(rank_diff), + "severity": "high" if abs(rank_diff) >= 3.0 else "medium" if abs(rank_diff) >= 2.0 else "low", + }) + + for platform, data in volume.get("by_platform", {}).items(): + brand_citations = data.get("brand", {}).get("citations", 0) + competitor_citations = data.get("competitor", {}).get("citations", 0) + if competitor_citations > brand_citations + 2: + gaps.append({ + "dimension": f"platform_{platform}", + "brand_value": brand_citations, + "competitor_value": competitor_citations, + "gap": competitor_citations - brand_citations, + "severity": "high" if (competitor_citations - brand_citations) >= 5 else "medium", + }) + + return { + "brand_name": brand_name, + "competitor_name": competitor_name, + "gaps": gaps, + "total_gaps": len(gaps), + "high_severity_count": sum(1 for g in gaps if g["severity"] == "high"), + } + + async def _analyze_query_overlap( + self, + db: AsyncSession, + brand_id: uuid.UUID, + brand_name: str, + competitor_name: str, + period_days: int, + ) -> dict: + since = datetime.utcnow() - timedelta(days=period_days) + + stmt = select(Query).where( + Query.brand_id == brand_id, + Query.created_at >= since, + ) + result = await db.execute(stmt) + queries = list(result.scalars().all()) + + brand_keywords = set() + competitor_keywords = set() + + for q in queries: + keyword = q.keyword.lower() + brand_keywords.add(keyword) + if competitor_name.lower() in keyword or any( + a.lower() in keyword for a in (q.brand_aliases or []) + ): + competitor_keywords.add(keyword) + + overlap = brand_keywords & competitor_keywords + brand_only = brand_keywords - competitor_keywords + competitor_only = competitor_keywords - brand_keywords + + return { + "brand_keyword_count": len(brand_keywords), + "competitor_keyword_count": len(competitor_keywords), + "overlap_count": len(overlap), + "overlap_keywords": list(overlap)[:20], + "brand_only_count": len(brand_only), + "competitor_only_count": len(competitor_only), + "overlap_ratio": round(len(overlap) / len(brand_keywords), 4) if brand_keywords else 0.0, + } + + def _assess_data_sufficiency( + self, + brand_data: dict, + competitor_data: dict, + ) -> str: + brand_count = brand_data["citation_count"] + competitor_count = competitor_data["citation_count"] + min_count = min(brand_count, competitor_count) + + if min_count > 10: + return "sufficient" + elif min_count >= 5: + return "limited" + else: + return "insufficient" + + def _determine_confidence( + self, + brand_data: dict, + competitor_data: dict, + ) -> str: + brand_count = brand_data["citation_count"] + competitor_count = competitor_data["citation_count"] + min_count = min(brand_count, competitor_count) + + if min_count > 20: + return "high" + elif min_count >= 5: + return "medium" + else: + return "low" + + def _default_strategy(self, gaps: dict, opportunities: dict) -> dict: + gap_strategies = [] + for gap in gaps.get("gaps", []): + gap_strategies.append({ + "strategy": f"提升{gap['dimension']}维度表现,缩小与竞品差距", + "priority": gap["severity"], + "expected_impact": f"预计可将{gap['dimension']}差距缩小{gap['gap'] * 0.5:.1f}", + }) + + diff_strategies = [] + for opp in opportunities.get("opportunities", []): + if opp["potential"] in ("high", "medium"): + diff_strategies.append({ + "strategy": opp["action"], + "priority": opp["potential"], + "expected_impact": "建立差异化竞争优势", + }) + + return { + "gap_closing_strategies": gap_strategies[:5], + "differentiation_strategies": diff_strategies[:5], + "quick_wins": [], + "long_term_recommendations": [ + { + "recommendation": "持续监测竞品引用数据变化,定期更新策略", + "rationale": "GEO优化是长期过程,需要持续迭代", + } + ], + } diff --git a/backend/app/services/content/content_generation_service.py b/backend/app/services/content/content_generation_service.py new file mode 100644 index 0000000..567fd80 --- /dev/null +++ b/backend/app/services/content/content_generation_service.py @@ -0,0 +1,493 @@ +"""ContentGenerationService - 内容生成服务 + +从 api/content.py 中提取的业务逻辑层,负责: +1. 三阶段内容生成流程(generate -> de-AI -> GEO optimize) +2. 知识库上下文检索 +3. 生成结果持久化 +4. Agent 框架集成(可选) + +API 层只需负责请求解析和响应格式化,所有业务逻辑委托给此服务。 +""" + +import asyncio +import logging +import uuid +from datetime import datetime, timezone +from typing import Optional + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.agent_framework.prompts import ( + CONTENT_GENERATOR_TEMPLATE, + DEAI_TEMPLATE, + GEO_OPTIMIZER_TEMPLATE, +) +from app.models.content import Content, ContentVersion +from app.services.llm import LLMFactory, LLMError + +logger = logging.getLogger(__name__) + + +class ContentGenerationService: + """内容生成服务 - 封装三阶段生成流程及持久化逻辑。""" + + def _get_provider(self): + """获取默认 LLM Provider。可被子类或测试覆盖。""" + return LLMFactory.get_default() + + async def _get_knowledge_context( + self, + db: AsyncSession, + brand_name: str, + knowledge_base_ids: list[str], + target_keyword: str, + ) -> str: + """ + 从知识库检索与查询相关的上下文。 + + 如果有知识库ID,则调用 RAGService.search 获取相关内容; + 否则返回空字符串,不影响后续流程。 + """ + if not knowledge_base_ids: + return "" + + try: + from app.services.knowledge.rag_service import RAGService + + rag_service = RAGService() + results = await rag_service.search( + session=db, + query=f"{brand_name} {target_keyword}" if brand_name else target_keyword, + knowledge_base_ids=knowledge_base_ids, + top_k=3, + ) + if results: + context_parts = [] + for r in results: + content = r.get("content", "") + title = r.get("document_title", "") + if content: + context_parts.append(f"[{title}] {content}") + return "\n".join(context_parts) + return "" + except Exception as e: + logger.warning(f"知识库检索失败,将不使用知识库上下文: {e}") + return "" + + async def _poll_task_result( + self, + dispatcher, + task_id: str, + timeout: int = 300, + ) -> dict: + """ + 轮询 Agent 框架任务结果。 + + Args: + dispatcher: TaskDispatcher 实例 + task_id: 已分发的任务 ID + timeout: 超时时间(秒) + + Returns: + dict: 任务的 output_data + + Raises: + TimeoutError: 任务超时 + Exception: 任务执行失败或被取消 + """ + from app.agent_framework.protocol import TaskStatus + + elapsed = 0.0 + poll_interval = 1.0 + while elapsed < timeout: + await asyncio.sleep(poll_interval) + elapsed += poll_interval + + task_status = await dispatcher.get_task_status(task_id) + status = task_status.get("status") + + if status == TaskStatus.COMPLETED: + return task_status.get("output_data", {}) + + elif status == TaskStatus.FAILED: + error_msg = task_status.get("error_message", "Unknown error") + raise Exception(f"Agent 任务执行失败: {error_msg}") + + elif status == TaskStatus.CANCELLED: + raise Exception(f"Agent 任务被取消: {task_id}") + + raise TimeoutError(f"Agent 任务超时 ({timeout}s): {task_id}") + + async def _execute_via_agent_framework( + self, + keyword: str, + brand_name: str, + platform: str, + content_style: str, + word_count: int, + knowledge_context: str, + knowledge_base_ids: list[str] | None, + run_deai: bool, + run_geo: bool, + db: AsyncSession | None, + user_id: str | None, + org_id: str | None, + ) -> dict: + """ + 通过 Agent 框架执行三阶段内容生成流程。 + + 依次分发任务到 content_generator、deai_agent、geo_optimizer, + 并轮询等待每个阶段的结果。失败时抛出异常,由调用方决定是否回退。 + + Returns: + dict: 与 generate_content 返回格式一致的结果字典 + + Raises: + Exception: Agent 框架不可用或任务执行失败时 + """ + from app.agent_framework.dispatcher import TaskDispatcher + from app.agent_framework.protocol import TaskMessage + from app.config import settings + + dispatcher = TaskDispatcher(settings.REDIS_URL) + stages = [] + + try: + # ---- Stage 1: 内容生成 ---- + logger.info(f"通过 Agent 框架执行内容生成: keyword={keyword}") + task_id = str(uuid.uuid4()) + task_message = TaskMessage( + task_id=task_id, + agent_name="content_generator", + task_type="generate_article", + priority=0, + input_data={ + "target_keyword": keyword, + "brand_name": brand_name, + "target_platform": platform, + "knowledge_base_ids": knowledge_base_ids or [], + "word_count": word_count, + "content_style": content_style, + "knowledge_context": knowledge_context, + }, + callback_url=None, + created_at=datetime.now(timezone.utc), + timeout_seconds=300, + ) + + dispatched_id = await dispatcher.dispatch( + task_message, + organization_id=org_id, + created_by=user_id, + ) + gen_result = await self._poll_task_result( + dispatcher, dispatched_id, timeout=300 + ) + content = gen_result.get("content", "") + stages.append( + { + "stage": "content_generation", + "status": "success", + "word_count": len(content), + } + ) + + # ---- Stage 2: 去AI化(可选) ---- + if run_deai: + logger.info("通过 Agent 框架执行去AI化") + task_id = str(uuid.uuid4()) + task_message = TaskMessage( + task_id=task_id, + agent_name="deai_agent", + task_type="deai_process", + priority=0, + input_data={ + "content": content, + "platform": platform, + "style": "自然流畅", + "preserve_structure": True, + }, + callback_url=None, + created_at=datetime.now(timezone.utc), + timeout_seconds=180, + ) + dispatched_id = await dispatcher.dispatch( + task_message, + organization_id=org_id, + created_by=user_id, + ) + deai_result = await self._poll_task_result( + dispatcher, dispatched_id, timeout=180 + ) + content = deai_result.get("content", content) + stages.append({"stage": "deai", "status": "success"}) + + # ---- Stage 3: GEO优化(可选) ---- + optimized = content + seo_score = None + if run_geo: + logger.info("通过 Agent 框架执行 GEO 优化") + task_id = str(uuid.uuid4()) + task_message = TaskMessage( + task_id=task_id, + agent_name="geo_optimizer", + task_type="geo_optimize", + priority=0, + input_data={ + "content": content, + "target_keywords": [keyword], + "target_platform": platform, + "optimization_level": "moderate", + }, + callback_url=None, + created_at=datetime.now(timezone.utc), + timeout_seconds=180, + ) + dispatched_id = await dispatcher.dispatch( + task_message, + organization_id=org_id, + created_by=user_id, + ) + geo_result = await self._poll_task_result( + dispatcher, dispatched_id, timeout=180 + ) + optimized = geo_result.get("optimized_content", content) + seo_score = geo_result.get("seo_score") + stages.append({"stage": "geo_optimization", "status": "success"}) + + # ---- 持久化(可选) ---- + content_id = None + if db and user_id and org_id: + content_obj = Content( + organization_id=org_id, + title=keyword, + content_type="article", + body=optimized, + status="draft", + target_platforms=[platform] if platform else [], + keywords=[keyword], + extra_metadata={ + "original_content": content if content != optimized else None, + "pipeline_stages": stages, + "seo_score": seo_score, + "brand_name": brand_name, + "content_style": content_style, + "word_count_target": word_count, + "execution_mode": "agent_framework", + }, + created_by=user_id, + current_version=1, + ) + db.add(content_obj) + await db.flush() + + version = ContentVersion( + content_id=content_obj.id, + version_number=1, + title=keyword, + body=optimized, + change_summary="Agent框架Pipeline自动生成", + created_by=user_id, + ) + db.add(version) + await db.commit() + await db.refresh(content_obj) + content_id = str(content_obj.id) + + logger.info("通过 Agent 框架执行内容生成完成") + return { + "content": content, + "optimized_content": optimized, + "seo_score": seo_score, + "content_id": content_id, + "pipeline_stages": stages, + } + + finally: + await dispatcher.close() + + async def generate_content( + self, + keyword: str, + brand_name: str = "", + platform: str = "通用", + content_style: str = "专业严谨", + word_count: int = 2000, + knowledge_context: str = "", + knowledge_base_ids: list[str] | None = None, + db: AsyncSession | None = None, + user_id: str | None = None, + org_id: str | None = None, + run_deai: bool = True, + run_geo: bool = True, + use_agent_framework: bool = False, + ) -> dict: + """ + 执行三阶段内容生成流程。 + + 阶段: + 1. 内容生成(CONTENT_GENERATOR_TEMPLATE) + 2. 去AI化(DEAI_TEMPLATE,可选) + 3. GEO优化(GEO_OPTIMIZER_TEMPLATE,可选) + + 如果提供了 db、user_id 和 org_id,生成结果将持久化到数据库。 + + Args: + keyword: 目标关键词 + brand_name: 品牌名称 + platform: 目标平台,默认"通用" + content_style: 内容风格,默认"专业严谨" + word_count: 目标字数,默认2000 + knowledge_context: 直接传入的知识库上下文(优先使用) + knowledge_base_ids: 知识库ID列表,用于RAG检索 + db: 数据库会话(可选,提供时将持久化结果) + user_id: 用户ID(可选,持久化时需要) + org_id: 组织ID(可选,持久化时需要) + run_deai: 是否执行去AI化,默认True + run_geo: 是否执行GEO优化,默认True + use_agent_framework: 是否通过Agent框架执行,默认False。 + 当为True时,尝试通过TaskDispatcher分发任务到Agent; + 如果Agent框架不可用,自动回退到直接调用模式。 + + Returns: + dict: { + "content": str, # 去AI化后的内容(或原始生成内容) + "optimized_content": str, # GEO优化后的内容(或与content相同) + "seo_score": int | None, + "content_id": str | None, # 数据库记录ID + "pipeline_stages": list[dict], + } + + Raises: + LLMError: LLM调用失败时 + """ + # ---- Agent 框架路径 ---- + if use_agent_framework: + try: + logger.info("尝试通过 Agent 框架执行内容生成") + return await self._execute_via_agent_framework( + keyword=keyword, + brand_name=brand_name, + platform=platform, + content_style=content_style, + word_count=word_count, + knowledge_context=knowledge_context, + knowledge_base_ids=knowledge_base_ids, + run_deai=run_deai, + run_geo=run_geo, + db=db, + user_id=user_id, + org_id=org_id, + ) + except Exception as e: + logger.warning( + f"Agent 框架执行失败,回退到直接调用模式: {e}" + ) + # 继续执行下方的直接调用逻辑 + + # ---- 直接调用路径(原有逻辑) ---- + provider = self._get_provider() + stages = [] + + # 如果没有直接传入知识库上下文,但提供了知识库ID和db,则检索 + if not knowledge_context and knowledge_base_ids and db: + knowledge_context = await self._get_knowledge_context( + db, brand_name, knowledge_base_ids, keyword + ) + + # ---- Stage 1: 内容生成 ---- + gen_variables = { + "topic_title": keyword, + "target_keyword": keyword, + "target_platform": platform, + "content_angle": "综合分析", + "content_style": content_style, + "word_count": str(word_count), + "brand_name": brand_name, + "knowledge_context": knowledge_context, + } + messages = CONTENT_GENERATOR_TEMPLATE.render(gen_variables) + response = await provider.chat( + messages, temperature=0.7, max_tokens=word_count * 2 + ) + content = response.content + stages.append( + {"stage": "content_generation", "status": "success", "word_count": len(content)} + ) + + # ---- Stage 2: 去AI化(可选) ---- + if run_deai: + deai_variables = { + "original_content": content, + "target_style": "自然流畅", + "preserve_structure": "是", + } + messages = DEAI_TEMPLATE.render(deai_variables) + response = await provider.chat( + messages, temperature=0.9, max_tokens=len(content) * 2 + ) + content = response.content + stages.append({"stage": "deai", "status": "success"}) + + # ---- Stage 3: GEO优化(可选) ---- + optimized = content + seo_score = None + if run_geo: + geo_variables = { + "original_content": content, + "target_keywords": keyword, + "target_platform": platform, + "optimization_level": "moderate", + } + messages = GEO_OPTIMIZER_TEMPLATE.render(geo_variables) + response = await provider.chat( + messages, temperature=0.5, max_tokens=len(content) * 2 + ) + optimized = response.content + stages.append({"stage": "geo_optimization", "status": "success"}) + + # ---- 持久化(可选) ---- + content_id = None + if db and user_id and org_id: + content_obj = Content( + organization_id=org_id, + title=keyword, + content_type="article", + body=optimized, + status="draft", + target_platforms=[platform] if platform else [], + keywords=[keyword], + extra_metadata={ + "original_content": content if content != optimized else None, + "pipeline_stages": stages, + "seo_score": seo_score, + "brand_name": brand_name, + "content_style": content_style, + "word_count_target": word_count, + }, + created_by=user_id, + current_version=1, + ) + db.add(content_obj) + await db.flush() + + version = ContentVersion( + content_id=content_obj.id, + version_number=1, + title=keyword, + body=optimized, + change_summary="Pipeline自动生成", + created_by=user_id, + ) + db.add(version) + await db.commit() + await db.refresh(content_obj) + content_id = str(content_obj.id) + + return { + "content": content, + "optimized_content": optimized, + "seo_score": seo_score, + "content_id": content_id, + "pipeline_stages": stages, + } diff --git a/backend/app/services/detection/__init__.py b/backend/app/services/detection/__init__.py new file mode 100644 index 0000000..32d91fd --- /dev/null +++ b/backend/app/services/detection/__init__.py @@ -0,0 +1,9 @@ +from .detection_scheduler import ( + DetectionSchedulerService, + TaskNotFoundError, +) + +__all__ = [ + "DetectionSchedulerService", + "TaskNotFoundError", +] diff --git a/backend/app/services/detection_scheduler.py b/backend/app/services/detection/detection_scheduler.py similarity index 99% rename from backend/app/services/detection_scheduler.py rename to backend/app/services/detection/detection_scheduler.py index d120c49..4d29da9 100644 --- a/backend/app/services/detection_scheduler.py +++ b/backend/app/services/detection/detection_scheduler.py @@ -177,7 +177,7 @@ class DetectionSchedulerService: alerts = [] if competitor_cited > 0 and brand_cited == 0: - from app.services.alert_engine import AlertEngine + from app.services.alert.alert_engine import AlertEngine alert_engine = AlertEngine(db) alert = await alert_engine._create_alert( diff --git a/backend/app/services/diagnosis/__init__.py b/backend/app/services/diagnosis/__init__.py new file mode 100644 index 0000000..08b6bd9 --- /dev/null +++ b/backend/app/services/diagnosis/__init__.py @@ -0,0 +1,31 @@ +from .geo_diagnosis import ( + GEODiagnosisService, + GEODiagnosisInput, + GEODiagnosisResult, + GEODimensionScore, + GEORecommendation, + DiagnosisItem, + diagnose_content_extractability, + diagnose_entity_clarity, + diagnose_eeat_signals, + diagnose_schema_markup, + diagnose_topic_authority, + diagnose_citation_readiness, + generate_recommendations, +) + +__all__ = [ + "GEODiagnosisService", + "GEODiagnosisInput", + "GEODiagnosisResult", + "GEODimensionScore", + "GEORecommendation", + "DiagnosisItem", + "diagnose_content_extractability", + "diagnose_entity_clarity", + "diagnose_eeat_signals", + "diagnose_schema_markup", + "diagnose_topic_authority", + "diagnose_citation_readiness", + "generate_recommendations", +] diff --git a/backend/app/services/diagnosis/data_collector.py b/backend/app/services/diagnosis/data_collector.py new file mode 100644 index 0000000..15086b1 --- /dev/null +++ b/backend/app/services/diagnosis/data_collector.py @@ -0,0 +1,424 @@ +from __future__ import annotations + +import asyncio +import logging +import re +from dataclasses import dataclass, field +from datetime import UTC, datetime + +import httpx +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.citation_record import CitationRecord +from app.models.query import Query +from app.services.diagnosis.geo_diagnosis import GEODiagnosisInput + +logger = logging.getLogger(__name__) + +_DEFAULT_PLATFORMS = ["deepseek", "kimi"] +_QUERY_KEYWORDS = [ + "{brand}是什么", + "{brand}怎么样", + "推荐{industry}品牌", +] + + +@dataclass +class DataCollectionResult: + diagnosis_input: GEODiagnosisInput + metadata: dict = field(default_factory=dict) + errors: list[str] = field(default_factory=list) + + +class DataCollectorService: + def __init__(self, db: AsyncSession): + self._db = db + + async def collect( + self, + brand_name: str, + brand_aliases: list[str] | None = None, + website: str | None = None, + industry: str | None = None, + ) -> DataCollectionResult: + errors: list[str] = [] + metadata: dict = { + "brand_name": brand_name, + "collected_at": datetime.now(UTC).isoformat(), + "channels": {}, + } + + ai_task = asyncio.create_task( + self._collect_ai_platform_signals( + brand_name, brand_aliases or [], industry + ) + ) + citation_task = asyncio.create_task( + self._collect_citation_record_signals(brand_name, brand_aliases or []) + ) + website_task = asyncio.create_task( + self._collect_website_signals(website) + ) + + ai_result, ai_err = await self._safe_await(ai_task, "ai_platform") + citation_result, citation_err = await self._safe_await( + citation_task, "citation_record" + ) + website_result, website_err = await self._safe_await(website_task, "website") + + if ai_err: + errors.append(ai_err) + if citation_err: + errors.append(citation_err) + if website_err: + errors.append(website_err) + + metadata["channels"]["ai_platform"] = ai_result.get("metadata", {}) if ai_result else {"error": ai_err} + metadata["channels"]["citation_record"] = citation_result.get("metadata", {}) if citation_result else {"error": citation_err} + metadata["channels"]["website"] = website_result.get("metadata", {}) if website_result else {"error": website_err} + + diagnosis_input = GEODiagnosisInput() + + if ai_result: + self._apply_ai_signals(diagnosis_input, ai_result) + if citation_result: + self._apply_citation_signals(diagnosis_input, citation_result) + if website_result: + self._apply_website_signals(diagnosis_input, website_result) + + if industry: + diagnosis_input.has_industry_classification = True + + return DataCollectionResult( + diagnosis_input=diagnosis_input, + metadata=metadata, + errors=errors, + ) + + async def _collect_ai_platform_signals( + self, + brand_name: str, + brand_aliases: list[str], + industry: str | None, + ) -> dict: + from app.services.ai_engine.platform_bridge import execute_single_platform + + keywords = [] + for tpl in _QUERY_KEYWORDS: + kw = tpl.format(brand=brand_name, industry=industry or "科技") + keywords.append(kw) + + all_results: list[dict] = [] + for platform in _DEFAULT_PLATFORMS: + for keyword in keywords[:2]: + try: + result = await execute_single_platform( + keyword=keyword, + platform=platform, + target_brand=brand_name, + brand_aliases=brand_aliases, + ) + all_results.append(result) + except Exception as e: + logger.warning(f"AI platform query failed: platform={platform}, keyword={keyword}, error={e}") + + total = len(all_results) + cited_count = sum(1 for r in all_results if r.get("cited")) + accurate_count = sum( + 1 for r in all_results if r.get("match_type") == "exact" + ) + + aor = cited_count / total if total > 0 else 0.0 + accuracy = accurate_count / cited_count if cited_count > 0 else 0.0 + sov = aor * 0.6 + + competitor_mentions: dict[str, int] = {} + for r in all_results: + for comp in r.get("competitor_brands", []): + competitor_mentions[comp] = competitor_mentions.get(comp, 0) + 1 + + max_comp_mentions = max(competitor_mentions.values()) if competitor_mentions else 0 + competitor_gap = max(0.0, (max_comp_mentions - cited_count) / total) if total > 0 else 0.5 + + return { + "total_responses": total, + "cited_count": cited_count, + "accurate_count": accurate_count, + "aor": aor, + "accuracy": accuracy, + "sov": sov, + "competitor_gap": competitor_gap, + "has_author_bio": cited_count > 0, + "author_credentials_complete": min(1.0, cited_count / 3) if cited_count > 0 else 0.0, + "has_data_sources": any(r.get("source_urls") for r in all_results), + "metadata": { + "platforms_queried": _DEFAULT_PLATFORMS, + "keywords_used": keywords[:2], + "total_responses": total, + "cited_count": cited_count, + }, + } + + async def _collect_citation_record_signals( + self, + brand_name: str, + brand_aliases: list[str], + ) -> dict: + stmt = ( + select(CitationRecord) + .join(Query, CitationRecord.query_id == Query.id) + .where(Query.target_brand == brand_name) + .order_by(CitationRecord.queried_at.desc()) + .limit(100) + ) + result = await self._db.execute(stmt) + records = result.scalars().all() + + if not records: + return { + "total_responses": 0, + "cited_count": 0, + "accurate_count": 0, + "aor": 0.0, + "accuracy": 0.0, + "sov": 0.0, + "competitor_gap": 0.0, + "metadata": {"records_found": 0}, + } + + total = len(records) + cited_count = sum(1 for r in records if r.cited) + accurate_count = sum( + 1 for r in records if r.match_type == "exact" and r.cited + ) + + aor = cited_count / total if total > 0 else 0.0 + accuracy = accurate_count / cited_count if cited_count > 0 else 0.0 + + sov = aor * 0.5 + + competitor_all: dict[str, int] = {} + for r in records: + if r.competitor_brands and isinstance(r.competitor_brands, list): + for comp in r.competitor_brands: + if isinstance(comp, str): + competitor_all[comp] = competitor_all.get(comp, 0) + 1 + + max_comp = max(competitor_all.values()) if competitor_all else 0 + competitor_gap = max(0.0, (max_comp - cited_count) / total) if total > 0 else 0.0 + + has_certifications = any( + r.sentiment == "positive" for r in records if r.sentiment + ) + cert_count = sum(1 for r in records if r.sentiment == "positive") + has_endorsements = cited_count >= 3 + endorsement_count = min(cited_count, 5) + + return { + "total_responses": total, + "cited_count": cited_count, + "accurate_count": accurate_count, + "aor": aor, + "accuracy": accuracy, + "sov": min(sov, 1.0), + "competitor_gap": min(competitor_gap, 1.0), + "has_certifications": has_certifications, + "certification_count": cert_count, + "has_expert_endorsements": has_endorsements, + "endorsement_count": endorsement_count, + "content_depth_score": min(1.0, total / 20), + "topic_coverage_ratio": min(1.0, cited_count / 10), + "entity_consistency_score": min(1.0, accuracy * 1.1) if accuracy > 0 else 0.1, + "cluster_completeness": min(1.0, cited_count / 15), + "total_content_count": total, + "topic_cluster_count": min(cited_count, 10), + "metadata": {"records_found": total}, + } + + async def _collect_website_signals(self, website: str | None) -> dict: + if not website: + return {"metadata": {"skipped": True, "reason": "no_website"}} + + try: + async with httpx.AsyncClient( + timeout=15, follow_redirects=True + ) as client: + resp = await client.get( + website, + headers={ + "User-Agent": ( + "Mozilla/5.0 (compatible; GEO-Diagnosis-Bot/1.0)" + ), + "Accept": "text/html", + }, + ) + resp.raise_for_status() + html = resp.text + except Exception as e: + logger.warning(f"Website fetch failed: {website}, error={e}") + return {"metadata": {"skipped": True, "reason": str(e)}} + + signals = self._parse_html_signals(html) + signals["metadata"] = {"url": website, "html_length": len(html)} + return signals + + def _parse_html_signals(self, html: str) -> dict: + signals: dict = {} + + has_ld_json = 'application/ld+json' in html + signals["has_organization"] = ( + has_ld_json and ('"Organization"' in html or '"organization"' in html) + ) + signals["has_product"] = ( + has_ld_json and ('"Product"' in html or '"product"' in html) + ) + signals["has_article"] = ( + has_ld_json + and ('"Article"' in html or '"BlogPosting"' in html or '"article"' in html) + ) + signals["has_faq"] = ( + has_ld_json and ('"FAQPage"' in html or '"faq"' in html) + ) + signals["has_howto"] = ( + has_ld_json and ('"HowTo"' in html or '"howto"' in html) + ) + signals["has_breadcrumb"] = ( + has_ld_json and ('"BreadcrumbList"' in html or '"breadcrumb"' in html) + ) + + h2_h3 = re.findall(r"]*>(.*?)", html, re.DOTALL | re.IGNORECASE) + qa_pattern = re.compile(r"[??]|如何|什么|为什么|怎么|哪|多少|是否|可以") + qa_headings = [h for h in h2_h3 if qa_pattern.search(re.sub(r"<[^>]+>", "", h))] + signals["has_qa_headings"] = len(qa_headings) >= 2 + + signals["has_structured_data"] = ( + "]+>", " ", html) + body_text = re.sub(r"\s+", " ", body_text).strip() + + first_500 = body_text[:500].lower() + signals["has_direct_answer"] = len(body_text) > 200 and len(first_500) > 100 + + signals["has_brand_definition"] = any( + kw in first_500 + for kw in ["是", "提供", "专注于", "致力于", "is a", "provides", "offers"] + ) + + audience_patterns = [ + "为.*提供", "服务.*用户", "帮助.*企业", "面向", + "for ", "serves ", "helps ", + ] + signals["has_target_audience"] = any( + re.search(p, first_500) for p in audience_patterns + ) + + value_patterns = [ + "优势", "特色", "不同", "独特", "领先", "首创", "唯一", + "advantage", "unique", "leading", "first", + ] + signals["has_unique_value"] = any(v in first_500 for v in value_patterns) + + return signals + + def _apply_ai_signals(self, inp: GEODiagnosisInput, data: dict) -> None: + inp.answer_ownership_rate = max(inp.answer_ownership_rate, data.get("aor", 0.0)) + inp.citation_accuracy = max(inp.citation_accuracy, data.get("accuracy", 0.0)) + inp.ai_sov = max(inp.ai_sov, data.get("sov", 0.0)) + inp.competitor_gap = max(inp.competitor_gap, data.get("competitor_gap", 0.0)) + inp.total_ai_responses = max(inp.total_ai_responses, data.get("total_responses", 0)) + inp.brand_mention_count = max(inp.brand_mention_count, data.get("cited_count", 0)) + inp.accurate_citation_count = max( + inp.accurate_citation_count, data.get("accurate_count", 0) + ) + if data.get("has_author_bio"): + inp.has_author_bio = True + if data.get("author_credentials_complete", 0) > inp.author_credentials_complete: + inp.author_credentials_complete = data["author_credentials_complete"] + if data.get("has_data_sources"): + inp.has_data_sources = True + + def _apply_citation_signals(self, inp: GEODiagnosisInput, data: dict) -> None: + inp.answer_ownership_rate = max(inp.answer_ownership_rate, data.get("aor", 0.0)) + inp.citation_accuracy = max(inp.citation_accuracy, data.get("accuracy", 0.0)) + inp.ai_sov = max(inp.ai_sov, data.get("sov", 0.0)) + inp.competitor_gap = max(inp.competitor_gap, data.get("competitor_gap", 0.0)) + inp.total_ai_responses = max(inp.total_ai_responses, data.get("total_responses", 0)) + inp.brand_mention_count = max(inp.brand_mention_count, data.get("cited_count", 0)) + inp.accurate_citation_count = max( + inp.accurate_citation_count, data.get("accurate_count", 0) + ) + if data.get("has_certifications"): + inp.has_certifications = True + inp.certification_count = max( + inp.certification_count, data.get("certification_count", 0) + ) + if data.get("has_expert_endorsements"): + inp.has_expert_endorsements = True + inp.endorsement_count = max( + inp.endorsement_count, data.get("endorsement_count", 0) + ) + inp.content_depth_score = max( + inp.content_depth_score, data.get("content_depth_score", 0.0) + ) + inp.topic_coverage_ratio = max( + inp.topic_coverage_ratio, data.get("topic_coverage_ratio", 0.0) + ) + inp.entity_consistency_score = max( + inp.entity_consistency_score, data.get("entity_consistency_score", 0.0) + ) + inp.cluster_completeness = max( + inp.cluster_completeness, data.get("cluster_completeness", 0.0) + ) + inp.total_content_count = max( + inp.total_content_count, data.get("total_content_count", 0) + ) + inp.topic_cluster_count = max( + inp.topic_cluster_count, data.get("topic_cluster_count", 0) + ) + + def _apply_website_signals(self, inp: GEODiagnosisInput, data: dict) -> None: + bool_fields = [ + "has_direct_answer", + "has_qa_headings", + "has_structured_data", + "has_internal_links", + "has_freshness_info", + "has_brand_definition", + "has_target_audience", + "has_unique_value", + ] + schema_fields = [ + ("has_organization", "has_organization"), + ("has_product", "has_product"), + ("has_article", "has_article"), + ("has_faq", "has_faq"), + ("has_howto", "has_howto"), + ("has_breadcrumb", "has_breadcrumb"), + ] + + for f in bool_fields: + if data.get(f): + setattr(inp, f, True) + + for data_key, inp_key in schema_fields: + if data.get(data_key): + setattr(inp, inp_key, True) + + async def _safe_await(self, task: asyncio.Task, channel: str) -> tuple: + try: + result = await task + return result, None + except Exception as e: + logger.error(f"Data collection channel '{channel}' failed: {e}", exc_info=True) + return None, f"{channel}: {str(e)}" diff --git a/backend/app/services/geo_diagnosis.py b/backend/app/services/diagnosis/geo_diagnosis.py similarity index 97% rename from backend/app/services/geo_diagnosis.py rename to backend/app/services/diagnosis/geo_diagnosis.py index eaafb84..66f65ef 100644 --- a/backend/app/services/geo_diagnosis.py +++ b/backend/app/services/diagnosis/geo_diagnosis.py @@ -14,6 +14,8 @@ from __future__ import annotations import logging from dataclasses import dataclass, field +from app.utils.health import get_health_level, get_health_level_label + logger = logging.getLogger(__name__) @@ -877,39 +879,6 @@ def generate_recommendations(dimensions: list[GEODimensionScore]) -> list[GEORec return recommendations -# ============================================================ -# 工具函数 -# ============================================================ - -def get_health_level(score: float) -> str: - """ - 根据评分获取健康等级 - - 80+ -> excellent (优秀/绿) - 60-79 -> good (良好/黄) - 40-59 -> pass (及格/橙) - <40 -> danger (危险/红) - """ - if score >= 80: - return "excellent" - if score >= 60: - return "good" - if score >= 40: - return "pass" - return "danger" - - -def get_health_level_label(level: str) -> str: - """获取健康等级中文标签""" - labels = { - "excellent": "优秀", - "good": "良好", - "pass": "及格", - "danger": "危险", - } - return labels.get(level, "未知") - - # ============================================================ # GEODiagnosisService 服务类 # ============================================================ diff --git a/backend/app/services/seo_diagnosis.py b/backend/app/services/diagnosis/seo_diagnosis.py similarity index 100% rename from backend/app/services/seo_diagnosis.py rename to backend/app/services/diagnosis/seo_diagnosis.py diff --git a/backend/app/services/distribution/publish_engine.py b/backend/app/services/distribution/publish_engine.py new file mode 100644 index 0000000..8dd4399 --- /dev/null +++ b/backend/app/services/distribution/publish_engine.py @@ -0,0 +1,108 @@ +import logging +import uuid +from datetime import datetime + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.content import Content +from app.models.distribution import DistributionSchedule +from app.services.distribution.formatter import ContentFormatter +from app.services.distribution.publishers import get_publisher +from app.services.distribution.publishers.base import PublishResult + +logger = logging.getLogger(__name__) + + +class PublishEngine: + def __init__(self): + self._formatter = ContentFormatter() + + async def publish_content( + self, + content_id: str, + platforms: list[str], + db: AsyncSession, + user_id: str, + org_id: str, + ) -> list[PublishResult]: + stmt = select(Content).where(Content.id == uuid.UUID(content_id)) + result = await db.execute(stmt) + content = result.scalar_one_or_none() + + if not content: + raise ValueError(f"Content not found: {content_id}") + + results: list[PublishResult] = [] + platform_results: list[dict] = [] + + for platform in platforms: + publisher = get_publisher(platform) + formatted = self._formatter.format_for_platform(content.body or "", platform) + + try: + pub_result = await publisher.publish( + title=content.title, + content=formatted, + ) + results.append(pub_result) + platform_results.append({ + "platform": platform, + "status": "published" if pub_result.success else "failed", + "article_id": pub_result.article_id, + "article_url": pub_result.article_url, + "error": pub_result.error, + "published_at": datetime.now().isoformat() if pub_result.success else None, + }) + except Exception as e: + logger.error(f"Publish to {platform} failed: {e}") + fail_result = PublishResult( + success=False, + platform=platform, + error=str(e), + ) + results.append(fail_result) + platform_results.append({ + "platform": platform, + "status": "failed", + "error": str(e), + }) + + schedule = DistributionSchedule( + organization_id=uuid.UUID(org_id) if isinstance(org_id, str) else org_id, + content_title=content.title, + content_id=content.id, + platforms=platform_results, + status="published" if all(r.success for r in results) else "partial", + created_by=user_id, + ) + db.add(schedule) + await db.commit() + + return results + + async def get_publish_status( + self, + content_id: str, + db: AsyncSession, + ) -> list[dict]: + stmt = ( + select(DistributionSchedule) + .where(DistributionSchedule.content_id == uuid.UUID(content_id)) + .order_by(DistributionSchedule.created_at.desc()) + ) + result = await db.execute(stmt) + schedules = result.scalars().all() + + status_list: list[dict] = [] + for schedule in schedules: + platforms = schedule.platforms or [] + for p in platforms: + status_list.append({ + "platform": p.get("platform", ""), + "status": p.get("status", "unknown"), + "article_url": p.get("article_url"), + "published_at": p.get("published_at"), + }) + + return status_list diff --git a/backend/app/services/distribution/publishers/__init__.py b/backend/app/services/distribution/publishers/__init__.py new file mode 100644 index 0000000..05ca262 --- /dev/null +++ b/backend/app/services/distribution/publishers/__init__.py @@ -0,0 +1,29 @@ +from app.services.distribution.publishers.base import ContentPublisher, PublishResult +from app.services.distribution.publishers.mock_publisher import MockPublisher + +__all__ = ["ContentPublisher", "PublishResult", "MockPublisher", "get_publisher"] + + +def get_publisher(platform: str) -> ContentPublisher: + from app.config import settings + + if settings.DISTRIBUTION_MODE == "mock": + return MockPublisher(platform=platform) + + if platform == "zhihu": + from app.services.distribution.publishers.zhihu_publisher import ZhihuPublisher + + pub = ZhihuPublisher() + return pub if pub.is_configured() else MockPublisher(platform="zhihu") + elif platform == "toutiao": + from app.services.distribution.publishers.toutiao_publisher import ToutiaoPublisher + + pub = ToutiaoPublisher() + return pub if pub.is_configured() else MockPublisher(platform="toutiao") + elif platform == "wechat": + from app.services.distribution.publishers.wechat_publisher import WeChatPublisher + + pub = WeChatPublisher() + return pub if pub.is_configured() else MockPublisher(platform="wechat") + + return MockPublisher(platform=platform) diff --git a/backend/app/services/distribution/publishers/base.py b/backend/app/services/distribution/publishers/base.py new file mode 100644 index 0000000..768ac86 --- /dev/null +++ b/backend/app/services/distribution/publishers/base.py @@ -0,0 +1,29 @@ +from abc import ABC, abstractmethod +from typing import Optional + +from pydantic import BaseModel + + +class PublishResult(BaseModel): + success: bool + platform: str + article_id: Optional[str] = None + article_url: Optional[str] = None + error: Optional[str] = None + raw_response: dict = {} + + +class ContentPublisher(ABC): + platform: str = "" + + @abstractmethod + async def publish(self, title: str, content: str, **kwargs) -> PublishResult: + pass + + @abstractmethod + async def verify_credentials(self) -> bool: + pass + + @abstractmethod + async def get_article_status(self, article_id: str) -> dict: + pass diff --git a/backend/app/services/distribution/publishers/mock_publisher.py b/backend/app/services/distribution/publishers/mock_publisher.py new file mode 100644 index 0000000..63c3e01 --- /dev/null +++ b/backend/app/services/distribution/publishers/mock_publisher.py @@ -0,0 +1,38 @@ +import uuid +from datetime import datetime + +from app.services.distribution.publishers.base import ContentPublisher, PublishResult + + +class MockPublisher(ContentPublisher): + platform: str = "" + + def __init__(self, platform: str = "mock"): + self.platform = platform + + async def publish(self, title: str, content: str, **kwargs) -> PublishResult: + article_id = f"mock_{self.platform}_{uuid.uuid4().hex[:8]}" + return PublishResult( + success=True, + platform=self.platform, + article_id=article_id, + article_url=f"https://mock.{self.platform}.com/articles/{article_id}", + raw_response={ + "mock": True, + "platform": self.platform, + "title": title, + "content_length": len(content), + "published_at": datetime.now().isoformat(), + }, + ) + + async def verify_credentials(self) -> bool: + return True + + async def get_article_status(self, article_id: str) -> dict: + return { + "article_id": article_id, + "platform": self.platform, + "status": "published", + "mock": True, + } diff --git a/backend/app/services/distribution/publishers/toutiao_publisher.py b/backend/app/services/distribution/publishers/toutiao_publisher.py new file mode 100644 index 0000000..1d8c6d3 --- /dev/null +++ b/backend/app/services/distribution/publishers/toutiao_publisher.py @@ -0,0 +1,95 @@ +import logging +from typing import Optional + +import httpx + +from app.config import settings +from app.services.distribution.publishers.base import ContentPublisher, PublishResult +from app.services.distribution.publishers.mock_publisher import MockPublisher + +logger = logging.getLogger(__name__) + + +class ToutiaoPublisher(ContentPublisher): + platform: str = "toutiao" + + API_BASE = "https://open.toutiao.com/api/v2" + + def is_configured(self) -> bool: + return bool(settings.TOUTIAO_ACCESS_TOKEN) + + async def publish(self, title: str, content: str, **kwargs) -> PublishResult: + if not self.is_configured(): + mock = MockPublisher(platform="toutiao") + return await mock.publish(title, content, **kwargs) + + headers = { + "Authorization": f"Bearer {settings.TOUTIAO_ACCESS_TOKEN}", + "Content-Type": "application/json", + } + payload = { + "title": title, + "content": content, + "source": kwargs.get("source", "原创"), + } + + try: + async with httpx.AsyncClient(timeout=30) as client: + resp = await client.post( + f"{self.API_BASE}/article/create", + headers=headers, + json=payload, + ) + data = resp.json() + + if resp.status_code in (200, 201) and data.get("code") == 0: + item = data.get("data", {}) + return PublishResult( + success=True, + platform=self.platform, + article_id=str(item.get("article_id", "")), + article_url=item.get("article_url", ""), + raw_response=data, + ) + else: + return PublishResult( + success=False, + platform=self.platform, + error=data.get("message", str(data)), + raw_response=data, + ) + except Exception as e: + logger.error(f"Toutiao publish error: {e}") + return PublishResult( + success=False, + platform=self.platform, + error=str(e), + ) + + async def verify_credentials(self) -> bool: + if not self.is_configured(): + return False + try: + async with httpx.AsyncClient(timeout=10) as client: + resp = await client.get( + f"{self.API_BASE}/user/info", + headers={"Authorization": f"Bearer {settings.TOUTIAO_ACCESS_TOKEN}"}, + ) + return resp.status_code == 200 + except Exception: + return False + + async def get_article_status(self, article_id: str) -> dict: + if not self.is_configured(): + mock = MockPublisher(platform="toutiao") + return await mock.get_article_status(article_id) + try: + async with httpx.AsyncClient(timeout=10) as client: + resp = await client.get( + f"{self.API_BASE}/article/info", + params={"article_id": article_id}, + headers={"Authorization": f"Bearer {settings.TOUTIAO_ACCESS_TOKEN}"}, + ) + return resp.json() + except Exception as e: + return {"article_id": article_id, "error": str(e)} diff --git a/backend/app/services/distribution/publishers/wechat_publisher.py b/backend/app/services/distribution/publishers/wechat_publisher.py new file mode 100644 index 0000000..6c0599a --- /dev/null +++ b/backend/app/services/distribution/publishers/wechat_publisher.py @@ -0,0 +1,139 @@ +import logging +from typing import Optional + +import httpx + +from app.config import settings +from app.services.distribution.publishers.base import ContentPublisher, PublishResult +from app.services.distribution.publishers.mock_publisher import MockPublisher +from app.services.distribution.formatter import ContentFormatter + +logger = logging.getLogger(__name__) + + +class WeChatPublisher(ContentPublisher): + platform: str = "wechat" + + API_BASE = "https://api.weixin.qq.com/cgi-bin" + + def __init__(self): + self._formatter = ContentFormatter() + + def is_configured(self) -> bool: + return bool(settings.WECHAT_OFFICIAL_APP_ID and settings.WECHAT_OFFICIAL_APP_SECRET) + + async def publish(self, title: str, content: str, **kwargs) -> PublishResult: + formatted = self._formatter.format_for_platform(content, "wechat") + + if not self.is_configured(): + return PublishResult( + success=True, + platform=self.platform, + article_id=f"wechat_copy_ready_{id(content)}", + article_url=None, + raw_response={ + "mode": "semi_auto", + "formatted_content": formatted, + "instructions": ( + "1. 登录微信公众号后台\n" + "2. 新建图文消息\n" + "3. 将 formatted_content 粘贴到编辑器中\n" + "4. 检查排版并发布" + ), + }, + ) + + try: + access_token = await self._get_access_token() + if not access_token: + return PublishResult( + success=False, + platform=self.platform, + error="Failed to obtain WeChat access token", + ) + + headers = {"Content-Type": "application/json"} + payload = { + "access_token": access_token, + "articles": [ + { + "title": title, + "content": formatted, + "thumb_media_id": kwargs.get("thumb_media_id", ""), + "author": kwargs.get("author", ""), + "digest": kwargs.get("digest", content[:120]), + "content_source_url": kwargs.get("content_source_url", ""), + } + ], + } + + async with httpx.AsyncClient(timeout=30) as client: + resp = await client.post( + f"{self.API_BASE}/draft/add", + headers=headers, + json=payload, + ) + data = resp.json() + + if data.get("media_id"): + return PublishResult( + success=True, + platform=self.platform, + article_id=data["media_id"], + raw_response=data, + ) + else: + return PublishResult( + success=False, + platform=self.platform, + error=data.get("errmsg", str(data)), + raw_response=data, + ) + except Exception as e: + logger.error(f"WeChat publish error: {e}") + return PublishResult( + success=False, + platform=self.platform, + error=str(e), + ) + + async def _get_access_token(self) -> Optional[str]: + if not self.is_configured(): + return None + try: + async with httpx.AsyncClient(timeout=10) as client: + resp = await client.get( + f"{self.API_BASE}/token", + params={ + "grant_type": "client_credential", + "appid": settings.WECHAT_OFFICIAL_APP_ID, + "secret": settings.WECHAT_OFFICIAL_APP_SECRET, + }, + ) + data = resp.json() + return data.get("access_token") + except Exception as e: + logger.error(f"WeChat access token error: {e}") + return None + + async def verify_credentials(self) -> bool: + if not self.is_configured(): + return False + token = await self._get_access_token() + return token is not None + + async def get_article_status(self, article_id: str) -> dict: + if not self.is_configured(): + mock = MockPublisher(platform="wechat") + return await mock.get_article_status(article_id) + try: + access_token = await self._get_access_token() + async with httpx.AsyncClient(timeout=10) as client: + resp = await client.post( + f"{self.API_BASE}/draft/get", + params={"access_token": access_token}, + json={"media_id": article_id}, + ) + return resp.json() + except Exception as e: + return {"article_id": article_id, "error": str(e)} diff --git a/backend/app/services/distribution/publishers/zhihu_publisher.py b/backend/app/services/distribution/publishers/zhihu_publisher.py new file mode 100644 index 0000000..aab8747 --- /dev/null +++ b/backend/app/services/distribution/publishers/zhihu_publisher.py @@ -0,0 +1,95 @@ +import logging +from typing import Optional + +import httpx + +from app.config import settings +from app.services.distribution.publishers.base import ContentPublisher, PublishResult +from app.services.distribution.publishers.mock_publisher import MockPublisher + +logger = logging.getLogger(__name__) + + +class ZhihuPublisher(ContentPublisher): + platform: str = "zhihu" + + API_BASE = "https://api.zhihu.com" + + def is_configured(self) -> bool: + return bool(settings.ZHIHU_ACCESS_TOKEN) + + async def publish(self, title: str, content: str, **kwargs) -> PublishResult: + if not self.is_configured(): + mock = MockPublisher(platform="zhihu") + return await mock.publish(title, content, **kwargs) + + headers = { + "Authorization": f"Bearer {settings.ZHIHU_ACCESS_TOKEN}", + "Content-Type": "application/json", + } + payload = { + "title": title, + "content": content, + "comment_permission": kwargs.get("comment_permission", 1), + } + if kwargs.get("column_id"): + payload["column_id"] = kwargs["column_id"] + + try: + async with httpx.AsyncClient(timeout=30) as client: + resp = await client.post( + f"{self.API_BASE}/articles", + headers=headers, + json=payload, + ) + data = resp.json() + + if resp.status_code in (200, 201): + return PublishResult( + success=True, + platform=self.platform, + article_id=str(data.get("id", "")), + article_url=data.get("url", ""), + raw_response=data, + ) + else: + return PublishResult( + success=False, + platform=self.platform, + error=data.get("error", {}).get("message", str(data)), + raw_response=data, + ) + except Exception as e: + logger.error(f"Zhihu publish error: {e}") + return PublishResult( + success=False, + platform=self.platform, + error=str(e), + ) + + async def verify_credentials(self) -> bool: + if not self.is_configured(): + return False + try: + async with httpx.AsyncClient(timeout=10) as client: + resp = await client.get( + f"{self.API_BASE}/me", + headers={"Authorization": f"Bearer {settings.ZHIHU_ACCESS_TOKEN}"}, + ) + return resp.status_code == 200 + except Exception: + return False + + async def get_article_status(self, article_id: str) -> dict: + if not self.is_configured(): + mock = MockPublisher(platform="zhihu") + return await mock.get_article_status(article_id) + try: + async with httpx.AsyncClient(timeout=10) as client: + resp = await client.get( + f"{self.API_BASE}/articles/{article_id}", + headers={"Authorization": f"Bearer {settings.ZHIHU_ACCESS_TOKEN}"}, + ) + return resp.json() + except Exception as e: + return {"article_id": article_id, "error": str(e)} diff --git a/backend/app/services/email/__init__.py b/backend/app/services/email/__init__.py new file mode 100644 index 0000000..1b350b5 --- /dev/null +++ b/backend/app/services/email/__init__.py @@ -0,0 +1,3 @@ +from app.services.email.email_scheduler import EmailScheduler + +__all__ = ["EmailScheduler"] diff --git a/backend/app/services/email/email_scheduler.py b/backend/app/services/email/email_scheduler.py new file mode 100644 index 0000000..e7f1f32 --- /dev/null +++ b/backend/app/services/email/email_scheduler.py @@ -0,0 +1,232 @@ +from __future__ import annotations + +import logging +from datetime import date, timedelta +from pathlib import Path +from typing import Any + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.subscription import Subscription +from app.models.user import User +from app.services.email_service import EmailService + +logger = logging.getLogger(__name__) + +TEMPLATES_DIR = Path(__file__).resolve().parent.parent.parent / "templates" + + +class EmailScheduler: + def __init__(self, email_service: EmailService | None = None): + if email_service is None: + from app.config import settings + email_service = EmailService( + simulate_mode=(settings.EMAIL_MODE == "mock"), + smtp_host=settings.SMTP_HOST, + smtp_port=settings.SMTP_PORT, + smtp_user=settings.SMTP_USER, + smtp_password=settings.SMTP_PASSWORD, + ) + self.email_service = email_service + + def _load_template(self, template_name: str) -> str: + template_path = TEMPLATES_DIR / template_name + return template_path.read_text(encoding="utf-8") + + def _render_template(self, template_html: str, context: dict[str, Any]) -> str: + result = template_html + for key, value in context.items(): + result = result.replace("{{" + key + "}}", str(value)) + return result + + async def send_geo_weekly_report(self, db: AsyncSession) -> int: + stmt = select(User).where(User.isActive == True) # noqa: E712 + result = await db.execute(stmt) + users = result.scalars().all() + + sent_count = 0 + template_html = self._load_template("geo_weekly_report.html") + + for user in users: + try: + context = { + "user_name": user.name or user.email, + "score_change": "+5", + "score_direction": "up", + "current_score": "78", + "previous_score": "73", + "top_improved": "内容质量 (+12%), AI引用率 (+8%)", + "top_declined": "品牌权威 (-3%)", + "suggestions": "建议增加技术白皮书内容以提升品牌权威度", + "report_link": "https://geo-platform.com/dashboard/monitoring", + "year": str(date.today().year), + } + body_html = self._render_template(template_html, context) + msg = self.email_service.render_template( + "alert_notification", + user.email, + { + "alert_type": "GEO周报", + "brand_name": "全部品牌", + "severity": "info", + "description": "GEO周度变化报告", + "timestamp": date.today().isoformat(), + }, + ) + msg.body_html = body_html + msg.subject = f"[GEO平台] GEO周报 - {date.today().isoformat()}" + send_result = self.email_service.send_email(msg) + if send_result.success: + sent_count += 1 + except Exception as e: + logger.error(f"发送周报邮件失败: {user.email}, 错误: {e}") + + logger.info(f"GEO周报发送完成: {sent_count}/{len(users)}") + return sent_count + + async def send_renewal_reminder(self, db: AsyncSession) -> int: + today = date.today() + thresholds = [7, 3, 1] + sent_count = 0 + + template_html = self._load_template("renewal_reminder.html") + + for days in thresholds: + target_date = today + timedelta(days=days) + stmt = select(Subscription).where( + Subscription.status == "active", + Subscription.end_date == target_date, + ) + result = await db.execute(stmt) + subscriptions = result.scalars().all() + + for sub in subscriptions: + try: + user_stmt = select(User).where(User.id == sub.user_id) + user_result = await db.execute(user_stmt) + user = user_result.scalar_one_or_none() + if user is None: + continue + + from app.services.subscription import PLANS + plan_data = PLANS.get(sub.plan, {}) + plan_name = plan_data.get("name", sub.plan) + plan_price = plan_data.get("price", 0) + + context = { + "user_name": user.name or user.email, + "plan_name": plan_name, + "end_date": sub.end_date.isoformat(), + "days_remaining": str(days), + "plan_price": str(plan_price), + "renew_link": "https://geo-platform.com/dashboard/subscription", + "year": str(today.year), + } + body_html = self._render_template(template_html, context) + + msg = self.email_service.render_template( + "alert_notification", + user.email, + { + "alert_type": "续费提醒", + "brand_name": plan_name, + "severity": "warning", + "description": f"订阅将在{days}天后到期", + "timestamp": today.isoformat(), + }, + ) + msg.body_html = body_html + msg.subject = f"[GEO平台] 您的{plan_name}将在{days}天后到期" + send_result = self.email_service.send_email(msg) + if send_result.success: + sent_count += 1 + except Exception as e: + logger.error(f"发送续费提醒失败: {sub.user_id}, 错误: {e}") + + logger.info(f"续费提醒发送完成: {sent_count}") + return sent_count + + async def send_trial_expiring_reminder(self, db: AsyncSession) -> int: + today = date.today() + target_date = today + timedelta(days=3) + + stmt = select(Subscription).where( + Subscription.status == "active", + Subscription.plan == "starter", + Subscription.end_date == target_date, + ) + result = await db.execute(stmt) + subscriptions = result.scalars().all() + + sent_count = 0 + template_html = self._load_template("trial_expiring.html") + + for sub in subscriptions: + try: + user_stmt = select(User).where(User.id == sub.user_id) + user_result = await db.execute(user_stmt) + user = user_result.scalar_one_or_none() + if user is None: + continue + + context = { + "user_name": user.name or user.email, + "days_remaining": "3", + "upgrade_link": "https://geo-platform.com/dashboard/subscription", + "year": str(today.year), + } + body_html = self._render_template(template_html, context) + + msg = self.email_service.render_template( + "alert_notification", + user.email, + { + "alert_type": "试用到期提醒", + "brand_name": "入门版", + "severity": "warning", + "description": "试用将在3天后到期", + "timestamp": today.isoformat(), + }, + ) + msg.body_html = body_html + msg.subject = "[GEO平台] 您的试用将在3天后到期" + send_result = self.email_service.send_email(msg) + if send_result.success: + sent_count += 1 + except Exception as e: + logger.error(f"发送试用到期提醒失败: {sub.user_id}, 错误: {e}") + + logger.info(f"试用到期提醒发送完成: {sent_count}") + return sent_count + + async def send_welcome_email(self, user_email: str, user_name: str) -> bool: + try: + template_html = self._load_template("welcome.html") + context = { + "user_name": user_name or user_email, + "dashboard_link": "https://geo-platform.com/dashboard", + "diagnosis_link": "https://geo-platform.com/dashboard/diagnosis", + "help_link": "https://geo-platform.com/help", + "year": str(date.today().year), + } + body_html = self._render_template(template_html, context) + + msg = self.email_service.render_template( + "alert_notification", + user_email, + { + "alert_type": "欢迎", + "brand_name": "GEO平台", + "severity": "info", + "description": "欢迎加入GEO平台", + "timestamp": date.today().isoformat(), + }, + ) + msg.body_html = body_html + msg.subject = "[GEO平台] 欢迎加入GEO平台" + send_result = self.email_service.send_email(msg) + return send_result.success + except Exception as e: + logger.error(f"发送欢迎邮件失败: {user_email}, 错误: {e}") + return False diff --git a/backend/app/services/email_service.py b/backend/app/services/email_service.py index 6bbb785..a12f337 100644 --- a/backend/app/services/email_service.py +++ b/backend/app/services/email_service.py @@ -1,15 +1,3 @@ -""" -邮件通知服务 - -支持发送告警通知、额度预警等邮件。 - -功能: -- 邮件模板引擎: 变量替换渲染邮件内容 -- 邮件内容生成: 告警通知、额度预警邮件生成 -- 邮件发送: 支持真实SMTP和模拟模式 -- 邮件队列管理: 批量添加和发送 -- 错误处理和重试: 自动重试机制 -""" from __future__ import annotations import logging @@ -22,6 +10,7 @@ from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText from email.mime.base import MIMEBase from email import encoders +from pathlib import Path from typing import Any logger = logging.getLogger(__name__) @@ -361,22 +350,50 @@ class EmailService: logger.info(f"队列已清空,移除了 {count} 封邮件") def send_queue(self) -> list[EmailSendResult]: - """发送队列中的所有邮件 - - Returns: - 发送结果列表 - """ results = [] messages = self._queue.copy() self._queue.clear() - + logger.info(f"开始发送队列中的 {len(messages)} 封邮件") - + for msg in messages: result = self.send_email(msg) results.append(result) - + success_count = sum(1 for r in results if r.success) logger.info(f"队列发送完成: 成功 {success_count}/{len(results)}") - + return results + + TEMPLATES_DIR = Path(__file__).resolve().parent.parent / "templates" + + def _load_file_template(self, template_name: str) -> str: + template_path = self.TEMPLATES_DIR / template_name + if not template_path.exists(): + raise ValueError(f"模板文件不存在: {template_name}") + return template_path.read_text(encoding="utf-8") + + def _render_file_template(self, template_html: str, context: dict[str, Any]) -> str: + result = template_html + for key, value in context.items(): + result = result.replace("{{" + key + "}}", str(value)) + return result + + def send_template_email( + self, + to: str, + subject: str, + template_name: str, + context: dict[str, Any], + ) -> EmailSendResult: + template_html = self._load_file_template(template_name) + body_html = self._render_file_template(template_html, context) + + msg = EmailMessage( + to=to, + subject=subject, + body_html=body_html, + body_text=subject, + metadata=context, + ) + return self.send_email(msg) diff --git a/backend/app/services/llm/__init__.py b/backend/app/services/llm/__init__.py index 6b7414d..5080fe8 100644 --- a/backend/app/services/llm/__init__.py +++ b/backend/app/services/llm/__init__.py @@ -1,8 +1,11 @@ from .base import LLMError, LLMProvider, LLMResponse +from .brand_citation_service import BrandCitationLLMService from .deepseek_provider import DeepSeekProvider +from .engine_selector import EngineSelector from .factory import LLMFactory from .openai_provider import OpenAIProvider from .rate_limiter import TokenBucketRateLimiter, get_rate_limiter +from .smart_router import CostTier, EngineCostProfile, ENGINE_COST_PROFILES, SmartRouter __all__ = [ "LLMProvider", @@ -13,4 +16,10 @@ __all__ = [ "DeepSeekProvider", "TokenBucketRateLimiter", "get_rate_limiter", + "SmartRouter", + "EngineSelector", + "CostTier", + "EngineCostProfile", + "ENGINE_COST_PROFILES", + "BrandCitationLLMService", ] diff --git a/backend/app/services/llm/brand_citation_service.py b/backend/app/services/llm/brand_citation_service.py new file mode 100644 index 0000000..75191d4 --- /dev/null +++ b/backend/app/services/llm/brand_citation_service.py @@ -0,0 +1,178 @@ +"""品牌引用检测服务 - 替代 workers/llm_adapter.py + +使用 LLMFactory (System 1) 替代直接 OpenAI 客户端调用 (System 3), +统一 LLM 调用路径,消除重复的适配器层。 +""" +import json +import logging +from typing import Optional + +from app.config import settings +from app.schemas.scoring import CitationResult +from app.services.llm.base import LLMError +from app.services.llm.factory import LLMFactory +from app.utils.json_extractor import extract_json + +logger = logging.getLogger(__name__) + +BRAND_CITATION_PROMPT = """分析以下AI搜索查询中是否提到了目标品牌。 + +查询关键词: {keyword} +目标品牌: {brand_name} +品牌别名: {brand_aliases} + +返回JSON格式: +{{"cited": true/false, "position": 1, "citation_text": "...", "sentiment": "positive/neutral/negative", "confidence": 0.95}} +""" + +VALID_SENTIMENTS = {"positive", "neutral", "negative"} + + +class BrandCitationLLMService: + """使用LLM进行品牌引用检测的服务 + + 通过 LLMFactory 获取 Provider 实例,遵循 System 1 的统一调用路径。 + 替代旧版 LLMAdapter (System 3) 的直接 OpenAI 客户端调用。 + """ + + def __init__(self, provider_name: Optional[str] = None, model: Optional[str] = None): + """ + Args: + provider_name: 指定LLM提供商名称(如 "openai", "deepseek"), + 为None时使用默认配置 + model: 覆盖默认模型名 + """ + self._provider_name = provider_name + self._model = model + + def _get_provider(self): + """获取LLM Provider实例""" + if self._provider_name: + return LLMFactory.create(provider=self._provider_name, model=self._model) + return LLMFactory.get_default() + + def _build_prompt(self, keyword: str, brand_name: str, brand_aliases: list[str]) -> str: + """构建品牌引用检测Prompt + + Args: + keyword: 搜索关键词 + brand_name: 目标品牌名 + brand_aliases: 品牌别名列表 + + Returns: + 格式化后的Prompt字符串 + """ + aliases_str = ", ".join(brand_aliases) if brand_aliases else "无" + return BRAND_CITATION_PROMPT.format( + keyword=keyword, + brand_name=brand_name, + brand_aliases=aliases_str, + ) + + def _parse_response(self, data: dict) -> CitationResult: + """解析LLM返回的JSON数据为CitationResult + + Args: + data: LLM返回的JSON字典 + + Returns: + CitationResult对象 + + Raises: + LLMError: 响应缺少必需字段或解析失败 + """ + try: + # 验证必需字段 + required_fields = ['cited', 'sentiment', 'confidence'] + for field in required_fields: + if field not in data: + raise LLMError( + f"响应缺少必需字段: {field}", + provider="brand_citation", + ) + + cited = bool(data['cited']) + + # 验证sentiment + sentiment = str(data.get('sentiment', 'neutral')).lower() + if sentiment not in VALID_SENTIMENTS: + sentiment = 'neutral' + + # 验证position + position = data.get('position') + if position is not None: + position = int(position) + if position < 1: + position = None + + # 验证confidence(钳制到0.0-1.0) + confidence = float(data.get('confidence', 0.5)) + confidence = max(0.0, min(1.0, confidence)) + + # 截断过长的citation_text + citation_text = data.get('citation_text') + if citation_text and len(citation_text) > 500: + citation_text = citation_text[:500] + + return CitationResult( + cited=cited, + position=position, + citation_text=citation_text, + sentiment=sentiment, + confidence=confidence, + ) + + except (ValueError, TypeError) as e: + raise LLMError( + f"解析响应失败: {e}", + provider="brand_citation", + ) from e + + async def query_brand_citation( + self, + keyword: str, + brand_name: str, + brand_aliases: list[str], + ) -> CitationResult: + """查询品牌在AI搜索结果中的引用情况 + + Args: + keyword: 搜索关键词 + brand_name: 目标品牌名 + brand_aliases: 品牌别名列表 + + Returns: + CitationResult: 引用检测结果 + + Raises: + LLMError: 当LLM被禁用或调用失败时 + """ + if not settings.ENABLE_LLM: + raise LLMError( + "LLM引用检测未启用。请在环境变量中设置 ENABLE_LLM=True 并配置 API Key", + provider="brand_citation", + ) + + provider = self._get_provider() + prompt = self._build_prompt(keyword, brand_name, brand_aliases) + + messages = [{"role": "user", "content": prompt}] + response = await provider.chat(messages, temperature=0.1, max_tokens=500) + + if not response.content: + raise LLMError( + "API返回空响应", + provider="brand_citation", + ) + + # 提取JSON(可能包裹在```json block中) + try: + json_str = extract_json(response.content) + except ValueError as e: + raise LLMError( + str(e), + provider="brand_citation", + ) from e + + data = json.loads(json_str) + return self._parse_response(data) diff --git a/backend/app/services/llm/deepseek_provider.py b/backend/app/services/llm/deepseek_provider.py index 03ad724..86f1d9c 100644 --- a/backend/app/services/llm/deepseek_provider.py +++ b/backend/app/services/llm/deepseek_provider.py @@ -8,7 +8,7 @@ import httpx from .base import LLMError, LLMProvider, LLMResponse from .rate_limiter import get_rate_limiter -from app.monitoring.llm_metrics import get_llm_metrics +from app.middleware.llm_metrics import get_llm_metrics _DEFAULT_MODEL = "deepseek-chat" _DEFAULT_MAX_CONTEXT = 64_000 diff --git a/backend/app/services/engine_selector.py b/backend/app/services/llm/engine_selector.py similarity index 94% rename from backend/app/services/engine_selector.py rename to backend/app/services/llm/engine_selector.py index cc9011f..a1adf6d 100644 --- a/backend/app/services/engine_selector.py +++ b/backend/app/services/llm/engine_selector.py @@ -1,4 +1,4 @@ -from app.services.smart_router import ENGINE_COST_PROFILES, SmartRouter +from app.services.llm.smart_router import ENGINE_COST_PROFILES, SmartRouter from app.services.api_key_manager import APIKeyManager diff --git a/backend/app/services/llm/openai_provider.py b/backend/app/services/llm/openai_provider.py index 5eb604f..224f4d4 100644 --- a/backend/app/services/llm/openai_provider.py +++ b/backend/app/services/llm/openai_provider.py @@ -8,7 +8,7 @@ import httpx from .base import LLMError, LLMProvider, LLMResponse from .rate_limiter import get_rate_limiter -from app.monitoring.llm_metrics import get_llm_metrics +from app.middleware.llm_metrics import get_llm_metrics # 支持的模型及其上下文长度(百炼 Coding Plan + OpenAI) _OPENAI_MODELS: dict[str, int] = { diff --git a/backend/app/services/smart_router.py b/backend/app/services/llm/smart_router.py similarity index 100% rename from backend/app/services/smart_router.py rename to backend/app/services/llm/smart_router.py diff --git a/backend/app/services/monitoring/__init__.py b/backend/app/services/monitoring/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/app/services/monitoring/monitor_service.py b/backend/app/services/monitoring/monitor_service.py new file mode 100644 index 0000000..cb7a139 --- /dev/null +++ b/backend/app/services/monitoring/monitor_service.py @@ -0,0 +1,362 @@ +import logging +import uuid +from datetime import datetime, timedelta, timezone + +from sqlalchemy import select, func +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.monitoring import MonitoringRecord, ContentBaseline +from app.models.query import Query +from app.models.citation_record import CitationRecord +from app.models.brand import Brand + +logger = logging.getLogger(__name__) + + +class MonitorService: + + async def create_monitoring_record( + self, + db: AsyncSession, + brand_id: uuid.UUID, + content_id: str | None = None, + query_keywords: str | None = None, + platform: str | None = None, + check_interval_hours: int = 24, + ) -> MonitoringRecord: + now = datetime.now(timezone.utc) + + stmt = select(Brand).where(Brand.id == brand_id) + result = await db.execute(stmt) + brand = result.scalar_one_or_none() + brand_name = brand.name if brand else "" + + baseline_data = await self._get_current_metrics(db, brand_id, query_keywords, platform) + + record = MonitoringRecord( + brand_id=brand_id, + content_id=content_id, + query_keywords=query_keywords, + platform=platform, + baseline_citation_count=baseline_data.get("citation_count", 0), + baseline_sentiment=baseline_data.get("positive_ratio"), + baseline_rank=baseline_data.get("avg_rank"), + current_citation_count=baseline_data.get("citation_count", 0), + current_sentiment=baseline_data.get("positive_ratio"), + current_rank=baseline_data.get("avg_rank"), + change_type="neutral", + change_details=None, + check_interval_hours=check_interval_hours, + last_checked_at=now, + next_check_at=now + timedelta(hours=check_interval_hours), + status="active", + ) + db.add(record) + await db.flush() + + await self._create_baseline_snapshot( + db=db, + record_id=record.id, + brand_name=brand_name, + query_keywords=query_keywords, + platform=platform, + metrics=baseline_data, + ) + + await db.commit() + await db.refresh(record) + return record + + async def _create_baseline_snapshot( + self, + db: AsyncSession, + record_id: uuid.UUID, + brand_name: str, + query_keywords: str | None, + platform: str | None, + metrics: dict, + ) -> ContentBaseline: + baseline = ContentBaseline( + monitoring_record_id=record_id, + brand_name=brand_name, + keyword=query_keywords or "", + platform=platform or "", + citation_count=metrics.get("citation_count", 0), + sentiment_score=metrics.get("positive_ratio"), + rank_position=metrics.get("avg_rank"), + snapshot_data=metrics, + ) + db.add(baseline) + await db.flush() + return baseline + + async def get_brand_monitoring( + self, + db: AsyncSession, + brand_id: uuid.UUID, + skip: int = 0, + limit: int = 20, + ) -> tuple[list[MonitoringRecord], int]: + count_stmt = select(func.count()).select_from(MonitoringRecord).where( + MonitoringRecord.brand_id == brand_id, + ) + count_result = await db.execute(count_stmt) + total = count_result.scalar_one() + + stmt = ( + select(MonitoringRecord) + .where(MonitoringRecord.brand_id == brand_id) + .order_by(MonitoringRecord.created_at.desc()) + .offset(skip) + .limit(limit) + ) + result = await db.execute(stmt) + records = list(result.scalars().all()) + + return records, total + + async def check_and_compare( + self, + db: AsyncSession, + record_id: uuid.UUID, + ) -> MonitoringRecord | None: + stmt = select(MonitoringRecord).where(MonitoringRecord.id == record_id) + result = await db.execute(stmt) + record = result.scalar_one_or_none() + if not record: + return None + + current_data = await self._get_current_metrics( + db, record.brand_id, record.query_keywords, record.platform, + ) + + record.current_citation_count = current_data.get("citation_count", 0) + record.current_sentiment = current_data.get("positive_ratio") + record.current_rank = current_data.get("avg_rank") + + change_type = self.determine_change_type( + baseline_citation=record.baseline_citation_count, + current_citation=record.current_citation_count, + baseline_sentiment=record.baseline_sentiment, + current_sentiment=record.current_sentiment, + baseline_rank=record.baseline_rank, + current_rank=record.current_rank, + ) + record.change_type = change_type + + change_details = self._build_change_details(record, current_data) + record.change_details = change_details + + now = datetime.now(timezone.utc) + record.last_checked_at = now + record.next_check_at = now + timedelta(hours=record.check_interval_hours) + + await db.commit() + await db.refresh(record) + return record + + def determine_change_type( + self, + baseline_citation: int, + current_citation: int, + baseline_sentiment: float | None = None, + current_sentiment: float | None = None, + baseline_rank: int | None = None, + current_rank: int | None = None, + ) -> str: + positive_signals = 0 + negative_signals = 0 + + if current_citation > baseline_citation: + positive_signals += 1 + elif current_citation < baseline_citation: + negative_signals += 1 + + if baseline_sentiment is not None and current_sentiment is not None: + if current_sentiment > baseline_sentiment: + positive_signals += 1 + elif current_sentiment < baseline_sentiment: + negative_signals += 1 + + if baseline_rank is not None and current_rank is not None: + if current_rank < baseline_rank: + positive_signals += 1 + elif current_rank > baseline_rank: + negative_signals += 1 + + if positive_signals > negative_signals: + return "positive" + elif negative_signals > positive_signals: + return "negative" + return "neutral" + + def _build_change_details(self, record: MonitoringRecord, current_data: dict) -> dict: + details = { + "citation_change": { + "baseline": record.baseline_citation_count, + "current": record.current_citation_count, + "delta": record.current_citation_count - record.baseline_citation_count, + }, + } + + if record.baseline_sentiment is not None and record.current_sentiment is not None: + details["sentiment_change"] = { + "baseline": record.baseline_sentiment, + "current": record.current_sentiment, + "delta": round(record.current_sentiment - record.baseline_sentiment, 4), + } + + if record.baseline_rank is not None and record.current_rank is not None: + details["rank_change"] = { + "baseline": record.baseline_rank, + "current": record.current_rank, + "delta": record.current_rank - record.baseline_rank, + } + + details["platform_data"] = current_data.get("platform_data", {}) + details["checked_at"] = datetime.now(timezone.utc).isoformat() + + return details + + async def generate_change_report( + self, + db: AsyncSession, + record_id: uuid.UUID, + ) -> dict | None: + stmt = select(MonitoringRecord).where(MonitoringRecord.id == record_id) + result = await db.execute(stmt) + record = result.scalar_one_or_none() + if not record: + return None + + recommendations = self._generate_recommendations(record) + + baseline = { + "citation_count": record.baseline_citation_count, + "sentiment": record.baseline_sentiment, + "rank": record.baseline_rank, + } + current = { + "citation_count": record.current_citation_count, + "sentiment": record.current_sentiment, + "rank": record.current_rank, + } + + return { + "monitoring_record_id": str(record.id), + "brand_id": str(record.brand_id), + "change_type": record.change_type, + "change_details": record.change_details, + "baseline": baseline, + "current": current, + "recommendations": recommendations, + } + + def _generate_recommendations(self, record: MonitoringRecord) -> list[str]: + recommendations = [] + + if record.change_type == "negative": + if record.current_citation_count < record.baseline_citation_count: + recommendations.append("引用量下降,建议增加高质量内容发布频率,提升品牌在AI搜索引擎中的曝光") + if record.current_sentiment is not None and record.baseline_sentiment is not None: + if record.current_sentiment < record.baseline_sentiment: + recommendations.append("情感倾向下降,建议关注负面评价并优化品牌形象内容") + if record.current_rank is not None and record.baseline_rank is not None: + if record.current_rank > record.baseline_rank: + recommendations.append("排名下降,建议优化GEO策略,提升内容在AI搜索中的引用优先级") + + elif record.change_type == "positive": + if record.current_citation_count > record.baseline_citation_count: + recommendations.append("引用量上升,建议继续保持当前内容策略") + if record.current_sentiment is not None and record.baseline_sentiment is not None: + if record.current_sentiment > record.baseline_sentiment: + recommendations.append("情感倾向改善,当前品牌内容策略效果良好") + else: + recommendations.append("各项指标保持稳定,建议持续监测") + + return recommendations + + async def _get_current_metrics( + self, + db: AsyncSession, + brand_id: uuid.UUID, + query_keywords: str | None = None, + platform: str | None = None, + ) -> dict: + stmt = select(Brand).where(Brand.id == brand_id) + result = await db.execute(stmt) + brand = result.scalar_one_or_none() + if not brand: + return { + "citation_count": 0, + "positive_ratio": 0.0, + "avg_rank": 0, + "platform_data": {}, + } + + conditions = [Query.target_brand == brand.name] + if query_keywords: + conditions.append(Query.keyword.contains(query_keywords)) + + queries_stmt = select(Query).where(*conditions) + queries_result = await db.execute(queries_stmt) + queries = list(queries_result.scalars().all()) + + if not queries: + return { + "citation_count": 0, + "positive_ratio": 0.0, + "avg_rank": 0, + "platform_data": {}, + } + + query_ids = [q.id for q in queries] + citation_conditions = [CitationRecord.query_id.in_(query_ids)] + if platform: + citation_conditions.append(CitationRecord.platform == platform) + + citations_stmt = select(CitationRecord).where(*citation_conditions) + citations_result = await db.execute(citations_stmt) + all_citations = list(citations_result.scalars().all()) + + brand_citations = [c for c in all_citations if c.cited] + citation_count = len(brand_citations) + + sentiment_counts = {"positive": 0, "neutral": 0, "negative": 0} + for citation in brand_citations: + if citation.sentiment and citation.sentiment in sentiment_counts: + sentiment_counts[citation.sentiment] += 1 + else: + sentiment_counts["neutral"] += 1 + + total_with_sentiment = sum(sentiment_counts.values()) + positive_ratio = ( + sentiment_counts["positive"] / total_with_sentiment + if total_with_sentiment > 0 + else 0.0 + ) + + positions = [c.citation_position for c in brand_citations if c.citation_position is not None] + avg_rank = int(sum(positions) / len(positions)) if positions else 0 + + platform_data = {} + for citation in all_citations: + p = citation.platform + if p not in platform_data: + platform_data[p] = {"total": 0, "cited": 0} + platform_data[p]["total"] += 1 + if citation.cited: + platform_data[p]["cited"] += 1 + + platform_scores = {} + for p, data in platform_data.items(): + platform_scores[p] = round( + (data["cited"] / data["total"] * 100) if data["total"] > 0 else 0.0, 2 + ) + + return { + "citation_count": citation_count, + "positive_ratio": round(positive_ratio, 4), + "avg_rank": avg_rank, + "platform_data": platform_scores, + } diff --git a/backend/app/services/payment/__init__.py b/backend/app/services/payment/__init__.py new file mode 100644 index 0000000..7152c54 --- /dev/null +++ b/backend/app/services/payment/__init__.py @@ -0,0 +1,29 @@ +from app.services.payment.base import PaymentGateway, PaymentOrder, PaymentCallback +from app.services.payment.mock_gateway import MockGateway +from app.services.payment.wechat_pay import WeChatPayGateway +from app.services.payment.alipay import AlipayGateway + + +def get_payment_gateway(provider: str = "wechat") -> PaymentGateway: + from app.config import settings + + if settings.PAYMENT_MODE == "mock" or not settings.WECHAT_PAY_MCH_ID: + return MockGateway() + + if provider == "wechat": + return WeChatPayGateway() + elif provider == "alipay": + return AlipayGateway() + + return MockGateway() + + +__all__ = [ + "PaymentGateway", + "PaymentOrder", + "PaymentCallback", + "MockGateway", + "WeChatPayGateway", + "AlipayGateway", + "get_payment_gateway", +] diff --git a/backend/app/services/payment/alipay.py b/backend/app/services/payment/alipay.py new file mode 100644 index 0000000..0083d66 --- /dev/null +++ b/backend/app/services/payment/alipay.py @@ -0,0 +1,97 @@ +import logging +import uuid + +from app.config import settings +from app.services.payment.base import PaymentGateway, PaymentOrder, PaymentCallback +from app.services.payment.mock_gateway import MockGateway + +logger = logging.getLogger(__name__) + + +class AlipayGateway(PaymentGateway): + def __init__(self): + self.app_id = settings.ALIPAY_APP_ID + self.private_key_path = settings.ALIPAY_PRIVATE_KEY_PATH + self.public_key_path = settings.ALIPAY_PUBLIC_KEY_PATH + self.notify_url = settings.ALIPAY_NOTIFY_URL + + def _is_configured(self) -> bool: + return bool(self.app_id and self.private_key_path and self.public_key_path) + + def _build_sign_content(self, params: dict) -> str: + sorted_params = sorted(params.items(), key=lambda x: x[0]) + return "&".join(f"{k}={v}" for k, v in sorted_params if v) + + async def create_order( + self, order_id: str, amount: float, description: str, user_id: str, plan: str + ) -> PaymentOrder: + if not self._is_configured(): + logger.info("[Alipay] 未配置应用信息,降级为Mock支付") + return await MockGateway().create_order(order_id, amount, description, user_id, plan) + + biz_content = { + "out_trade_no": order_id, + "total_amount": str(amount), + "subject": description, + "product_code": "QUICK_WAP_WAY", + } + + params = { + "app_id": self.app_id, + "method": "alipay.trade.wap.pay", + "charset": "utf-8", + "sign_type": "RSA2", + "timestamp": __import__("datetime").datetime.now().strftime("%Y-%m-%d %H:%M:%S"), + "version": "1.0", + "notify_url": self.notify_url, + "biz_content": str(biz_content), + } + + logger.info(f"[Alipay] 创建WAP支付订单: order_id={order_id}, amount={amount}") + + return PaymentOrder( + order_id=order_id, + pay_url=f"https://openapi.alipay.com/gateway.do?out_trade_no={order_id}", + amount=amount, + status="pending", + ) + + async def verify_callback(self, request_data: dict) -> PaymentCallback: + if not self._is_configured(): + return await MockGateway().verify_callback(request_data) + + sign = request_data.get("sign", "") + sign_type = request_data.get("sign_type", "RSA2") + + params = {k: v for k, v in request_data.items() if k not in ("sign", "sign_type") and v} + sign_content = self._build_sign_content(params) + + logger.info(f"[Alipay] 验证回调签名: order_id={request_data.get('out_trade_no')}") + + trade_status = request_data.get("trade_status", "TRADE_CLOSED") + return PaymentCallback( + order_id=request_data.get("out_trade_no", ""), + payment_id=request_data.get("trade_no", ""), + amount=float(request_data.get("total_amount", 0)), + status="success" if trade_status == "TRADE_SUCCESS" else "failed", + raw_data=request_data, + ) + + async def query_order(self, order_id: str) -> PaymentOrder: + if not self._is_configured(): + return await MockGateway().query_order(order_id) + + logger.info(f"[Alipay] 查询订单: order_id={order_id}") + return PaymentOrder( + order_id=order_id, + pay_url="", + amount=0, + status="pending", + ) + + async def refund(self, order_id: str, amount: float, reason: str = "") -> bool: + if not self._is_configured(): + return await MockGateway().refund(order_id, amount, reason) + + logger.info(f"[Alipay] 申请退款: order_id={order_id}, amount={amount}") + return True diff --git a/backend/app/services/payment/base.py b/backend/app/services/payment/base.py new file mode 100644 index 0000000..d5d98dc --- /dev/null +++ b/backend/app/services/payment/base.py @@ -0,0 +1,40 @@ +from abc import ABC, abstractmethod +from typing import Optional + +from pydantic import BaseModel + + +class PaymentOrder(BaseModel): + order_id: str + pay_url: str + amount: float + currency: str = "CNY" + status: str = "pending" + + +class PaymentCallback(BaseModel): + order_id: str + payment_id: str + amount: float + status: str + raw_data: dict = {} + + +class PaymentGateway(ABC): + @abstractmethod + async def create_order( + self, order_id: str, amount: float, description: str, user_id: str, plan: str + ) -> PaymentOrder: + pass + + @abstractmethod + async def verify_callback(self, request_data: dict) -> PaymentCallback: + pass + + @abstractmethod + async def query_order(self, order_id: str) -> PaymentOrder: + pass + + @abstractmethod + async def refund(self, order_id: str, amount: float, reason: str = "") -> bool: + pass diff --git a/backend/app/services/payment/mock_gateway.py b/backend/app/services/payment/mock_gateway.py new file mode 100644 index 0000000..bbbac26 --- /dev/null +++ b/backend/app/services/payment/mock_gateway.py @@ -0,0 +1,42 @@ +import logging + +from app.services.payment.base import PaymentGateway, PaymentOrder, PaymentCallback + +logger = logging.getLogger(__name__) + + +class MockGateway(PaymentGateway): + async def create_order( + self, order_id: str, amount: float, description: str, user_id: str, plan: str + ) -> PaymentOrder: + logger.info(f"[MockPayment] 创建订单: order_id={order_id}, amount={amount}, plan={plan}") + return PaymentOrder( + order_id=order_id, + pay_url=f"mock://pay/{order_id}", + amount=amount, + status="pending", + ) + + async def verify_callback(self, request_data: dict) -> PaymentCallback: + order_id = request_data.get("order_id", request_data.get("out_trade_no", "")) + logger.info(f"[MockPayment] 验证回调: order_id={order_id}") + return PaymentCallback( + order_id=order_id, + payment_id=f"mock_pay_{order_id}", + amount=float(request_data.get("amount", 0)), + status="success", + raw_data=request_data, + ) + + async def query_order(self, order_id: str) -> PaymentOrder: + logger.info(f"[MockPayment] 查询订单: order_id={order_id}") + return PaymentOrder( + order_id=order_id, + pay_url=f"mock://pay/{order_id}", + amount=0, + status="completed", + ) + + async def refund(self, order_id: str, amount: float, reason: str = "") -> bool: + logger.info(f"[MockPayment] 退款: order_id={order_id}, amount={amount}, reason={reason}") + return True diff --git a/backend/app/services/payment/wechat_pay.py b/backend/app/services/payment/wechat_pay.py new file mode 100644 index 0000000..66400bb --- /dev/null +++ b/backend/app/services/payment/wechat_pay.py @@ -0,0 +1,107 @@ +import hashlib +import logging +import time +import uuid + +from app.config import settings +from app.services.payment.base import PaymentGateway, PaymentOrder, PaymentCallback +from app.services.payment.mock_gateway import MockGateway + +logger = logging.getLogger(__name__) + + +class WeChatPayGateway(PaymentGateway): + def __init__(self): + self.mch_id = settings.WECHAT_PAY_MCH_ID + self.api_key = settings.WECHAT_PAY_API_KEY + self.app_id = settings.WECHAT_PAY_APP_ID + self.cert_path = settings.WECHAT_PAY_CERT_PATH + self.notify_url = settings.WECHAT_PAY_NOTIFY_URL + + def _is_configured(self) -> bool: + return bool(self.mch_id and self.api_key and self.app_id) + + def _get_gateway(self) -> PaymentGateway: + if not self._is_configured(): + return MockGateway() + return self + + def _generate_sign(self, params: dict) -> str: + sorted_params = sorted(params.items(), key=lambda x: x[0]) + sign_str = "&".join(f"{k}={v}" for k, v in sorted_params if v) + f"&key={self.api_key}" + return hashlib.md5(sign_str.encode()).hexdigest().upper() + + async def create_order( + self, order_id: str, amount: float, description: str, user_id: str, plan: str + ) -> PaymentOrder: + if not self._is_configured(): + logger.info("[WeChatPay] 未配置商户信息,降级为Mock支付") + return await MockGateway().create_order(order_id, amount, description, user_id, plan) + + params = { + "appid": self.app_id, + "mch_id": self.mch_id, + "nonce_str": uuid.uuid4().hex[:32], + "body": description, + "out_trade_no": order_id, + "total_fee": str(int(amount * 100)), + "spbill_create_ip": "127.0.0.1", + "notify_url": self.notify_url, + "trade_type": "NATIVE", + } + params["sign"] = self._generate_sign(params) + + logger.info(f"[WeChatPay] 创建Native支付订单: order_id={order_id}, amount={amount}") + + return PaymentOrder( + order_id=order_id, + pay_url=f"weixin://wxpay/bizpayurl?pr={order_id}", + amount=amount, + status="pending", + ) + + async def verify_callback(self, request_data: dict) -> PaymentCallback: + if not self._is_configured(): + return await MockGateway().verify_callback(request_data) + + received_sign = request_data.get("sign", "") + params = {k: v for k, v in request_data.items() if k != "sign" and v} + expected_sign = self._generate_sign(params) + + if received_sign != expected_sign: + logger.warning(f"[WeChatPay] 回调签名验证失败: order_id={request_data.get('out_trade_no')}") + return PaymentCallback( + order_id=request_data.get("out_trade_no", ""), + payment_id=request_data.get("transaction_id", ""), + amount=float(request_data.get("total_fee", 0)) / 100, + status="failed", + raw_data=request_data, + ) + + result_code = request_data.get("result_code", "FAIL") + return PaymentCallback( + order_id=request_data.get("out_trade_no", ""), + payment_id=request_data.get("transaction_id", ""), + amount=float(request_data.get("total_fee", 0)) / 100, + status="success" if result_code == "SUCCESS" else "failed", + raw_data=request_data, + ) + + async def query_order(self, order_id: str) -> PaymentOrder: + if not self._is_configured(): + return await MockGateway().query_order(order_id) + + logger.info(f"[WeChatPay] 查询订单: order_id={order_id}") + return PaymentOrder( + order_id=order_id, + pay_url="", + amount=0, + status="pending", + ) + + async def refund(self, order_id: str, amount: float, reason: str = "") -> bool: + if not self._is_configured(): + return await MockGateway().refund(order_id, amount, reason) + + logger.info(f"[WeChatPay] 申请退款: order_id={order_id}, amount={amount}") + return True diff --git a/backend/app/services/schema/__init__.py b/backend/app/services/schema/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/app/services/schema/schema_advisor_service.py b/backend/app/services/schema/schema_advisor_service.py new file mode 100644 index 0000000..9a9d968 --- /dev/null +++ b/backend/app/services/schema/schema_advisor_service.py @@ -0,0 +1,349 @@ +import json +import logging +import uuid +from datetime import datetime, timezone + +from sqlalchemy import select, func +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.schema_suggestion import SchemaSuggestion +from app.services.llm import LLMFactory, LLMError +from app.agent_framework.prompts.schema_advisor import SCHEMA_ADVISOR_TEMPLATE +from app.utils.json_extractor import extract_json + +logger = logging.getLogger(__name__) + +SCHEMA_TEMPLATES = { + "Organization": { + "@context": "https://schema.org", + "@type": "Organization", + "name": "", + "description": "", + "url": "", + "logo": "", + "sameAs": [], + "contactPoint": { + "@type": "ContactPoint", + "contactType": "customer service", + "telephone": "", + }, + }, + "Product": { + "@context": "https://schema.org", + "@type": "Product", + "name": "", + "description": "", + "brand": {"@type": "Brand", "name": ""}, + "offers": { + "@type": "Offer", + "priceCurrency": "CNY", + "availability": "https://schema.org/InStock", + }, + }, + "FAQPage": { + "@context": "https://schema.org", + "@type": "FAQPage", + "mainEntity": [ + { + "@type": "Question", + "name": "", + "acceptedAnswer": { + "@type": "Answer", + "text": "", + }, + } + ], + }, + "Article": { + "@context": "https://schema.org", + "@type": "Article", + "headline": "", + "description": "", + "author": {"@type": "Organization", "name": ""}, + "datePublished": "", + "image": "", + }, + "LocalBusiness": { + "@context": "https://schema.org", + "@type": "LocalBusiness", + "name": "", + "address": { + "@type": "PostalAddress", + "streetAddress": "", + "addressLocality": "", + "addressRegion": "", + "postalCode": "", + "addressCountry": "CN", + }, + "geo": { + "@type": "GeoCoordinates", + "latitude": "", + "longitude": "", + }, + "telephone": "", + "openingHours": "", + }, +} + +DIMENSION_SCHEMA_MAP = { + "schema_marketing": ["Organization", "LocalBusiness"], + "entity_clarity": ["Organization", "Product"], + "citation_readiness": ["FAQPage", "Article"], + "brand_visibility": ["Organization", "Product"], + "local_seo": ["LocalBusiness"], +} + +PRIORITY_THRESHOLD = { + "high": 30.0, + "medium": 60.0, +} + +DIFFICULTY_MAP = { + "Organization": "easy", + "Product": "medium", + "FAQPage": "medium", + "Article": "easy", + "LocalBusiness": "hard", +} + + +class SchemaAdvisorService: + + async def generate_suggestions( + self, + db: AsyncSession, + brand_id: uuid.UUID, + diagnosis_data: dict, + brand_info: dict, + target_url: str | None = None, + focus_dimensions: list[str] | None = None, + ) -> list[SchemaSuggestion]: + missing_dimensions = self._identify_missing_dimensions(diagnosis_data, focus_dimensions) + matched = self.match_templates(missing_dimensions) + filled = await self.fill_template_with_llm(matched, brand_info) + + suggestions = [] + for item in filled: + validation = self.validate_json_ld(item.get("json_ld_filled") or {}) + suggestion = SchemaSuggestion( + brand_id=brand_id, + schema_type=item["schema_type"], + target_url=target_url, + json_ld_template=item["json_ld_template"], + json_ld_filled=item.get("json_ld_filled"), + priority=item["priority"], + status="pending", + diagnosis_dimensions=item.get("diagnosis_dimensions"), + implementation_difficulty=DIFFICULTY_MAP.get(item["schema_type"], "medium"), + estimated_impact=item.get("estimated_impact"), + validation_errors=None if validation["is_valid"] else {"errors": validation["errors"]}, + ) + db.add(suggestion) + suggestions.append(suggestion) + + await db.commit() + for s in suggestions: + await db.refresh(s) + return self.prioritize_suggestions(suggestions) + + def match_templates(self, missing_dimensions: list[dict]) -> list[dict]: + matched = [] + seen_types = set() + for dim in missing_dimensions: + schema_types = DIMENSION_SCHEMA_MAP.get(dim["dimension"], []) + for schema_type in schema_types: + if schema_type in seen_types: + continue + seen_types.add(schema_type) + template = SCHEMA_TEMPLATES.get(schema_type) + if template: + import copy + percentage = dim["percentage"] + if percentage < PRIORITY_THRESHOLD["high"]: + priority = "high" + elif percentage < PRIORITY_THRESHOLD["medium"]: + priority = "medium" + else: + priority = "low" + matched.append({ + "schema_type": schema_type, + "priority": priority, + "diagnosis_dimensions": { + "dimension": dim["dimension"], + "current_score": dim["current_score"], + "max_score": dim["max_score"], + "percentage": dim["percentage"], + }, + "json_ld_template": copy.deepcopy(template), + }) + return matched + + async def fill_template_with_llm(self, matched: list[dict], brand_info: dict) -> list[dict]: + provider = LLMFactory.get_default() + results = [] + for item in matched: + schema_type = item["schema_type"] + template = item["json_ld_template"] + try: + variables = { + "brand_name": brand_info.get("name", ""), + "brand_website": brand_info.get("website", ""), + "brand_industry": brand_info.get("industry", ""), + "schema_type": schema_type, + "diagnosis_data": json.dumps(item.get("diagnosis_dimensions", {}), ensure_ascii=False), + "existing_schemas": "无", + } + messages = SCHEMA_ADVISOR_TEMPLATE.render(variables) + response = await provider.chat( + messages, + temperature=0.3, + max_tokens=2048, + ) + filled = json.loads(extract_json(response.content)) + item["json_ld_filled"] = filled + item["estimated_impact"] = self._generate_impact_description( + schema_type, item.get("diagnosis_dimensions", {}).get("dimension", "") + ) + except (json.JSONDecodeError, LLMError, ValueError) as e: + logger.warning(f"LLM填充Schema {schema_type} 失败: {e}") + item["json_ld_filled"] = None + item["estimated_impact"] = self._generate_impact_description( + schema_type, item.get("diagnosis_dimensions", {}).get("dimension", "") + ) + results.append(item) + return results + + def validate_json_ld(self, json_ld: dict) -> dict: + errors = [] + warnings = [] + + if not json_ld: + return {"is_valid": False, "errors": ["JSON-LD为空"], "warnings": []} + + if "@context" not in json_ld: + errors.append("缺少@context字段") + + if "@type" not in json_ld: + errors.append("缺少@type字段") + + if "@context" in json_ld and json_ld["@context"] != "https://schema.org": + warnings.append(f"@context值非标准: {json_ld.get('@context')}") + + if "@type" in json_ld and json_ld["@type"] not in SCHEMA_TEMPLATES: + warnings.append(f"@type非推荐类型: {json_ld.get('@type')}") + + try: + json.dumps(json_ld) + except (json.JSONDecodeError, TypeError) as e: + errors.append(f"JSON序列化失败: {e}") + + return { + "is_valid": len(errors) == 0, + "errors": errors, + "warnings": warnings, + } + + def prioritize_suggestions(self, suggestions: list[SchemaSuggestion]) -> list[SchemaSuggestion]: + priority_order = {"high": 0, "medium": 1, "low": 2} + return sorted(suggestions, key=lambda x: priority_order.get(x.priority, 1)) + + async def get_suggestions( + self, + db: AsyncSession, + brand_id: uuid.UUID, + status_filter: str | None = None, + schema_type: str | None = None, + skip: int = 0, + limit: int = 20, + ) -> tuple[list[SchemaSuggestion], int]: + conditions = [SchemaSuggestion.brand_id == brand_id] + if status_filter: + conditions.append(SchemaSuggestion.status == status_filter) + if schema_type: + conditions.append(SchemaSuggestion.schema_type == schema_type) + + count_stmt = select(func.count()).select_from(SchemaSuggestion).where(*conditions) + count_result = await db.execute(count_stmt) + total = count_result.scalar_one() + + stmt = ( + select(SchemaSuggestion) + .where(*conditions) + .order_by(SchemaSuggestion.created_at.desc()) + .offset(skip) + .limit(limit) + ) + result = await db.execute(stmt) + suggestions = list(result.scalars().all()) + return self.prioritize_suggestions(suggestions), total + + async def get_suggestion_by_id( + self, + db: AsyncSession, + suggestion_id: uuid.UUID, + ) -> SchemaSuggestion | None: + stmt = select(SchemaSuggestion).where(SchemaSuggestion.id == suggestion_id) + result = await db.execute(stmt) + return result.scalar_one_or_none() + + async def update_status( + self, + db: AsyncSession, + suggestion_id: uuid.UUID, + new_status: str, + ) -> SchemaSuggestion | None: + stmt = select(SchemaSuggestion).where(SchemaSuggestion.id == suggestion_id) + result = await db.execute(stmt) + suggestion = result.scalar_one_or_none() + if not suggestion: + return None + suggestion.status = new_status + await db.commit() + await db.refresh(suggestion) + return suggestion + + def _identify_missing_dimensions( + self, + diagnosis_data: dict, + focus_dimensions: list[str] | None = None, + ) -> list[dict]: + dimensions = [] + dimension_scores = diagnosis_data.get("dimensions", {}) + for dim_name, dim_info in dimension_scores.items(): + if dim_name not in DIMENSION_SCHEMA_MAP: + continue + if focus_dimensions and dim_name not in focus_dimensions: + continue + score = dim_info.get("score", 0) if isinstance(dim_info, dict) else dim_info + max_score = dim_info.get("max_score", 100) if isinstance(dim_info, dict) else 100 + percentage = (score / max_score * 100) if max_score > 0 else 0 + if percentage < 80: + dimensions.append({ + "dimension": dim_name, + "current_score": round(score, 2), + "max_score": max_score, + "percentage": round(percentage, 2), + }) + if not dimensions and diagnosis_data: + overall = diagnosis_data.get("overall_score", 0) + if overall < 80: + for dim_name in DIMENSION_SCHEMA_MAP: + if focus_dimensions and dim_name not in focus_dimensions: + continue + dimensions.append({ + "dimension": dim_name, + "current_score": 0, + "max_score": 100, + "percentage": 0, + }) + return dimensions + + def _generate_impact_description(self, schema_type: str, dimension: str) -> str: + impacts = { + "Organization": "增强品牌实体识别,提升AI搜索引擎对品牌的理解和引用概率", + "Product": "提升产品在搜索结果中的富摘要展示,增加点击率和引用率", + "FAQPage": "增加FAQ富摘要展示机会,提升在AI回答中的直接引用概率", + "Article": "优化文章内容的结构化表达,提升AI搜索引擎的内容理解和引用", + "LocalBusiness": "增强本地搜索可见性,提升地理位置相关查询的引用率", + } + return impacts.get(schema_type, f"提升{dimension}维度的得分和AI引用率") diff --git a/backend/app/services/scoring/__init__.py b/backend/app/services/scoring/__init__.py new file mode 100644 index 0000000..4758ff7 --- /dev/null +++ b/backend/app/services/scoring/__init__.py @@ -0,0 +1,45 @@ +from .scoring_service import ( + ScoringService, + ScoringResultV2, + DimensionScore, + calculate_mention_rate_score, + calculate_sov_score, + calculate_quality_score, + calculate_overall_score, + calculate_mention_rate_v2, + calculate_recommendation_rank_v2, + calculate_sentiment_score_v2, + calculate_citation_quality_v2, + calculate_competitive_position_v2, + calculate_v2_score, + get_health_level, + get_health_level_label, +) +from .brand_scoring_data_service import ( + BrandScoringDataService, + BrandScoringResult, + get_brand_scoring_data_service, + REQUIRED_PLATFORMS, +) + +__all__ = [ + "ScoringService", + "ScoringResultV2", + "DimensionScore", + "calculate_mention_rate_score", + "calculate_sov_score", + "calculate_quality_score", + "calculate_overall_score", + "calculate_mention_rate_v2", + "calculate_recommendation_rank_v2", + "calculate_sentiment_score_v2", + "calculate_citation_quality_v2", + "calculate_competitive_position_v2", + "calculate_v2_score", + "get_health_level", + "get_health_level_label", + "BrandScoringDataService", + "BrandScoringResult", + "get_brand_scoring_data_service", + "REQUIRED_PLATFORMS", +] diff --git a/backend/app/services/scoring/brand_scoring_data_service.py b/backend/app/services/scoring/brand_scoring_data_service.py new file mode 100644 index 0000000..6619c0e --- /dev/null +++ b/backend/app/services/scoring/brand_scoring_data_service.py @@ -0,0 +1,311 @@ +import uuid +from dataclasses import dataclass, field +from datetime import datetime, timedelta + +from sqlalchemy import select, func, Integer +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.brand import Brand +from app.models.competitor import Competitor +from app.models.query import Query as QueryModel +from app.models.citation_record import CitationRecord +from app.schemas.scoring import CitationResult +from app.services.scoring.scoring_service import ScoringService, ScoringResultV2 +from app.services.analysis.sentiment_service import get_sentiment_service + +REQUIRED_PLATFORMS = [ + "wenxin", + "kimi", + "tongyi", + "doubao", + "xinghuo", + "tiangong", + "qingyan", +] + + +@dataclass +class BrandScoringResult: + v2_result: ScoringResultV2 + competitor_data: dict = field(default_factory=dict) + sentiment_counts: dict = field(default_factory=lambda: {"positive": 0, "neutral": 0, "negative": 0}) + platform_scores: dict = field(default_factory=dict) + total_queries: int = 0 + mentioned_count: int = 0 + all_citations: list = field(default_factory=list) + brand_citations: list = field(default_factory=list) + change_from_yesterday: float = 0.0 + + +class BrandScoringDataService: + + async def get_brand_scoring_data( + self, + db: AsyncSession, + user_id: uuid.UUID, + brand: Brand, + ) -> BrandScoringResult: + queries_stmt = select(QueryModel).where( + QueryModel.user_id == user_id, + QueryModel.target_brand == brand.name, + ) + queries_result = await db.execute(queries_stmt) + queries = list(queries_result.scalars().all()) + + if not queries: + scoring_service = ScoringService() + empty_result = scoring_service.calculate_v2( + mentioned_count=0, + total_queries=0, + positions=[], + sentiment_counts={"positive": 0, "neutral": 0, "negative": 0}, + citations=[], + brand_mentions=0, + competitor_mentions={}, + ) + return BrandScoringResult( + v2_result=empty_result, + competitor_data={ + "brand_mentions": 0, + "competitor_mentions": {}, + "ahead_count": 0, + "behind_count": 0, + }, + sentiment_counts={"positive": 0, "neutral": 0, "negative": 0}, + platform_scores={platform: 0.0 for platform in REQUIRED_PLATFORMS}, + total_queries=0, + mentioned_count=0, + all_citations=[], + brand_citations=[], + change_from_yesterday=0.0, + ) + + query_ids = [q.id for q in queries] + + citations_stmt = select(CitationRecord).where( + CitationRecord.query_id.in_(query_ids), + ) + citations_result = await db.execute(citations_stmt) + all_citations = list(citations_result.scalars().all()) + + total_queries = len(all_citations) + brand_citations = [c for c in all_citations if c.cited] + + competitor_stmt = select(Competitor).where(Competitor.brand_id == brand.id) + competitor_result = await db.execute(competitor_stmt) + competitors = list(competitor_result.scalars().all()) + competitor_names = [c.name for c in competitors] + + competitor_mentions: dict[str, int] = {} + for comp_name in competitor_names: + count = sum( + 1 for c in all_citations + if c.cited and c.competitor_brands + and comp_name in c.competitor_brands + ) + if count > 0: + competitor_mentions[comp_name] = count + + sentiment_service = get_sentiment_service() + sentiment_counts = {"positive": 0, "neutral": 0, "negative": 0} + for citation in brand_citations: + if citation.sentiment and citation.sentiment in ("positive", "neutral", "negative"): + sentiment_counts[citation.sentiment] += 1 + else: + content = citation.raw_response or citation.citation_text or "" + if content.strip(): + try: + result = await sentiment_service.analyze( + brand_name=brand.name, + content=content, + ) + sentiment_counts[result.sentiment] += 1 + except Exception: + sentiment_counts["neutral"] += 1 + else: + sentiment_counts["neutral"] += 1 + + citation_results = [ + CitationResult( + cited=c.cited, + position=c.citation_position, + citation_text=c.citation_text, + sentiment="neutral", + confidence=c.confidence or 0.0, + ) + for c in brand_citations + ] + + positions = [c.citation_position for c in brand_citations if c.cited] + + scoring_service = ScoringService() + v2_result = scoring_service.calculate_v2( + mentioned_count=len(brand_citations), + total_queries=total_queries, + positions=positions, + sentiment_counts=sentiment_counts, + citations=citation_results, + brand_mentions=len(brand_citations), + competitor_mentions=competitor_mentions, + ) + + platform_scores = await self.get_platform_scores(db, user_id, brand.id) + + competitor_data = { + "brand_mentions": len(brand_citations), + "competitor_mentions": competitor_mentions, + "ahead_count": sum(1 for count in competitor_mentions.values() if len(brand_citations) > count), + "behind_count": sum(1 for count in competitor_mentions.values() if len(brand_citations) <= count), + } + + today = datetime.now().date() + yesterday = today - timedelta(days=1) + + today_citations = [ + c for c in all_citations + if c.queried_at.date() == today + ] + yesterday_citations = [ + c for c in all_citations + if c.queried_at.date() == yesterday + ] + + today_cited = sum(1 for c in today_citations if c.cited) + today_total = len(today_citations) + today_score = (today_cited / today_total * 100) if today_total > 0 else 0.0 + + yesterday_cited = sum(1 for c in yesterday_citations if c.cited) + yesterday_total = len(yesterday_citations) + yesterday_score = (yesterday_cited / yesterday_total * 100) if yesterday_total > 0 else 0.0 + + change_from_yesterday = round(today_score - yesterday_score, 2) + + return BrandScoringResult( + v2_result=v2_result, + competitor_data=competitor_data, + sentiment_counts=sentiment_counts, + platform_scores=platform_scores, + total_queries=total_queries, + mentioned_count=len(brand_citations), + all_citations=all_citations, + brand_citations=brand_citations, + change_from_yesterday=change_from_yesterday, + ) + + async def get_platform_scores( + self, + db: AsyncSession, + user_id: uuid.UUID, + brand_id: uuid.UUID, + ) -> dict[str, float]: + brand_stmt = select(Brand).where( + Brand.id == brand_id, + Brand.user_id == user_id, + ) + brand_result = await db.execute(brand_stmt) + brand = brand_result.scalar_one_or_none() + + if not brand: + return {platform: 0.0 for platform in REQUIRED_PLATFORMS} + + queries_stmt = select(QueryModel).where( + QueryModel.user_id == user_id, + QueryModel.target_brand == brand.name, + ) + queries_result = await db.execute(queries_stmt) + queries = list(queries_result.scalars().all()) + + if not queries: + return {platform: 0.0 for platform in REQUIRED_PLATFORMS} + + query_ids = [q.id for q in queries] + + citation_stmt = select( + CitationRecord.platform, + func.count().label("total"), + func.sum( + func.cast( + func.case((CitationRecord.cited == True, 1), else_=0), + Integer, + ) + ).label("cited"), + ).where( + CitationRecord.query_id.in_(query_ids), + ).group_by(CitationRecord.platform) + + result = await db.execute(citation_stmt) + rows = result.all() + + platform_data = {row.platform: (row.total or 0, row.cited or 0) for row in rows} + + platform_scores = {} + for platform in REQUIRED_PLATFORMS: + total, cited = platform_data.get(platform, (0, 0)) + if total > 0: + platform_scores[platform] = round((cited / total) * 100, 2) + else: + platform_scores[platform] = 0.0 + + return platform_scores + + async def get_competitor_platform_scores( + self, + db: AsyncSession, + user_id: uuid.UUID, + brand_id: uuid.UUID, + ) -> dict[str, float]: + competitor_stmt = select(Competitor).where(Competitor.brand_id == brand_id) + competitor_result = await db.execute(competitor_stmt) + competitors = list(competitor_result.scalars().all()) + + if not competitors: + return {platform: 0.0 for platform in REQUIRED_PLATFORMS} + + competitor_names = [c.name for c in competitors] + + competitor_queries_stmt = select(QueryModel).where( + QueryModel.user_id == user_id, + QueryModel.target_brand.in_(competitor_names), + ) + competitor_queries_result = await db.execute(competitor_queries_stmt) + competitor_queries = list(competitor_queries_result.scalars().all()) + + if not competitor_queries: + return {platform: 0.0 for platform in REQUIRED_PLATFORMS} + + competitor_query_ids = [q.id for q in competitor_queries] + + citation_stmt = select( + CitationRecord.platform, + func.count().label("total"), + func.sum( + func.cast( + func.case((CitationRecord.cited == True, 1), else_=0), + Integer, + ) + ).label("cited"), + ).where( + CitationRecord.query_id.in_(competitor_query_ids), + ).group_by(CitationRecord.platform) + + result = await db.execute(citation_stmt) + rows = result.all() + + platform_data = {row.platform: (row.total or 0, row.cited or 0) for row in rows} + + platform_scores = {} + for platform in REQUIRED_PLATFORMS: + total, cited = platform_data.get(platform, (0, 0)) + if total > 0: + platform_scores[platform] = round((cited / total) * 100, 2) + else: + platform_scores[platform] = 0.0 + + return platform_scores + + +_brand_scoring_data_service = BrandScoringDataService() + + +def get_brand_scoring_data_service() -> BrandScoringDataService: + return _brand_scoring_data_service diff --git a/backend/app/services/scoring_service.py b/backend/app/services/scoring/scoring_service.py similarity index 96% rename from backend/app/services/scoring_service.py rename to backend/app/services/scoring/scoring_service.py index 0b018f1..4ceb5a7 100644 --- a/backend/app/services/scoring_service.py +++ b/backend/app/services/scoring/scoring_service.py @@ -15,6 +15,7 @@ import math from dataclasses import dataclass, field from app.schemas.scoring import CitationResult +from app.utils.health import get_health_level, get_health_level_label # noqa: F401 — re-exported for backward compatibility logger = logging.getLogger(__name__) @@ -567,39 +568,6 @@ def calculate_v2_score( ) -# ============================================================ -# 健康等级工具函数 -# ============================================================ - -def get_health_level(score: float) -> str: - """ - 根据评分获取健康等级 - - 80+ -> excellent (优秀/绿) - 60-79 -> good (良好/黄) - 40-59 -> pass (及格/橙) - <40 -> danger (危险/红) - """ - if score >= 80: - return "excellent" - if score >= 60: - return "good" - if score >= 40: - return "pass" - return "danger" - - -def get_health_level_label(level: str) -> str: - """获取健康等级中文标签""" - labels = { - "excellent": "优秀", - "good": "良好", - "pass": "及格", - "danger": "危险", - } - return labels.get(level, "未知") - - # ============================================================ # ScoringService (兼容V1接口 + V2新接口) # ============================================================ diff --git a/backend/app/services/strategy/geo_plan_generator.py b/backend/app/services/strategy/geo_plan_generator.py new file mode 100644 index 0000000..a1b3cf8 --- /dev/null +++ b/backend/app/services/strategy/geo_plan_generator.py @@ -0,0 +1,570 @@ +from __future__ import annotations + +import asyncio +import json +import logging +from dataclasses import dataclass, field +from typing import Any + +from app.config import settings +from app.services.scoring.scoring_service import ScoringResultV2 +from app.utils.json_extractor import extract_json + +logger = logging.getLogger(__name__) + + +@dataclass +class GeoPlanActionItem: + action_type: str + title: str + description: str + reason: str + priority: str + target_keyword: str | None = None + target_platform: str | None = None + content_style: str | None = None + estimated_impact: str | None = None + difficulty: str = "medium" + execution_params: dict[str, Any] | None = None + + +@dataclass +class GeoPlanData: + title: str + estimated_weeks: int + actions: list[GeoPlanActionItem] = field(default_factory=list) + weekly_plan: list[dict[str, Any]] = field(default_factory=list) + + +def _get_weakest_dimensions( + mention_rate_pct: float, + rank_pct: float, + sentiment_pct: float, + citation_pct: float, + competitive_pct: float, +) -> list[tuple[str, float]]: + dimensions = [ + ("提及率", mention_rate_pct), + ("推荐排名", rank_pct), + ("情感倾向", sentiment_pct), + ("引用质量", citation_pct), + ("竞品对比", competitive_pct), + ] + return sorted(dimensions, key=lambda x: x[1]) + + +def _generate_rule_based_plan( + brand_name: str, + overall_score: float, + target_score: int, + mention_rate_pct: float, + rank_pct: float, + sentiment_pct: float, + citation_pct: float, + competitive_pct: float, + total_queries: int, + platform_scores: dict[str, float], + competitor_data: dict[str, Any], +) -> GeoPlanData: + actions: list[GeoPlanActionItem] = [] + score_gap = target_score - overall_score + estimated_weeks = min(12, max(4, int(score_gap / 5) + 4)) + + if mention_rate_pct < 50: + actions.append(GeoPlanActionItem( + action_type="content_creation", + title=f"提升{brand_name}在AI平台的提及率", + description=( + f"当前提及率仅{mention_rate_pct:.0f}%,品牌在AI回答中被提及的频率较低。" + f"需要创建高质量内容提高品牌在AI搜索结果中的出现频率。" + ), + reason=f"提及率得分率{mention_rate_pct:.0f}%,低于50%阈值,是最需要优先改善的维度", + priority="high", + target_keyword=f"{brand_name}+行业关键词", + target_platform="知乎", + content_style="专业严谨", + estimated_impact="预计可将提及率提升15-25个百分点", + difficulty="medium", + execution_params={ + "keyword": f"{brand_name} 行业解决方案", + "platform": "知乎", + "style": "专业严谨", + "word_count": 2000, + }, + )) + + if citation_pct < 40: + actions.append(GeoPlanActionItem( + action_type="content_creation", + title=f"提升{brand_name}引用内容质量", + description=( + f"当前引用质量得分率仅{citation_pct:.0f}%,AI对品牌的引用多为浅层提及。" + f"需要创建深度评测和对比内容,增加被深度引用的概率。" + ), + reason=f"引用质量得分率{citation_pct:.0f}%,低于40%阈值,引用缺乏深度正面描述", + priority="high", + target_keyword=f"{brand_name}评测/对比", + target_platform="通用", + content_style="专业严谨", + estimated_impact="预计可将引用质量得分率提升15-25个百分点", + difficulty="medium", + execution_params={ + "keyword": f"{brand_name} 深度评测对比", + "platform": "通用", + "style": "专业严谨", + "word_count": 3000, + }, + )) + + if rank_pct < 40: + actions.append(GeoPlanActionItem( + action_type="content_optimization", + title=f"提升{brand_name}在AI推荐中的排名", + description=( + f"当前推荐排名得分率仅{rank_pct:.0f}%,品牌在AI推荐列表中排名靠后。" + f"需要优化现有内容,增加品牌在推荐场景中的出现概率。" + ), + reason=f"推荐排名得分率{rank_pct:.0f}%,低于40%阈值,排名靠后用户看到概率低", + priority="high", + target_keyword=f"最佳{brand_name}推荐", + target_platform="通用", + content_style="专业严谨", + estimated_impact="预计可将推荐排名提升2-3位", + difficulty="medium", + execution_params={ + "keyword": f"最佳{brand_name}推荐", + "platform": "通用", + "style": "专业严谨", + "word_count": 2000, + }, + )) + + if competitive_pct < 40: + ahead_competitors = [] + if competitor_data: + brand_mentions = competitor_data.get("brand_mentions", 0) + for name, count in competitor_data.get("competitor_mentions", {}).items(): + if count > brand_mentions: + ahead_competitors.append(name) + ahead_str = "、".join(ahead_competitors[:3]) if ahead_competitors else "竞品" + + actions.append(GeoPlanActionItem( + action_type="content_creation", + title=f"缩小与{ahead_str}的差距", + description=( + f"当前竞品对比得分率仅{competitive_pct:.0f}%," + f"品牌在AI引用中落后于主要竞品。需要创建对比内容突出品牌优势。" + ), + reason=f"竞品对比得分率{competitive_pct:.0f}%,低于40%阈值,品牌落后于主要竞品", + priority="high", + target_keyword=f"{brand_name} vs {ahead_competitors[0] if ahead_competitors else '竞品'}", + target_platform="知乎", + content_style="专业严谨", + estimated_impact="预计3-6个月内可将竞品对比得分率提升15-25个百分点", + difficulty="hard", + execution_params={ + "keyword": f"{brand_name} vs {ahead_competitors[0] if ahead_competitors else '竞品'} 对比评测", + "platform": "知乎", + "style": "专业严谨", + "word_count": 2500, + }, + )) + + if sentiment_pct < 40: + actions.append(GeoPlanActionItem( + action_type="content_optimization", + title=f"改善AI平台对{brand_name}的情感倾向", + description=( + f"当前情感倾向得分率仅{sentiment_pct:.0f}%,AI在引用品牌时倾向使用负面或中性表述。" + f"需要优化内容以增加正面引用比例。" + ), + reason=f"情感倾向得分率{sentiment_pct:.0f}%,低于40%阈值,负面引用影响品牌形象", + priority="medium", + target_keyword=f"{brand_name}优势/正面评价", + target_platform="通用", + content_style="轻松活泼", + estimated_impact="减少负面引用比例10-20个百分点", + difficulty="medium", + execution_params={ + "keyword": f"{brand_name} 优势 正面评价", + "platform": "通用", + "style": "轻松活泼", + "word_count": 1500, + }, + )) + + if total_queries < 10: + suggested_queries = [ + f"{brand_name}怎么样", + f"{brand_name}推荐", + f"最佳{brand_name}", + f"{brand_name}评测", + f"{brand_name}对比", + ] + actions.append(GeoPlanActionItem( + action_type="query_expansion", + title="扩展查询词覆盖范围", + description=( + f"当前仅有{total_queries}个查询词,覆盖范围不足," + f"无法全面反映品牌在AI搜索中的表现。" + ), + reason=f"查询词数量仅{total_queries}个,低于10个阈值,分析结果不够全面", + priority="high" if total_queries < 3 else "medium", + target_keyword=None, + target_platform=None, + content_style=None, + estimated_impact="更多查询词可提升评分准确度,发现更多优化机会", + difficulty="easy", + execution_params={ + "suggested_queries": suggested_queries, + }, + )) + + schema_score = 0 + if schema_score == 0: + actions.append(GeoPlanActionItem( + action_type="schema_optimization", + title="添加FAQ结构化数据", + description=( + "当前网站缺少结构化数据(Schema),AI搜索引擎无法有效提取品牌信息。" + "添加FAQ Schema可以显著提升品牌在AI回答中的引用概率。" + ), + reason="网站Schema标记缺失,AI搜索引擎无法高效提取品牌关键信息", + priority="medium", + target_keyword=None, + target_platform=None, + content_style=None, + estimated_impact="添加Schema后预计可提升引用率10-15个百分点", + difficulty="easy", + execution_params={ + "optimization_type": "add_faq_schema", + }, + )) + + weak_platforms = sorted( + platform_scores.items(), + key=lambda x: x[1], + )[:3] + weak_platform_names = [p[0] for p in weak_platforms if p[1] < 40] + if weak_platform_names: + actions.append(GeoPlanActionItem( + action_type="platform_targeting", + title=f"重点优化{'、'.join(weak_platform_names)}平台表现", + description=( + f"在这些平台上品牌评分低于40分,AI引用率极低。" + f"需要针对性优化各平台的内容策略。" + ), + reason=f"平台{'、'.join(weak_platform_names)}评分低于40分,AI引用率极低", + priority="high", + target_keyword=None, + target_platform=weak_platform_names[0], + content_style=None, + estimated_impact="预计可将弱平台评分提升20-30分", + difficulty="hard", + execution_params={ + "target_platforms": weak_platform_names, + }, + )) + + priority_order = {"high": 0, "medium": 1, "low": 2} + actions.sort(key=lambda a: priority_order.get(a.priority, 1)) + actions = actions[:8] + + weekly_plan = _generate_weekly_plan(actions, estimated_weeks) + + title = f"{brand_name} GEO优化方案 - 从{overall_score:.0f}分提升至{target_score}分" + + return GeoPlanData( + title=title, + estimated_weeks=estimated_weeks, + actions=actions, + weekly_plan=weekly_plan, + ) + + +def _generate_weekly_plan( + actions: list[GeoPlanActionItem], + estimated_weeks: int, +) -> list[dict[str, Any]]: + weekly_plan: list[dict[str, Any]] = [] + high_actions = [i for i, a in enumerate(actions) if a.priority == "high"] + medium_actions = [i for i, a in enumerate(actions) if a.priority == "medium"] + low_actions = [i for i, a in enumerate(actions) if a.priority == "low"] + + weeks_per_high = max(1, (estimated_weeks // 2) // max(len(high_actions), 1)) if high_actions else 0 + week_idx = 0 + for i, action_idx in enumerate(high_actions): + week_num = week_idx + 1 + impact_str = actions[action_idx].estimated_impact or "" + try: + num_part = impact_str.split("-")[-1].replace("个百分点", "").strip() + expected_val = max(3, int(num_part)) if num_part.isdigit() else 5 + except (ValueError, IndexError): + expected_val = 5 + expected = f"+{expected_val}" + weekly_plan.append({ + "week": week_num, + "action_indices": [action_idx], + "expected_score_change": expected, + }) + week_idx += weeks_per_high + + remaining_weeks = estimated_weeks - week_idx + medium_per_week = max(1, len(medium_actions) // max(remaining_weeks, 1)) if medium_actions and remaining_weeks > 0 else len(medium_actions) + batch: list[int] = [] + for i, action_idx in enumerate(medium_actions): + batch.append(action_idx) + if len(batch) >= medium_per_week or i == len(medium_actions) - 1: + week_idx += 1 + weekly_plan.append({ + "week": week_idx, + "action_indices": batch[:], + "expected_score_change": "+3", + }) + batch = [] + + for action_idx in low_actions: + week_idx += 1 + weekly_plan.append({ + "week": week_idx, + "action_indices": [action_idx], + "expected_score_change": "+2", + }) + + return weekly_plan + + +GEO_PLAN_PROMPT = """你是一个GEO(生成式引擎优化)策略专家。基于以下品牌诊断数据,制定一个8周GEO优化方案。 + +品牌: {brand_name} +当前评分: {overall_score}/100 +目标评分: {target_score}/100 + +评分维度: +- 提及率: {mention_rate_percentage}% +- 推荐排名: {rank_percentage}% +- 情感倾向: {sentiment_percentage}% +- 引用质量: {citation_percentage}% +- 竞品对比: {competitive_percentage}% + +竞品数据: {competitor_data} +平台评分: {platform_scores} + +请返回JSON格式: +{{ + "title": "方案标题", + "estimated_weeks": 8, + "actions": [ + {{ + "action_type": "content_creation", + "title": "行动标题", + "description": "详细描述", + "reason": "基于诊断数据的原因", + "priority": "high", + "target_keyword": "推荐关键词", + "target_platform": "推荐平台", + "content_style": "推荐风格", + "estimated_impact": "预期效果", + "difficulty": "medium", + "execution_params": {{ + "keyword": "关键词", + "platform": "平台", + "style": "风格", + "word_count": 2000 + }} + }} + ], + "weekly_plan": [ + {{ + "week": 1, + "action_indices": [0, 1], + "expected_score_change": "+5" + }} + ] +}} + +要求: +1. 行动项必须基于诊断数据,优先解决最弱维度 +2. 每个行动项必须有 execution_params,可直接传给内容生成API +3. 行动项按优先级排序 +4. 生成5-8个行动项 +5. 周计划要合理分配任务""" + + +async def _generate_llm_plan( + brand_name: str, + overall_score: float, + target_score: int, + mention_rate_pct: float, + rank_pct: float, + sentiment_pct: float, + citation_pct: float, + competitive_pct: float, + total_queries: int, + platform_scores: dict[str, float], + competitor_data: dict[str, Any], +) -> GeoPlanData: + if not settings.ENABLE_LLM or not settings.DEEPSEEK_API_KEY: + logger.info("LLM未启用或API Key未配置,使用规则生成方案") + return _generate_rule_based_plan( + brand_name=brand_name, + overall_score=overall_score, + target_score=target_score, + mention_rate_pct=mention_rate_pct, + rank_pct=rank_pct, + sentiment_pct=sentiment_pct, + citation_pct=citation_pct, + competitive_pct=competitive_pct, + total_queries=total_queries, + platform_scores=platform_scores, + competitor_data=competitor_data, + ) + + try: + prompt = GEO_PLAN_PROMPT.format( + brand_name=brand_name, + overall_score=round(overall_score, 1), + target_score=target_score, + mention_rate_percentage=round(mention_rate_pct, 1), + rank_percentage=round(rank_pct, 1), + sentiment_percentage=round(sentiment_pct, 1), + citation_percentage=round(citation_pct, 1), + competitive_percentage=round(competitive_pct, 1), + competitor_data=json.dumps(competitor_data, ensure_ascii=False, indent=2), + platform_scores=json.dumps(platform_scores, ensure_ascii=False, indent=2), + ) + + from openai import OpenAI + + client = OpenAI( + api_key=settings.DEEPSEEK_API_KEY, + base_url="https://api.deepseek.com", + ) + + response = await asyncio.to_thread( + client.chat.completions.create, + model="deepseek-chat", + messages=[{"role": "user", "content": prompt}], + temperature=0.3, + max_tokens=3000, + ) + + content = response.choices[0].message.content + if not content: + raise ValueError("LLM返回空响应") + + json_str = extract_json(content) + result = json.loads(json_str) + + valid_action_types = { + "content_creation", "content_optimization", + "query_expansion", "schema_optimization", "platform_targeting", + } + valid_priorities = {"high", "medium", "low"} + valid_difficulties = {"easy", "medium", "hard"} + + actions: list[GeoPlanActionItem] = [] + for item in result.get("actions", []): + action_type = item.get("action_type", "content_creation") + if action_type not in valid_action_types: + action_type = "content_creation" + + priority = item.get("priority", "medium") + if priority not in valid_priorities: + priority = "medium" + + difficulty = item.get("difficulty", "medium") + if difficulty not in valid_difficulties: + difficulty = "medium" + + actions.append(GeoPlanActionItem( + action_type=action_type, + title=item.get("title", "优化行动"), + description=item.get("description", ""), + reason=item.get("reason", ""), + priority=priority, + target_keyword=item.get("target_keyword"), + target_platform=item.get("target_platform"), + content_style=item.get("content_style"), + estimated_impact=item.get("estimated_impact"), + difficulty=difficulty, + execution_params=item.get("execution_params"), + )) + + if not actions: + logger.warning("LLM未返回有效行动项,回退到规则生成") + return _generate_rule_based_plan( + brand_name=brand_name, + overall_score=overall_score, + target_score=target_score, + mention_rate_pct=mention_rate_pct, + rank_pct=rank_pct, + sentiment_pct=sentiment_pct, + citation_pct=citation_pct, + competitive_pct=competitive_pct, + total_queries=total_queries, + platform_scores=platform_scores, + competitor_data=competitor_data, + ) + + weekly_plan = result.get("weekly_plan", []) + + return GeoPlanData( + title=result.get("title", f"{brand_name} GEO优化方案"), + estimated_weeks=result.get("estimated_weeks", 8), + actions=actions[:8], + weekly_plan=weekly_plan, + ) + + except Exception as e: + logger.error(f"LLM生成方案失败: {e},回退到规则生成") + return _generate_rule_based_plan( + brand_name=brand_name, + overall_score=overall_score, + target_score=target_score, + mention_rate_pct=mention_rate_pct, + rank_pct=rank_pct, + sentiment_pct=sentiment_pct, + citation_pct=citation_pct, + competitive_pct=competitive_pct, + total_queries=total_queries, + platform_scores=platform_scores, + competitor_data=competitor_data, + ) + + +async def generate_geo_plan( + brand_name: str, + scoring_result: ScoringResultV2, + target_score: int, + total_queries: int = 0, + platform_scores: dict[str, float] | None = None, + competitor_data: dict[str, Any] | None = None, +) -> GeoPlanData: + if settings.ENABLE_LLM and settings.DEEPSEEK_API_KEY: + return await _generate_llm_plan( + brand_name=brand_name, + overall_score=scoring_result.overall_score, + target_score=target_score, + mention_rate_pct=scoring_result.mention_rate.percentage, + rank_pct=scoring_result.recommendation_rank.percentage, + sentiment_pct=scoring_result.sentiment_score.percentage, + citation_pct=scoring_result.citation_quality.percentage, + competitive_pct=scoring_result.competitive_position.percentage, + total_queries=total_queries, + platform_scores=platform_scores or {}, + competitor_data=competitor_data or {}, + ) + return _generate_rule_based_plan( + brand_name=brand_name, + overall_score=scoring_result.overall_score, + target_score=target_score, + mention_rate_pct=scoring_result.mention_rate.percentage, + rank_pct=scoring_result.recommendation_rank.percentage, + sentiment_pct=scoring_result.sentiment_score.percentage, + citation_pct=scoring_result.citation_quality.percentage, + competitive_pct=scoring_result.competitive_position.percentage, + total_queries=total_queries, + platform_scores=platform_scores or {}, + competitor_data=competitor_data or {}, + ) diff --git a/backend/app/services/subscription.py b/backend/app/services/subscription.py index e111f55..04c2beb 100644 --- a/backend/app/services/subscription.py +++ b/backend/app/services/subscription.py @@ -110,7 +110,7 @@ async def get_current_subscription( async def subscribe( - db: AsyncSession, user_id: uuid.UUID, plan: str + db: AsyncSession, user_id: uuid.UUID, plan: str, payment_provider: str = "mock" ) -> SubscriptionResponse: plan_data = PLANS.get(plan) if plan_data is None: @@ -119,6 +119,10 @@ async def subscribe( today = date.today() end_date = today + timedelta(days=30) + payment_method = payment_provider + if payment_provider == "mock": + payment_method = "模拟支付" + subscription = Subscription( user_id=user_id, plan=plan, @@ -126,7 +130,7 @@ async def subscribe( start_date=today, end_date=end_date, amount=plan_data["price"], - payment_method="模拟支付", + payment_method=payment_method, ) db.add(subscription) @@ -139,7 +143,7 @@ async def subscribe( await db.commit() await db.refresh(subscription) - logger.info(f"[模拟支付] 用户{user_id} 订阅{plan},金额{plan_data['price']}元") + logger.info(f"[支付] 用户{user_id} 订阅{plan},金额{plan_data['price']}元,支付方式{payment_method}") return SubscriptionResponse.model_validate(subscription) diff --git a/backend/app/services/trend/__init__.py b/backend/app/services/trend/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/app/services/trend/trend_analyzer_service.py b/backend/app/services/trend/trend_analyzer_service.py new file mode 100644 index 0000000..5b0d8ba --- /dev/null +++ b/backend/app/services/trend/trend_analyzer_service.py @@ -0,0 +1,506 @@ +import json +import logging +import uuid +from datetime import datetime, timedelta, timezone +from collections import defaultdict + +from sqlalchemy import func, select, and_ +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.citation_record import CitationRecord +from app.models.query import Query +from app.models.trend_insight import TrendInsight +from app.models.brand import Brand +from app.services.llm import LLMFactory, LLMError +from app.utils.json_extractor import extract_json + +logger = logging.getLogger(__name__) + + +class TrendAnalyzerService: + + def __init__(self, db: AsyncSession): + self.db = db + + async def analyze_trends( + self, + brand_id: uuid.UUID, + days: int = 30, + platforms: list[str] | None = None, + keywords: list[str] | None = None, + ) -> dict: + brand = await self._get_brand(brand_id) + if brand is None: + raise ValueError("品牌不存在") + + citations = await self.get_time_series_data(brand.name, days, platforms, keywords) + if len(citations) < 7: + now = datetime.now(timezone.utc) + return { + "status": "insufficient_data", + "message": f"数据不足7天(当前{len(citations)}天),无法进行趋势分析", + "brand_id": str(brand_id), + "days": days, + } + + aggregated = self._aggregate_time_series(citations, days) + change_rate = self.calculate_change_rate(aggregated) + classified_type = self.detect_trends(change_rate) + absolute_change = self._calculate_absolute_change(aggregated) + sentiment_trend = self._analyze_sentiment_trend(citations) + platform_comparison = await self.compare_platforms(citations) + cause_analysis = await self.infer_causes( + brand_name=brand.name, + keyword="", + trend_type=classified_type, + change_rate=change_rate, + data_points=aggregated, + ) + recommendations = self.generate_recommendations( + classified_type, change_rate, sentiment_trend + ) + confidence = self._calculate_confidence(len(citations), change_rate) + severity = self._determine_severity(classified_type, change_rate) + + now = datetime.now(timezone.utc) + period_start = now - timedelta(days=days) + + insight = TrendInsight( + brand_id=str(brand_id), + trend_type=classified_type, + keyword=brand.name, + platform="all", + period_start=period_start, + period_end=now, + data_points=aggregated, + change_rate=change_rate, + absolute_change=absolute_change, + sentiment_trend=sentiment_trend, + cause_analysis=cause_analysis, + recommendations=recommendations, + confidence=confidence, + severity=severity, + ) + self.db.add(insight) + await self.db.commit() + await self.db.refresh(insight) + + return { + "status": "success", + "insight_id": str(insight.id), + "trend_type": insight.trend_type, + "change_rate": insight.change_rate, + "absolute_change": insight.absolute_change, + "sentiment_trend": insight.sentiment_trend, + "cause_analysis": insight.cause_analysis, + "recommendations": insight.recommendations, + "confidence": insight.confidence, + "severity": insight.severity, + "data_points": insight.data_points, + } + + async def get_time_series_data( + self, + brand_name: str, + days: int, + platforms: list[str] | None = None, + keywords: list[str] | None = None, + ) -> list[dict]: + now = datetime.now(timezone.utc) + start_date = now - timedelta(days=days) + + conditions = [ + Query.target_brand == brand_name, + CitationRecord.queried_at >= start_date, + ] + if platforms: + conditions.append(CitationRecord.platform.in_(platforms)) + if keywords: + conditions.append(Query.keyword.in_(keywords)) + + stmt = ( + select( + CitationRecord.platform, + CitationRecord.cited, + CitationRecord.sentiment, + CitationRecord.queried_at, + Query.keyword, + ) + .join(Query, CitationRecord.query_id == Query.id) + .where(and_(*conditions)) + .order_by(CitationRecord.queried_at.asc()) + ) + result = await self.db.execute(stmt) + + citations = [] + for row in result.all(): + citations.append({ + "platform": row.platform, + "cited": row.cited, + "sentiment": row.sentiment or "neutral", + "queried_at": row.queried_at, + "keyword": row.keyword, + }) + + return citations + + def detect_trends(self, change_rate: float) -> str: + if change_rate > 20: + return "rising" + elif change_rate < -20: + return "declining" + else: + return "stable" + + def detect_hotspots( + self, + citations: list[dict], + days: int = 30, + ) -> list[dict]: + if not citations: + return [] + + daily_counts = defaultdict(lambda: defaultdict(int)) + for c in citations: + date_str = c.get("queried_at", datetime.now(timezone.utc)).strftime("%Y-%m-%d") + keyword = c.get("keyword", "") + daily_counts[keyword][date_str] += 1 + + hotspots = [] + for keyword, daily in daily_counts.items(): + counts = list(daily.values()) + if len(counts) < 7: + continue + recent_7 = counts[-7:] + mean_7 = sum(recent_7) / len(recent_7) + latest = recent_7[-1] + if mean_7 > 0 and latest > mean_7 * 2: + hotspots.append({ + "keyword": keyword, + "latest_count": latest, + "mean_7d": round(mean_7, 2), + "surge_ratio": round(latest / mean_7, 2), + "trend_type": "hotspot", + }) + + hotspots.sort(key=lambda x: x["surge_ratio"], reverse=True) + return hotspots + + async def compare_platforms(self, citations: list[dict]) -> dict: + platform_data = defaultdict(lambda: {"cited": 0, "total": 0, "positive": 0}) + for c in citations: + platform = c.get("platform", "unknown") + entry = platform_data[platform] + entry["total"] += 1 + if c.get("cited"): + entry["cited"] += 1 + if c.get("sentiment") == "positive": + entry["positive"] += 1 + + result = {} + for platform, data in platform_data.items(): + citation_rate = (data["cited"] / data["total"] * 100) if data["total"] > 0 else 0.0 + positive_rate = (data["positive"] / data["total"] * 100) if data["total"] > 0 else 0.0 + result[platform] = { + "citation_rate": round(citation_rate, 2), + "positive_rate": round(positive_rate, 2), + "total_mentions": data["total"], + } + + return result + + async def infer_causes( + self, + brand_name: str, + keyword: str, + trend_type: str, + change_rate: float, + data_points: list[dict], + ) -> str: + trend_desc = { + "rising": "上升", + "declining": "下降", + "stable": "平稳", + "hotspot": "热点", + "platform_shift": "平台迁移", + }.get(trend_type, trend_type) + + messages = [ + { + "role": "system", + "content": "你是一个品牌趋势分析专家,请根据提供的数据推断趋势变化的可能原因。返回一段简洁的分析文本。", + }, + { + "role": "user", + "content": ( + f"品牌: {brand_name}\n" + f"关键词: {keyword or brand_name}\n" + f"趋势类型: {trend_desc}\n" + f"变化率: {change_rate}%\n" + f"时间序列数据(最近5点): {json.dumps(data_points[-5:], ensure_ascii=False)}\n\n" + f"请推断导致该趋势变化的可能原因,返回简洁的分析文本。" + ), + }, + ] + + try: + provider = LLMFactory.get_default() + response = await provider.chat(messages, temperature=0.3, max_tokens=1024) + return response.content.strip() + except (LLMError, Exception) as e: + logger.warning(f"LLM推断原因失败: {e}") + return f"品牌{brand_name}的引用趋势{trend_desc},变化率{change_rate}%" + + def generate_recommendations( + self, + trend_type: str, + change_rate: float, + sentiment_trend: dict, + ) -> list[str]: + recommendations = [] + sentiment_direction = sentiment_trend.get("direction", "neutral") if sentiment_trend else "neutral" + + if trend_type == "rising": + recommendations.append("引用趋势上升,建议加大内容投入以巩固优势") + if sentiment_direction == "positive": + recommendations.append("情感倾向正向提升,可强化正面内容传播") + elif sentiment_direction == "negative": + recommendations.append("虽然引用量上升但情感转负,需关注负面评价并积极回应") + elif trend_type == "declining": + recommendations.append("引用趋势下降,建议审查内容策略并优化关键词覆盖") + recommendations.append("考虑增加高质量内容产出,提升AI引擎引用概率") + if sentiment_direction == "negative": + recommendations.append("情感倾向同步恶化,建议优先处理负面舆情") + elif trend_type == "stable": + recommendations.append("引用趋势平稳,建议探索新的关键词和内容方向以突破现状") + elif trend_type == "hotspot": + recommendations.append("检测到热点趋势,建议快速响应并产出相关内容以获取流量红利") + elif trend_type == "platform_shift": + recommendations.append("检测到平台迁移趋势,建议调整各平台内容策略以适应变化") + + if abs(change_rate) > 50: + recommendations.append("变化幅度较大,建议密切关注后续走势并准备应对方案") + + return recommendations + + def calculate_change_rate(self, data_points: list[dict]) -> float: + if len(data_points) < 2: + return 0.0 + + half = len(data_points) // 2 + previous = sum(p.get("citation_count", 0) for p in data_points[:half]) + current = sum(p.get("citation_count", 0) for p in data_points[half:]) + + if previous == 0: + return 100.0 if current > 0 else 0.0 + + return round(((current - previous) / previous) * 100, 2) + + async def get_insights( + self, + brand_id: uuid.UUID, + skip: int = 0, + limit: int = 20, + ) -> tuple[list[TrendInsight], int]: + conditions = [TrendInsight.brand_id == str(brand_id)] + stmt = ( + select(TrendInsight) + .where(and_(*conditions)) + .order_by(TrendInsight.created_at.desc()) + .offset(skip) + .limit(limit) + ) + result = await self.db.execute(stmt) + items = result.scalars().all() + + count_stmt = ( + select(func.count()) + .select_from(TrendInsight) + .where(and_(*conditions)) + ) + count_result = await self.db.execute(count_stmt) + total = count_result.scalar_one() + + return list(items), total + + async def get_insight_by_id(self, insight_id: uuid.UUID) -> TrendInsight | None: + stmt = select(TrendInsight).where(TrendInsight.id == insight_id) + result = await self.db.execute(stmt) + return result.scalar_one_or_none() + + async def get_summary( + self, + brand_id: uuid.UUID, + days: int = 30, + ) -> dict: + conditions = [TrendInsight.brand_id == str(brand_id)] + stmt = ( + select(TrendInsight.trend_type, func.count().label("count")) + .where(and_(*conditions)) + .group_by(TrendInsight.trend_type) + ) + result = await self.db.execute(stmt) + + summary = { + "brand_id": brand_id, + "period_days": days, + "rising_count": 0, + "declining_count": 0, + "hotspot_count": 0, + "top_keywords": [], + "platform_highlights": {}, + } + for row in result.all(): + trend_type = row.trend_type + count = row.count + if trend_type == "rising": + summary["rising_count"] = count + elif trend_type == "declining": + summary["declining_count"] = count + elif trend_type == "hotspot": + summary["hotspot_count"] = count + + keyword_stmt = ( + select(TrendInsight.keyword, func.count().label("count")) + .where(and_(*conditions), TrendInsight.keyword.isnot(None)) + .group_by(TrendInsight.keyword) + .order_by(func.count().desc()) + .limit(10) + ) + keyword_result = await self.db.execute(keyword_stmt) + summary["top_keywords"] = [row.keyword for row in keyword_result.all()] + + platform_stmt = ( + select(TrendInsight.platform, TrendInsight.trend_type, func.count().label("count")) + .where(and_(*conditions), TrendInsight.platform.isnot(None)) + .group_by(TrendInsight.platform, TrendInsight.trend_type) + ) + platform_result = await self.db.execute(platform_stmt) + for row in platform_result.all(): + platform = row.platform + if platform not in summary["platform_highlights"]: + summary["platform_highlights"][platform] = {} + summary["platform_highlights"][platform][row.trend_type] = row.count + + return summary + + async def get_hotspots( + self, + brand_id: uuid.UUID, + days: int = 30, + ) -> dict: + brand = await self._get_brand(brand_id) + if brand is None: + raise ValueError("品牌不存在") + + citations = await self.get_time_series_data(brand.name, days) + if len(citations) < 7: + return { + "status": "insufficient_data", + "message": f"数据不足7天(当前{len(citations)}天),无法进行热点分析", + "hotspots": [], + } + + hotspots = self.detect_hotspots(citations, days) + + return { + "status": "success", + "hotspots": hotspots, + } + + def _aggregate_time_series(self, citations: list[dict], days: int) -> list[dict]: + daily = defaultdict(lambda: {"citation_count": 0, "positive_count": 0}) + for c in citations: + date_str = c.get("queried_at", datetime.now(timezone.utc)).strftime("%Y-%m-%d") + entry = daily[date_str] + entry["citation_count"] += 1 + sentiment = c.get("sentiment", "neutral") + if sentiment == "positive": + entry["positive_count"] += 1 + + result = [] + for date_str in sorted(daily.keys()): + entry = daily[date_str] + positive_ratio = entry["positive_count"] / entry["citation_count"] if entry["citation_count"] > 0 else 0.0 + result.append({ + "date": date_str, + "citation_count": entry["citation_count"], + "positive_ratio": round(positive_ratio, 4), + }) + + return result + + def _calculate_absolute_change(self, data_points: list[dict]) -> int: + if len(data_points) < 2: + return 0 + + half = len(data_points) // 2 + previous = sum(p.get("citation_count", 0) for p in data_points[:half]) + current = sum(p.get("citation_count", 0) for p in data_points[half:]) + return current - previous + + def _analyze_sentiment_trend(self, citations: list[dict]) -> dict: + if not citations: + return {"direction": "neutral", "positive_ratio": 0.0, "negative_ratio": 0.0} + + half = len(citations) // 2 + first_half = citations[:half] + second_half = citations[half:] + + def sentiment_ratios(records: list[dict]) -> dict: + if not records: + return {"positive": 0.0, "negative": 0.0} + positive = sum(1 for r in records if r.get("sentiment") == "positive") + negative = sum(1 for r in records if r.get("sentiment") == "negative") + total = len(records) + return { + "positive": round(positive / total, 4), + "negative": round(negative / total, 4), + } + + prev = sentiment_ratios(first_half) + curr = sentiment_ratios(second_half) + diff = curr["positive"] - prev["positive"] + + if diff > 0.1: + direction = "positive" + elif diff < -0.1: + direction = "negative" + else: + direction = "neutral" + + return { + "direction": direction, + "previous_positive_ratio": prev["positive"], + "current_positive_ratio": curr["positive"], + "previous_negative_ratio": prev["negative"], + "current_negative_ratio": curr["negative"], + } + + def _calculate_confidence(self, data_days: int, change_rate: float) -> float: + base = 0.3 + if data_days >= 21: + base += 0.3 + elif data_days >= 14: + base += 0.2 + elif data_days >= 7: + base += 0.1 + + if abs(change_rate) > 30: + base += 0.2 + elif abs(change_rate) > 10: + base += 0.1 + + return round(min(base, 1.0), 2) + + def _determine_severity(self, trend_type: str, change_rate: float) -> str: + if trend_type == "hotspot" or abs(change_rate) > 50: + return "critical" + elif trend_type in ("rising", "declining") and abs(change_rate) > 20: + return "warning" + return "info" + + async def _get_brand(self, brand_id: uuid.UUID) -> Brand | None: + stmt = select(Brand).where(Brand.id == brand_id) + result = await self.db.execute(stmt) + return result.scalar_one_or_none() diff --git a/backend/app/services/usage_recorder.py b/backend/app/services/usage_recorder.py index 72f4b3e..81d38bd 100644 --- a/backend/app/services/usage_recorder.py +++ b/backend/app/services/usage_recorder.py @@ -1,5 +1,5 @@ from app.services.usage_tracker import UsageTracker -from app.services.smart_router import ENGINE_COST_PROFILES +from app.services.llm.smart_router import ENGINE_COST_PROFILES class UsageRecorder: diff --git a/backend/app/templates/geo_weekly_report.html b/backend/app/templates/geo_weekly_report.html new file mode 100644 index 0000000..0b80d1f --- /dev/null +++ b/backend/app/templates/geo_weekly_report.html @@ -0,0 +1,69 @@ + + + + + +GEO周报 + + + + +
+ + + + +
+

GEO 周度报告

+

{{user_name}},这是您本周的GEO变化概览

+
+ + + + + + +
+

当前评分

+

{{current_score}}

+
+

上周评分

+

{{previous_score}}

+
+ + +
+

▲ 评分变化: {{score_change}}

+
+ + + + + + + + +
+

↑ 提升维度

+

{{top_improved}}

+
+

↓ 下降维度

+

{{top_declined}}

+
+ + +
+

💡 优化建议

+

{{suggestions}}

+
+ + +
+查看详细报告 +
+
+

© {{year}} GEO平台 — AI搜索优化领导者

+
+
+ + diff --git a/backend/app/templates/renewal_reminder.html b/backend/app/templates/renewal_reminder.html new file mode 100644 index 0000000..dde56ec --- /dev/null +++ b/backend/app/templates/renewal_reminder.html @@ -0,0 +1,63 @@ + + + + + +续费提醒 + + + + +
+ + + + +
+

续费提醒

+

{{user_name}},您的订阅即将到期

+
+ + +
+

您的订阅将在

+

{{days_remaining}}天

+

后到期

+
+ + + + + + + + + + +
+

当前套餐

+

{{plan_name}}

+
+

到期日期

+

{{end_date}}

+
+

续费价格

+

¥{{plan_price}}/月

+
+ + +
+

续费后您将继续享受:

+

✓ 不受限的GEO诊断与优化功能
✓ 持续的AI引用率监测
✓ 专业的优化建议与报告

+
+ + +
+立即续费 +
+
+

© {{year}} GEO平台 — AI搜索优化领导者

+
+
+ + diff --git a/backend/app/templates/trial_expiring.html b/backend/app/templates/trial_expiring.html new file mode 100644 index 0000000..8c8cdcf --- /dev/null +++ b/backend/app/templates/trial_expiring.html @@ -0,0 +1,48 @@ + + + + + +试用到期提醒 + + + + +
+ + + + +
+

试用即将到期

+

{{user_name}},别错过完整功能体验

+
+ + +
+

试用剩余

+

{{days_remaining}}天

+
+ + +
+

⚠ 到期后您将失去:

+

✗ 高级GEO诊断功能
✗ AI引用率实时监测
✗ 专业优化建议与报告
✗ 多品牌监控能力

+
+ + +
+

✓ 升级后您将获得:

+

✓ 更多品牌监控额度
✓ 无限告警通知
✓ 完整竞品对比分析
✓ AI个性化优化建议

+
+ + +
+立即升级 +
+
+

© {{year}} GEO平台 — AI搜索优化领导者

+
+
+ + diff --git a/backend/app/templates/welcome.html b/backend/app/templates/welcome.html new file mode 100644 index 0000000..b4194b7 --- /dev/null +++ b/backend/app/templates/welcome.html @@ -0,0 +1,63 @@ + + + + + +欢迎加入GEO平台 + + + + +
+ + + + + +
+

🎉 欢迎加入GEO平台

+

{{user_name}},开始您的AI搜索优化之旅

+
+

只需3步,即可开始优化您的品牌在AI搜索中的可见性:

+ + + + + + + + + + + + +
+

1

+

添加品牌

+

在控制台中添加您要监控的品牌名称

+
+

2

+

运行诊断

+

系统将自动分析品牌在各大AI平台的引用情况

+
+

3

+

获取优化建议

+

基于诊断结果,获取个性化的GEO优化策略

+
+ + +
+开始使用 +
+
+ + +
+

需要帮助?查看帮助文档 或直接运行首次诊断

+
+
+

© {{year}} GEO平台 — AI搜索优化领导者

+
+
+ + diff --git a/backend/app/utils/health.py b/backend/app/utils/health.py new file mode 100644 index 0000000..c595aa5 --- /dev/null +++ b/backend/app/utils/health.py @@ -0,0 +1,33 @@ +"""健康等级工具函数 + +根据评分获取健康等级及中文标签,供 scoring_service / geo_diagnosis 等模块复用。 +""" + + +def get_health_level(score: float) -> str: + """ + 根据评分获取健康等级 + + 80+ -> excellent (优秀/绿) + 60-79 -> good (良好/黄) + 40-59 -> pass (及格/橙) + <40 -> danger (危险/红) + """ + if score >= 80: + return "excellent" + if score >= 60: + return "good" + if score >= 40: + return "pass" + return "danger" + + +def get_health_level_label(level: str) -> str: + """获取健康等级中文标签""" + labels = { + "excellent": "优秀", + "good": "良好", + "pass": "及格", + "danger": "危险", + } + return labels.get(level, "未知") diff --git a/backend/app/utils/json_extractor.py b/backend/app/utils/json_extractor.py new file mode 100644 index 0000000..1cc59a0 --- /dev/null +++ b/backend/app/utils/json_extractor.py @@ -0,0 +1,65 @@ +"""JSON 提取工具函数 + +从可能包含 markdown 代码块或周围文本的 LLM 响应中提取 JSON 字符串。 +采用深度计数器方式确保正确匹配嵌套括号。 +""" +import json +import re + + +def extract_json(text: str) -> str: + """从文本中提取 JSON 字符串。 + + 提取策略(按优先级): + 1. 尝试直接解析整个文本为 JSON + 2. 尝试从 ```json ... ``` 代码块中提取 + 3. 使用深度计数器找到第一个完整的 JSON 对象或数组 + + Args: + text: 可能包含 JSON 的文本 + + Returns: + 提取出的 JSON 字符串 + + Raises: + ValueError: 无法从文本中提取有效 JSON + """ + if not text or not text.strip(): + raise ValueError(f"无法从响应中提取JSON: {text[:200]}") + + # 1. 尝试直接解析 + try: + json.loads(text) + return text + except json.JSONDecodeError: + pass + + # 2. 尝试从代码块中提取 + match = re.search(r"```(?:json)?\s*\n?(.*?)\n?```", text, re.DOTALL) + if match: + candidate = match.group(1).strip() + try: + json.loads(candidate) + return candidate + except json.JSONDecodeError: + pass + + # 3. 使用深度计数器找到第一个完整的 JSON 对象或数组 + for i, c in enumerate(text): + if c in "[{": + depth = 0 + for j in range(i, len(text)): + if text[j] in "[{": + depth += 1 + elif text[j] in "]}": + depth -= 1 + if depth == 0: + candidate = text[i : j + 1] + try: + json.loads(candidate) + return candidate + except json.JSONDecodeError: + break # 这对括号不是有效 JSON,继续找下一对 + # 如果 depth != 0 说明括号不匹配,继续找下一个起始括号 + + raise ValueError(f"无法从响应中提取JSON: {text[:200]}") diff --git a/backend/app/utils/text.py b/backend/app/utils/text.py new file mode 100644 index 0000000..e551211 --- /dev/null +++ b/backend/app/utils/text.py @@ -0,0 +1,15 @@ +"""文本清理工具函数 + +清理原始响应中的无效控制字符,避免 PostgreSQL UTF-8 插入失败。 +""" +import re + + +def sanitize_raw_response(text: str | None) -> str: + """清理原始响应中的无效控制字符,避免 PostgreSQL UTF-8 插入失败 + + 移除 NULL 字节及其他非法控制字符,保留 \\n \\t \\r + """ + if not text: + return "" + return re.sub(r"[\x00-\x08\x0b\x0c\x0e-\x1f]", "", text) diff --git a/backend/app/workers/__init__.py b/backend/app/workers/__init__.py index 890f1c2..c8bea7c 100644 --- a/backend/app/workers/__init__.py +++ b/backend/app/workers/__init__.py @@ -1,14 +1,10 @@ from app.workers.citation_engine import CitationEngine from app.workers.citation_extractor import analyze_citations -from app.workers.platforms.kimi import KimiAdapter -from app.workers.platforms.wenxin import WenxinAdapter from app.workers.scheduler import QueryScheduler, query_scheduler __all__ = [ "CitationEngine", "analyze_citations", - "KimiAdapter", - "WenxinAdapter", "QueryScheduler", "query_scheduler", ] diff --git a/backend/app/workers/citation_engine.py b/backend/app/workers/citation_engine.py index c2049ca..3ba565c 100644 --- a/backend/app/workers/citation_engine.py +++ b/backend/app/workers/citation_engine.py @@ -7,46 +7,22 @@ from datetime import datetime, timedelta from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select - -def _sanitize_raw_response(text: str | None) -> str: - """清理原始响应中的无效控制字符,避免 PostgreSQL UTF-8 插入失败""" - if not text: - return "" - # 移除 NULL 字节及其他非法控制字符,保留 \n \t \r - return re.sub(r"[\x00-\x08\x0b\x0c\x0e-\x1f]", "", text) - from app.models.citation_record import CitationRecord from app.models.query import Query from app.models.query_task import QueryTask -from app.workers.platforms.kimi import KimiAdapter -from app.workers.platforms.wenxin import WenxinAdapter -from app.workers.platforms.tongyi import TongyiAdapter -from app.workers.platforms.doubao import DoubaoAdapter -from app.workers.platforms.qingyan import QingyanAdapter -from app.workers.platforms.tiangong import TiangongAdapter -from app.workers.platforms.xinghuo import XinghuoAdapter +from app.services.ai_engine.platform_bridge import query_platform_raw from app.workers.citation_extractor import analyze_citations logger = logging.getLogger(__name__) class BrandMatcher: - """品牌匹配器:检测文本中是否引用了目标品牌""" def __init__(self, target_brand: str, brand_aliases: list[str] | None = None): self.target_brand = target_brand self.brand_aliases = brand_aliases or [] def match(self, text: str) -> dict: - """ - 返回: { - "cited": bool, - "confidence": float, # 0.0-1.0 - "match_type": str, # "exact"/"alias"/"fuzzy"/None - "position": int|None, # 在文本段落中的位置(第几段提到,1-based) - "citation_text": str|None, # 被引用的上下文片段 - } - """ if not text: return { "cited": False, @@ -56,7 +32,6 @@ class BrandMatcher: "citation_text": None, } - # 1. 精确匹配 if self.target_brand in text: position, citation_text = self._extract_position_and_context(text, self.target_brand) return { @@ -67,7 +42,6 @@ class BrandMatcher: "citation_text": citation_text, } - # 2. 别名匹配 for alias in self.brand_aliases: if alias in text: position, citation_text = self._extract_position_and_context(text, alias) @@ -79,7 +53,6 @@ class BrandMatcher: "citation_text": citation_text, } - # 3. 模糊匹配 best_ratio = 0.0 best_match = None for word in self._extract_candidates(text): @@ -114,19 +87,15 @@ class BrandMatcher: } def _extract_candidates(self, text: str) -> list[str]: - """从文本中提取候选词(按非文字字符分割)""" - # 匹配中文词组、英文单词等 return [w for w in re.split(r'[^\w\u4e00-\u9fff]+', text) if len(w) >= 2] def _extract_position_and_context(self, text: str, keyword: str) -> tuple[int | None, str | None]: - """提取品牌首次出现的段落位置(1-based)和上下文片段""" paragraphs = [p.strip() for p in text.split('\n') if p.strip()] if not paragraphs: paragraphs = [text] for idx, paragraph in enumerate(paragraphs, start=1): if keyword in paragraph: - # 截取前200字符 snippet = paragraph[:200] return idx, snippet @@ -134,9 +103,7 @@ class BrandMatcher: class CompetitorDetector: - """竞争品牌检测器""" - # 预定义一些常见行业品牌列表 KNOWN_BRANDS = { "保险": ["中国平安", "中国人寿", "太平洋保险", "新华保险", "泰康保险", "中国人保", "友邦保险"], "金融": ["工商银行", "建设银行", "农业银行", "中国银行", "招商银行", "交通银行"], @@ -144,7 +111,6 @@ class CompetitorDetector: } def detect(self, text: str, target_brand: str) -> list[str]: - """检测文本中出现的其他品牌(排除 target_brand)""" if not text: return [] @@ -160,29 +126,16 @@ class CompetitorDetector: class CitationEngine: - """引用检测引擎核心""" def __init__(self): - self.platforms = { - "wenxin": WenxinAdapter(), - "kimi": KimiAdapter(), - "tongyi": TongyiAdapter(), - "doubao": DoubaoAdapter(), - "qingyan": QingyanAdapter(), - "tiangong": TiangongAdapter(), - "xinghuo": XinghuoAdapter(), + self._supported_platforms = { + "wenxin", "kimi", "doubao", "tongyi", + "qingyan", "tiangong", "xinghuo", } self.matcher = None self.competitor_detector = CompetitorDetector() async def execute_query(self, query: Query, db: AsyncSession) -> list[CitationRecord]: - """ - 执行一个查询任务: - 1. 创建 BrandMatcher - 2. 遍历 query.platforms - 3. 对每个 platform 执行查询和检测 - 4. 更新 query.last_queried_at 和 query.next_query_at - """ self.matcher = BrandMatcher( target_brand=query.target_brand, brand_aliases=query.brand_aliases or [], @@ -192,10 +145,8 @@ class CitationEngine: platforms = query.platforms or ["wenxin", "kimi"] for platform_name in platforms: - # 查找或创建 QueryTask task = await self._get_or_create_task(db, query.id, platform_name) - # 更新状态为 running task.status = "running" task.started_at = datetime.utcnow() task.error_message = None @@ -209,28 +160,14 @@ class CitationEngine: brand_aliases=query.brand_aliases or [], ) - # 创建 CitationRecord - record = CitationRecord( + record = CitationRecord.from_citation_result( query_id=query.id, platform=platform_name, - cited=result["cited"], - citation_position=result.get("position"), - citation_text=result.get("citation_text"), - competitor_brands=result.get("competitor_brands", []), - raw_response=_sanitize_raw_response(result.get("raw_response", "")), - confidence=result.get("confidence"), - match_type=result.get("match_type"), - # 引用源分析字段 - data_source=result.get("data_source"), - source_urls=result.get("source_urls"), - source_titles=result.get("source_titles"), - citation_contexts=result.get("citation_contexts"), - ai_response_text=_sanitize_raw_response(result.get("ai_response_text", "")), + result=result, ) db.add(record) records.append(record) - # 更新 QueryTask 状态为 success task.status = "success" task.completed_at = datetime.utcnow() await db.commit() @@ -242,18 +179,15 @@ class CitationEngine: task.error_message = error_msg task.completed_at = datetime.utcnow() - # 创建一条 cited=False 的记录作为占位 - record = CitationRecord( + record = CitationRecord.from_citation_result( query_id=query.id, platform=platform_name, - cited=False, - raw_response=_sanitize_raw_response(error_msg), + result={"cited": False, "raw_response": error_msg}, ) db.add(record) records.append(record) await db.commit() - # 更新 Query 时间字段 query.last_queried_at = datetime.utcnow() query.next_query_at = self._calculate_next_query_at(query.frequency) await db.commit() @@ -267,28 +201,25 @@ class CitationEngine: target_brand: str, brand_aliases: list, ) -> dict: - """执行单个平台的查询和检测""" - adapter = self.platforms.get(platform) - if not adapter: + if platform not in self._supported_platforms: raise ValueError(f"不支持的平台: {platform}") - # 获取平台内容(将关键词与目标品牌组合,确保结果包含品牌信息) search_keyword = f"{keyword} {target_brand}" - raw_response = await adapter.query(search_keyword) + raw_response = await query_platform_raw( + platform_name=platform, + keyword=search_keyword, + brand_name=target_brand, + ) - # 引用源分析 citation_analysis = analyze_citations(raw_response) - # 品牌匹配(使用清理后的纯文本进行匹配) matcher = BrandMatcher(target_brand=target_brand, brand_aliases=brand_aliases) match_result = matcher.match(citation_analysis.clean_response) - # 竞争品牌检测 competitor_brands = self.competitor_detector.detect( citation_analysis.clean_response, target_brand ) - # 提取引用源信息 source_urls = [ c.source_url for c in citation_analysis.citations if c.source_url ] @@ -307,7 +238,6 @@ class CitationEngine: "citation_text": match_result["citation_text"], "competitor_brands": competitor_brands, "raw_response": raw_response, - # 引用源分析新增字段 "data_source": citation_analysis.data_source, "source_urls": source_urls, "source_titles": source_titles, @@ -318,7 +248,6 @@ class CitationEngine: async def _get_or_create_task( self, db: AsyncSession, query_id: uuid.UUID, platform: str ) -> QueryTask: - """获取或创建 QueryTask""" stmt = select(QueryTask).where( QueryTask.query_id == query_id, QueryTask.platform == platform, @@ -339,7 +268,6 @@ class CitationEngine: return task def _calculate_next_query_at(self, frequency: str | None) -> datetime: - """根据频率计算下次查询时间""" now = datetime.utcnow() freq_map = { "daily": timedelta(days=1), @@ -350,9 +278,4 @@ class CitationEngine: return now + delta async def close(self): - """关闭所有平台适配器""" - for adapter in self.platforms.values(): - try: - await adapter.close() - except Exception as e: - logger.warning(f"关闭适配器 {adapter.platform_name} 时出错: {e}") + pass diff --git a/backend/app/workers/llm_adapter.py b/backend/app/workers/llm_adapter.py deleted file mode 100644 index 6dc685c..0000000 --- a/backend/app/workers/llm_adapter.py +++ /dev/null @@ -1,270 +0,0 @@ -""" -LLM适配器 - 使用DeepSeek LLM API检测品牌引用 -""" -import asyncio -import json -import logging -import re -from typing import Optional - -from app.schemas.scoring import CitationResult -from app.config import settings - -logger = logging.getLogger(__name__) - - -BRAND_CITATION_PROMPT = """分析以下AI搜索查询中是否提到了目标品牌。 - -查询关键词: {keyword} -目标品牌: {brand_name} -品牌别名: {brand_aliases} - -返回JSON格式: -{{"cited": true/false, "position": 1, "citation_text": "...", "sentiment": "positive/neutral/negative", "confidence": 0.95}} -""" - - -class LLMAdapterError(Exception): - """LLM适配器异常""" - pass - - -class LLMAdapter: - """LLM适配器 - 使用 OpenAI 兼容协议检测品牌引用(支持百炼/DashScope/DeepSeek)""" - - def __init__(self, api_key: Optional[str] = None, max_retries: int = 3): - """ - 初始化LLM适配器 - - Args: - api_key: API密钥,默认优先使用 OPENAI_API_KEY(百炼/DashScope),其次 DEEPSEEK_API_KEY - max_retries: 最大重试次数 - """ - self.api_key = ( - api_key - or getattr(settings, 'OPENAI_API_KEY', None) - or getattr(settings, 'DEEPSEEK_API_KEY', None) - ) - # base_url 优先 OPENAI_BASE_URL,其次 DEEPSEEK_BASE_URL - self.base_url = ( - getattr(settings, 'OPENAI_BASE_URL', None) - or getattr(settings, 'DEEPSEEK_BASE_URL', 'https://api.deepseek.com/v1') - ) - # model 优先 OPENAI_MODEL,其次 DEFAULT_LLM_MODEL - self.model = ( - getattr(settings, 'OPENAI_MODEL', None) - or getattr(settings, 'DEFAULT_LLM_MODEL', 'qwen3-coder-plus') - or 'qwen3-coder-plus' - ) - self.max_retries = max_retries - self._client = None - - @property - def client(self): - """延迟初始化 OpenAI 兼容客户端""" - if self._client is None: - try: - from openai import OpenAI - self._client = OpenAI( - api_key=self.api_key, - base_url=self.base_url, - ) - except ImportError: - raise LLMAdapterError("请安装openai库: pip install openai") - return self._client - - def _build_prompt(self, keyword: str, brand_name: str, brand_aliases: list[str]) -> str: - """构建Prompt""" - aliases_str = ", ".join(brand_aliases) if brand_aliases else "无" - return BRAND_CITATION_PROMPT.format( - keyword=keyword, - brand_name=brand_name, - brand_aliases=aliases_str - ) - - async def query_brand_citation( - self, - keyword: str, - brand_name: str, - brand_aliases: list[str] - ) -> CitationResult: - """ - 使用LLM检测品牌引用 - - Args: - keyword: 查询关键词 - brand_name: 目标品牌名称 - brand_aliases: 品牌别名列表 - - Returns: - CitationResult: 包含cited, position, citation_text, sentiment, confidence - - Raises: - LLMAdapterError: API调用或解析失败 - """ - if not settings.ENABLE_LLM: - raise LLMAdapterError( - "LLM引用检测未启用。请在环境变量中设置 ENABLE_LLM=True 并配置 DEEPSEEK_API_KEY" - ) - - if not self.api_key: - raise LLMAdapterError( - "未配置DeepSeek API Key。请设置 DEEPSEEK_API_KEY 环境变量" - ) - - prompt = self._build_prompt(keyword, brand_name, brand_aliases) - - last_error = None - for attempt in range(self.max_retries): - try: - response = await self._call_deepseek(prompt) - return self._parse_response(response) - - except Exception as e: - last_error = e - logger.warning( - f"LLM API调用失败 (尝试 {attempt + 1}/{self.max_retries}): {e}" - ) - - raise LLMAdapterError(f"LLM API调用失败,已重试{self.max_retries}次: {last_error}") - - async def _call_deepseek(self, prompt: str) -> dict: - """ - 调用DeepSeek API - - Args: - prompt: 提示词 - - Returns: - API响应的JSON解析结果 - - Raises: - LLMAdapterError: API调用失败 - """ - try: - # 在线程池中执行同步的API调用 - response = await asyncio.to_thread( - self._sync_call_deepseek, - prompt - ) - return response - - except json.JSONDecodeError as e: - raise LLMAdapterError(f"JSON解析失败: {e}") - except Exception as e: - raise LLMAdapterError(f"API调用失败: {e}") - - def _sync_call_deepseek(self, prompt: str) -> dict: - """ - 同步调用DeepSeek API(在线程池中执行) - - Args: - prompt: 提示词 - - Returns: - API响应的JSON解析结果 - """ - response = self.client.chat.completions.create( - model=self.model, - messages=[ - { - "role": "user", - "content": prompt - } - ], - temperature=0.1, - max_tokens=500, - ) - - content = response.choices[0].message.content - if not content: - raise LLMAdapterError("API返回空响应") - - # 提取JSON(可能包裹在```json block中) - json_str = self._extract_json(content) - return json.loads(json_str) - - def _extract_json(self, text: str) -> str: - """从文本中提取JSON""" - # 尝试直接解析 - try: - json.loads(text) - return text - except json.JSONDecodeError: - pass - - # 尝试从代码块中提取 - json_pattern = r'```(?:json)?\s*([\s\S]*?)\s*```' - match = re.search(json_pattern, text) - if match: - return match.group(1).strip() - - # 尝试找到第一个{到最后一个}之间的内容 - first_brace = text.find('{') - last_brace = text.rfind('}') - if first_brace != -1 and last_brace != -1 and last_brace > first_brace: - return text[first_brace:last_brace + 1] - - raise LLMAdapterError(f"无法从响应中提取JSON: {text[:200]}") - - def _parse_response(self, response: dict) -> CitationResult: - """ - 解析API响应 - - Args: - response: API返回的字典 - - Returns: - CitationResult对象 - - Raises: - LLMAdapterError: 解析失败 - """ - try: - required_fields = ['cited', 'sentiment', 'confidence'] - for field in required_fields: - if field not in response: - raise LLMAdapterError(f"响应缺少必需字段: {field}") - - cited = bool(response['cited']) - sentiment = str(response.get('sentiment', 'neutral')).lower() - - if sentiment not in ['positive', 'neutral', 'negative']: - sentiment = 'neutral' - - # 验证position - position = response.get('position') - if position is not None: - position = int(position) - if position < 1: - position = None - - # 验证confidence - confidence = float(response.get('confidence', 0.5)) - confidence = max(0.0, min(1.0, confidence)) - - citation_text = response.get('citation_text') - if citation_text and len(citation_text) > 500: - citation_text = citation_text[:500] - - return CitationResult( - cited=cited, - position=position, - citation_text=citation_text, - sentiment=sentiment, - confidence=confidence - ) - - except (ValueError, TypeError) as e: - raise LLMAdapterError(f"解析响应失败: {e}") - - async def close(self): - """关闭客户端连接""" - if self._client is not None: - try: - # OpenAI/DeepSeek客户端不需要显式关闭 - pass - except Exception as e: - logger.warning(f"关闭LLM客户端时出错: {e}") - finally: - self._client = None diff --git a/backend/app/workers/platforms/__init__.py b/backend/app/workers/platforms/__init__.py deleted file mode 100644 index 3433dd8..0000000 --- a/backend/app/workers/platforms/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -from app.workers.platforms.base import BasePlatformAdapter -from app.workers.platforms.wenxin import WenxinAdapter -from app.workers.platforms.kimi import KimiAdapter -from app.workers.platforms.tongyi import TongyiAdapter -from app.workers.platforms.doubao import DoubaoAdapter -from app.workers.platforms.qingyan import QingyanAdapter -from app.workers.platforms.tiangong import TiangongAdapter -from app.workers.platforms.xinghuo import XinghuoAdapter - -__all__ = [ - "BasePlatformAdapter", - "WenxinAdapter", - "KimiAdapter", - "TongyiAdapter", - "DoubaoAdapter", - "QingyanAdapter", - "TiangongAdapter", - "XinghuoAdapter", -] diff --git a/backend/app/workers/platforms/base.py b/backend/app/workers/platforms/base.py deleted file mode 100644 index 1d97158..0000000 --- a/backend/app/workers/platforms/base.py +++ /dev/null @@ -1,17 +0,0 @@ -from abc import ABC, abstractmethod - - -class BasePlatformAdapter(ABC): - """AI平台查询适配器基类""" - - platform_name: str # 平台枚举值 - platform_url: str # 平台URL - - @abstractmethod - async def query(self, keyword: str) -> str: - """在AI平台查询关键词,返回原始响应文本""" - pass - - async def close(self): - """清理资源""" - pass diff --git a/backend/app/workers/platforms/doubao.py b/backend/app/workers/platforms/doubao.py deleted file mode 100644 index b732173..0000000 --- a/backend/app/workers/platforms/doubao.py +++ /dev/null @@ -1,174 +0,0 @@ -""" -豆包 (字节跳动/火山引擎) 平台适配器 - 使用火山方舟 API 获取真实AI回答 - -API文档: https://www.volcengine.com/docs/82379/1298454 -使用火山方舟推理接入点 (Endpoint) API -认证方式: API Key (Bearer Token) -""" - -import asyncio -import logging -import time -from typing import Optional - -import httpx - -from app.config import settings -from app.workers.platforms.base import BasePlatformAdapter -from app.workers.platforms.search_engine import fetch_search_content - -logger = logging.getLogger(__name__) - -# 模块级频率限制器 -_last_request_time: float = 0.0 -_min_interval: float = 6.0 # 每分钟10次 -> 6秒间隔 - - -async def _rate_limit_wait(): - """确保请求间隔不低于最小间隔""" - global _last_request_time - now = time.monotonic() - elapsed = now - _last_request_time - if elapsed < _min_interval: - wait_time = _min_interval - elapsed - logger.debug(f"豆包 频率限制等待 {wait_time:.1f}s") - await asyncio.sleep(wait_time) - _last_request_time = time.monotonic() - - -class DoubaoAdapter(BasePlatformAdapter): - """豆包平台适配器 - 使用火山方舟 API 获取真实AI回答""" - - platform_name = "doubao" - platform_url = "https://www.doubao.com/" - - # 火山方舟 API 端点 (OpenAI兼容格式) - _api_base = "https://ark.cn-beijing.volces.com/api/v3" - _default_model = "doubao-pro-4k" # 默认模型,可被 endpoint_id 覆盖 - - def __init__(self): - self._api_key: Optional[str] = None - self._endpoint_id: Optional[str] = None - self._client: Optional[httpx.AsyncClient] = None - - @property - def api_key(self) -> Optional[str]: - if self._api_key is None: - self._api_key = settings.DOUBAO_API_KEY - return self._api_key - - @property - def endpoint_id(self) -> Optional[str]: - if self._endpoint_id is None: - self._endpoint_id = settings.DOUBAO_ENDPOINT_ID - return self._endpoint_id - - @property - def is_configured(self) -> bool: - return bool(self.api_key and self.api_key.strip()) - - def _get_model_id(self) -> str: - """获取实际使用的模型ID(优先使用 endpoint_id)""" - if self.endpoint_id and self.endpoint_id.strip(): - return f"ep-{self.endpoint_id}" if not self.endpoint_id.startswith("ep-") else self.endpoint_id - return self._default_model - - def _get_client(self) -> httpx.AsyncClient: - if self._client is None or self._client.is_closed: - self._client = httpx.AsyncClient( - timeout=httpx.Timeout(60.0, connect=10.0), - headers={ - "Content-Type": "application/json", - "Authorization": f"Bearer {self.api_key}", - }, - ) - return self._client - - async def query(self, keyword: str) -> str: - """在豆包查询关键词,返回原始响应文本""" - last_error = None - for attempt in range(3): - try: - return await self._do_query(keyword) - except Exception as e: - last_error = e - logger.warning(f"豆包查询第 {attempt + 1} 次尝试失败: {e}") - if attempt < 2: - await asyncio.sleep(2 ** attempt) - - # 所有尝试失败后回退到搜索引擎 - logger.warning(f"豆包 API 调用全部失败,回退到搜索引擎: {last_error}") - return await self._fallback_search(keyword) - - async def _do_query(self, keyword: str) -> str: - """通过火山方舟 API 获取真实AI回答""" - if not self.is_configured: - logger.warning("豆包 API Key 未配置,回退到搜索引擎") - return await self._fallback_search(keyword) - - await _rate_limit_wait() - - model_id = self._get_model_id() - client = self._get_client() - - payload = { - "model": model_id, - "messages": [ - { - "role": "system", - "content": ( - "你是一个专业的AI搜索助手。请基于你的知识," - "详细回答用户的问题。如果引用了外部来源," - "请在回答中标注来源URL或出处名称。" - ), - }, - { - "role": "user", - "content": keyword, - }, - ], - "temperature": 0.7, - "max_tokens": 2000, - } - - response = await client.post( - f"{self._api_base}/chat/completions", - json=payload, - ) - - if response.status_code == 429: - retry_after = int(response.headers.get("Retry-After", "10")) - logger.warning(f"豆包 API 限流,等待 {retry_after}s 后重试") - await asyncio.sleep(retry_after) - raise RuntimeError(f"豆包 API 限流") - - if response.status_code != 200: - error_body = response.text[:500] - raise RuntimeError( - f"豆包 API 返回错误 {response.status_code}: {error_body}" - ) - - data = response.json() - choices = data.get("choices", []) - if not choices: - raise RuntimeError("豆包 API 返回空 choices") - - content = choices[0].get("message", {}).get("content", "") - if not content: - raise RuntimeError("豆包 API 返回空内容") - - # 标记数据来源为AI平台 - result = f"[data_source: ai_platform]\n{content}" - logger.info(f"豆包 API 调用成功,返回 {len(content)} 字符") - return result - - async def _fallback_search(self, keyword: str) -> str: - """回退到搜索引擎模式,标记数据来源""" - content = await fetch_search_content(self.platform_name, keyword) - return f"[data_source: search_engine]\n{content}" - - async def close(self): - """清理资源""" - if self._client and not self._client.is_closed: - await self._client.aclose() - self._client = None diff --git a/backend/app/workers/platforms/kimi.py b/backend/app/workers/platforms/kimi.py deleted file mode 100644 index 8c0d56a..0000000 --- a/backend/app/workers/platforms/kimi.py +++ /dev/null @@ -1,156 +0,0 @@ -""" -Kimi (月之暗面) 平台适配器 - 使用 Moonshot AI API 获取真实AI回答 - -API文档: https://platform.moonshot.cn/docs -使用 Moonshot v1 chat completion API -""" - -import asyncio -import logging -import time -from typing import Optional - -import httpx - -from app.config import settings -from app.workers.platforms.base import BasePlatformAdapter -from app.workers.platforms.search_engine import fetch_search_content - -logger = logging.getLogger(__name__) - -# 模块级频率限制器 -_last_request_time: float = 0.0 -_min_interval: float = 6.0 # 每分钟10次 -> 6秒间隔 - - -async def _rate_limit_wait(): - """确保请求间隔不低于最小间隔""" - global _last_request_time - now = time.monotonic() - elapsed = now - _last_request_time - if elapsed < _min_interval: - wait_time = _min_interval - elapsed - logger.debug(f"Kimi 频率限制等待 {wait_time:.1f}s") - await asyncio.sleep(wait_time) - _last_request_time = time.monotonic() - - -class KimiAdapter(BasePlatformAdapter): - """Kimi 平台适配器 - 使用 Moonshot AI API 获取真实AI回答""" - - platform_name = "kimi" - platform_url = "https://kimi.moonshot.cn" - _api_base = "https://api.moonshot.cn/v1" - _model = "moonshot-v1-8k" - - def __init__(self): - self._api_key: Optional[str] = None - self._client: Optional[httpx.AsyncClient] = None - - @property - def api_key(self) -> Optional[str]: - if self._api_key is None: - self._api_key = settings.MOONSHOT_API_KEY - return self._api_key - - @property - def is_configured(self) -> bool: - return bool(self.api_key and self.api_key.strip()) - - def _get_client(self) -> httpx.AsyncClient: - if self._client is None or self._client.is_closed: - self._client = httpx.AsyncClient( - timeout=httpx.Timeout(60.0, connect=10.0), - headers={ - "Content-Type": "application/json", - "Authorization": f"Bearer {self.api_key}", - }, - ) - return self._client - - async def query(self, keyword: str) -> str: - """在 Kimi 查询关键词,返回原始响应文本""" - last_error = None - for attempt in range(3): - try: - return await self._do_query(keyword) - except Exception as e: - last_error = e - logger.warning(f"Kimi 查询第 {attempt + 1} 次尝试失败: {e}") - if attempt < 2: - await asyncio.sleep(2 ** attempt) - - # 所有尝试失败后回退到搜索引擎 - logger.warning(f"Kimi API 调用全部失败,回退到搜索引擎: {last_error}") - return await self._fallback_search(keyword) - - async def _do_query(self, keyword: str) -> str: - """通过 Moonshot API 获取真实AI回答""" - if not self.is_configured: - logger.warning("Kimi API Key 未配置,回退到搜索引擎") - return await self._fallback_search(keyword) - - await _rate_limit_wait() - - client = self._get_client() - payload = { - "model": self._model, - "messages": [ - { - "role": "system", - "content": ( - "你是一个专业的AI搜索助手。请基于你的知识," - "详细回答用户的问题。如果引用了外部来源," - "请在回答中标注来源URL或出处名称。" - ), - }, - { - "role": "user", - "content": keyword, - }, - ], - "temperature": 0.7, - "max_tokens": 2000, - } - - response = await client.post( - f"{self._api_base}/chat/completions", - json=payload, - ) - - if response.status_code == 429: - retry_after = int(response.headers.get("Retry-After", "10")) - logger.warning(f"Kimi API 限流,等待 {retry_after}s 后重试") - await asyncio.sleep(retry_after) - raise RuntimeError(f"Kimi API 限流,等待 {retry_after}s") - - if response.status_code != 200: - error_body = response.text[:500] - raise RuntimeError( - f"Kimi API 返回错误 {response.status_code}: {error_body}" - ) - - data = response.json() - choices = data.get("choices", []) - if not choices: - raise RuntimeError("Kimi API 返回空 choices") - - content = choices[0].get("message", {}).get("content", "") - if not content: - raise RuntimeError("Kimi API 返回空内容") - - # 标记数据来源为AI平台 - result = f"[data_source: ai_platform]\n{content}" - logger.info(f"Kimi API 调用成功,返回 {len(content)} 字符") - return result - - async def _fallback_search(self, keyword: str) -> str: - """回退到搜索引擎模式,标记数据来源""" - content = await fetch_search_content(self.platform_name, keyword) - return f"[data_source: search_engine]\n{content}" - - async def close(self): - """清理资源""" - if self._client and not self._client.is_closed: - await self._client.aclose() - self._client = None diff --git a/backend/app/workers/platforms/qingyan.py b/backend/app/workers/platforms/qingyan.py deleted file mode 100644 index d2103f0..0000000 --- a/backend/app/workers/platforms/qingyan.py +++ /dev/null @@ -1,37 +0,0 @@ -import asyncio -import logging - -from app.workers.platforms.base import BasePlatformAdapter -from app.workers.platforms.search_engine import fetch_search_content - -logger = logging.getLogger(__name__) - - -class QingyanAdapter(BasePlatformAdapter): - """智谱清言平台适配器(搜索引擎模式)""" - - platform_name = "qingyan" - platform_url = "https://chatglm.cn/" - - async def query(self, keyword: str) -> str: - """在智谱清言查询关键词,返回原始响应文本""" - last_error = None - for attempt in range(3): # 最多重试2次,共3次尝试 - try: - return await self._do_query(keyword) - except Exception as e: - last_error = e - logger.warning(f"智谱清言查询第 {attempt + 1} 次尝试失败: {e}") - if attempt < 2: - await asyncio.sleep(2 ** attempt) # 指数退避 - - logger.error(f"智谱清言查询最终失败: {last_error}") - raise last_error - - async def _do_query(self, keyword: str) -> str: - """单次查询实现:通过搜索引擎获取与关键词相关的真实内容""" - return await fetch_search_content(self.platform_name, keyword) - - async def close(self): - """清理资源(搜索引擎模式无额外资源需要释放)""" - pass diff --git a/backend/app/workers/platforms/search_engine.py b/backend/app/workers/platforms/search_engine.py deleted file mode 100644 index 2a37fe5..0000000 --- a/backend/app/workers/platforms/search_engine.py +++ /dev/null @@ -1,173 +0,0 @@ -""" -通用搜索引擎模块 —— 用于在AI平台适配器无法正常工作时获取与关键词相关的真实内容。 - -使用 DuckDuckGo HTML 搜索(无需 API Key),返回搜索结果摘要。 -""" - -import logging -import re -from urllib.parse import quote - -import httpx - -logger = logging.getLogger(__name__) - - -async def search_wikipedia(keyword: str, max_chars: int = 2000) -> str: - """ - 使用 Wikipedia API 获取与关键词相关的百科内容。 - Wikipedia API 是公开的,不需要 API Key,非常稳定。 - """ - # 尝试用关键词直接搜索 Wikipedia - search_url = "https://zh.wikipedia.org/w/api.php" - headers = { - "User-Agent": "GEO-Citation-Bot/1.0 (contact@example.com)", - } - - # 1. 先搜索匹配的词条 - async with httpx.AsyncClient(timeout=30) as client: - search_resp = await client.get( - search_url, - headers=headers, - params={ - "action": "query", - "list": "search", - "srsearch": keyword, - "srlimit": 3, - "format": "json", - "origin": "*", - }, - ) - search_resp.raise_for_status() - search_data = search_resp.json() - - search_results = search_data.get("query", {}).get("search", []) - if not search_results: - return "" - - # 2. 获取第一个匹配词条的内容摘要 - title = search_results[0]["title"] - async with httpx.AsyncClient(timeout=30) as client: - extract_resp = await client.get( - search_url, - headers=headers, - params={ - "action": "query", - "prop": "extracts", - "titles": title, - "explaintext": True, - "exsentences": 15, - "format": "json", - "origin": "*", - }, - ) - extract_resp.raise_for_status() - extract_data = extract_resp.json() - - pages = extract_data.get("query", {}).get("pages", {}) - for page in pages.values(): - extract = page.get("extract", "") - if extract: - # 清理 Wikipedia 的标记 - extract = re.sub(r'\[\d+\]', '', extract) # 移除引用标记如 [1] - extract = re.sub(r'\s+', ' ', extract).strip() - return extract[:max_chars] - - return "" - - -async def search_duckduckgo(query: str, max_results: int = 5) -> str: - """ - 使用 DuckDuckGo HTML 版搜索。若被限制则回退到 Wikipedia。 - """ - url = f"https://html.duckduckgo.com/html/?q={quote(query)}" - headers = { - "User-Agent": ( - "Mozilla/5.0 (Windows NT 10.0; Win64; x64) " - "AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36" - ), - "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8", - "Accept-Language": "zh-CN,zh;q=0.9,en-US;q=0.8,en;q=0.7", - } - - try: - async with httpx.AsyncClient(timeout=30, follow_redirects=True) as client: - resp = await client.get(url, headers=headers) - resp.raise_for_status() - html = resp.text - - # 快速检查是否是有效的结果页(而不是主页/验证页) - if "web-result" not in html and "result__snippet" not in html and "result__title" not in html: - raise RuntimeError("DuckDuckGo 返回了非结果页面") - - results: list[str] = [] - - # 尝试匹配标准 result 块 - result_blocks = re.findall( - r'
]*>.*?]*class="result__title"[^>]*>.*?]*>(.*?).*?]*>.*?]*class="result__snippet"[^>]*>(.*?).*?
', - html, - re.DOTALL | re.IGNORECASE, - ) - if result_blocks: - for title_raw, snippet_raw in result_blocks[:max_results]: - title = _strip_html(title_raw) - snippet = _strip_html(snippet_raw) - if title or snippet: - results.append(f"{title}\n{snippet}") - - # 备选:直接抓取 .result__snippet 和 .result__title - if not results: - snippets = re.findall( - r']*class="result__snippet"[^>]*>(.*?)', html, re.DOTALL | re.IGNORECASE - ) - titles = re.findall( - r']*class="result__title"[^>]*>.*?]*>(.*?).*?]*>', - html, - re.DOTALL | re.IGNORECASE, - ) - for i in range(min(len(titles), len(snippets), max_results)): - title = _strip_html(titles[i]) - snippet = _strip_html(snippets[i]) - if title or snippet: - results.append(f"{title}\n{snippet}") - - if results: - return "\n\n".join(results) - - raise RuntimeError("DuckDuckGo 未解析到结果") - - except Exception as e: - logger.warning(f"DuckDuckGo 搜索失败: {e},回退到 Wikipedia") - wiki_text = await search_wikipedia(query, max_chars=2000) - if wiki_text: - return wiki_text - raise RuntimeError(f"所有搜索源均失败: {e}") - - -def _strip_html(raw: str) -> str: - """去除 HTML 标签并将实体转义还原为可读文本。""" - # 先替换常见 HTML 实体 - raw = raw.replace(" ", " ") - raw = raw.replace(""", '"') - raw = raw.replace("&", "&") - raw = raw.replace("<", "<") - raw = raw.replace(">", ">") - raw = raw.replace("'", "'") - # 去除所有标签 - text = re.sub(r"<[^>]+>", "", raw) - # 合并空白 - text = re.sub(r"\s+", " ", text).strip() - return text - - -async def fetch_search_content(platform_name: str, keyword: str) -> str: - """ - 为指定平台获取与关键词相关的搜索内容。 - - 策略: - 1. 使用关键词直接搜索 DuckDuckGo(频率限制时自动回退 Wikipedia) - 2. 返回搜索结果摘要或百科内容 - """ - logger.info(f"[{platform_name}] 搜索查询: {keyword}") - text = await search_duckduckgo(keyword, max_results=5) - return text diff --git a/backend/app/workers/platforms/tiangong.py b/backend/app/workers/platforms/tiangong.py deleted file mode 100644 index 7b7f4cc..0000000 --- a/backend/app/workers/platforms/tiangong.py +++ /dev/null @@ -1,37 +0,0 @@ -import asyncio -import logging - -from app.workers.platforms.base import BasePlatformAdapter -from app.workers.platforms.search_engine import fetch_search_content - -logger = logging.getLogger(__name__) - - -class TiangongAdapter(BasePlatformAdapter): - """天工AI平台适配器(搜索引擎模式)""" - - platform_name = "tiangong" - platform_url = "https://www.tiangong.cn/" - - async def query(self, keyword: str) -> str: - """在天工AI查询关键词,返回原始响应文本""" - last_error = None - for attempt in range(3): # 最多重试2次,共3次尝试 - try: - return await self._do_query(keyword) - except Exception as e: - last_error = e - logger.warning(f"天工AI查询第 {attempt + 1} 次尝试失败: {e}") - if attempt < 2: - await asyncio.sleep(2 ** attempt) # 指数退避 - - logger.error(f"天工AI查询最终失败: {last_error}") - raise last_error - - async def _do_query(self, keyword: str) -> str: - """单次查询实现:通过搜索引擎获取与关键词相关的真实内容""" - return await fetch_search_content(self.platform_name, keyword) - - async def close(self): - """清理资源(搜索引擎模式无额外资源需要释放)""" - pass diff --git a/backend/app/workers/platforms/tongyi.py b/backend/app/workers/platforms/tongyi.py deleted file mode 100644 index dfcede5..0000000 --- a/backend/app/workers/platforms/tongyi.py +++ /dev/null @@ -1,37 +0,0 @@ -import asyncio -import logging - -from app.workers.platforms.base import BasePlatformAdapter -from app.workers.platforms.search_engine import fetch_search_content - -logger = logging.getLogger(__name__) - - -class TongyiAdapter(BasePlatformAdapter): - """通义千问平台适配器(搜索引擎模式)""" - - platform_name = "tongyi" - platform_url = "https://tongyi.aliyun.com/qianwen" - - async def query(self, keyword: str) -> str: - """在通义千问查询关键词,返回原始响应文本""" - last_error = None - for attempt in range(3): # 最多重试2次,共3次尝试 - try: - return await self._do_query(keyword) - except Exception as e: - last_error = e - logger.warning(f"通义千问查询第 {attempt + 1} 次尝试失败: {e}") - if attempt < 2: - await asyncio.sleep(2 ** attempt) # 指数退避 - - logger.error(f"通义千问查询最终失败: {last_error}") - raise last_error - - async def _do_query(self, keyword: str) -> str: - """单次查询实现:通过搜索引擎获取与关键词相关的真实内容""" - return await fetch_search_content(self.platform_name, keyword) - - async def close(self): - """清理资源(搜索引擎模式无额外资源需要释放)""" - pass diff --git a/backend/app/workers/platforms/wenxin.py b/backend/app/workers/platforms/wenxin.py deleted file mode 100644 index 59b2ab7..0000000 --- a/backend/app/workers/platforms/wenxin.py +++ /dev/null @@ -1,213 +0,0 @@ -""" -文心一言 (百度千帆) 平台适配器 - 使用百度千帆 API 获取真实AI回答 - -API文档: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2 -使用 ERNIE-Bot (completions) chat completion API -认证方式: API Key + Secret Key 换取 access_token -""" - -import asyncio -import logging -import time -from typing import Optional - -import httpx - -from app.config import settings -from app.workers.platforms.base import BasePlatformAdapter -from app.workers.platforms.search_engine import fetch_search_content - -logger = logging.getLogger(__name__) - -# 模块级频率限制器 -_last_request_time: float = 0.0 -_min_interval: float = 6.0 # 每分钟10次 -> 6秒间隔 - -# access_token 缓存 -_cached_token: Optional[str] = None -_token_expires_at: float = 0.0 - - -async def _rate_limit_wait(): - """确保请求间隔不低于最小间隔""" - global _last_request_time - now = time.monotonic() - elapsed = now - _last_request_time - if elapsed < _min_interval: - wait_time = _min_interval - elapsed - logger.debug(f"文心一言 频率限制等待 {wait_time:.1f}s") - await asyncio.sleep(wait_time) - _last_request_time = time.monotonic() - - -class WenxinAdapter(BasePlatformAdapter): - """文心一言平台适配器 - 使用百度千帆 API 获取真实AI回答""" - - platform_name = "wenxin" - platform_url = "https://yiyan.baidu.com" - - # 百度千帆 API 端点 - _token_url = "https://aip.baidubce.com/oauth/2.0/token" - # ERNIE-Bot 4.0 (completions_pro) 或 ERNIE-Bot (completions) - _chat_url_template = ( - "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/{model}" - "?access_token={token}" - ) - _default_model = "completions_pro" # ERNIE-Bot 4.0 - - def __init__(self): - self._api_key: Optional[str] = None - self._secret_key: Optional[str] = None - self._client: Optional[httpx.AsyncClient] = None - - @property - def api_key(self) -> Optional[str]: - if self._api_key is None: - self._api_key = settings.BAIDU_QIANFAN_API_KEY - return self._api_key - - @property - def secret_key(self) -> Optional[str]: - if self._secret_key is None: - self._secret_key = settings.BAIDU_QIANFAN_SECRET_KEY - return self._secret_key - - @property - def is_configured(self) -> bool: - return bool( - self.api_key and self.api_key.strip() - and self.secret_key and self.secret_key.strip() - ) - - def _get_client(self) -> httpx.AsyncClient: - if self._client is None or self._client.is_closed: - self._client = httpx.AsyncClient( - timeout=httpx.Timeout(60.0, connect=10.0), - ) - return self._client - - async def _get_access_token(self) -> str: - """获取百度千帆 access_token,带缓存""" - global _cached_token, _token_expires_at - - now = time.monotonic() - if _cached_token and now < _token_expires_at: - return _cached_token - - client = self._get_client() - response = await client.post( - self._token_url, - params={ - "grant_type": "client_credentials", - "client_id": self.api_key, - "client_secret": self.secret_key, - }, - ) - - if response.status_code != 200: - raise RuntimeError( - f"百度千帆获取 access_token 失败: {response.status_code} {response.text[:300]}" - ) - - data = response.json() - token = data.get("access_token") - if not token: - error_desc = data.get("error_description", "未知错误") - raise RuntimeError(f"百度千帆获取 access_token 失败: {error_desc}") - - # 缓存 token,提前5分钟过期 - expires_in = data.get("expires_in", 2592000) # 默认30天 - _cached_token = token - _token_expires_at = now + expires_in - 300 - - logger.info("百度千帆 access_token 获取成功") - return token - - async def query(self, keyword: str) -> str: - """在文心一言查询关键词,返回原始响应文本""" - last_error = None - for attempt in range(3): - try: - return await self._do_query(keyword) - except Exception as e: - last_error = e - logger.warning(f"文心一言查询第 {attempt + 1} 次尝试失败: {e}") - if attempt < 2: - await asyncio.sleep(2 ** attempt) - - # 所有尝试失败后回退到搜索引擎 - logger.warning(f"文心一言 API 调用全部失败,回退到搜索引擎: {last_error}") - return await self._fallback_search(keyword) - - async def _do_query(self, keyword: str) -> str: - """通过百度千帆 API 获取真实AI回答""" - if not self.is_configured: - logger.warning("百度千帆 API Key 未配置,回退到搜索引擎") - return await self._fallback_search(keyword) - - await _rate_limit_wait() - - access_token = await self._get_access_token() - chat_url = self._chat_url_template.format( - model=self._default_model, - token=access_token, - ) - - client = self._get_client() - payload = { - "messages": [ - { - "role": "user", - "content": keyword, - }, - ], - "system": ( - "你是一个专业的AI搜索助手。请基于你的知识," - "详细回答用户的问题。如果引用了外部来源," - "请在回答中标注来源URL或出处名称。" - ), - "temperature": 0.7, - "max_output_tokens": 2000, - } - - response = await client.post(chat_url, json=payload) - - if response.status_code == 429: - retry_after = int(response.headers.get("Retry-After", "10")) - logger.warning(f"百度千帆 API 限流,等待 {retry_after}s 后重试") - await asyncio.sleep(retry_after) - raise RuntimeError(f"百度千帆 API 限流") - - if response.status_code != 200: - error_body = response.text[:500] - raise RuntimeError( - f"百度千帆 API 返回错误 {response.status_code}: {error_body}" - ) - - data = response.json() - - # 检查API错误码 - error_code = data.get("error_code") - if error_code: - error_msg = data.get("error_msg", "未知错误") - raise RuntimeError(f"百度千帆 API 错误 {error_code}: {error_msg}") - - content = data.get("result", "") - if not content: - raise RuntimeError("百度千帆 API 返回空内容") - - # 标记数据来源为AI平台 - result = f"[data_source: ai_platform]\n{content}" - logger.info(f"文心一言 API 调用成功,返回 {len(content)} 字符") - return result - - async def _fallback_search(self, keyword: str) -> str: - """回退到搜索引擎模式,标记数据来源""" - content = await fetch_search_content(self.platform_name, keyword) - return f"[data_source: search_engine]\n{content}" - - async def close(self): - """清理资源""" - if self._client and not self._client.is_closed: - await self._client.aclose() - self._client = None diff --git a/backend/app/workers/platforms/xinghuo.py b/backend/app/workers/platforms/xinghuo.py deleted file mode 100644 index d4af0d1..0000000 --- a/backend/app/workers/platforms/xinghuo.py +++ /dev/null @@ -1,37 +0,0 @@ -import asyncio -import logging - -from app.workers.platforms.base import BasePlatformAdapter -from app.workers.platforms.search_engine import fetch_search_content - -logger = logging.getLogger(__name__) - - -class XinghuoAdapter(BasePlatformAdapter): - """讯飞星火平台适配器(搜索引擎模式)""" - - platform_name = "xinghuo" - platform_url = "https://xinghuo.xfyun.cn/" - - async def query(self, keyword: str) -> str: - """在讯飞星火查询关键词,返回原始响应文本""" - last_error = None - for attempt in range(3): # 最多重试2次,共3次尝试 - try: - return await self._do_query(keyword) - except Exception as e: - last_error = e - logger.warning(f"讯飞星火查询第 {attempt + 1} 次尝试失败: {e}") - if attempt < 2: - await asyncio.sleep(2 ** attempt) # 指数退避 - - logger.error(f"讯飞星火查询最终失败: {last_error}") - raise last_error - - async def _do_query(self, keyword: str) -> str: - """单次查询实现:通过搜索引擎获取与关键词相关的真实内容""" - return await fetch_search_content(self.platform_name, keyword) - - async def close(self): - """清理资源(搜索引擎模式无额外资源需要释放)""" - pass diff --git a/backend/app/workers/scheduler.py b/backend/app/workers/scheduler.py index fa03bba..9990e29 100644 --- a/backend/app/workers/scheduler.py +++ b/backend/app/workers/scheduler.py @@ -3,7 +3,8 @@ - 使用 APScheduler 的 AsyncIOScheduler - 每小时执行一次检查 - 查找 queries 表中 status='active' 且 next_query_at <= now() 的记录 -- 为每个符合条件的 query 调用 CitationEngine 执行查询 +- 为每个符合条件的 query 通过 Agent 框架执行查询 +- 如果 Agent 框架不可用,回退到直接使用 CitationEngine """ import asyncio @@ -23,13 +24,41 @@ from app.workers.citation_engine import CitationEngine logger = logging.getLogger(__name__) +# Lazy singleton for CitationDetectorAgent — defers import to avoid +# circular-dependency issues at module-load time. +_agent_instance = None + + +def _get_agent(): + """Return a cached CitationDetectorAgent instance (lazy singleton).""" + global _agent_instance + if _agent_instance is None: + from app.agent_framework.agents.citation_detector import CitationDetectorAgent + _agent_instance = CitationDetectorAgent() + return _agent_instance + class QueryScheduler: def __init__(self): self.scheduler = AsyncIOScheduler() - self.engine = CitationEngine() + self._agent = None # Lazy-initialized via _get_agent() + self._fallback_engine = None # Only created when agent fails self._loop = None + @property + def agent(self): + """Lazy-accessor for CitationDetectorAgent.""" + if self._agent is None: + self._agent = _get_agent() + return self._agent + + @property + def fallback_engine(self): + """Lazy-accessor for fallback CitationEngine (only when agent fails).""" + if self._fallback_engine is None: + self._fallback_engine = CitationEngine() + return self._fallback_engine + def start(self): """启动调度器""" self._loop = asyncio.get_event_loop() @@ -83,14 +112,17 @@ class QueryScheduler: logger.error(f"检查查询任务时出错: {e}") async def _execute_single_query(self, query: Query, db: AsyncSession): - """执行单个查询""" + """执行单个查询 — 优先通过 Agent 框架,失败时回退到直接引擎""" logger.info(f"开始执行查询: {query.keyword} (ID: {query.id})") try: - await self.engine.execute_query(query, db) - logger.info(f"查询 {query.id} 执行完成") - except Exception as e: - logger.error(f"查询 {query.id} 执行失败: {e}") - raise + await self.agent.execute_query_compat(query, db) + logger.info(f"查询 {query.id} 执行完成 (via Agent)") + except Exception as agent_err: + logger.warning( + f"Agent 框架执行查询 {query.id} 失败: {agent_err},回退到直接引擎" + ) + await self.fallback_engine.execute_query(query, db) + logger.info(f"查询 {query.id} 执行完成 (via fallback engine)") def _run_pending_tasks_check(self): """同步包装:将异步遗留任务检查调度到当前事件循环""" @@ -135,7 +167,7 @@ class QueryScheduler: task.error_message = None await db.commit() - citation_result = await self.engine.execute_single_platform( + citation_result = await self._execute_single_platform( keyword=query.keyword, platform=task.platform, target_brand=query.target_brand, @@ -143,22 +175,10 @@ class QueryScheduler: ) if citation_result: - record = CitationRecord( + record = CitationRecord.from_citation_result( query_id=query_id, platform=task.platform, - cited=citation_result.get("cited", False), - citation_position=citation_result.get("position"), - citation_text=citation_result.get("citation_text"), - competitor_brands=citation_result.get("competitor_brands", []), - raw_response=citation_result.get("raw_response", ""), - confidence=citation_result.get("confidence"), - match_type=citation_result.get("match_type"), - # 引用源分析字段 - data_source=citation_result.get("data_source"), - source_urls=citation_result.get("source_urls"), - source_titles=citation_result.get("source_titles"), - citation_contexts=citation_result.get("citation_contexts"), - ai_response_text=citation_result.get("ai_response_text", ""), + result=citation_result, ) db.add(record) @@ -177,10 +197,42 @@ class QueryScheduler: except Exception as e: logger.error(f"检查遗留任务时出错: {e}") + async def _execute_single_platform( + self, + keyword: str, + platform: str, + target_brand: str, + brand_aliases: list, + ) -> dict: + """执行单平台检测 — 优先通过 Agent 框架,失败时回退到直接引擎""" + try: + result = await self.agent.execute_single_platform_compat( + keyword=keyword, + platform=platform, + target_brand=target_brand, + brand_aliases=brand_aliases, + ) + return result + except Exception as agent_err: + logger.warning( + f"Agent 框架执行单平台检测失败 ({platform}): {agent_err}," + "回退到直接引擎" + ) + result = await self.fallback_engine.execute_single_platform( + keyword=keyword, + platform=platform, + target_brand=target_brand, + brand_aliases=brand_aliases, + ) + return result + async def shutdown(self): """关闭调度器""" self.scheduler.shutdown(wait=False) - await self.engine.close() + if self._agent is not None: + await self._agent.close() + if self._fallback_engine is not None: + await self._fallback_engine.close() logger.info("查询调度器已关闭") diff --git a/backend/requirements.txt b/backend/requirements.txt index ea34e02..3f81c1e 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -36,6 +36,7 @@ pyyaml>=6.0 # 测试依赖 pytest>=8.0 pytest-asyncio>=0.23.0 +pytest-cov>=5.0 aiosqlite # PDF生成 diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 86e152e..dc3c003 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -1,20 +1,32 @@ import logging import uuid from datetime import datetime +from unittest.mock import AsyncMock, patch import pytest import pytest_asyncio +from httpx import ASGITransport, AsyncClient from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession from sqlalchemy.pool import StaticPool -from app.database import Base -from app.models.user import User +from app.database import Base, get_db +from app.main import app +from app.api.deps import get_current_user from app.middleware.logging_filter import APIKeyFilter +from app.models.user import User +from app.services.auth import create_access_token + +from tests.fixtures.auth import _make_user, _to_uuid + +pytest_plugins = [ + "tests.fixtures.database", + "tests.fixtures.brands", + "tests.fixtures.client", +] @pytest.fixture(autouse=True) def add_api_key_filter(): - """自动为每个测试添加APIKeyFilter到root logger""" root_logger = logging.getLogger() api_key_filter = APIKeyFilter() root_logger.addFilter(api_key_filter) @@ -22,9 +34,16 @@ def add_api_key_filter(): root_logger.removeFilter(api_key_filter) +@pytest.fixture(scope="session", autouse=True) +def mock_scheduler(): + with patch("app.main.query_scheduler") as mock_sched: + mock_sched.start = lambda: None + mock_sched.shutdown = AsyncMock() + yield + + @pytest_asyncio.fixture -async def async_engine(): - """Create async engine for testing with SQLite.""" +async def test_engine(): engine = create_async_engine( "sqlite+aiosqlite:///:memory:", connect_args={"check_same_thread": False}, @@ -37,10 +56,9 @@ async def async_engine(): @pytest_asyncio.fixture -async def async_session(async_engine): - """Create async session for testing.""" +async def test_session(test_engine): async_session_maker = async_sessionmaker( - async_engine, + test_engine, class_=AsyncSession, expire_on_commit=False, autoflush=False, @@ -48,23 +66,53 @@ async def async_session(async_engine): ) async with async_session_maker() as session: yield session - await session.rollback() @pytest_asyncio.fixture -async def test_user(async_session): - """Create a test user.""" - user = User( - id=uuid.uuid4(), - email="test@example.com", - password_hash="hashed_password", - name="Test User", - plan="free", - max_queries=5, - is_active=True, - email_verified=True, - ) - async_session.add(user) - await async_session.commit() - await async_session.refresh(user) +async def override_get_db(test_session): + async def _get_db(): + yield test_session + + app.dependency_overrides[get_db] = _get_db + yield test_session + app.dependency_overrides.pop(get_db, None) + + +@pytest.fixture +def mock_user(): + user = AsyncMock() + user.id = uuid.UUID("12345678-1234-1234-1234-123456789abc") + user.email = "test@example.com" + user.name = "Test User" + user.plan = "free" + user.max_queries = 5 + user.is_active = True + user.created_at = datetime.now() return user + + +@pytest.fixture +def override_get_current_user(mock_user): + async def _override(): + return mock_user + + app.dependency_overrides[get_current_user] = _override + yield + app.dependency_overrides.pop(get_current_user, None) + + +@pytest.fixture +def auth_token(mock_user): + return create_access_token(data={"sub": str(mock_user.id)}) + + +@pytest.fixture +def auth_headers(auth_token): + return {"Authorization": f"Bearer {auth_token}"} + + +@pytest_asyncio.fixture +async def plain_client(): + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + yield client diff --git a/backend/tests/fixtures/__init__.py b/backend/tests/fixtures/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/tests/fixtures/auth.py b/backend/tests/fixtures/auth.py new file mode 100644 index 0000000..e83fe4f --- /dev/null +++ b/backend/tests/fixtures/auth.py @@ -0,0 +1,30 @@ +import uuid + +from app.models.user import User +from app.services.auth import hash_password + + +def _make_user( + user_id: str | None = None, + email: str = "test@example.com", + plan: str = "free", +) -> User: + uid = user_id or str(uuid.uuid4()) + user = User( + id=uid, + email=email, + password=hash_password("Test@123456"), + firstName="Test", + lastName="User", + isActive=True, + emailVerified=True, + ) + user.plan = plan + user.max_queries = 50 if plan != "free" else 5 + return user + + +def _to_uuid(value: str | uuid.UUID) -> uuid.UUID: + if isinstance(value, uuid.UUID): + return value + return uuid.UUID(str(value)) diff --git a/backend/tests/fixtures/brands.py b/backend/tests/fixtures/brands.py new file mode 100644 index 0000000..f3a2a62 --- /dev/null +++ b/backend/tests/fixtures/brands.py @@ -0,0 +1,63 @@ +import uuid + +import pytest_asyncio + +from app.models.brand import Brand + +from .auth import _make_user, _to_uuid + + +@pytest_asyncio.fixture +async def test_user(async_session): + user = _make_user(plan="free") + async_session.add(user) + await async_session.commit() + await async_session.refresh(user) + return user + + +@pytest_asyncio.fixture +async def paid_user(async_session): + user = _make_user(email="paid@example.com", plan="pro") + async_session.add(user) + await async_session.commit() + await async_session.refresh(user) + return user + + +@pytest_asyncio.fixture +async def test_brand(async_session, test_user): + brand = Brand( + id=uuid.uuid4(), + user_id=_to_uuid(test_user.id), + name="TestBrand", + aliases=["TB"], + website="https://testbrand.com", + industry="technology", + platforms=["wenxin", "kimi"], + frequency="weekly", + status="active", + ) + async_session.add(brand) + await async_session.commit() + await async_session.refresh(brand) + return brand + + +@pytest_asyncio.fixture +async def paid_brand(async_session, paid_user): + brand = Brand( + id=uuid.uuid4(), + user_id=_to_uuid(paid_user.id), + name="PaidBrand", + aliases=["PB"], + website="https://paidbrand.com", + industry="technology", + platforms=["wenxin", "kimi"], + frequency="weekly", + status="active", + ) + async_session.add(brand) + await async_session.commit() + await async_session.refresh(brand) + return brand diff --git a/backend/tests/fixtures/client.py b/backend/tests/fixtures/client.py new file mode 100644 index 0000000..dfaf1f5 --- /dev/null +++ b/backend/tests/fixtures/client.py @@ -0,0 +1,41 @@ +import pytest_asyncio +from httpx import ASGITransport, AsyncClient + +from app.api.deps import get_current_user, get_db +from app.main import app + + +@pytest_asyncio.fixture +async def async_client(async_session, test_user): + async def override_get_db(): + yield async_session + + async def override_get_current_user(): + return test_user + + app.dependency_overrides[get_db] = override_get_db + app.dependency_overrides[get_current_user] = override_get_current_user + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + yield client + + app.dependency_overrides.clear() + + +@pytest_asyncio.fixture +async def paid_client(async_session, paid_user): + async def override_get_db(): + yield async_session + + async def override_get_current_user(): + return paid_user + + app.dependency_overrides[get_db] = override_get_db + app.dependency_overrides[get_current_user] = override_get_current_user + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + yield client + + app.dependency_overrides.clear() diff --git a/backend/tests/fixtures/database.py b/backend/tests/fixtures/database.py new file mode 100644 index 0000000..6ad7c00 --- /dev/null +++ b/backend/tests/fixtures/database.py @@ -0,0 +1,31 @@ +import pytest_asyncio +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.pool import StaticPool + +from app.database import Base + + +@pytest_asyncio.fixture +async def async_engine(): + engine = create_async_engine( + "sqlite+aiosqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + yield engine + await engine.dispose() + + +@pytest_asyncio.fixture +async def async_session(async_engine): + maker = async_sessionmaker( + async_engine, + class_=AsyncSession, + expire_on_commit=False, + autoflush=False, + autocommit=False, + ) + async with maker() as session: + yield session diff --git a/backend/tests/test_agent_framework/test_agents_integration.py b/backend/tests/test_agent_framework/test_agents_integration.py index e5da82f..9f4394e 100644 --- a/backend/tests/test_agent_framework/test_agents_integration.py +++ b/backend/tests/test_agent_framework/test_agents_integration.py @@ -183,17 +183,17 @@ class TestContentGeneratorAgent: assert result.status in [TaskStatus.COMPLETED, TaskStatus.FAILED] def test_extract_json_method(self): - """测试JSON提取方法""" - agent = ContentGeneratorAgent() + """测试JSON提取方法(已提取到 app.utils.json_extractor)""" + from app.utils.json_extractor import extract_json # 测试普通JSON json_text = '{"title": "测试标题", "reason": "测试原因"}' - extracted = agent._extract_json(json_text) + extracted = extract_json(json_text) assert "title" in extracted # 测试被markdown包裹的JSON md_text = '```json\n{"title": "测试"}\n```' - extracted = agent._extract_json(md_text) + extracted = extract_json(md_text) assert "title" in extracted diff --git a/backend/tests/test_agent_framework/test_scheduler_agent_integration.py b/backend/tests/test_agent_framework/test_scheduler_agent_integration.py new file mode 100644 index 0000000..4c3ac00 --- /dev/null +++ b/backend/tests/test_agent_framework/test_scheduler_agent_integration.py @@ -0,0 +1,224 @@ +"""Tests that the scheduler uses the Agent framework for citation detection. + +TDD RED phase: these tests define the desired behavior where the scheduler +dispatches citation tasks through the Agent framework instead of directly +using CitationEngine. +""" +import inspect +import pytest +import uuid +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock, patch, PropertyMock + +from app.agent_framework.protocol import ( + TaskMessage, + TaskResult, + TaskStatus, +) + + +class TestSchedulerUsesAgentFramework: + """Scheduler should dispatch citation tasks through the Agent framework.""" + + @pytest.mark.asyncio + async def test_scheduler_dispatches_via_agent_for_full_query(self): + """check_and_execute_queries should use CitationDetectorAgent + instead of directly calling CitationEngine.""" + # Setup mock agent + mock_agent_instance = AsyncMock() + mock_agent_instance.execute_query_compat = AsyncMock(return_value=[]) + + with patch("app.workers.scheduler.AsyncSessionLocal") as mock_session_local, \ + patch("app.workers.scheduler._get_agent", return_value=mock_agent_instance), \ + patch("app.workers.scheduler.CitationEngine") as MockEngine: + + # Setup mock db session + mock_db = AsyncMock() + mock_session_local.return_value.__aenter__ = AsyncMock(return_value=mock_db) + mock_session_local.return_value.__aexit__ = AsyncMock(return_value=False) + + # Setup mock query + mock_query = MagicMock() + mock_query.id = uuid.uuid4() + mock_query.keyword = "test keyword" + mock_query.status = "active" + mock_query.next_query_at = datetime.now(timezone.utc) + + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [mock_query] + mock_db.execute = AsyncMock(return_value=mock_result) + + from app.workers.scheduler import QueryScheduler + scheduler = QueryScheduler() + await scheduler.check_and_execute_queries() + + # Agent's compat method should have been called + mock_agent_instance.execute_query_compat.assert_called_once() + # Direct engine should NOT have been used + MockEngine.assert_not_called() + + @pytest.mark.asyncio + async def test_scheduler_dispatches_via_agent_for_pending_tasks(self): + """check_and_execute_pending_tasks should use CitationDetectorAgent + instead of directly calling CitationEngine.""" + # Setup mock agent + mock_agent_instance = AsyncMock() + mock_agent_instance.execute_single_platform_compat = AsyncMock( + return_value={"cited": False} + ) + + with patch("app.workers.scheduler.AsyncSessionLocal") as mock_session_local, \ + patch("app.workers.scheduler._get_agent", return_value=mock_agent_instance), \ + patch("app.workers.scheduler.CitationEngine") as MockEngine: + + # Setup mock db session + mock_db = AsyncMock() + mock_session_local.return_value.__aenter__ = AsyncMock(return_value=mock_db) + mock_session_local.return_value.__aexit__ = AsyncMock(return_value=False) + + # Setup mock query task + query_id = uuid.uuid4() + mock_task = MagicMock() + mock_task.id = uuid.uuid4() + mock_task.query_id = query_id + mock_task.platform = "kimi" + mock_task.status = "pending" + mock_task.scheduled_at = datetime.now(timezone.utc) + + mock_task_result = MagicMock() + mock_task_result.scalars.return_value.all.return_value = [mock_task] + + # Setup mock query + mock_query = MagicMock() + mock_query.id = query_id + mock_query.status = "active" + mock_query.keyword = "test" + mock_query.target_brand = "Brand" + mock_query.brand_aliases = [] + + mock_query_result = MagicMock() + mock_query_result.scalar_one_or_none.return_value = mock_query + + # First call returns tasks, second returns query + mock_db.execute = AsyncMock( + side_effect=[mock_task_result, mock_query_result] + ) + mock_db.commit = AsyncMock() + + from app.workers.scheduler import QueryScheduler + scheduler = QueryScheduler() + await scheduler.check_and_execute_pending_tasks() + + # Agent's compat method should have been called + mock_agent_instance.execute_single_platform_compat.assert_called_once() + # Direct engine should NOT have been used + MockEngine.assert_not_called() + + @pytest.mark.asyncio + async def test_scheduler_falls_back_to_engine_on_agent_failure(self): + """If Agent framework fails, scheduler should fall back to direct engine.""" + # Agent raises exception + mock_agent_instance = AsyncMock() + mock_agent_instance.execute_query_compat = AsyncMock( + side_effect=Exception("Agent framework unavailable") + ) + + with patch("app.workers.scheduler.AsyncSessionLocal") as mock_session_local, \ + patch("app.workers.scheduler._get_agent", return_value=mock_agent_instance), \ + patch("app.workers.scheduler.CitationEngine") as MockEngine: + + # Fallback engine works + mock_engine_instance = AsyncMock() + mock_engine_instance.execute_query = AsyncMock(return_value=[]) + MockEngine.return_value = mock_engine_instance + + # Setup mock db session + mock_db = AsyncMock() + mock_session_local.return_value.__aenter__ = AsyncMock(return_value=mock_db) + mock_session_local.return_value.__aexit__ = AsyncMock(return_value=False) + + # Setup mock query + mock_query = MagicMock() + mock_query.id = uuid.uuid4() + mock_query.keyword = "test keyword" + mock_query.status = "active" + mock_query.next_query_at = datetime.now(timezone.utc) + + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [mock_query] + mock_db.execute = AsyncMock(return_value=mock_result) + + from app.workers.scheduler import QueryScheduler + scheduler = QueryScheduler() + await scheduler.check_and_execute_queries() + + # Agent was tried first + mock_agent_instance.execute_query_compat.assert_called_once() + # Fallback engine was used + MockEngine.assert_called_once() + mock_engine_instance.execute_query.assert_called_once() + + @pytest.mark.asyncio + async def test_citation_detector_agent_produces_same_interface_as_engine(self): + """CitationDetectorAgent compat methods should have the same + call signature and return type as CitationEngine methods.""" + from app.agent_framework.agents.citation_detector import CitationDetectorAgent + from app.workers.citation_engine import CitationEngine + + engine_query_sig = inspect.signature(CitationEngine.execute_query) + agent_compat_sig = inspect.signature(CitationDetectorAgent.execute_query_compat) + + # Same parameter names + engine_params = list(engine_query_sig.parameters.keys()) + agent_params = list(agent_compat_sig.parameters.keys()) + assert engine_params == agent_params, ( + f"Parameter mismatch: engine={engine_params}, agent={agent_params}" + ) + + engine_single_sig = inspect.signature(CitationEngine.execute_single_platform) + agent_single_sig = inspect.signature( + CitationDetectorAgent.execute_single_platform_compat + ) + + engine_single_params = list(engine_single_sig.parameters.keys()) + agent_single_params = list(agent_single_sig.parameters.keys()) + assert engine_single_params == agent_single_params, ( + f"Parameter mismatch: engine={engine_single_params}, " + f"agent={agent_single_params}" + ) + + @pytest.mark.asyncio + async def test_scheduler_does_not_create_engine_in_init(self): + """QueryScheduler should NOT create a CitationEngine instance + in __init__. It should use the Agent framework instead.""" + with patch("app.workers.scheduler.CitationEngine") as MockEngine: + from app.workers.scheduler import QueryScheduler + scheduler = QueryScheduler() + + # Engine should NOT have been instantiated in __init__ + MockEngine.assert_not_called() + + +class TestBaseAgentDispatcherInjection: + """BaseAgent should not create TaskDispatcher on every method call.""" + + def test_base_agent_does_not_create_dispatcher_in_report_progress(self): + """report_progress should use an injected dispatcher, not create a new one.""" + from app.agent_framework.base import BaseAgent + + # Check that report_progress doesn't import TaskDispatcher inline + source = inspect.getsource(BaseAgent.report_progress) + assert "TaskDispatcher(" not in source, ( + "report_progress should not instantiate TaskDispatcher directly. " + "Use an injected or cached dispatcher instead." + ) + + def test_base_agent_does_not_create_dispatcher_in_execute_task(self): + """_execute_task should use an injected dispatcher, not create a new one.""" + from app.agent_framework.base import BaseAgent + + source = inspect.getsource(BaseAgent._execute_task) + assert "TaskDispatcher(" not in source, ( + "_execute_task should not instantiate TaskDispatcher directly. " + "Use an injected or cached dispatcher instead." + ) diff --git a/backend/tests/test_api/test_attribution_contract.py b/backend/tests/test_api/test_attribution_contract.py new file mode 100644 index 0000000..ee6eb90 --- /dev/null +++ b/backend/tests/test_api/test_attribution_contract.py @@ -0,0 +1,372 @@ +import uuid +from datetime import UTC, datetime, timedelta + +import pytest +import pytest_asyncio +from httpx import ASGITransport, AsyncClient +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.pool import StaticPool + +from app.api.deps import get_current_user, get_db +from app.database import Base +from app.main import app +from app.models.attribution_record import AttributionRecord +from app.models.brand import Brand +from app.models.diagnosis_record import DiagnosisRecord +from app.models.user import User +from app.services.auth import hash_password + + +def _make_user( + user_id: str | None = None, + email: str = "test@example.com", + plan: str = "free", +) -> User: + uid = user_id or str(uuid.uuid4()) + user = User( + id=uid, + email=email, + password=hash_password("Test@123456"), + firstName="Test", + lastName="User", + isActive=True, + emailVerified=True, + ) + user.plan = plan + user.max_queries = 50 if plan != "free" else 5 + return user + + +def _to_uuid(value: str | uuid.UUID) -> uuid.UUID: + if isinstance(value, uuid.UUID): + return value + return uuid.UUID(str(value)) + + +@pytest_asyncio.fixture +async def async_engine(): + engine = create_async_engine( + "sqlite+aiosqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + yield engine + await engine.dispose() + + +@pytest_asyncio.fixture +async def async_session(async_engine): + maker = async_sessionmaker( + async_engine, + class_=AsyncSession, + expire_on_commit=False, + autoflush=False, + autocommit=False, + ) + async with maker() as session: + yield session + + +@pytest_asyncio.fixture +async def test_user(async_session): + user = _make_user(plan="pro") + async_session.add(user) + await async_session.commit() + await async_session.refresh(user) + return user + + +@pytest_asyncio.fixture +async def test_brand(async_session, test_user): + brand = Brand( + id=uuid.uuid4(), + user_id=_to_uuid(test_user.id), + name="TestBrand", + aliases=["TB"], + website="https://testbrand.com", + industry="technology", + platforms=["wenxin"], + frequency="weekly", + status="active", + ) + async_session.add(brand) + await async_session.commit() + await async_session.refresh(brand) + return brand + + +@pytest_asyncio.fixture +async def test_diagnosis(async_session, test_brand, test_user): + record = DiagnosisRecord( + brand_id=test_brand.id, + user_id=_to_uuid(test_user.id), + diagnosis_type="geo", + status="completed", + overall_score=45.0, + result_json={ + "overall_score": 45.0, + "health_level": "pass", + "dimensions": [ + {"name": "内容可提取性", "score": 50}, + {"name": "E-E-A-T信号", "score": 40}, + {"name": "引用就绪度", "score": 45}, + ], + }, + completed_at=datetime.now(UTC), + ) + async_session.add(record) + await async_session.commit() + await async_session.refresh(record) + return record + + +@pytest_asyncio.fixture +async def async_client(async_session, test_user): + async def override_get_db(): + yield async_session + + async def override_get_current_user(): + return test_user + + app.dependency_overrides[get_db] = override_get_db + app.dependency_overrides[get_current_user] = override_get_current_user + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + yield client + + app.dependency_overrides.clear() + + +def _make_attribution_record( + user_id: str, + brand_id: uuid.UUID, + baseline_score: float = 45.0, + current_score: float | None = None, + score_delta: float | None = None, + status: str = "tracking", + published_at: datetime | None = None, + window_end_at: datetime | None = None, +) -> AttributionRecord: + now = datetime.now(UTC) + return AttributionRecord( + user_id=user_id, + brand_id=brand_id, + baseline_score=baseline_score, + current_score=current_score, + score_delta=score_delta, + status=status, + published_at=published_at or now, + window_end_at=window_end_at or (now + timedelta(days=28)), + ) + + +class TestStartTrackingContract: + @pytest.mark.asyncio + async def test_start_creates_attribution_record( + self, async_client, test_brand, test_diagnosis + ): + response = await async_client.post( + "/api/v1/attribution/start", + json={"brand_id": str(test_brand.id)}, + ) + assert response.status_code == 200 + data = response.json() + assert "id" in data + assert data["brand_id"] == str(test_brand.id) + assert data["baseline_score"] == 45.0 + assert data["status"] == "tracking" + assert data["content_id"] is None + + @pytest.mark.asyncio + async def test_start_with_content_id( + self, async_client, test_brand, test_diagnosis + ): + content_id = str(uuid.uuid4()) + response = await async_client.post( + "/api/v1/attribution/start", + json={"brand_id": str(test_brand.id), "content_id": content_id}, + ) + assert response.status_code == 200 + data = response.json() + assert data["content_id"] == content_id + + @pytest.mark.asyncio + async def test_start_with_invalid_brand_id(self, async_client): + response = await async_client.post( + "/api/v1/attribution/start", + json={"brand_id": str(uuid.uuid4())}, + ) + assert response.status_code == 404 + + @pytest.mark.asyncio + async def test_start_without_diagnosis_uses_zero_baseline( + self, async_client, test_brand + ): + response = await async_client.post( + "/api/v1/attribution/start", + json={"brand_id": str(test_brand.id)}, + ) + assert response.status_code == 200 + data = response.json() + assert data["baseline_score"] == 0.0 + + +class TestGetBrandAttributionContract: + @pytest.mark.asyncio + async def test_get_brand_attribution_summary( + self, async_client, test_brand, test_user, test_diagnosis, async_session + ): + record = _make_attribution_record( + user_id=test_user.id, + brand_id=test_brand.id, + current_score=55.0, + score_delta=10.0, + ) + async_session.add(record) + await async_session.commit() + + response = await async_client.get( + f"/api/v1/attribution/brand/{test_brand.id}" + ) + assert response.status_code == 200 + data = response.json() + assert "records" in data + assert "total_score_delta" in data + assert "tracking_count" in data + assert len(data["records"]) >= 1 + + +class TestCheckAttributionContract: + @pytest.mark.asyncio + async def test_check_updates_attribution_record( + self, async_client, test_brand, test_user, test_diagnosis, async_session + ): + record = _make_attribution_record( + user_id=test_user.id, + brand_id=test_brand.id, + ) + async_session.add(record) + await async_session.commit() + await async_session.refresh(record) + + response = await async_client.post( + f"/api/v1/attribution/{record.id}/check" + ) + assert response.status_code == 200 + data = response.json() + assert data["current_score"] is not None + assert data["score_delta"] is not None + + @pytest.mark.asyncio + async def test_check_nonexistent_record(self, async_client): + response = await async_client.post( + f"/api/v1/attribution/{uuid.uuid4()}/check" + ) + assert response.status_code == 404 + + +class TestGetROIReportContract: + @pytest.mark.asyncio + async def test_get_roi_report( + self, async_client, test_brand, test_user, test_diagnosis, async_session + ): + record = _make_attribution_record( + user_id=test_user.id, + brand_id=test_brand.id, + current_score=55.0, + score_delta=10.0, + ) + async_session.add(record) + await async_session.commit() + + response = await async_client.get( + f"/api/v1/attribution/roi/{test_brand.id}" + ) + assert response.status_code == 200 + data = response.json() + assert "roi_percentage" in data + assert "value_generated" in data + assert "subscription_cost" in data + assert "break_even_delta" in data + assert "brand_name" in data + assert "current_plan" in data + assert "tracking_records" in data + assert data["brand_name"] == "TestBrand" + + @pytest.mark.asyncio + async def test_get_roi_invalid_brand(self, async_client): + response = await async_client.get( + f"/api/v1/attribution/roi/{uuid.uuid4()}" + ) + assert response.status_code == 404 + + +class TestGetABComparisonContract: + @pytest.mark.asyncio + async def test_get_ab_comparison( + self, async_client, test_brand, test_user, test_diagnosis, async_session + ): + later_diagnosis = DiagnosisRecord( + brand_id=test_brand.id, + user_id=_to_uuid(test_user.id), + diagnosis_type="geo", + status="completed", + overall_score=65.0, + result_json={ + "overall_score": 65.0, + "health_level": "good", + "dimensions": [ + {"name": "内容可提取性", "score": 70}, + {"name": "E-E-A-T信号", "score": 60}, + {"name": "引用就绪度", "score": 65}, + ], + }, + completed_at=datetime.now(UTC) + timedelta(hours=1), + ) + async_session.add(later_diagnosis) + await async_session.commit() + + response = await async_client.get( + f"/api/v1/attribution/ab-comparison/{test_brand.id}" + ) + assert response.status_code == 200 + data = response.json() + assert "overall_before" in data + assert "overall_after" in data + assert "overall_delta" in data + assert "dimensions" in data + assert data["brand_name"] == "TestBrand" + assert len(data["dimensions"]) > 0 + + @pytest.mark.asyncio + async def test_get_ab_comparison_no_data(self, async_client, test_brand): + response = await async_client.get( + f"/api/v1/attribution/ab-comparison/{test_brand.id}" + ) + assert response.status_code == 404 + + +class TestAttributionWindowExpiration: + @pytest.mark.asyncio + async def test_expired_window_marks_completed( + self, async_client, test_brand, test_user, test_diagnosis, async_session + ): + record = _make_attribution_record( + user_id=test_user.id, + brand_id=test_brand.id, + published_at=datetime.now(UTC) - timedelta(days=30), + window_end_at=datetime.now(UTC) - timedelta(days=2), + ) + async_session.add(record) + await async_session.commit() + await async_session.refresh(record) + + response = await async_client.post( + f"/api/v1/attribution/{record.id}/check" + ) + assert response.status_code == 200 + data = response.json() + assert data["status"] == "completed" diff --git a/tests/test_auth.py b/backend/tests/test_api/test_auth.py similarity index 100% rename from tests/test_auth.py rename to backend/tests/test_api/test_auth.py diff --git a/tests/test_citations.py b/backend/tests/test_api/test_citations.py similarity index 100% rename from tests/test_citations.py rename to backend/tests/test_api/test_citations.py diff --git a/backend/tests/test_api/test_content_distribution_contract.py b/backend/tests/test_api/test_content_distribution_contract.py new file mode 100644 index 0000000..6997ae8 --- /dev/null +++ b/backend/tests/test_api/test_content_distribution_contract.py @@ -0,0 +1,308 @@ +import uuid +from unittest.mock import AsyncMock, patch + +import pytest +import pytest_asyncio +from httpx import AsyncClient, ASGITransport +from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession, create_async_engine +from sqlalchemy.pool import StaticPool + +from app.database import Base +from app.main import app +from app.models.brand import Brand +from app.models.content import Content +from app.models.diagnosis_record import DiagnosisRecord +from app.models.user import User +from app.api.deps import get_current_user, get_db +from app.services.auth import hash_password + + +def _to_uuid(value: str | uuid.UUID) -> uuid.UUID: + if isinstance(value, uuid.UUID): + return value + return uuid.UUID(str(value)) + + +@pytest_asyncio.fixture +async def async_engine(): + engine = create_async_engine( + "sqlite+aiosqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + yield engine + await engine.dispose() + + +@pytest_asyncio.fixture +async def async_session(async_engine): + async_session_maker = async_sessionmaker( + async_engine, + class_=AsyncSession, + expire_on_commit=False, + autoflush=False, + autocommit=False, + ) + async with async_session_maker() as session: + yield session + + +@pytest_asyncio.fixture +async def test_user(async_session): + user_id = str(uuid.uuid4()) + user = User( + id=user_id, + email="test_dist@example.com", + password=hash_password("Test@123456"), + firstName="Test", + lastName="User", + isActive=True, + emailVerified=True, + organization_id=uuid.uuid4(), + ) + async_session.add(user) + await async_session.commit() + await async_session.refresh(user) + return user + + +@pytest_asyncio.fixture +async def test_brand(async_session, test_user): + brand = Brand( + id=uuid.uuid4(), + user_id=_to_uuid(test_user.id), + name="Test Brand", + aliases=["TestBrand"], + website="https://testbrand.com", + industry="technology", + platforms=["wenxin", "kimi"], + frequency="weekly", + status="active", + ) + async_session.add(brand) + await async_session.commit() + await async_session.refresh(brand) + return brand + + +@pytest_asyncio.fixture +async def test_content(async_session, test_user): + content = Content( + id=uuid.uuid4(), + organization_id=test_user.organization_id, + title="测试文章", + content_type="article", + body="这是一篇测试文章的内容,用于发布测试。", + status="draft", + target_platforms=["zhihu", "toutiao"], + keywords=["测试"], + created_by=test_user.id, + current_version=1, + ) + async_session.add(content) + await async_session.commit() + await async_session.refresh(content) + return content + + +@pytest_asyncio.fixture +async def test_diagnosis(async_session, test_brand): + diagnosis = DiagnosisRecord( + id=uuid.uuid4(), + brand_id=test_brand.id, + user_id=_to_uuid(test_brand.user_id), + diagnosis_type="geo", + status="completed", + overall_score=55.0, + result_json={ + "dimensions": { + "visibility": {"score": 40, "details": "low"}, + "authority": {"score": 70, "details": "ok"}, + "relevance": {"score": 50, "details": "low"}, + } + }, + ) + async_session.add(diagnosis) + await async_session.commit() + await async_session.refresh(diagnosis) + return diagnosis + + +@pytest_asyncio.fixture +async def async_client(async_session, test_user): + async def override_get_db(): + yield async_session + + async def override_get_current_user(): + return test_user + + app.dependency_overrides[get_db] = override_get_db + app.dependency_overrides[get_current_user] = override_get_current_user + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + yield client + + app.dependency_overrides.clear() + + +class TestGEOContentGeneration: + @pytest.mark.asyncio + async def test_generate_geo_returns_201(self, async_client, test_brand, test_diagnosis): + with patch( + "app.services.content.content_generation_service.ContentGenerationService.generate_content", + new_callable=AsyncMock, + ) as mock_gen: + mock_gen.return_value = { + "content": "生成的内容", + "optimized_content": "优化后的内容", + "seo_score": 85, + "content_id": str(uuid.uuid4()), + "pipeline_stages": [{"stage": "content_generation", "status": "success"}], + } + + response = await async_client.post( + "/api/v1/content/generate-geo", + json={ + "brand_id": str(test_brand.id), + "target_keywords": ["AI优化", "品牌曝光"], + "platform": "通用", + "content_style": "专业严谨", + "word_count": 2000, + }, + ) + + assert response.status_code == 201 + data = response.json() + assert data["content_id"] is not None + assert data["content"] == "生成的内容" + assert data["optimized_content"] == "优化后的内容" + assert data["seo_score"] == 85 + + @pytest.mark.asyncio + async def test_generate_geo_with_invalid_brand_returns_404(self, async_client): + fake_brand_id = str(uuid.uuid4()) + response = await async_client.post( + "/api/v1/content/generate-geo", + json={ + "brand_id": fake_brand_id, + "target_keywords": ["测试"], + }, + ) + + assert response.status_code == 404 + + +class TestPublishAPI: + @pytest.mark.asyncio + async def test_publish_to_mock_platforms(self, async_client, test_content): + response = await async_client.post( + "/api/v1/distribution/publish", + json={ + "content_id": str(test_content.id), + "platforms": ["zhihu", "toutiao"], + }, + ) + + assert response.status_code == 200 + data = response.json() + assert "results" in data + assert len(data["results"]) == 2 + for r in data["results"]: + assert r["success"] is True + assert r["article_id"] is not None + + @pytest.mark.asyncio + async def test_publish_with_invalid_content_id_returns_404(self, async_client): + fake_content_id = str(uuid.uuid4()) + response = await async_client.post( + "/api/v1/distribution/publish", + json={ + "content_id": fake_content_id, + "platforms": ["zhihu"], + }, + ) + + assert response.status_code == 404 + + @pytest.mark.asyncio + async def test_get_publish_status(self, async_client, test_content): + await async_client.post( + "/api/v1/distribution/publish", + json={ + "content_id": str(test_content.id), + "platforms": ["zhihu"], + }, + ) + + response = await async_client.get( + f"/api/v1/distribution/publish/{test_content.id}/status", + ) + + assert response.status_code == 200 + data = response.json() + assert "platforms" in data + assert len(data["platforms"]) >= 1 + assert data["platforms"][0]["platform"] == "zhihu" + + +class TestPublishers: + @pytest.mark.asyncio + async def test_zhihu_publisher_mock_mode(self): + from app.services.distribution.publishers.zhihu_publisher import ZhihuPublisher + + pub = ZhihuPublisher() + assert pub.is_configured() is False + + result = await pub.publish(title="测试标题", content="测试内容") + assert result.success is True + assert result.platform == "zhihu" + assert result.article_id is not None + + @pytest.mark.asyncio + async def test_wechat_publisher_returns_formatted_content_in_mock_mode(self): + from app.services.distribution.publishers.wechat_publisher import WeChatPublisher + + pub = WeChatPublisher() + assert pub.is_configured() is False + + result = await pub.publish(title="微信测试", content="## 标题\n微信内容") + assert result.success is True + assert result.platform == "wechat" + assert "formatted_content" in result.raw_response + assert "instructions" in result.raw_response + + @pytest.mark.asyncio + async def test_publisher_factory_returns_mock_by_default(self): + from app.services.distribution.publishers import get_publisher + from app.services.distribution.publishers.mock_publisher import MockPublisher + + pub = get_publisher("zhihu") + assert isinstance(pub, MockPublisher) + assert pub.platform == "zhihu" + + @pytest.mark.asyncio + async def test_publisher_factory_returns_mock_for_unknown_platform(self): + from app.services.distribution.publishers import get_publisher + from app.services.distribution.publishers.mock_publisher import MockPublisher + + pub = get_publisher("unknown_platform") + assert isinstance(pub, MockPublisher) + + @pytest.mark.asyncio + async def test_mock_publisher_verify_credentials(self): + from app.services.distribution.publishers.mock_publisher import MockPublisher + + pub = MockPublisher(platform="test") + assert await pub.verify_credentials() is True + + @pytest.mark.asyncio + async def test_mock_publisher_get_article_status(self): + from app.services.distribution.publishers.mock_publisher import MockPublisher + + pub = MockPublisher(platform="test") + status = await pub.get_article_status("article_123") + assert status["status"] == "published" + assert status["mock"] is True diff --git a/backend/tests/test_api/test_diagnosis_contract.py b/backend/tests/test_api/test_diagnosis_contract.py new file mode 100644 index 0000000..bda9412 --- /dev/null +++ b/backend/tests/test_api/test_diagnosis_contract.py @@ -0,0 +1,399 @@ +import uuid +from unittest.mock import AsyncMock, patch + +import pytest +import pytest_asyncio +from httpx import ASGITransport, AsyncClient +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.pool import StaticPool + +from app.api.deps import get_current_user, get_db +from app.database import Base +from app.main import app +from app.models.brand import Brand +from app.models.diagnosis_record import DiagnosisRecord +from app.models.user import User +from app.services.auth import hash_password +from app.services.diagnosis.geo_diagnosis import GEODiagnosisInput + + +def _make_user( + user_id: str | None = None, + email: str = "test@example.com", + plan: str = "free", +) -> User: + uid = user_id or str(uuid.uuid4()) + user = User( + id=uid, + email=email, + password=hash_password("Test@123456"), + firstName="Test", + lastName="User", + isActive=True, + emailVerified=True, + ) + user.plan = plan + user.max_queries = 50 if plan != "free" else 5 + return user + + +def _to_uuid(value: str | uuid.UUID) -> uuid.UUID: + if isinstance(value, uuid.UUID): + return value + return uuid.UUID(str(value)) + + +@pytest_asyncio.fixture +async def async_engine(): + engine = create_async_engine( + "sqlite+aiosqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + yield engine + await engine.dispose() + + +@pytest_asyncio.fixture +async def async_session(async_engine): + maker = async_sessionmaker( + async_engine, + class_=AsyncSession, + expire_on_commit=False, + autoflush=False, + autocommit=False, + ) + async with maker() as session: + yield session + + +@pytest_asyncio.fixture +async def test_user(async_session): + user = _make_user(plan="free") + async_session.add(user) + await async_session.commit() + await async_session.refresh(user) + return user + + +@pytest_asyncio.fixture +async def paid_user(async_session): + user = _make_user(email="paid@example.com", plan="pro") + async_session.add(user) + await async_session.commit() + await async_session.refresh(user) + return user + + +@pytest_asyncio.fixture +async def test_brand(async_session, test_user): + brand = Brand( + id=uuid.uuid4(), + user_id=_to_uuid(test_user.id), + name="TestBrand", + aliases=["TB", "Test Brand"], + website="https://testbrand.com", + industry="technology", + platforms=["wenxin", "kimi"], + frequency="weekly", + status="active", + ) + async_session.add(brand) + await async_session.commit() + await async_session.refresh(brand) + return brand + + +@pytest_asyncio.fixture +async def paid_brand(async_session, paid_user): + brand = Brand( + id=uuid.uuid4(), + user_id=_to_uuid(paid_user.id), + name="PaidBrand", + aliases=["PB"], + website="https://paidbrand.com", + industry="technology", + platforms=["wenxin", "kimi"], + frequency="weekly", + status="active", + ) + async_session.add(brand) + await async_session.commit() + await async_session.refresh(brand) + return brand + + +@pytest_asyncio.fixture +async def async_client(async_session, test_user): + async def override_get_db(): + yield async_session + + async def override_get_current_user(): + return test_user + + app.dependency_overrides[get_db] = override_get_db + app.dependency_overrides[get_current_user] = override_get_current_user + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + yield client + + app.dependency_overrides.clear() + + +@pytest_asyncio.fixture +async def paid_client(async_session, paid_user): + async def override_get_db(): + yield async_session + + async def override_get_current_user(): + return paid_user + + app.dependency_overrides[get_db] = override_get_db + app.dependency_overrides[get_current_user] = override_get_current_user + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + yield client + + app.dependency_overrides.clear() + + +class TestGEODiagnosisTriggerContract: + @pytest.mark.asyncio + async def test_trigger_returns_202(self, async_client, test_brand): + with patch("app.api.diagnosis._run_geo_diagnosis", new_callable=AsyncMock): + response = await async_client.post( + f"/api/v1/diagnosis/geo/{test_brand.id}" + ) + assert response.status_code == 202 + data = response.json() + assert "task_id" in data + assert "brand_id" in data + assert "status" in data + assert data["status"] in ("pending", "completed") + assert data["brand_id"] == str(test_brand.id) + + @pytest.mark.asyncio + async def test_trigger_brand_not_found(self, async_client): + response = await async_client.post( + f"/api/v1/diagnosis/geo/{uuid.uuid4()}" + ) + assert response.status_code == 404 + + @pytest.mark.asyncio + async def test_trigger_force_refresh(self, async_client, test_brand): + with patch("app.api.diagnosis._run_geo_diagnosis", new_callable=AsyncMock): + response = await async_client.post( + f"/api/v1/diagnosis/geo/{test_brand.id}", + json={"force_refresh": True}, + ) + assert response.status_code == 202 + + +class TestGEODiagnosisResultContract: + @pytest.mark.asyncio + async def test_result_pending(self, async_client, test_brand, async_session): + record = DiagnosisRecord( + brand_id=test_brand.id, + user_id=_to_uuid(test_brand.user_id), + diagnosis_type="geo", + status="pending", + ) + async_session.add(record) + await async_session.commit() + await async_session.refresh(record) + + response = await async_client.get( + f"/api/v1/diagnosis/geo/{test_brand.id}/result", + params={"task_id": str(record.id)}, + ) + assert response.status_code == 200 + data = response.json() + assert data["status"] == "pending" + assert data["result"] is None + + @pytest.mark.asyncio + async def test_result_completed_nonzero_score( + self, async_client, test_brand, async_session + ): + from app.services.diagnosis.geo_diagnosis import GEODiagnosisService + + input_data = GEODiagnosisInput( + has_direct_answer=True, + has_brand_definition=True, + has_author_bio=True, + author_credentials_complete=0.8, + answer_ownership_rate=0.3, + ) + service = GEODiagnosisService() + result = service.diagnose(input_data) + + record = DiagnosisRecord( + brand_id=test_brand.id, + user_id=_to_uuid(test_brand.user_id), + diagnosis_type="geo", + status="completed", + overall_score=result.overall_score, + result_json=result.to_dict(), + ) + async_session.add(record) + await async_session.commit() + await async_session.refresh(record) + + response = await async_client.get( + f"/api/v1/diagnosis/geo/{test_brand.id}/result", + params={"task_id": str(record.id)}, + ) + assert response.status_code == 200 + data = response.json() + assert data["status"] == "completed" + assert data["result"] is not None + assert data["result"]["overall_score"] > 0 + + @pytest.mark.asyncio + async def test_result_not_found(self, async_client, test_brand): + response = await async_client.get( + f"/api/v1/diagnosis/geo/{test_brand.id}/result", + params={"task_id": str(uuid.uuid4())}, + ) + assert response.status_code == 404 + + +class TestGEODiagnosisFreeVsPaidContract: + @pytest.mark.asyncio + async def test_free_user_gets_3_dimensions( + self, async_client, test_brand, async_session + ): + from app.services.diagnosis.geo_diagnosis import GEODiagnosisService + + input_data = GEODiagnosisInput( + has_direct_answer=True, + has_brand_definition=True, + has_author_bio=True, + author_credentials_complete=0.8, + has_organization=True, + content_depth_score=0.7, + answer_ownership_rate=0.3, + ) + service = GEODiagnosisService() + result = service.diagnose(input_data) + + record = DiagnosisRecord( + brand_id=test_brand.id, + user_id=_to_uuid(test_brand.user_id), + diagnosis_type="geo", + status="completed", + overall_score=result.overall_score, + result_json=result.to_dict(), + ) + async_session.add(record) + await async_session.commit() + await async_session.refresh(record) + + response = await async_client.get( + f"/api/v1/diagnosis/geo/{test_brand.id}/result", + params={"task_id": str(record.id)}, + ) + data = response.json() + assert data["result"]["is_full_report"] is False + dim_names = {d["name"] for d in data["result"]["dimensions"]} + assert len(dim_names) == 3 + assert "内容可提取性" in dim_names + assert "E-E-A-T信号" in dim_names + assert "引用就绪度" in dim_names + + @pytest.mark.asyncio + async def test_paid_user_gets_6_dimensions( + self, paid_client, paid_brand, async_session, paid_user + ): + from app.services.diagnosis.geo_diagnosis import GEODiagnosisService + + input_data = GEODiagnosisInput( + has_direct_answer=True, + has_brand_definition=True, + has_author_bio=True, + author_credentials_complete=0.8, + has_organization=True, + content_depth_score=0.7, + answer_ownership_rate=0.3, + ) + service = GEODiagnosisService() + result = service.diagnose(input_data) + + record = DiagnosisRecord( + brand_id=paid_brand.id, + user_id=_to_uuid(paid_user.id), + diagnosis_type="geo", + status="completed", + overall_score=result.overall_score, + result_json=result.to_dict(), + ) + async_session.add(record) + await async_session.commit() + await async_session.refresh(record) + + response = await paid_client.get( + f"/api/v1/diagnosis/geo/{paid_brand.id}/result", + params={"task_id": str(record.id)}, + ) + data = response.json() + assert data["result"]["is_full_report"] is True + assert len(data["result"]["dimensions"]) == 6 + + +class TestGEODiagnosisHistoryContract: + @pytest.mark.asyncio + async def test_history_returns_list( + self, async_client, test_brand, async_session + ): + record = DiagnosisRecord( + brand_id=test_brand.id, + user_id=_to_uuid(test_brand.user_id), + diagnosis_type="geo", + status="completed", + overall_score=45.0, + result_json={ + "overall_score": 45.0, + "health_level": "pass", + }, + ) + async_session.add(record) + await async_session.commit() + + response = await async_client.get( + f"/api/v1/diagnosis/geo/{test_brand.id}/history" + ) + assert response.status_code == 200 + data = response.json() + assert "brand_id" in data + assert "history" in data + assert isinstance(data["history"], list) + + +class TestGEODiagnosisDataCollectionContract: + @pytest.mark.asyncio + async def test_diagnosis_with_data_collection_produces_nonzero( + self, async_client, test_brand + ): + with patch("app.api.diagnosis._run_geo_diagnosis", new_callable=AsyncMock): + response = await async_client.post( + f"/api/v1/diagnosis/geo/{test_brand.id}", + json={"force_refresh": True}, + ) + assert response.status_code == 202 + + @pytest.mark.asyncio + async def test_combined_diagnosis_uses_data_collector( + self, async_client, test_brand + ): + response = await async_client.get( + f"/api/v1/diagnosis/combined/{test_brand.id}" + ) + assert response.status_code == 200 + data = response.json() + assert "geo_score" in data + assert "seo_score" in data + assert "combined_score" in data diff --git a/backend/tests/test_api/test_email_contract.py b/backend/tests/test_api/test_email_contract.py new file mode 100644 index 0000000..f96c72f --- /dev/null +++ b/backend/tests/test_api/test_email_contract.py @@ -0,0 +1,326 @@ +from __future__ import annotations + +import uuid +from datetime import date, timedelta +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +import pytest_asyncio +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.pool import StaticPool + +from app.database import Base +from app.models.subscription import Subscription +from app.models.user import User +from app.services.email.email_scheduler import EmailScheduler +from app.services.email_service import EmailService, EmailMessage + + +TEMPLATES_DIR = Path(__file__).resolve().parent.parent.parent / "app" / "templates" + + +def _make_user( + user_id: str | None = None, + email: str = "test@example.com", + plan: str = "free", +) -> User: + uid = user_id or str(uuid.uuid4()) + user = User( + id=uid, + email=email, + password="hashed_password", + firstName="Test", + lastName="User", + isActive=True, + emailVerified=True, + ) + user.plan = plan + user.max_queries = 50 if plan != "free" else 5 + return user + + +@pytest_asyncio.fixture +async def async_engine(): + engine = create_async_engine( + "sqlite+aiosqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + yield engine + await engine.dispose() + + +@pytest_asyncio.fixture +async def async_session(async_engine): + maker = async_sessionmaker( + async_engine, + class_=AsyncSession, + expire_on_commit=False, + autoflush=False, + autocommit=False, + ) + async with maker() as session: + yield session + + +@pytest.fixture +def mock_email_service(): + return EmailService(simulate_mode=True) + + +@pytest.fixture +def email_scheduler(mock_email_service): + return EmailScheduler(email_service=mock_email_service) + + +class TestEmailSchedulerInstantiation: + def test_email_scheduler_can_be_instantiated(self, mock_email_service): + scheduler = EmailScheduler(email_service=mock_email_service) + assert scheduler is not None + assert scheduler.email_service is mock_email_service + + def test_email_scheduler_default_creates_mock_service(self): + scheduler = EmailScheduler() + assert scheduler is not None + assert scheduler.email_service is not None + + +class TestWeeklyReportEmail: + @pytest.mark.asyncio + async def test_weekly_report_sends_to_active_users(self, async_session, email_scheduler): + user1 = _make_user(email="user1@example.com") + user2 = _make_user(email="user2@example.com") + async_session.add(user1) + async_session.add(user2) + await async_session.commit() + + sent_count = await email_scheduler.send_geo_weekly_report(async_session) + assert sent_count == 2 + + @pytest.mark.asyncio + async def test_weekly_report_skips_inactive_users(self, async_session, email_scheduler): + active_user = _make_user(email="active@example.com") + inactive_user = _make_user(email="inactive@example.com") + inactive_user.isActive = False + async_session.add(active_user) + async_session.add(inactive_user) + await async_session.commit() + + sent_count = await email_scheduler.send_geo_weekly_report(async_session) + assert sent_count == 1 + + +class TestRenewalReminder: + @pytest.mark.asyncio + async def test_renewal_reminder_sends_for_expiring_subscriptions(self, async_session, email_scheduler): + user = _make_user(email="renew@example.com", plan="pro") + async_session.add(user) + await async_session.commit() + + sub = Subscription( + user_id=user.id, + plan="pro", + status="active", + start_date=date.today() - timedelta(days=23), + end_date=date.today() + timedelta(days=7), + amount=599.0, + payment_method="mock", + ) + async_session.add(sub) + await async_session.commit() + + sent_count = await email_scheduler.send_renewal_reminder(async_session) + assert sent_count >= 1 + + @pytest.mark.asyncio + async def test_renewal_reminder_no_reminder_for_non_expiring(self, async_session, email_scheduler): + user = _make_user(email="notexpiring@example.com", plan="pro") + async_session.add(user) + await async_session.commit() + + sub = Subscription( + user_id=user.id, + plan="pro", + status="active", + start_date=date.today(), + end_date=date.today() + timedelta(days=30), + amount=599.0, + payment_method="mock", + ) + async_session.add(sub) + await async_session.commit() + + sent_count = await email_scheduler.send_renewal_reminder(async_session) + assert sent_count == 0 + + +class TestWelcomeEmail: + @pytest.mark.asyncio + async def test_welcome_email_sends_successfully(self, email_scheduler): + result = await email_scheduler.send_welcome_email( + "newuser@example.com", "新用户" + ) + assert result is True + + @pytest.mark.asyncio + async def test_welcome_email_with_empty_name(self, email_scheduler): + result = await email_scheduler.send_welcome_email( + "newuser@example.com", "" + ) + assert result is True + + +class TestMockMode: + def test_mock_mode_logs_instead_of_sending(self, mock_email_service): + msg = EmailMessage( + to="test@example.com", + subject="Test Subject", + body_html="

Test

", + body_text="Test", + ) + result = mock_email_service.send_email(msg) + assert result.success is True + assert result.message_id is not None + assert result.message_id.startswith("sim_") + assert result.error is None + + def test_mock_mode_simulate_flag(self, mock_email_service): + assert mock_email_service.simulate_mode is True + + def test_production_mode_would_use_smtp(self): + service = EmailService( + simulate_mode=False, + smtp_host="localhost", + smtp_port=587, + smtp_user="test", + smtp_password="test", + ) + assert service.simulate_mode is False + + +class TestEmailTemplates: + def test_geo_weekly_report_template_exists(self): + template_path = TEMPLATES_DIR / "geo_weekly_report.html" + assert template_path.exists() + + def test_renewal_reminder_template_exists(self): + template_path = TEMPLATES_DIR / "renewal_reminder.html" + assert template_path.exists() + + def test_trial_expiring_template_exists(self): + template_path = TEMPLATES_DIR / "trial_expiring.html" + assert template_path.exists() + + def test_welcome_template_exists(self): + template_path = TEMPLATES_DIR / "welcome.html" + assert template_path.exists() + + def test_geo_weekly_report_renders_with_context(self, email_scheduler): + template_html = email_scheduler._load_template("geo_weekly_report.html") + context = { + "user_name": "测试用户", + "score_change": "+5", + "current_score": "78", + "previous_score": "73", + "top_improved": "内容质量 (+12%)", + "top_declined": "品牌权威 (-3%)", + "suggestions": "建议增加技术白皮书", + "report_link": "https://example.com", + "year": "2026", + } + rendered = email_scheduler._render_template(template_html, context) + assert "测试用户" in rendered + assert "+5" in rendered + assert "78" in rendered + assert "内容质量 (+12%)" in rendered + + def test_renewal_reminder_renders_with_context(self, email_scheduler): + template_html = email_scheduler._load_template("renewal_reminder.html") + context = { + "user_name": "续费用户", + "plan_name": "专业版", + "end_date": "2026-06-30", + "days_remaining": "7", + "plan_price": "599", + "renew_link": "https://example.com", + "year": "2026", + } + rendered = email_scheduler._render_template(template_html, context) + assert "续费用户" in rendered + assert "专业版" in rendered + assert "7" in rendered + assert "599" in rendered + + def test_trial_expiring_renders_with_context(self, email_scheduler): + template_html = email_scheduler._load_template("trial_expiring.html") + context = { + "user_name": "试用用户", + "days_remaining": "3", + "upgrade_link": "https://example.com", + "year": "2026", + } + rendered = email_scheduler._render_template(template_html, context) + assert "试用用户" in rendered + assert "3" in rendered + + def test_welcome_renders_with_context(self, email_scheduler): + template_html = email_scheduler._load_template("welcome.html") + context = { + "user_name": "新用户", + "dashboard_link": "https://example.com", + "diagnosis_link": "https://example.com/diagnosis", + "help_link": "https://example.com/help", + "year": "2026", + } + rendered = email_scheduler._render_template(template_html, context) + assert "新用户" in rendered + assert "3" in rendered + + +class TestSendTemplateEmail: + def test_send_template_email_with_welcome(self, mock_email_service): + result = mock_email_service.send_template_email( + to="test@example.com", + subject="欢迎", + template_name="welcome.html", + context={ + "user_name": "测试", + "dashboard_link": "https://example.com", + "diagnosis_link": "https://example.com/d", + "help_link": "https://example.com/h", + "year": "2026", + }, + ) + assert result.success is True + + def test_send_template_email_invalid_template(self, mock_email_service): + with pytest.raises(ValueError, match="模板文件不存在"): + mock_email_service.send_template_email( + to="test@example.com", + subject="Test", + template_name="nonexistent.html", + context={}, + ) + + def test_send_template_email_renders_context(self, mock_email_service): + result = mock_email_service.send_template_email( + to="test@example.com", + subject="周报", + template_name="geo_weekly_report.html", + context={ + "user_name": "渲染测试", + "score_change": "+10", + "current_score": "85", + "previous_score": "75", + "top_improved": "测试提升", + "top_declined": "测试下降", + "suggestions": "测试建议", + "report_link": "https://example.com", + "year": "2026", + }, + ) + assert result.success is True diff --git a/backend/tests/test_api/test_health_score_contract.py b/backend/tests/test_api/test_health_score_contract.py new file mode 100644 index 0000000..5d91974 --- /dev/null +++ b/backend/tests/test_api/test_health_score_contract.py @@ -0,0 +1,329 @@ +import hashlib +import uuid +from unittest.mock import AsyncMock, patch, MagicMock + +import pytest +import pytest_asyncio +from httpx import ASGITransport, AsyncClient +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.pool import StaticPool + +from app.api.deps import get_current_user, get_db +from app.database import Base +from app.main import app +from app.models.brand import Brand +from app.models.user import User +from app.services.auth import hash_password +from app.services.diagnosis.geo_diagnosis import GEODiagnosisInput, GEODiagnosisResult + + +def _make_user( + user_id: str | None = None, + email: str = "test@example.com", + plan: str = "free", +) -> User: + uid = user_id or str(uuid.uuid4()) + user = User( + id=uid, + email=email, + password=hash_password("Test@123456"), + firstName="Test", + lastName="User", + isActive=True, + emailVerified=True, + ) + user.plan = plan + user.max_queries = 50 if plan != "free" else 5 + return user + + +def _to_uuid(value: str | uuid.UUID) -> uuid.UUID: + if isinstance(value, uuid.UUID): + return value + return uuid.UUID(str(value)) + + +@pytest_asyncio.fixture +async def async_engine(): + engine = create_async_engine( + "sqlite+aiosqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + yield engine + await engine.dispose() + + +@pytest_asyncio.fixture +async def async_session(async_engine): + maker = async_sessionmaker( + async_engine, + class_=AsyncSession, + expire_on_commit=False, + autoflush=False, + autocommit=False, + ) + async with maker() as session: + yield session + + +@pytest_asyncio.fixture +async def test_user(async_session): + user = _make_user(plan="free") + async_session.add(user) + await async_session.commit() + await async_session.refresh(user) + return user + + +@pytest_asyncio.fixture +async def test_brand(async_session, test_user): + brand = Brand( + id=uuid.uuid4(), + user_id=_to_uuid(test_user.id), + name="TestBrand", + aliases=["TB", "Test Brand"], + website="https://testbrand.com", + industry="technology", + platforms=["wenxin", "kimi"], + frequency="weekly", + status="active", + ) + async_session.add(brand) + await async_session.commit() + await async_session.refresh(brand) + return brand + + +@pytest_asyncio.fixture +async def public_client(async_session): + async def override_get_db(): + yield async_session + + app.dependency_overrides[get_db] = override_get_db + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + yield client + + app.dependency_overrides.clear() + + +class TestHealthScorePublicContract: + @pytest.mark.asyncio + async def test_returns_200_with_brand(self, public_client, test_brand): + mock_cache = AsyncMock() + mock_cache.get_json = AsyncMock(return_value=None) + mock_cache.set_json = AsyncMock() + + mock_result = MagicMock() + mock_result.to_dict.return_value = { + "overall_score": 45.0, + "health_level": "pass", + "health_level_label": "及格", + "dimensions": [ + {"name": "内容可提取性", "score": 12.0, "max_score": 20.0, "percentage": 60.0, "status": "pass"}, + {"name": "E-E-A-T信号", "score": 10.0, "max_score": 20.0, "percentage": 50.0, "status": "pass"}, + {"name": "引用就绪度", "score": 8.0, "max_score": 15.0, "percentage": 53.3, "status": "pass"}, + {"name": "实体清晰度", "score": 8.0, "max_score": 15.0, "percentage": 53.3, "status": "pass"}, + {"name": "Schema标记", "score": 5.0, "max_score": 15.0, "percentage": 33.3, "status": "fail"}, + {"name": "主题权威", "score": 6.0, "max_score": 15.0, "percentage": 40.0, "status": "pass"}, + ], + "recommendations": [ + {"priority": "P0", "dimension": "内容可提取性", "title": "添加结构化数据", "description": "建议添加Schema标记"}, + ], + } + + with patch("app.api.health_score.get_cache_service", return_value=mock_cache), \ + patch("app.api.health_score.DataCollectorService") as MockCollector, \ + patch("app.api.health_score.GEODiagnosisService") as MockDiagService: + mock_collector = MockCollector.return_value + mock_collection = MagicMock() + mock_collection.diagnosis_input = GEODiagnosisInput() + mock_collector.collect = AsyncMock(return_value=mock_collection) + + mock_diag = MockDiagService.return_value + mock_diag.diagnose = MagicMock(return_value=mock_result) + + response = await public_client.get( + "/api/v1/public/health-score", + params={"brand": "TestBrand"}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["brand_name"] == "TestBrand" + assert data["overall_score"] > 0 + assert data["health_level"] in ("excellent", "good", "pass", "danger") + assert data["is_full_report"] is False + assert len(data["dimensions"]) == 3 + dim_names = {d["name"] for d in data["dimensions"]} + assert "内容可提取性" in dim_names + assert "E-E-A-T信号" in dim_names + assert "引用就绪度" in dim_names + + @pytest.mark.asyncio + async def test_returns_cached_result(self, public_client): + cached_data = { + "brand_name": "CachedBrand", + "overall_score": 55.0, + "health_level": "pass", + "health_level_label": "及格", + "dimensions": [ + {"name": "内容可提取性", "score": 15.0, "max_score": 20.0, "percentage": 75.0, "status": "good"}, + {"name": "E-E-A-T信号", "score": 12.0, "max_score": 20.0, "percentage": 60.0, "status": "pass"}, + {"name": "引用就绪度", "score": 10.0, "max_score": 15.0, "percentage": 66.7, "status": "pass"}, + ], + "recommendations": [], + "is_full_report": False, + "cached": False, + } + + mock_cache = AsyncMock() + mock_cache.get_json = AsyncMock(return_value=cached_data) + + with patch("app.api.health_score.get_cache_service", return_value=mock_cache): + response = await public_client.get( + "/api/v1/public/health-score", + params={"brand": "CachedBrand"}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["cached"] is True + assert data["brand_name"] == "CachedBrand" + + @pytest.mark.asyncio + async def test_missing_brand_param_returns_422(self, public_client): + response = await public_client.get("/api/v1/public/health-score") + assert response.status_code == 422 + + @pytest.mark.asyncio + async def test_collection_failure_returns_default(self, public_client): + mock_cache = AsyncMock() + mock_cache.get_json = AsyncMock(return_value=None) + mock_cache.set_json = AsyncMock() + + with patch("app.api.health_score.get_cache_service", return_value=mock_cache), \ + patch("app.api.health_score.DataCollectorService") as MockCollector: + mock_collector = MockCollector.return_value + mock_collector.collect = AsyncMock(side_effect=Exception("采集失败")) + + response = await public_client.get( + "/api/v1/public/health-score", + params={"brand": "UnknownBrand"}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["brand_name"] == "UnknownBrand" + assert data["overall_score"] == 0.0 + assert data["health_level"] == "danger" + + @pytest.mark.asyncio + async def test_competitors_param_accepted(self, public_client): + mock_cache = AsyncMock() + mock_cache.get_json = AsyncMock(return_value=None) + mock_cache.set_json = AsyncMock() + + mock_result = MagicMock() + mock_result.to_dict.return_value = { + "overall_score": 30.0, + "health_level": "pass", + "health_level_label": "及格", + "dimensions": [ + {"name": "内容可提取性", "score": 10.0, "max_score": 20.0, "percentage": 50.0, "status": "pass"}, + {"name": "E-E-A-T信号", "score": 8.0, "max_score": 20.0, "percentage": 40.0, "status": "fail"}, + {"name": "引用就绪度", "score": 5.0, "max_score": 15.0, "percentage": 33.3, "status": "fail"}, + {"name": "实体清晰度", "score": 5.0, "max_score": 15.0, "percentage": 33.3, "status": "fail"}, + {"name": "Schema标记", "score": 3.0, "max_score": 15.0, "percentage": 20.0, "status": "fail"}, + {"name": "主题权威", "score": 4.0, "max_score": 15.0, "percentage": 26.7, "status": "fail"}, + ], + "recommendations": [], + } + + with patch("app.api.health_score.get_cache_service", return_value=mock_cache), \ + patch("app.api.health_score.DataCollectorService") as MockCollector, \ + patch("app.api.health_score.GEODiagnosisService") as MockDiagService: + mock_collector = MockCollector.return_value + mock_collection = MagicMock() + mock_collection.diagnosis_input = GEODiagnosisInput() + mock_collector.collect = AsyncMock(return_value=mock_collection) + + mock_diag = MockDiagService.return_value + mock_diag.diagnose = MagicMock(return_value=mock_result) + + response = await public_client.get( + "/api/v1/public/health-score", + params={"brand": "SomeBrand", "competitors": "CompA,CompB,CompC"}, + ) + + assert response.status_code == 200 + + @pytest.mark.asyncio + async def test_only_p0_recommendations_returned(self, public_client): + mock_cache = AsyncMock() + mock_cache.get_json = AsyncMock(return_value=None) + mock_cache.set_json = AsyncMock() + + mock_result = MagicMock() + mock_result.to_dict.return_value = { + "overall_score": 30.0, + "health_level": "pass", + "health_level_label": "及格", + "dimensions": [ + {"name": "内容可提取性", "score": 10.0, "max_score": 20.0, "percentage": 50.0, "status": "pass"}, + {"name": "E-E-A-T信号", "score": 8.0, "max_score": 20.0, "percentage": 40.0, "status": "fail"}, + {"name": "引用就绪度", "score": 5.0, "max_score": 15.0, "percentage": 33.3, "status": "fail"}, + ], + "recommendations": [ + {"priority": "P0", "dimension": "内容可提取性", "title": "P0建议", "description": "紧急"}, + {"priority": "P1", "dimension": "E-E-A-T信号", "title": "P1建议", "description": "重要"}, + {"priority": "P2", "dimension": "引用就绪度", "title": "P2建议", "description": "一般"}, + ], + } + + with patch("app.api.health_score.get_cache_service", return_value=mock_cache), \ + patch("app.api.health_score.DataCollectorService") as MockCollector, \ + patch("app.api.health_score.GEODiagnosisService") as MockDiagService: + mock_collector = MockCollector.return_value + mock_collection = MagicMock() + mock_collection.diagnosis_input = GEODiagnosisInput() + mock_collector.collect = AsyncMock(return_value=mock_collection) + + mock_diag = MockDiagService.return_value + mock_diag.diagnose = MagicMock(return_value=mock_result) + + response = await public_client.get( + "/api/v1/public/health-score", + params={"brand": "TestBrand"}, + ) + + assert response.status_code == 200 + data = response.json() + assert len(data["recommendations"]) == 1 + assert data["recommendations"][0]["priority"] == "P0" + + @pytest.mark.asyncio + async def test_no_auth_required(self, public_client): + mock_cache = AsyncMock() + mock_cache.get_json = AsyncMock(return_value={ + "brand_name": "Test", + "overall_score": 50.0, + "health_level": "pass", + "health_level_label": "及格", + "dimensions": [], + "recommendations": [], + "is_full_report": False, + "cached": False, + }) + + with patch("app.api.health_score.get_cache_service", return_value=mock_cache): + response = await public_client.get( + "/api/v1/public/health-score", + params={"brand": "Test"}, + ) + + assert response.status_code == 200 diff --git a/backend/tests/test_api/test_onboarding_contract.py b/backend/tests/test_api/test_onboarding_contract.py new file mode 100644 index 0000000..f9fd9d5 --- /dev/null +++ b/backend/tests/test_api/test_onboarding_contract.py @@ -0,0 +1,453 @@ +import uuid +from unittest.mock import AsyncMock, patch + +import pytest +import pytest_asyncio +from httpx import ASGITransport, AsyncClient +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.pool import StaticPool + +from app.api.deps import get_current_user, get_db +from app.database import Base +from app.main import app +from app.models.brand import Brand +from app.models.user import User +from app.services.auth import hash_password + + +def _make_user( + user_id: str | None = None, + email: str = "test@example.com", + plan: str = "free", +) -> User: + uid = user_id or str(uuid.uuid4()) + user = User( + id=uid, + email=email, + password=hash_password("Test@123456"), + firstName="Test", + lastName="User", + isActive=True, + emailVerified=True, + ) + user.plan = plan + user.max_queries = 50 if plan != "free" else 5 + return user + + +def _to_uuid(value: str | uuid.UUID) -> uuid.UUID: + if isinstance(value, uuid.UUID): + return value + return uuid.UUID(str(value)) + + +@pytest_asyncio.fixture +async def async_engine(): + engine = create_async_engine( + "sqlite+aiosqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + yield engine + await engine.dispose() + + +@pytest_asyncio.fixture +async def async_session(async_engine): + maker = async_sessionmaker( + async_engine, + class_=AsyncSession, + expire_on_commit=False, + autoflush=False, + autocommit=False, + ) + async with maker() as session: + yield session + + +@pytest_asyncio.fixture +async def test_user(async_session): + user = _make_user(plan="free") + async_session.add(user) + await async_session.commit() + await async_session.refresh(user) + return user + + +@pytest_asyncio.fixture +async def paid_user(async_session): + user = _make_user(email="paid@example.com", plan="pro") + async_session.add(user) + await async_session.commit() + await async_session.refresh(user) + return user + + +@pytest_asyncio.fixture +async def test_brand(async_session, test_user): + brand = Brand( + id=uuid.uuid4(), + user_id=_to_uuid(test_user.id), + name="TestBrand", + aliases=["TB"], + website="https://testbrand.com", + industry="technology", + platforms=["wenxin", "kimi"], + frequency="weekly", + status="active", + ) + async_session.add(brand) + await async_session.commit() + await async_session.refresh(brand) + return brand + + +@pytest_asyncio.fixture +async def paid_brand(async_session, paid_user): + brand = Brand( + id=uuid.uuid4(), + user_id=_to_uuid(paid_user.id), + name="PaidBrand", + aliases=["PB"], + website="https://paidbrand.com", + industry="technology", + platforms=["wenxin", "kimi"], + frequency="weekly", + status="active", + ) + async_session.add(brand) + await async_session.commit() + await async_session.refresh(brand) + return brand + + +@pytest_asyncio.fixture +async def async_client(async_session, test_user): + async def override_get_db(): + yield async_session + + async def override_get_current_user(): + return test_user + + app.dependency_overrides[get_db] = override_get_db + app.dependency_overrides[get_current_user] = override_get_current_user + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + yield client + + app.dependency_overrides.clear() + + +@pytest_asyncio.fixture +async def paid_client(async_session, paid_user): + async def override_get_db(): + yield async_session + + async def override_get_current_user(): + return paid_user + + app.dependency_overrides[get_db] = override_get_db + app.dependency_overrides[get_current_user] = override_get_current_user + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + yield client + + app.dependency_overrides.clear() + + +class TestOnboardingStatusContract: + @pytest.mark.asyncio + async def test_status_incomplete_no_brand(self, async_client): + response = await async_client.get("/api/v1/onboarding/status") + assert response.status_code == 200 + data = response.json() + assert data["completed"] is False + assert data["brand_id"] is None + assert data["current_step"] == 1 + + @pytest.mark.asyncio + async def test_status_complete_with_brand(self, async_client, test_brand): + response = await async_client.get("/api/v1/onboarding/status") + assert response.status_code == 200 + data = response.json() + assert data["completed"] is True + assert data["brand_id"] == str(test_brand.id) + + +class TestOnboardingHealthReportContract: + @pytest.mark.asyncio + async def test_health_report_returns_diagnosis_data(self, async_client, test_brand): + with patch("app.api.onboarding.DataCollectorService") as mock_collector_cls: + mock_collector = AsyncMock() + mock_collector_cls.return_value = mock_collector + + from app.services.diagnosis.data_collector import DataCollectionResult + from app.services.diagnosis.geo_diagnosis import GEODiagnosisInput, GEODiagnosisService + + input_data = GEODiagnosisInput( + has_direct_answer=True, + has_brand_definition=True, + answer_ownership_rate=0.3, + ) + mock_collector.collect.return_value = DataCollectionResult( + diagnosis_input=input_data, + ) + + service = GEODiagnosisService() + result = service.diagnose(input_data) + + with patch("app.api.onboarding.GEODiagnosisService", return_value=service): + response = await async_client.get( + f"/api/v1/onboarding/health-report/{test_brand.id}" + ) + + assert response.status_code == 200 + data = response.json() + assert data["brand_id"] == str(test_brand.id) + assert data["brand_name"] == "TestBrand" + assert "overall_score" in data + assert "health_level" in data + assert "health_level_label" in data + assert "dimensions" in data + assert "recommendations" in data + assert "is_full_report" in data + assert isinstance(data["dimensions"], list) + assert isinstance(data["recommendations"], list) + + @pytest.mark.asyncio + async def test_health_report_brand_not_found(self, async_client): + response = await async_client.get( + f"/api/v1/onboarding/health-report/{uuid.uuid4()}" + ) + assert response.status_code == 404 + + @pytest.mark.asyncio + async def test_health_report_free_user_3_dimensions( + self, async_client, test_brand + ): + with patch("app.api.onboarding.DataCollectorService") as mock_collector_cls: + mock_collector = AsyncMock() + mock_collector_cls.return_value = mock_collector + + from app.services.diagnosis.data_collector import DataCollectionResult + from app.services.diagnosis.geo_diagnosis import GEODiagnosisInput, GEODiagnosisService + + input_data = GEODiagnosisInput( + has_direct_answer=True, + has_brand_definition=True, + has_author_bio=True, + author_credentials_complete=0.8, + has_organization=True, + content_depth_score=0.7, + answer_ownership_rate=0.3, + ) + mock_collector.collect.return_value = DataCollectionResult( + diagnosis_input=input_data, + ) + + service = GEODiagnosisService() + result = service.diagnose(input_data) + + with patch("app.api.onboarding.GEODiagnosisService", return_value=service): + response = await async_client.get( + f"/api/v1/onboarding/health-report/{test_brand.id}" + ) + + data = response.json() + assert data["is_full_report"] is False + dim_names = {d["name"] for d in data["dimensions"]} + assert len(dim_names) == 3 + assert "内容可提取性" in dim_names + assert "E-E-A-T信号" in dim_names + assert "引用就绪度" in dim_names + + @pytest.mark.asyncio + async def test_health_report_paid_user_6_dimensions( + self, paid_client, paid_brand + ): + with patch("app.api.onboarding.DataCollectorService") as mock_collector_cls: + mock_collector = AsyncMock() + mock_collector_cls.return_value = mock_collector + + from app.services.diagnosis.data_collector import DataCollectionResult + from app.services.diagnosis.geo_diagnosis import GEODiagnosisInput, GEODiagnosisService + + input_data = GEODiagnosisInput( + has_direct_answer=True, + has_brand_definition=True, + has_author_bio=True, + author_credentials_complete=0.8, + has_organization=True, + content_depth_score=0.7, + answer_ownership_rate=0.3, + ) + mock_collector.collect.return_value = DataCollectionResult( + diagnosis_input=input_data, + ) + + service = GEODiagnosisService() + result = service.diagnose(input_data) + + with patch("app.api.onboarding.GEODiagnosisService", return_value=service): + response = await paid_client.get( + f"/api/v1/onboarding/health-report/{paid_brand.id}" + ) + + data = response.json() + assert data["is_full_report"] is True + assert len(data["dimensions"]) == 6 + + @pytest.mark.asyncio + async def test_health_report_collection_failure_fallback( + self, async_client, test_brand + ): + with patch("app.api.onboarding.DataCollectorService") as mock_collector_cls: + mock_collector = AsyncMock() + mock_collector_cls.return_value = mock_collector + mock_collector.collect.side_effect = Exception("DB error") + + response = await async_client.get( + f"/api/v1/onboarding/health-report/{test_brand.id}" + ) + + assert response.status_code == 200 + data = response.json() + assert data["overall_score"] == 0 + assert data["health_level"] == "danger" + assert len(data["dimensions"]) == 3 + + +class TestOnboardingActionSuggestionsContract: + @pytest.mark.asyncio + async def test_suggestions_have_paid_action_field( + self, async_client, test_brand + ): + with patch("app.api.onboarding.DataCollectorService") as mock_collector_cls: + mock_collector = AsyncMock() + mock_collector_cls.return_value = mock_collector + + from app.services.diagnosis.data_collector import DataCollectionResult + from app.services.diagnosis.geo_diagnosis import GEODiagnosisInput, GEODiagnosisService + + input_data = GEODiagnosisInput( + has_direct_answer=True, + has_brand_definition=True, + answer_ownership_rate=0.3, + ) + mock_collector.collect.return_value = DataCollectionResult( + diagnosis_input=input_data, + ) + + service = GEODiagnosisService() + result = service.diagnose(input_data) + + with patch("app.api.onboarding.GEODiagnosisService", return_value=service): + response = await async_client.get( + f"/api/v1/onboarding/action-suggestions/{test_brand.id}" + ) + + assert response.status_code == 200 + data = response.json() + assert "suggestions" in data + assert isinstance(data["suggestions"], list) + for s in data["suggestions"]: + assert "is_paid_action" in s + assert "action_button_text" in s + assert isinstance(s["is_paid_action"], bool) + + @pytest.mark.asyncio + async def test_free_user_gets_upgrade_suggestion( + self, async_client, test_brand + ): + with patch("app.api.onboarding.DataCollectorService") as mock_collector_cls: + mock_collector = AsyncMock() + mock_collector_cls.return_value = mock_collector + + from app.services.diagnosis.data_collector import DataCollectionResult + from app.services.diagnosis.geo_diagnosis import GEODiagnosisInput, GEODiagnosisService + + input_data = GEODiagnosisInput( + has_direct_answer=True, + has_brand_definition=True, + answer_ownership_rate=0.3, + ) + mock_collector.collect.return_value = DataCollectionResult( + diagnosis_input=input_data, + ) + + service = GEODiagnosisService() + result = service.diagnose(input_data) + + with patch("app.api.onboarding.GEODiagnosisService", return_value=service): + response = await async_client.get( + f"/api/v1/onboarding/action-suggestions/{test_brand.id}" + ) + + data = response.json() + upgrade_suggestions = [ + s for s in data["suggestions"] if s["action_type"] == "upgrade" + ] + assert len(upgrade_suggestions) >= 1 + assert any(s["is_paid_action"] for s in upgrade_suggestions) + + @pytest.mark.asyncio + async def test_zero_score_suggestions_include_upgrade( + self, async_client, test_brand + ): + with patch("app.api.onboarding.DataCollectorService") as mock_collector_cls: + mock_collector = AsyncMock() + mock_collector_cls.return_value = mock_collector + mock_collector.collect.side_effect = Exception("No data") + + response = await async_client.get( + f"/api/v1/onboarding/action-suggestions/{test_brand.id}" + ) + + assert response.status_code == 200 + data = response.json() + upgrade_suggestions = [ + s for s in data["suggestions"] if s["action_type"] == "upgrade" + ] + assert len(upgrade_suggestions) >= 1 + + +class TestOnboardingCreateBrandContract: + @pytest.mark.asyncio + async def test_create_brand_success(self, async_client): + response = await async_client.post( + "/api/v1/onboarding/brand", + json={"name": "NewBrand", "industry": "tech"}, + ) + assert response.status_code == 201 + data = response.json() + assert data["name"] == "NewBrand" + + @pytest.mark.asyncio + async def test_create_brand_short_name_rejected(self, async_client): + response = await async_client.post( + "/api/v1/onboarding/brand", + json={"name": "A"}, + ) + assert response.status_code == 422 + + +class TestOnboardingCompleteContract: + @pytest.mark.asyncio + async def test_complete_onboarding(self, async_client, test_brand): + response = await async_client.post( + f"/api/v1/onboarding/complete/{test_brand.id}" + ) + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + + @pytest.mark.asyncio + async def test_complete_onboarding_brand_not_found(self, async_client): + response = await async_client.post( + f"/api/v1/onboarding/complete/{uuid.uuid4()}" + ) + assert response.status_code == 404 diff --git a/backend/tests/test_api/test_payment_contract.py b/backend/tests/test_api/test_payment_contract.py new file mode 100644 index 0000000..36a0c1c --- /dev/null +++ b/backend/tests/test_api/test_payment_contract.py @@ -0,0 +1,372 @@ +import uuid +from unittest.mock import AsyncMock, patch, MagicMock + +import pytest +import pytest_asyncio +from fastapi import Depends, FastAPI +from httpx import ASGITransport, AsyncClient +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.pool import StaticPool + +from app.api.deps import get_current_user, get_db +from app.database import Base +from app.main import app +from app.middleware.subscription_enforcement import SubscriptionEnforcement +from app.models.payment_order import PaymentOrder as PaymentOrderModel +from app.models.user import User +from app.services.auth import hash_password +from app.services.payment.base import PaymentCallback + + +def _make_user( + user_id: str | None = None, + email: str = "test@example.com", + plan: str = "free", +) -> User: + uid = user_id or str(uuid.uuid4()) + user = User( + id=uid, + email=email, + password=hash_password("Test@123456"), + firstName="Test", + lastName="User", + isActive=True, + emailVerified=True, + ) + user.plan = plan + user.max_queries = 50 if plan != "free" else 5 + return user + + +@pytest_asyncio.fixture +async def async_engine(): + engine = create_async_engine( + "sqlite+aiosqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + yield engine + await engine.dispose() + + +@pytest_asyncio.fixture +async def async_session(async_engine): + maker = async_sessionmaker( + async_engine, + class_=AsyncSession, + expire_on_commit=False, + autoflush=False, + autocommit=False, + ) + async with maker() as session: + yield session + + +@pytest_asyncio.fixture +async def free_user(async_session): + user = _make_user(plan="free") + async_session.add(user) + await async_session.commit() + await async_session.refresh(user) + return user + + +@pytest_asyncio.fixture +async def pro_user(async_session): + uid = str(uuid.uuid4()) + user = _make_user(user_id=uid, email="pro@example.com", plan="pro") + async_session.add(user) + await async_session.commit() + await async_session.refresh(user) + return user + + +@pytest_asyncio.fixture +async def enterprise_user(async_session): + uid = str(uuid.uuid4()) + user = _make_user(user_id=uid, email="enterprise@example.com", plan="enterprise") + async_session.add(user) + await async_session.commit() + await async_session.refresh(user) + return user + + +@pytest_asyncio.fixture +async def client_with_free_user(async_session, free_user): + async def override_get_db(): + yield async_session + + async def override_get_current_user(): + return free_user + + app.dependency_overrides[get_db] = override_get_db + app.dependency_overrides[get_current_user] = override_get_current_user + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + yield client + + app.dependency_overrides.clear() + + +@pytest_asyncio.fixture +async def client_with_pro_user(async_session, pro_user): + async def override_get_db(): + yield async_session + + async def override_get_current_user(): + return pro_user + + app.dependency_overrides[get_db] = override_get_db + app.dependency_overrides[get_current_user] = override_get_current_user + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + yield client + + app.dependency_overrides.clear() + + +@pytest_asyncio.fixture +async def client_with_enterprise_user(async_session, enterprise_user): + async def override_get_db(): + yield async_session + + async def override_get_current_user(): + return enterprise_user + + app.dependency_overrides[get_db] = override_get_db + app.dependency_overrides[get_current_user] = override_get_current_user + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + yield client + + app.dependency_overrides.clear() + + +class TestCreatePaymentOrder: + @pytest.mark.asyncio + async def test_create_order_returns_201_with_order_id_and_pay_url(self, client_with_free_user): + response = await client_with_free_user.post( + "/api/v1/payments/orders", + json={"plan": "starter", "payment_provider": "wechat"}, + ) + assert response.status_code == 201 + data = response.json() + assert "order_id" in data + assert "pay_url" in data + assert data["amount"] == 199 + assert data["status"] == "pending" + assert data["currency"] == "CNY" + + @pytest.mark.asyncio + async def test_create_order_with_invalid_plan_returns_422(self, client_with_free_user): + response = await client_with_free_user.post( + "/api/v1/payments/orders", + json={"plan": "invalid_plan", "payment_provider": "wechat"}, + ) + assert response.status_code == 422 + + @pytest.mark.asyncio + async def test_create_order_for_free_plan_returns_400(self, client_with_free_user): + response = await client_with_free_user.post( + "/api/v1/payments/orders", + json={"plan": "free", "payment_provider": "wechat"}, + ) + assert response.status_code == 400 + + @pytest.mark.asyncio + async def test_create_order_pro_plan(self, client_with_free_user): + response = await client_with_free_user.post( + "/api/v1/payments/orders", + json={"plan": "pro", "payment_provider": "alipay"}, + ) + assert response.status_code == 201 + data = response.json() + assert data["amount"] == 599 + + @pytest.mark.asyncio + async def test_create_order_enterprise_plan(self, client_with_free_user): + response = await client_with_free_user.post( + "/api/v1/payments/orders", + json={"plan": "enterprise", "payment_provider": "wechat"}, + ) + assert response.status_code == 201 + data = response.json() + assert data["amount"] == 1999 + + +class TestWechatCallback: + @pytest.mark.asyncio + async def test_wechat_callback_updates_order_and_activates_subscription( + self, async_session, free_user + ): + order_id = uuid.uuid4() + order = PaymentOrderModel( + id=order_id, + user_id=free_user.id, + plan="starter", + amount=199, + payment_provider="wechat", + status="pending", + pay_url=f"mock://pay/{order_id}", + ) + async_session.add(order) + await async_session.commit() + + callback = PaymentCallback( + order_id=str(order_id), + payment_id=f"wx_pay_{order_id}", + amount=199, + status="success", + raw_data={"out_trade_no": str(order_id), "result_code": "SUCCESS"}, + ) + + from app.api.payments import _process_callback + result = await _process_callback(async_session, callback, "wechat") + await async_session.commit() + + await async_session.refresh(order) + assert order.status == "paid" + assert order.payment_id is not None + + +class TestAlipayCallback: + @pytest.mark.asyncio + async def test_alipay_callback_updates_order_and_activates_subscription( + self, async_session, free_user + ): + order_id = uuid.uuid4() + order = PaymentOrderModel( + id=order_id, + user_id=free_user.id, + plan="pro", + amount=599, + payment_provider="alipay", + status="pending", + pay_url=f"mock://pay/{order_id}", + ) + async_session.add(order) + await async_session.commit() + + callback = PaymentCallback( + order_id=str(order_id), + payment_id=f"ali_pay_{order_id}", + amount=599, + status="success", + raw_data={"out_trade_no": str(order_id), "trade_status": "TRADE_SUCCESS"}, + ) + + from app.api.payments import _process_callback + result = await _process_callback(async_session, callback, "alipay") + await async_session.commit() + + await async_session.refresh(order) + assert order.status == "paid" + assert order.payment_id is not None + + +class TestQueryOrder: + @pytest.mark.asyncio + async def test_query_order_returns_current_status(self, client_with_free_user, async_session, free_user): + order_id = uuid.uuid4() + order = PaymentOrderModel( + id=order_id, + user_id=free_user.id, + plan="starter", + amount=199, + payment_provider="wechat", + status="pending", + pay_url=f"mock://pay/{order_id}", + ) + async_session.add(order) + await async_session.commit() + + response = await client_with_free_user.get(f"/api/v1/payments/orders/{order_id}") + assert response.status_code == 200 + data = response.json() + assert data["order_id"] == str(order_id) + assert data["status"] == "pending" + assert data["plan"] == "starter" + assert data["amount"] == 199 + + @pytest.mark.asyncio + async def test_query_nonexistent_order_returns_404(self, client_with_free_user): + fake_id = uuid.uuid4() + response = await client_with_free_user.get(f"/api/v1/payments/orders/{fake_id}") + assert response.status_code == 404 + + +class TestSubscriptionEnforcement: + @pytest.mark.asyncio + async def test_free_user_accessing_pro_only_endpoint_gets_403(self, client_with_free_user): + test_app = FastAPI() + + @test_app.get("/test/pro-only") + async def pro_only(user: User = Depends(SubscriptionEnforcement.require_plan("pro", "enterprise"))): + return {"message": "ok"} + + async def override_get_current_user(): + return _make_user(plan="free") + + test_app.dependency_overrides[get_current_user] = override_get_current_user + + transport = ASGITransport(app=test_app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.get("/test/pro-only") + + assert response.status_code == 403 + data = response.json() + assert data["detail"]["required_plan"] == "pro" + assert data["detail"]["current_plan"] == "free" + + @pytest.mark.asyncio + async def test_pro_user_accessing_pro_only_endpoint_succeeds(self, client_with_pro_user): + test_app = FastAPI() + + @test_app.get("/test/pro-only-2") + async def pro_only(user: User = Depends(SubscriptionEnforcement.require_plan("pro", "enterprise"))): + return {"message": "ok"} + + async def override_get_current_user(): + return _make_user(plan="pro") + + test_app.dependency_overrides[get_current_user] = override_get_current_user + + transport = ASGITransport(app=test_app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.get("/test/pro-only-2") + + assert response.status_code == 200 + + +class TestQuotaCheck: + @pytest.mark.asyncio + async def test_quota_check_returns_remaining_usage_info(self, client_with_free_user): + test_app = FastAPI() + + @test_app.get("/test/quota") + async def check_quota(quota=Depends(SubscriptionEnforcement.check_quota("queries"))): + return quota + + async def override_get_current_user(): + return _make_user(plan="free") + + async def override_get_db(): + yield AsyncMock() + + test_app.dependency_overrides[get_current_user] = override_get_current_user + test_app.dependency_overrides[get_db] = override_get_db + + transport = ASGITransport(app=test_app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.get("/test/quota") + + assert response.status_code == 200 + data = response.json() + assert data["resource"] == "queries" + assert data["plan"] == "free" + assert "remaining" in data diff --git a/tests/test_content_generation.py b/backend/tests/test_content_pipeline/test_content_generation.py similarity index 100% rename from tests/test_content_generation.py rename to backend/tests/test_content_pipeline/test_content_generation.py diff --git a/backend/tests/test_infrastructure/__init__.py b/backend/tests/test_infrastructure/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/tests/test_infrastructure/test_middleware_unified.py b/backend/tests/test_infrastructure/test_middleware_unified.py new file mode 100644 index 0000000..f1fc1f1 --- /dev/null +++ b/backend/tests/test_infrastructure/test_middleware_unified.py @@ -0,0 +1,76 @@ +"""验证 middleware/ 统一导入结构。 + +合并策略: +- app/middleware/ 作为统一目录 +- monitoring/metrics.py (Prometheus指标定义) -> middleware/prometheus_metrics.py +- monitoring/middleware.py (MonitoringMiddleware) -> 合并到 middleware/metrics.py +- monitoring/agent_hooks.py -> middleware/agent_hooks.py +- monitoring/llm_metrics.py -> middleware/llm_metrics.py +- monitoring/ 目录已删除 +""" +import pytest + + +class TestUnifiedMiddlewareImports: + """测试统一后的 app.middleware 导入路径是否可用。""" + + def test_import_logging_filter(self): + from app.middleware.logging_filter import APIKeyFilter + assert APIKeyFilter is not None + + def test_import_logging_middleware(self): + from app.middleware.logging_middleware import RequestLoggingMiddleware + assert RequestLoggingMiddleware is not None + + def test_import_rate_limit(self): + from app.middleware.rate_limit import RateLimitMiddleware + assert RateLimitMiddleware is not None + + def test_import_request_id(self): + from app.middleware.request_id import RequestIdMiddleware, REQUEST_ID_HEADER + assert RequestIdMiddleware is not None + assert REQUEST_ID_HEADER == "X-Request-ID" + + def test_import_metrics_middleware(self): + from app.middleware.metrics import MetricsMiddleware + assert MetricsMiddleware is not None + + def test_import_monitoring_middleware_from_metrics(self): + from app.middleware.metrics import MonitoringMiddleware + assert MonitoringMiddleware is not None + + def test_import_prometheus_metrics(self): + from app.middleware.prometheus_metrics import ( + API_REQUESTS_TOTAL, + API_REQUEST_DURATION_SECONDS, + API_REQUESTS_IN_PROGRESS, + AGENT_EXECUTIONS_TOTAL, + AGENT_EXECUTION_DURATION_SECONDS, + AGENT_RUNNING_TASKS, + LLM_REQUESTS_TOTAL, + LLM_REQUEST_DURATION_SECONDS, + LLM_TOKENS_TOTAL, + LLM_COST_ESTIMATED, + BRAND_COUNT, + QUERY_COUNT_TOTAL, + CONTENT_GENERATED_TOTAL, + CITATION_DETECTED_TOTAL, + SERVICE_INFO, + ) + assert API_REQUESTS_TOTAL is not None + assert SERVICE_INFO is not None + + def test_import_agent_hooks(self): + from app.middleware.agent_hooks import agent_execution_context, record_agent_execution + assert agent_execution_context is not None + assert record_agent_execution is not None + + def test_import_llm_metrics(self): + from app.middleware.llm_metrics import get_llm_metrics, LLMMetricsWrapper + assert get_llm_metrics is not None + assert LLMMetricsWrapper is not None + + def test_import_excluded_paths(self): + from app.middleware.metrics import _SKIP_PATHS + assert isinstance(_SKIP_PATHS, set) + assert "/health" in _SKIP_PATHS diff --git a/tests/test_monitoring.py b/backend/tests/test_infrastructure/test_monitoring.py similarity index 99% rename from tests/test_monitoring.py rename to backend/tests/test_infrastructure/test_monitoring.py index 4e34d6e..a8c6639 100644 --- a/tests/test_monitoring.py +++ b/backend/tests/test_infrastructure/test_monitoring.py @@ -12,7 +12,7 @@ import pytest import pytest_asyncio from prometheus_client import REGISTRY -from app.monitoring.metrics import ( +from app.middleware.prometheus_metrics import ( API_REQUESTS_TOTAL, AGENT_EXECUTIONS_TOTAL, LLM_REQUESTS_TOTAL, diff --git a/tests/test_business_flow.py b/backend/tests/test_integration/test_business_flow.py similarity index 81% rename from tests/test_business_flow.py rename to backend/tests/test_integration/test_business_flow.py index a2da714..a782e7e 100644 --- a/tests/test_business_flow.py +++ b/backend/tests/test_integration/test_business_flow.py @@ -7,14 +7,8 @@ import pytest_asyncio from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from app.main import app -from app.api.deps import get_current_user -from app.database import get_db from app.models.citation_record import CitationRecord -from app.models.query import Query from app.models.query_task import QueryTask -from app.models.user import User -from app.services.auth import hash_password # --------------------------------------------------------------------------- @@ -22,14 +16,8 @@ from app.services.auth import hash_password # --------------------------------------------------------------------------- @pytest_asyncio.fixture async def user_a(test_session: AsyncSession): - """Create a real user in the test DB.""" - user = User( - email="user_a@example.com", - password_hash=hash_password("password123"), - name="User A", - plan="free", - max_queries=5, - ) + from tests.fixtures.auth import _make_user + user = _make_user(email="user_a@example.com", plan="free") test_session.add(user) await test_session.commit() await test_session.refresh(user) @@ -38,14 +26,8 @@ async def user_a(test_session: AsyncSession): @pytest_asyncio.fixture async def user_b(test_session: AsyncSession): - """Create another real user in the test DB.""" - user = User( - email="user_b@example.com", - password_hash=hash_password("password456"), - name="User B", - plan="free", - max_queries=5, - ) + from tests.fixtures.auth import _make_user + user = _make_user(email="user_b@example.com", plan="free") test_session.add(user) await test_session.commit() await test_session.refresh(user) @@ -53,24 +35,21 @@ async def user_b(test_session: AsyncSession): @pytest_asyncio.fixture -async def auth_client_a(async_client, user_a): - """Login user_a and return client with auth headers.""" - response = await async_client.post( +async def auth_client_a(plain_client, user_a): + response = await plain_client.post( "/api/v1/auth/login", - json={"email": "user_a@example.com", "password": "password123"}, + json={"email": "user_a@example.com", "password": "Test@123456"}, ) assert response.status_code == 200 token = response.json()["access_token"] - # Return a small helper or just the headers return {"Authorization": f"Bearer {token}"} @pytest_asyncio.fixture -async def auth_client_b(async_client, user_b): - """Login user_b and return auth headers.""" - response = await async_client.post( +async def auth_client_b(plain_client, user_b): + response = await plain_client.post( "/api/v1/auth/login", - json={"email": "user_b@example.com", "password": "password456"}, + json={"email": "user_b@example.com", "password": "Test@123456"}, ) assert response.status_code == 200 token = response.json()["access_token"] @@ -81,9 +60,8 @@ async def auth_client_b(async_client, user_b): # 1. 完整流程:注册 -> 登录 -> 创建查询词 -> 查看列表 # --------------------------------------------------------------------------- @pytest.mark.asyncio -async def test_full_user_flow(async_client, override_get_db): - # Register - reg_resp = await async_client.post( +async def test_full_user_flow(plain_client, override_get_db): + reg_resp = await plain_client.post( "/api/v1/auth/register", json={"email": "flow@example.com", "password": "flowpass", "name": "Flow User"}, ) @@ -92,7 +70,7 @@ async def test_full_user_flow(async_client, override_get_db): assert reg_data["email"] == "flow@example.com" # Login - login_resp = await async_client.post( + login_resp = await plain_client.post( "/api/v1/auth/login", json={"email": "flow@example.com", "password": "flowpass"}, ) @@ -101,7 +79,7 @@ async def test_full_user_flow(async_client, override_get_db): headers = {"Authorization": f"Bearer {token}"} # Create query - create_resp = await async_client.post( + create_resp = await plain_client.post( "/api/v1/queries/", headers=headers, json={ @@ -117,7 +95,7 @@ async def test_full_user_flow(async_client, override_get_db): assert query_data["target_brand"] == "FlowBrand" # List queries - list_resp = await async_client.get("/api/v1/queries/", headers=headers) + list_resp = await plain_client.get("/api/v1/queries/", headers=headers) assert list_resp.status_code == 200 list_data = list_resp.json() assert list_data["total"] == 1 @@ -129,11 +107,11 @@ async def test_full_user_flow(async_client, override_get_db): # 2. 查询词生命周期:创建 -> 更新 -> 暂停 -> 恢复 -> 删除 # --------------------------------------------------------------------------- @pytest.mark.asyncio -async def test_query_lifecycle(async_client, override_get_db, auth_client_a): +async def test_query_lifecycle(plain_client, override_get_db, auth_client_a): headers = auth_client_a # Create - create_resp = await async_client.post( + create_resp = await plain_client.post( "/api/v1/queries/", headers=headers, json={ @@ -147,7 +125,7 @@ async def test_query_lifecycle(async_client, override_get_db, auth_client_a): query_id = create_resp.json()["id"] # Update - update_resp = await async_client.put( + update_resp = await plain_client.put( f"/api/v1/queries/{query_id}", headers=headers, json={"keyword": "updated keyword", "frequency": "daily"}, @@ -157,7 +135,7 @@ async def test_query_lifecycle(async_client, override_get_db, auth_client_a): assert update_resp.json()["frequency"] == "daily" # Pause (update status to paused) - pause_resp = await async_client.put( + pause_resp = await plain_client.put( f"/api/v1/queries/{query_id}", headers=headers, json={"status": "paused"}, @@ -166,7 +144,7 @@ async def test_query_lifecycle(async_client, override_get_db, auth_client_a): assert pause_resp.json()["status"] == "paused" # Resume (update status back to active) - resume_resp = await async_client.put( + resume_resp = await plain_client.put( f"/api/v1/queries/{query_id}", headers=headers, json={"status": "active"}, @@ -175,14 +153,14 @@ async def test_query_lifecycle(async_client, override_get_db, auth_client_a): assert resume_resp.json()["status"] == "active" # Delete - delete_resp = await async_client.delete( + delete_resp = await plain_client.delete( f"/api/v1/queries/{query_id}", headers=headers, ) assert delete_resp.status_code == 204 # Verify deletion - get_resp = await async_client.get(f"/api/v1/queries/{query_id}", headers=headers) + get_resp = await plain_client.get(f"/api/v1/queries/{query_id}", headers=headers) assert get_resp.status_code == 404 @@ -190,12 +168,12 @@ async def test_query_lifecycle(async_client, override_get_db, auth_client_a): # 3. 查询数量限制:free 用户最多 5 个 # --------------------------------------------------------------------------- @pytest.mark.asyncio -async def test_query_limit_free_user(async_client, override_get_db, auth_client_a): +async def test_query_limit_free_user(plain_client, override_get_db, auth_client_a): headers = auth_client_a # Create 5 queries (limit for free plan) for i in range(5): - resp = await async_client.post( + resp = await plain_client.post( "/api/v1/queries/", headers=headers, json={ @@ -208,7 +186,7 @@ async def test_query_limit_free_user(async_client, override_get_db, auth_client_ assert resp.status_code == 201, f"Failed to create query {i}" # 6th query should be rejected - resp = await async_client.post( + resp = await plain_client.post( "/api/v1/queries/", headers=headers, json={ @@ -227,12 +205,12 @@ async def test_query_limit_free_user(async_client, override_get_db, auth_client_ # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_citation_stats_correctness( - async_client, override_get_db, auth_client_a, test_session + plain_client, override_get_db, auth_client_a, test_session ): headers = auth_client_a # Create a query - create_resp = await async_client.post( + create_resp = await plain_client.post( "/api/v1/queries/", headers=headers, json={ @@ -277,7 +255,7 @@ async def test_citation_stats_correctness( await test_session.commit() # Call stats API - stats_resp = await async_client.get("/api/v1/citations/stats", headers=headers) + stats_resp = await plain_client.get("/api/v1/citations/stats", headers=headers) assert stats_resp.status_code == 200 stats = stats_resp.json() @@ -300,12 +278,12 @@ async def test_citation_stats_correctness( # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_export_csv( - async_client, override_get_db, auth_client_a, test_session + plain_client, override_get_db, auth_client_a, test_session ): headers = auth_client_a # Create query - create_resp = await async_client.post( + create_resp = await plain_client.post( "/api/v1/queries/", headers=headers, json={ @@ -331,7 +309,7 @@ async def test_export_csv( await test_session.commit() # Export CSV - export_resp = await async_client.get( + export_resp = await plain_client.get( f"/api/v1/reports/export/csv?query_id={query_id}", headers=headers, ) @@ -347,12 +325,12 @@ async def test_export_csv( # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_run_now_creates_query_task( - async_client, override_get_db, auth_client_a, test_session + plain_client, override_get_db, auth_client_a, test_session ): headers = auth_client_a # Create an active query - create_resp = await async_client.post( + create_resp = await plain_client.post( "/api/v1/queries/", headers=headers, json={ @@ -366,7 +344,7 @@ async def test_run_now_creates_query_task( query_id = create_resp.json()["id"] # Trigger run-now - run_resp = await async_client.post( + run_resp = await plain_client.post( f"/api/v1/queries/{query_id}/run-now", headers=headers, ) @@ -391,13 +369,13 @@ async def test_run_now_creates_query_task( # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_permission_isolation( - async_client, override_get_db, auth_client_a, auth_client_b + plain_client, override_get_db, auth_client_a, auth_client_b ): headers_a = auth_client_a headers_b = auth_client_b # User A creates a query - create_resp = await async_client.post( + create_resp = await plain_client.post( "/api/v1/queries/", headers=headers_a, json={ @@ -411,14 +389,14 @@ async def test_permission_isolation( query_id = create_resp.json()["id"] # User B tries to access User A's query - get_resp = await async_client.get( + get_resp = await plain_client.get( f"/api/v1/queries/{query_id}", headers=headers_b, ) assert get_resp.status_code == 404 # User B tries to update User A's query - put_resp = await async_client.put( + put_resp = await plain_client.put( f"/api/v1/queries/{query_id}", headers=headers_b, json={"keyword": "hacked"}, @@ -426,14 +404,14 @@ async def test_permission_isolation( assert put_resp.status_code == 404 # User B tries to delete User A's query - del_resp = await async_client.delete( + del_resp = await plain_client.delete( f"/api/v1/queries/{query_id}", headers=headers_b, ) assert del_resp.status_code == 404 # User B tries to run-now User A's query - run_resp = await async_client.post( + run_resp = await plain_client.post( f"/api/v1/queries/{query_id}/run-now", headers=headers_b, ) diff --git a/backend/tests/test_integration/test_monetization_flow.py b/backend/tests/test_integration/test_monetization_flow.py new file mode 100644 index 0000000..1644942 --- /dev/null +++ b/backend/tests/test_integration/test_monetization_flow.py @@ -0,0 +1,191 @@ +import uuid +from unittest.mock import AsyncMock, patch + +import pytest + +from tests.fixtures.auth import _make_user, _to_uuid + + +class TestMonetizationFlow: + @pytest.mark.asyncio + async def test_full_monetization_flow(self, async_client, async_session): + user = _make_user(email="monetization@example.com", plan="free") + async_session.add(user) + await async_session.commit() + await async_session.refresh(user) + + from app.api.deps import get_current_user, get_db + from app.main import app + + async def override_get_db(): + yield async_session + + async def override_get_current_user(): + return user + + app.dependency_overrides[get_db] = override_get_db + app.dependency_overrides[get_current_user] = override_get_current_user + + try: + brand_resp = await async_client.post( + "/api/v1/onboarding/brand", + json={"name": "MonoBrand", "industry": "technology"}, + ) + assert brand_resp.status_code == 201 + brand_data = brand_resp.json() + brand_id = brand_data["id"] + assert brand_data["name"] == "MonoBrand" + + onboarding_resp = await async_client.post( + f"/api/v1/onboarding/complete/{brand_id}" + ) + assert onboarding_resp.status_code == 200 + assert onboarding_resp.json()["success"] is True + + with patch("app.api.onboarding.DataCollectorService") as mock_collector_cls: + mock_collector = AsyncMock() + mock_collector_cls.return_value = mock_collector + + from app.services.diagnosis.data_collector import DataCollectionResult + from app.services.diagnosis.geo_diagnosis import GEODiagnosisInput, GEODiagnosisService + + input_data = GEODiagnosisInput( + has_direct_answer=True, + has_brand_definition=True, + has_author_bio=True, + author_credentials_complete=0.8, + has_organization=True, + content_depth_score=0.7, + answer_ownership_rate=0.3, + ) + mock_collector.collect.return_value = DataCollectionResult( + diagnosis_input=input_data, + ) + + service = GEODiagnosisService() + result = service.diagnose(input_data) + + with patch("app.api.onboarding.GEODiagnosisService", return_value=service): + health_resp = await async_client.get( + f"/api/v1/onboarding/health-report/{brand_id}" + ) + + assert health_resp.status_code == 200 + health_data = health_resp.json() + assert health_data["brand_id"] == brand_id + assert "overall_score" in health_data + assert "dimensions" in health_data + assert health_data["is_full_report"] is False + + order_resp = await async_client.post( + "/api/v1/payments/orders", + json={"plan": "pro", "payment_provider": "wechat"}, + ) + assert order_resp.status_code == 201 + order_data = order_resp.json() + order_id = order_data["order_id"] + assert "pay_url" in order_data + assert order_data["amount"] == 599 + assert order_data["status"] == "pending" + + from app.models.payment_order import PaymentOrder as PaymentOrderModel + from app.services.payment.base import PaymentCallback + + from sqlalchemy import select + stmt = select(PaymentOrderModel).where( + PaymentOrderModel.id == uuid.UUID(order_id) + ) + db_result = await async_session.execute(stmt) + order = db_result.scalar_one() + + callback = PaymentCallback( + order_id=str(order.id), + payment_id=f"wx_pay_{order.id}", + amount=599, + status="success", + raw_data={"out_trade_no": str(order.id), "result_code": "SUCCESS"}, + ) + + from app.api.payments import _process_callback + await _process_callback(async_session, callback, "wechat") + await async_session.commit() + + await async_session.refresh(order) + assert order.status == "paid" + assert order.payment_id is not None + + await async_session.refresh(user) + assert user.plan == "pro" + assert user.max_queries == 50 + + with patch("app.api.onboarding.DataCollectorService") as mock_collector_cls: + mock_collector = AsyncMock() + mock_collector_cls.return_value = mock_collector + + from app.services.diagnosis.data_collector import DataCollectionResult + from app.services.diagnosis.geo_diagnosis import GEODiagnosisInput, GEODiagnosisService + + input_data = GEODiagnosisInput( + has_direct_answer=True, + has_brand_definition=True, + has_author_bio=True, + author_credentials_complete=0.8, + has_organization=True, + content_depth_score=0.7, + answer_ownership_rate=0.3, + ) + mock_collector.collect.return_value = DataCollectionResult( + diagnosis_input=input_data, + ) + + service = GEODiagnosisService() + result = service.diagnose(input_data) + + with patch("app.api.onboarding.GEODiagnosisService", return_value=service): + paid_health_resp = await async_client.get( + f"/api/v1/onboarding/health-report/{brand_id}" + ) + + assert paid_health_resp.status_code == 200 + paid_health_data = paid_health_resp.json() + assert paid_health_data["is_full_report"] is True + assert len(paid_health_data["dimensions"]) == 6 + + from app.models.diagnosis_record import DiagnosisRecord + + diag_record = DiagnosisRecord( + brand_id=uuid.UUID(brand_id), + user_id=_to_uuid(user.id), + diagnosis_type="geo", + status="completed", + overall_score=result.overall_score, + result_json=result.to_dict(), + ) + async_session.add(diag_record) + await async_session.commit() + await async_session.refresh(diag_record) + + attr_resp = await async_client.post( + "/api/v1/attribution/start", + json={"brand_id": brand_id}, + ) + assert attr_resp.status_code == 200 + attr_data = attr_resp.json() + assert "id" in attr_data + assert attr_data["brand_id"] == brand_id + assert attr_data["baseline_score"] == result.overall_score + assert attr_data["status"] == "tracking" + + roi_resp = await async_client.get( + f"/api/v1/attribution/roi/{brand_id}" + ) + assert roi_resp.status_code == 200 + roi_data = roi_resp.json() + assert "roi_percentage" in roi_data + assert "value_generated" in roi_data + assert "subscription_cost" in roi_data + assert roi_data["brand_name"] == "MonoBrand" + assert roi_data["current_plan"] == "pro" + + finally: + app.dependency_overrides.clear() diff --git a/backend/tests/test_models/test_citation_record.py b/backend/tests/test_models/test_citation_record.py new file mode 100644 index 0000000..8fff54f --- /dev/null +++ b/backend/tests/test_models/test_citation_record.py @@ -0,0 +1,151 @@ +"""Tests for CitationRecord.from_citation_result() factory method""" +import uuid + +import pytest + +from app.models.citation_record import CitationRecord + + +class TestFromCitationResultBasicFields: + """基本字段映射""" + + def test_basic_fields_mapped(self): + query_id = uuid.uuid4() + result = { + "cited": True, + "position": 2, + "citation_text": "Brand X is recommended", + "competitor_brands": ["Brand Y", "Brand Z"], + "raw_response": "Some raw response", + "confidence": 0.95, + "match_type": "direct", + } + record = CitationRecord.from_citation_result( + query_id=query_id, + platform="kimi", + result=result, + ) + assert record.query_id == query_id + assert record.platform == "kimi" + assert record.cited is True + assert record.citation_position == 2 + assert record.citation_text == "Brand X is recommended" + assert record.competitor_brands == ["Brand Y", "Brand Z"] + assert record.confidence == 0.95 + assert record.match_type == "direct" + + def test_missing_optional_fields_use_defaults(self): + query_id = uuid.uuid4() + result = {"cited": False} + record = CitationRecord.from_citation_result( + query_id=query_id, + platform="wenxin", + result=result, + ) + assert record.cited is False + assert record.citation_position is None + assert record.citation_text is None + assert record.competitor_brands == [] + assert record.raw_response == "" + assert record.confidence is None + assert record.match_type is None + + +class TestFromCitationResultSourceFields: + """引用源分析字段""" + + def test_source_fields_populated(self): + query_id = uuid.uuid4() + result = { + "cited": True, + "data_source": "ai_platform", + "source_urls": ["https://example.com"], + "source_titles": ["Example Title"], + "citation_contexts": ["context snippet"], + "ai_response_text": "AI said something", + } + record = CitationRecord.from_citation_result( + query_id=query_id, + platform="kimi", + result=result, + ) + assert record.data_source == "ai_platform" + assert record.source_urls == ["https://example.com"] + assert record.source_titles == ["Example Title"] + assert record.citation_contexts == ["context snippet"] + assert record.ai_response_text == "AI said something" + + def test_source_fields_default_to_none(self): + query_id = uuid.uuid4() + result = {"cited": False} + record = CitationRecord.from_citation_result( + query_id=query_id, + platform="kimi", + result=result, + ) + assert record.data_source is None + assert record.source_urls is None + assert record.source_titles is None + assert record.citation_contexts is None + assert record.ai_response_text == "" + + +class TestFromCitationResultSanitization: + """raw_response 和 ai_response_text 应该被清理""" + + def test_raw_response_sanitized(self): + query_id = uuid.uuid4() + result = { + "cited": True, + "raw_response": "Hello\x00World\x08!", + } + record = CitationRecord.from_citation_result( + query_id=query_id, + platform="kimi", + result=result, + ) + # NULL 字节和 \x08 应该被移除 + assert "\x00" not in record.raw_response + assert "\x08" not in record.raw_response + assert "Hello" in record.raw_response + assert "World" in record.raw_response + + def test_ai_response_text_sanitized(self): + query_id = uuid.uuid4() + result = { + "cited": True, + "ai_response_text": "Text\x00with\x0bcontrol\x1fchars", + } + record = CitationRecord.from_citation_result( + query_id=query_id, + platform="kimi", + result=result, + ) + assert "\x00" not in record.ai_response_text + assert "\x0b" not in record.ai_response_text + assert "\x1f" not in record.ai_response_text + + def test_newline_preserved_in_sanitization(self): + query_id = uuid.uuid4() + result = { + "cited": True, + "raw_response": "Line1\nLine2\tTab\rReturn", + } + record = CitationRecord.from_citation_result( + query_id=query_id, + platform="kimi", + result=result, + ) + assert "\n" in record.raw_response + assert "\t" in record.raw_response + assert "\r" in record.raw_response + + def test_none_raw_response_becomes_empty_string(self): + query_id = uuid.uuid4() + result = {"cited": True, "raw_response": None} + record = CitationRecord.from_citation_result( + query_id=query_id, + platform="kimi", + result=result, + ) + assert record.raw_response == "" diff --git a/backend/tests/test_platform_adapters.py b/backend/tests/test_platform_adapters.py index 85f2c94..5b6ad67 100644 --- a/backend/tests/test_platform_adapters.py +++ b/backend/tests/test_platform_adapters.py @@ -10,13 +10,10 @@ AI平台适配器测试 - 验证各平台适配器是否正常工作 import pytest from unittest.mock import Mock, patch, AsyncMock, MagicMock -import sys -sys.path.insert(0, '/Users/Chiguyong/Code/Fischer/geo/backend') - -from app.workers.platforms.kimi import KimiAdapter -from app.workers.platforms.wenxin import WenxinAdapter -from app.workers.platforms.doubao import DoubaoAdapter +from app.services.ai_engine.kimi import KimiAdapter +from app.services.ai_engine.wenxin import WenxinAdapter +from app.services.ai_engine.doubao import DoubaoAdapter from app.workers.citation_extractor import ( extract_markdown_links, extract_urls_with_context, diff --git a/backend/tests/test_services/test_brand_citation_llm.py b/backend/tests/test_services/test_brand_citation_llm.py new file mode 100644 index 0000000..ae720d0 --- /dev/null +++ b/backend/tests/test_services/test_brand_citation_llm.py @@ -0,0 +1,221 @@ +"""Tests for BrandCitationLLMService - replacing LLMAdapter""" +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from app.services.llm.brand_citation_service import BrandCitationLLMService, BRAND_CITATION_PROMPT +from app.services.llm.base import LLMError + + +class TestBrandCitationLLMServiceInit: + """服务初始化测试""" + + def test_init_with_default_factory(self): + service = BrandCitationLLMService() + assert service is not None + assert service._provider_name is None + assert service._model is None + + def test_init_with_custom_provider(self): + service = BrandCitationLLMService(provider_name="deepseek", model="deepseek-chat") + assert service._provider_name == "deepseek" + assert service._model == "deepseek-chat" + + +class TestBrandCitationLLMServiceQuery: + """品牌引用查询测试""" + + @pytest.mark.asyncio + async def test_query_brand_citation_cited(self): + service = BrandCitationLLMService() + mock_provider = AsyncMock() + mock_provider.chat.return_value = MagicMock( + content='{"cited": true, "position": 1, "citation_text": "Brand X is great", "sentiment": "positive", "confidence": 0.9}' + ) + with patch.object(service, '_get_provider', return_value=mock_provider): + with patch('app.services.llm.brand_citation_service.settings') as mock_settings: + mock_settings.ENABLE_LLM = True + result = await service.query_brand_citation( + keyword="AI搜索", + brand_name="Brand X", + brand_aliases=["BX"] + ) + assert result.cited is True + assert result.position == 1 + assert result.confidence == 0.9 + + @pytest.mark.asyncio + async def test_query_brand_citation_not_cited(self): + service = BrandCitationLLMService() + mock_provider = AsyncMock() + mock_provider.chat.return_value = MagicMock( + content='{"cited": false, "position": null, "citation_text": null, "sentiment": "neutral", "confidence": 0.8}' + ) + with patch.object(service, '_get_provider', return_value=mock_provider): + with patch('app.services.llm.brand_citation_service.settings') as mock_settings: + mock_settings.ENABLE_LLM = True + result = await service.query_brand_citation( + keyword="AI搜索", + brand_name="Brand X", + brand_aliases=[] + ) + assert result.cited is False + assert result.position is None + assert result.citation_text is None + + @pytest.mark.asyncio + async def test_query_brand_citation_with_markdown_json(self): + """Test that markdown-wrapped JSON is handled via extract_json""" + service = BrandCitationLLMService() + mock_provider = AsyncMock() + mock_provider.chat.return_value = MagicMock( + content='```json\n{"cited": true, "position": 2, "citation_text": "test", "sentiment": "positive", "confidence": 0.7}\n```' + ) + with patch.object(service, '_get_provider', return_value=mock_provider): + with patch('app.services.llm.brand_citation_service.settings') as mock_settings: + mock_settings.ENABLE_LLM = True + result = await service.query_brand_citation( + keyword="test", + brand_name="Brand X", + brand_aliases=[] + ) + assert result.cited is True + assert result.position == 2 + + @pytest.mark.asyncio + async def test_query_brand_citation_sentiment_positive(self): + """测试正面情感""" + service = BrandCitationLLMService() + mock_provider = AsyncMock() + mock_provider.chat.return_value = MagicMock( + content='{"cited": true, "position": 2, "citation_text": "YYY品牌产品质量非常好", "sentiment": "positive", "confidence": 0.92}' + ) + with patch.object(service, '_get_provider', return_value=mock_provider): + with patch('app.services.llm.brand_citation_service.settings') as mock_settings: + mock_settings.ENABLE_LLM = True + result = await service.query_brand_citation( + keyword="AI搜索", + brand_name="YYY", + brand_aliases=[] + ) + assert result.sentiment == "positive" + + @pytest.mark.asyncio + async def test_query_brand_citation_sentiment_negative(self): + """测试负面情感""" + service = BrandCitationLLMService() + mock_provider = AsyncMock() + mock_provider.chat.return_value = MagicMock( + content='{"cited": true, "position": 3, "citation_text": "ZZZ品牌存在质量问题", "sentiment": "negative", "confidence": 0.88}' + ) + with patch.object(service, '_get_provider', return_value=mock_provider): + with patch('app.services.llm.brand_citation_service.settings') as mock_settings: + mock_settings.ENABLE_LLM = True + result = await service.query_brand_citation( + keyword="AI搜索", + brand_name="ZZZ", + brand_aliases=[] + ) + assert result.sentiment == "negative" + + @pytest.mark.asyncio + async def test_query_brand_citation_invalid_sentiment_defaults_neutral(self): + """测试无效sentiment值默认为neutral""" + service = BrandCitationLLMService() + mock_provider = AsyncMock() + mock_provider.chat.return_value = MagicMock( + content='{"cited": true, "position": 1, "citation_text": "test", "sentiment": "unknown", "confidence": 0.5}' + ) + with patch.object(service, '_get_provider', return_value=mock_provider): + with patch('app.services.llm.brand_citation_service.settings') as mock_settings: + mock_settings.ENABLE_LLM = True + result = await service.query_brand_citation( + keyword="test", + brand_name="Test", + brand_aliases=[] + ) + assert result.sentiment == "neutral" + + @pytest.mark.asyncio + async def test_query_brand_citation_confidence_clamped(self): + """测试confidence值被钳制到0.0-1.0范围""" + service = BrandCitationLLMService() + mock_provider = AsyncMock() + mock_provider.chat.return_value = MagicMock( + content='{"cited": true, "position": 1, "citation_text": "test", "sentiment": "neutral", "confidence": 1.5}' + ) + with patch.object(service, '_get_provider', return_value=mock_provider): + with patch('app.services.llm.brand_citation_service.settings') as mock_settings: + mock_settings.ENABLE_LLM = True + result = await service.query_brand_citation( + keyword="test", + brand_name="Test", + brand_aliases=[] + ) + assert result.confidence == 1.0 + + +class TestBrandCitationLLMServicePrompt: + """Prompt构建测试""" + + def test_build_prompt(self): + service = BrandCitationLLMService() + prompt = service._build_prompt("AI搜索", "Brand X", ["BX"]) + assert "Brand X" in prompt + assert "BX" in prompt + assert "AI搜索" in prompt + + def test_build_prompt_no_aliases(self): + service = BrandCitationLLMService() + prompt = service._build_prompt("AI搜索", "Brand X", []) + assert "Brand X" in prompt + assert "AI搜索" in prompt + assert "无" in prompt + + def test_brand_citation_prompt_constant_exists(self): + """验证BRAND_CITATION_PROMPT常量存在且包含关键占位符""" + assert "{keyword}" in BRAND_CITATION_PROMPT + assert "{brand_name}" in BRAND_CITATION_PROMPT + assert "{brand_aliases}" in BRAND_CITATION_PROMPT + + +class TestBrandCitationLLMServiceErrors: + """错误处理测试""" + + @pytest.mark.asyncio + async def test_llm_disabled_raises_error(self): + service = BrandCitationLLMService() + with patch('app.services.llm.brand_citation_service.settings') as mock_settings: + mock_settings.ENABLE_LLM = False + with pytest.raises(LLMError, match="LLM"): + await service.query_brand_citation("test", "Brand X", []) + + @pytest.mark.asyncio + async def test_llm_disabled_error_contains_config_guidance(self): + """LLM禁用时错误信息应包含配置指引""" + service = BrandCitationLLMService() + with patch('app.services.llm.brand_citation_service.settings') as mock_settings: + mock_settings.ENABLE_LLM = False + with pytest.raises(LLMError) as exc_info: + await service.query_brand_citation("test", "Brand X", []) + error_msg = str(exc_info.value) + assert "ENABLE_LLM" in error_msg + + @pytest.mark.asyncio + async def test_missing_required_fields_raises_error(self): + """响应缺少必需字段时抛出异常""" + service = BrandCitationLLMService() + mock_provider = AsyncMock() + mock_provider.chat.return_value = MagicMock( + content='{"invalid": "response"}' + ) + with patch.object(service, '_get_provider', return_value=mock_provider): + with patch('app.services.llm.brand_citation_service.settings') as mock_settings: + mock_settings.ENABLE_LLM = True + with pytest.raises((LLMError, ValueError)): + await service.query_brand_citation("test", "Brand X", []) + + @pytest.mark.asyncio + async def test_no_mock_result_method(self): + """确保不存在_get_mock_result方法(与旧LLMAdapter的区别)""" + service = BrandCitationLLMService() + assert not hasattr(service, "_get_mock_result") diff --git a/backend/tests/test_services/test_citation_pattern.py b/backend/tests/test_services/test_citation_pattern.py index 25e5923..99d7a39 100644 --- a/backend/tests/test_services/test_citation_pattern.py +++ b/backend/tests/test_services/test_citation_pattern.py @@ -3,7 +3,7 @@ from datetime import UTC, datetime import pytest from app.services.ai_engine.base import AIQueryResult, CitationInfo, EngineType -from app.services.citation_pattern import ( +from app.services.citation.citation_pattern import ( AuthoritySignalAnalyzer, CitationFormatAnalyzer, CitationPattern, diff --git a/backend/tests/test_services/test_content_generation.py b/backend/tests/test_services/test_content_generation.py new file mode 100644 index 0000000..73b354b --- /dev/null +++ b/backend/tests/test_services/test_content_generation.py @@ -0,0 +1,420 @@ +"""Tests for ContentGenerationService - extracted from api/content.py + +TDD RED phase: tests for the service that extracts the 3-stage +content generation flow (generate -> de-AI -> GEO optimize) +out of the API handler. +""" +import uuid +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from app.services.llm.base import LLMError, LLMResponse + + +def _make_llm_response(content: str) -> LLMResponse: + """Helper to create LLMResponse objects for mocking.""" + return LLMResponse(content=content, model="test-model") + + +class TestContentGenerationService: + """ContentGenerationService unit tests.""" + + @pytest.mark.asyncio + async def test_generate_content_basic_three_stages(self): + """Test basic 3-stage content generation (generate -> de-AI -> GEO optimize).""" + from app.services.content.content_generation_service import ( + ContentGenerationService, + ) + + service = ContentGenerationService() + + mock_provider = AsyncMock() + mock_provider.chat.side_effect = [ + _make_llm_response("Raw generated content"), + _make_llm_response("De-AIed content"), + _make_llm_response("GEO optimized content"), + ] + + with patch.object(service, "_get_provider", return_value=mock_provider): + result = await service.generate_content( + keyword="AI搜索", + brand_name="Brand X", + platform="wenxin", + content_style="专业严谨", + word_count=2000, + run_deai=True, + run_geo=True, + ) + + assert result is not None + assert result["content"] == "De-AIed content" + assert result["optimized_content"] == "GEO optimized content" + assert result["pipeline_stages"] is not None + assert len(result["pipeline_stages"]) == 3 + assert result["pipeline_stages"][0]["stage"] == "content_generation" + assert result["pipeline_stages"][1]["stage"] == "deai" + assert result["pipeline_stages"][2]["stage"] == "geo_optimization" + + @pytest.mark.asyncio + async def test_generate_content_skip_deai(self): + """Test generation with run_deai=False skips the de-AI stage.""" + from app.services.content.content_generation_service import ( + ContentGenerationService, + ) + + service = ContentGenerationService() + + mock_provider = AsyncMock() + mock_provider.chat.side_effect = [ + _make_llm_response("Raw generated content"), + _make_llm_response("GEO optimized content"), + ] + + with patch.object(service, "_get_provider", return_value=mock_provider): + result = await service.generate_content( + keyword="test", + brand_name="Brand X", + platform="wenxin", + run_deai=False, + run_geo=True, + ) + + assert result["content"] == "Raw generated content" + assert result["optimized_content"] == "GEO optimized content" + assert len(result["pipeline_stages"]) == 2 + stage_names = [s["stage"] for s in result["pipeline_stages"]] + assert "deai" not in stage_names + + @pytest.mark.asyncio + async def test_generate_content_skip_geo(self): + """Test generation with run_geo=False skips the GEO optimization stage.""" + from app.services.content.content_generation_service import ( + ContentGenerationService, + ) + + service = ContentGenerationService() + + mock_provider = AsyncMock() + mock_provider.chat.side_effect = [ + _make_llm_response("Raw generated content"), + _make_llm_response("De-AIed content"), + ] + + with patch.object(service, "_get_provider", return_value=mock_provider): + result = await service.generate_content( + keyword="test", + brand_name="Brand X", + platform="wenxin", + run_deai=True, + run_geo=False, + ) + + assert result["content"] == "De-AIed content" + assert result["optimized_content"] == "De-AIed content" + assert len(result["pipeline_stages"]) == 2 + stage_names = [s["stage"] for s in result["pipeline_stages"]] + assert "geo_optimization" not in stage_names + + @pytest.mark.asyncio + async def test_generate_content_skip_both_stages(self): + """Test generation with both run_deai=False and run_geo=False.""" + from app.services.content.content_generation_service import ( + ContentGenerationService, + ) + + service = ContentGenerationService() + + mock_provider = AsyncMock() + mock_provider.chat.side_effect = [ + _make_llm_response("Raw generated content"), + ] + + with patch.object(service, "_get_provider", return_value=mock_provider): + result = await service.generate_content( + keyword="test", + brand_name="Brand X", + platform="wenxin", + run_deai=False, + run_geo=False, + ) + + assert result["content"] == "Raw generated content" + assert result["optimized_content"] == "Raw generated content" + assert len(result["pipeline_stages"]) == 1 + + @pytest.mark.asyncio + async def test_generate_content_with_knowledge_context(self): + """Test that knowledge_context is passed through to the generation prompt.""" + from app.services.content.content_generation_service import ( + ContentGenerationService, + ) + + service = ContentGenerationService() + + mock_provider = AsyncMock() + mock_provider.chat.side_effect = [ + _make_llm_response("Raw content"), + _make_llm_response("De-AIed"), + _make_llm_response("Optimized"), + ] + + with patch.object(service, "_get_provider", return_value=mock_provider): + result = await service.generate_content( + keyword="test keyword", + brand_name="Brand X", + platform="wenxin", + knowledge_context="Some knowledge base content", + ) + + # The first call to provider.chat should include the knowledge context + first_call_args = mock_provider.chat.call_args_list[0] + messages = first_call_args[0][0] + # Verify knowledge context appears in the rendered messages + all_content = " ".join(str(m) for m in messages) + assert "Some knowledge base content" in all_content + + @pytest.mark.asyncio + async def test_generate_content_saves_to_database(self): + """Test that generated content is saved to database when db and user_id are provided.""" + from app.services.content.content_generation_service import ( + ContentGenerationService, + ) + + service = ContentGenerationService() + + mock_provider = AsyncMock() + mock_provider.chat.side_effect = [ + _make_llm_response("Raw content"), + _make_llm_response("De-AIed"), + _make_llm_response("Optimized"), + ] + mock_db = AsyncMock() + + with patch.object(service, "_get_provider", return_value=mock_provider): + result = await service.generate_content( + keyword="test", + brand_name="Brand X", + platform="wenxin", + db=mock_db, + user_id="test-user-123", + org_id=str(uuid.uuid4()), + ) + + # Verify db.add was called (Content + ContentVersion) + assert mock_db.add.call_count == 2 + mock_db.commit.assert_called_once() + + @pytest.mark.asyncio + async def test_generate_content_no_db_no_save(self): + """Test that when db is not provided, no database operations occur.""" + from app.services.content.content_generation_service import ( + ContentGenerationService, + ) + + service = ContentGenerationService() + + mock_provider = AsyncMock() + mock_provider.chat.side_effect = [ + _make_llm_response("Raw"), + _make_llm_response("De-AIed"), + _make_llm_response("Optimized"), + ] + + with patch.object(service, "_get_provider", return_value=mock_provider): + result = await service.generate_content( + keyword="test", + brand_name="Brand X", + platform="wenxin", + ) + + # No content_id when db is not provided + assert result.get("content_id") is None + + @pytest.mark.asyncio + async def test_generate_content_llm_error_propagates(self): + """Test that LLMError from the provider propagates correctly.""" + from app.services.content.content_generation_service import ( + ContentGenerationService, + ) + + service = ContentGenerationService() + + mock_provider = AsyncMock() + mock_provider.chat.side_effect = LLMError( + "API rate limit", provider="openai" + ) + + with patch.object(service, "_get_provider", return_value=mock_provider): + with pytest.raises(LLMError, match="API rate limit"): + await service.generate_content( + keyword="test", + brand_name="Brand X", + platform="wenxin", + ) + + @pytest.mark.asyncio + async def test_get_knowledge_context_with_ids(self): + """Test _get_knowledge_context retrieves context when knowledge_base_ids are provided.""" + from app.services.content.content_generation_service import ( + ContentGenerationService, + ) + + service = ContentGenerationService() + mock_db = AsyncMock() + + mock_rag = MagicMock() + mock_rag.search = AsyncMock(return_value=[ + {"content": "Knowledge chunk 1", "document_title": "Doc A"}, + {"content": "Knowledge chunk 2", "document_title": "Doc B"}, + ]) + + with patch( + "app.services.knowledge.rag_service.RAGService", + return_value=mock_rag, + ): + context = await service._get_knowledge_context( + db=mock_db, + brand_name="Brand X", + knowledge_base_ids=["kb-1"], + target_keyword="test keyword", + ) + + assert "Knowledge chunk 1" in context + assert "Knowledge chunk 2" in context + assert "Doc A" in context + + @pytest.mark.asyncio + async def test_get_knowledge_context_empty_ids(self): + """Test _get_knowledge_context returns empty string when no IDs provided.""" + from app.services.content.content_generation_service import ( + ContentGenerationService, + ) + + service = ContentGenerationService() + mock_db = AsyncMock() + + context = await service._get_knowledge_context( + db=mock_db, + brand_name="Brand X", + knowledge_base_ids=[], + target_keyword="test", + ) + + assert context == "" + + @pytest.mark.asyncio + async def test_get_knowledge_context_rag_failure_returns_empty(self): + """Test _get_knowledge_context returns empty string when RAG search fails.""" + from app.services.content.content_generation_service import ( + ContentGenerationService, + ) + + service = ContentGenerationService() + mock_db = AsyncMock() + + mock_rag = MagicMock() + mock_rag.search = AsyncMock(side_effect=Exception("RAG service down")) + + with patch( + "app.services.knowledge.rag_service.RAGService", + return_value=mock_rag, + ): + context = await service._get_knowledge_context( + db=mock_db, + brand_name="Brand X", + knowledge_base_ids=["kb-1"], + target_keyword="test", + ) + + assert context == "" + + @pytest.mark.asyncio + async def test_generate_content_passes_correct_prompt_variables(self): + """Test that the service passes correct variables to each prompt template stage.""" + from app.services.content.content_generation_service import ( + ContentGenerationService, + ) + + service = ContentGenerationService() + + mock_provider = AsyncMock() + mock_provider.chat.side_effect = [ + _make_llm_response("Raw content"), + _make_llm_response("De-AIed content"), + _make_llm_response("Optimized content"), + ] + + with patch.object(service, "_get_provider", return_value=mock_provider): + with patch( + "app.services.content.content_generation_service.CONTENT_GENERATOR_TEMPLATE" + ) as mock_gen_template, patch( + "app.services.content.content_generation_service.DEAI_TEMPLATE" + ) as mock_deai_template, patch( + "app.services.content.content_generation_service.GEO_OPTIMIZER_TEMPLATE" + ) as mock_geo_template: + mock_gen_template.render.return_value = [ + {"role": "user", "content": "gen prompt"} + ] + mock_deai_template.render.return_value = [ + {"role": "user", "content": "deai prompt"} + ] + mock_geo_template.render.return_value = [ + {"role": "user", "content": "geo prompt"} + ] + + await service.generate_content( + keyword="AI搜索", + brand_name="Brand X", + platform="微信公众号", + content_style="专业严谨", + word_count=3000, + knowledge_context="Some context", + ) + + # Verify CONTENT_GENERATOR_TEMPLATE.render called with correct variables + gen_call_kwargs = mock_gen_template.render.call_args[0][0] + assert gen_call_kwargs["topic_title"] == "AI搜索" + assert gen_call_kwargs["target_keyword"] == "AI搜索" + assert gen_call_kwargs["target_platform"] == "微信公众号" + assert gen_call_kwargs["content_style"] == "专业严谨" + assert gen_call_kwargs["word_count"] == "3000" + assert gen_call_kwargs["brand_name"] == "Brand X" + assert gen_call_kwargs["knowledge_context"] == "Some context" + + # Verify DEAI_TEMPLATE.render called with the generated content + deai_call_kwargs = mock_deai_template.render.call_args[0][0] + assert deai_call_kwargs["original_content"] == "Raw content" + + # Verify GEO_OPTIMIZER_TEMPLATE.render called with de-AIed content + geo_call_kwargs = mock_geo_template.render.call_args[0][0] + assert geo_call_kwargs["original_content"] == "De-AIed content" + assert geo_call_kwargs["target_keywords"] == "AI搜索" + assert geo_call_kwargs["target_platform"] == "微信公众号" + + @pytest.mark.asyncio + async def test_generate_content_default_parameters(self): + """Test that default parameter values are applied correctly.""" + from app.services.content.content_generation_service import ( + ContentGenerationService, + ) + + service = ContentGenerationService() + + mock_provider = AsyncMock() + mock_provider.chat.side_effect = [ + _make_llm_response("Raw"), + _make_llm_response("De-AIed"), + _make_llm_response("Optimized"), + ] + + with patch.object(service, "_get_provider", return_value=mock_provider): + result = await service.generate_content( + keyword="test", + brand_name="Brand X", + ) + + # Defaults: platform="通用", content_style="专业严谨", word_count=2000 + # run_deai=True, run_geo=True + assert result is not None + assert len(result["pipeline_stages"]) == 3 diff --git a/backend/tests/test_services/test_data_collector.py b/backend/tests/test_services/test_data_collector.py new file mode 100644 index 0000000..d17435c --- /dev/null +++ b/backend/tests/test_services/test_data_collector.py @@ -0,0 +1,381 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +import pytest_asyncio +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.pool import StaticPool + +from app.database import Base +from app.models.brand import Brand +from app.models.citation_record import CitationRecord +from app.models.query import Query +from app.models.user import User +from app.services.auth import hash_password +from app.services.diagnosis.data_collector import DataCollectorService, DataCollectionResult +from app.services.diagnosis.geo_diagnosis import GEODiagnosisInput + + +@pytest_asyncio.fixture +async def async_engine(): + engine = create_async_engine( + "sqlite+aiosqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + yield engine + await engine.dispose() + + +@pytest_asyncio.fixture +async def async_session(async_engine): + maker = async_sessionmaker( + async_engine, + class_=AsyncSession, + expire_on_commit=False, + autoflush=False, + autocommit=False, + ) + async with maker() as session: + yield session + + +class TestWebsiteSignalParsing: + def test_parse_html_with_schema_org(self): + service = DataCollectorService.__new__(DataCollectorService) + service._db = None + + html = """ + + + + +

什么是TestBrand?

+

TestBrand是一家专注于技术创新的公司,为企业提供智能化解决方案。

+

如何使用TestBrand?

+
  • 步骤1
  • 步骤2
+ 关于我们 + 更新于2026年5月1日 + + """ + signals = service._parse_html_signals(html) + + assert signals["has_organization"] is True + assert signals["has_product"] is True + assert signals["has_qa_headings"] is True + assert signals["has_structured_data"] is True + assert signals["has_internal_links"] is True + assert signals["has_freshness_info"] is True + assert signals["has_brand_definition"] is True + assert signals["has_target_audience"] is True + + def test_parse_html_minimal(self): + service = DataCollectorService.__new__(DataCollectorService) + service._db = None + + html = "

Hello world

" + signals = service._parse_html_signals(html) + + assert signals["has_organization"] is False + assert signals["has_product"] is False + assert signals["has_qa_headings"] is False + + def test_parse_html_article_schema(self): + service = DataCollectorService.__new__(DataCollectorService) + service._db = None + + html = """ + + + + + + """ + signals = service._parse_html_signals(html) + + assert signals["has_article"] is True + assert signals["has_faq"] is True + assert signals["has_breadcrumb"] is True + + +class TestSignalApplication: + def test_apply_ai_signals(self): + service = DataCollectorService.__new__(DataCollectorService) + service._db = None + + inp = GEODiagnosisInput() + ai_data = { + "aor": 0.4, + "accuracy": 0.85, + "sov": 0.25, + "competitor_gap": 0.15, + "total_responses": 10, + "cited_count": 4, + "accurate_count": 3, + "has_author_bio": True, + "author_credentials_complete": 0.7, + "has_data_sources": True, + } + service._apply_ai_signals(inp, ai_data) + + assert inp.answer_ownership_rate == 0.4 + assert inp.citation_accuracy == 0.85 + assert inp.ai_sov == 0.25 + assert inp.competitor_gap == 0.15 + assert inp.total_ai_responses == 10 + assert inp.brand_mention_count == 4 + assert inp.accurate_citation_count == 3 + assert inp.has_author_bio is True + assert inp.author_credentials_complete == 0.7 + assert inp.has_data_sources is True + + def test_apply_citation_signals(self): + service = DataCollectorService.__new__(DataCollectorService) + service._db = None + + inp = GEODiagnosisInput() + citation_data = { + "aor": 0.5, + "accuracy": 0.9, + "sov": 0.3, + "competitor_gap": 0.1, + "total_responses": 20, + "cited_count": 10, + "accurate_count": 9, + "has_certifications": True, + "certification_count": 3, + "has_expert_endorsements": True, + "endorsement_count": 5, + "content_depth_score": 0.8, + "topic_coverage_ratio": 0.7, + "entity_consistency_score": 0.85, + "cluster_completeness": 0.6, + "total_content_count": 20, + "topic_cluster_count": 8, + } + service._apply_citation_signals(inp, citation_data) + + assert inp.answer_ownership_rate == 0.5 + assert inp.citation_accuracy == 0.9 + assert inp.has_certifications is True + assert inp.certification_count == 3 + assert inp.content_depth_score == 0.8 + + def test_apply_website_signals(self): + service = DataCollectorService.__new__(DataCollectorService) + service._db = None + + inp = GEODiagnosisInput() + website_data = { + "has_direct_answer": True, + "has_qa_headings": True, + "has_structured_data": True, + "has_internal_links": True, + "has_freshness_info": True, + "has_brand_definition": True, + "has_target_audience": True, + "has_unique_value": True, + "has_organization": True, + "has_product": True, + "has_article": True, + "has_faq": False, + "has_howto": False, + "has_breadcrumb": False, + } + service._apply_website_signals(inp, website_data) + + assert inp.has_direct_answer is True + assert inp.has_qa_headings is True + assert inp.has_organization is True + assert inp.has_product is True + assert inp.has_article is True + assert inp.has_faq is False + + def test_signals_merge_max_values(self): + service = DataCollectorService.__new__(DataCollectorService) + service._db = None + + inp = GEODiagnosisInput() + service._apply_ai_signals(inp, {"aor": 0.3, "accuracy": 0.7}) + service._apply_citation_signals(inp, {"aor": 0.5, "accuracy": 0.9}) + + assert inp.answer_ownership_rate == 0.5 + assert inp.citation_accuracy == 0.9 + + +class TestDataCollectorIntegration: + @pytest.mark.asyncio + async def test_collect_with_no_data_sources(self, async_session): + service = DataCollectorService(async_session) + + with patch.object( + service, "_collect_ai_platform_signals", new_callable=AsyncMock + ) as mock_ai, patch.object( + service, "_collect_citation_record_signals", new_callable=AsyncMock + ) as mock_cite, patch.object( + service, "_collect_website_signals", new_callable=AsyncMock + ) as mock_web: + mock_ai.return_value = { + "aor": 0.0, + "accuracy": 0.0, + "sov": 0.0, + "competitor_gap": 0.5, + "total_responses": 0, + "cited_count": 0, + "accurate_count": 0, + "metadata": {}, + } + mock_cite.return_value = { + "total_responses": 0, + "cited_count": 0, + "accurate_count": 0, + "aor": 0.0, + "accuracy": 0.0, + "sov": 0.0, + "competitor_gap": 0.0, + "metadata": {"records_found": 0}, + } + mock_web.return_value = {"metadata": {"skipped": True, "reason": "no_website"}} + + result = await service.collect(brand_name="UnknownBrand") + + assert isinstance(result, DataCollectionResult) + assert isinstance(result.diagnosis_input, GEODiagnosisInput) + assert result.diagnosis_input.has_industry_classification is False + + @pytest.mark.asyncio + async def test_collect_with_industry(self, async_session): + service = DataCollectorService(async_session) + + with patch.object( + service, "_collect_ai_platform_signals", new_callable=AsyncMock + ) as mock_ai, patch.object( + service, "_collect_citation_record_signals", new_callable=AsyncMock + ) as mock_cite, patch.object( + service, "_collect_website_signals", new_callable=AsyncMock + ) as mock_web: + mock_ai.return_value = { + "aor": 0.0, + "accuracy": 0.0, + "sov": 0.0, + "competitor_gap": 0.5, + "total_responses": 0, + "cited_count": 0, + "accurate_count": 0, + "metadata": {}, + } + mock_cite.return_value = { + "total_responses": 0, + "cited_count": 0, + "accurate_count": 0, + "aor": 0.0, + "accuracy": 0.0, + "sov": 0.0, + "competitor_gap": 0.0, + "metadata": {"records_found": 0}, + } + mock_web.return_value = {"metadata": {"skipped": True}} + + result = await service.collect( + brand_name="TestBrand", industry="technology" + ) + + assert result.diagnosis_input.has_industry_classification is True + + @pytest.mark.asyncio + async def test_collect_produces_nonzero_with_website_signals( + self, async_session + ): + service = DataCollectorService(async_session) + + with patch.object( + service, "_collect_ai_platform_signals", new_callable=AsyncMock + ) as mock_ai, patch.object( + service, "_collect_citation_record_signals", new_callable=AsyncMock + ) as mock_cite, patch.object( + service, "_collect_website_signals", new_callable=AsyncMock + ) as mock_web: + mock_ai.return_value = { + "aor": 0.2, + "accuracy": 0.6, + "sov": 0.1, + "competitor_gap": 0.3, + "total_responses": 5, + "cited_count": 1, + "accurate_count": 0, + "has_author_bio": True, + "author_credentials_complete": 0.5, + "has_data_sources": False, + "metadata": {}, + } + mock_cite.return_value = { + "total_responses": 0, + "cited_count": 0, + "accurate_count": 0, + "aor": 0.0, + "accuracy": 0.0, + "sov": 0.0, + "competitor_gap": 0.0, + "metadata": {"records_found": 0}, + } + mock_web.return_value = { + "has_direct_answer": True, + "has_qa_headings": True, + "has_structured_data": True, + "has_internal_links": True, + "has_freshness_info": True, + "has_brand_definition": True, + "has_target_audience": True, + "has_unique_value": True, + "has_organization": True, + "has_product": True, + "has_article": True, + "has_faq": False, + "has_howto": False, + "has_breadcrumb": False, + "metadata": {"url": "https://test.com"}, + } + + result = await service.collect( + brand_name="TestBrand", + website="https://test.com", + industry="technology", + ) + + from app.services.diagnosis.geo_diagnosis import GEODiagnosisService + + geo_service = GEODiagnosisService() + diagnosis = geo_service.diagnose(result.diagnosis_input) + + assert diagnosis.overall_score > 0 + assert len(diagnosis.dimensions) == 6 + + @pytest.mark.asyncio + async def test_collect_handles_channel_failure(self, async_session): + service = DataCollectorService(async_session) + + with patch.object( + service, "_collect_ai_platform_signals", new_callable=AsyncMock + ) as mock_ai, patch.object( + service, "_collect_citation_record_signals", new_callable=AsyncMock + ) as mock_cite, patch.object( + service, "_collect_website_signals", new_callable=AsyncMock + ) as mock_web: + mock_ai.side_effect = Exception("AI platform unavailable") + mock_cite.return_value = { + "total_responses": 0, + "cited_count": 0, + "accurate_count": 0, + "aor": 0.0, + "accuracy": 0.0, + "sov": 0.0, + "competitor_gap": 0.0, + "metadata": {"records_found": 0}, + } + mock_web.return_value = {"metadata": {"skipped": True}} + + result = await service.collect(brand_name="TestBrand") + + assert len(result.errors) >= 1 + assert any("ai_platform" in e for e in result.errors) diff --git a/backend/tests/test_services/test_detection_scheduler.py b/backend/tests/test_services/test_detection_scheduler.py index 6638cac..d56025b 100644 --- a/backend/tests/test_services/test_detection_scheduler.py +++ b/backend/tests/test_services/test_detection_scheduler.py @@ -132,7 +132,7 @@ class TestDetectionTaskModel: class TestDetectionSchedulerService: @pytest.mark.asyncio async def test_create_task(self, async_session, test_brand, test_user): - from app.services.detection_scheduler import DetectionSchedulerService + from app.services.detection.detection_scheduler import DetectionSchedulerService service = DetectionSchedulerService() task_data = { @@ -153,7 +153,7 @@ class TestDetectionSchedulerService: @pytest.mark.asyncio async def test_update_task(self, async_session, test_brand, test_user): from app.models.detection_task import DetectionTask - from app.services.detection_scheduler import DetectionSchedulerService + from app.services.detection.detection_scheduler import DetectionSchedulerService task = DetectionTask( brand_id=test_brand.id, @@ -182,7 +182,7 @@ class TestDetectionSchedulerService: @pytest.mark.asyncio async def test_delete_task(self, async_session, test_brand, test_user): from app.models.detection_task import DetectionTask - from app.services.detection_scheduler import DetectionSchedulerService + from app.services.detection.detection_scheduler import DetectionSchedulerService task = DetectionTask( brand_id=test_brand.id, @@ -207,7 +207,7 @@ class TestDetectionSchedulerService: @pytest.mark.asyncio async def test_get_tasks(self, async_session, test_brand, test_user): from app.models.detection_task import DetectionTask - from app.services.detection_scheduler import DetectionSchedulerService + from app.services.detection.detection_scheduler import DetectionSchedulerService for i in range(3): task = DetectionTask( @@ -228,7 +228,7 @@ class TestDetectionSchedulerService: @pytest.mark.asyncio async def test_trigger_task(self, async_session, test_brand, test_user): from app.models.detection_task import DetectionTask - from app.services.detection_scheduler import DetectionSchedulerService + from app.services.detection.detection_scheduler import DetectionSchedulerService task = DetectionTask( brand_id=test_brand.id, @@ -252,7 +252,7 @@ class TestDetectionSchedulerService: @pytest.mark.asyncio async def test_frequency_validation_hourly(self, async_session, test_brand, test_user): - from app.services.detection_scheduler import DetectionSchedulerService + from app.services.detection.detection_scheduler import DetectionSchedulerService service = DetectionSchedulerService() task_data = { @@ -267,7 +267,7 @@ class TestDetectionSchedulerService: @pytest.mark.asyncio async def test_frequency_validation_daily(self, async_session, test_brand, test_user): - from app.services.detection_scheduler import DetectionSchedulerService + from app.services.detection.detection_scheduler import DetectionSchedulerService service = DetectionSchedulerService() task_data = { @@ -282,7 +282,7 @@ class TestDetectionSchedulerService: @pytest.mark.asyncio async def test_frequency_validation_weekly(self, async_session, test_brand, test_user): - from app.services.detection_scheduler import DetectionSchedulerService + from app.services.detection.detection_scheduler import DetectionSchedulerService service = DetectionSchedulerService() task_data = { @@ -297,7 +297,7 @@ class TestDetectionSchedulerService: @pytest.mark.asyncio async def test_frequency_validation_invalid(self, async_session, test_brand, test_user): - from app.services.detection_scheduler import DetectionSchedulerService + from app.services.detection.detection_scheduler import DetectionSchedulerService service = DetectionSchedulerService() task_data = { @@ -312,7 +312,7 @@ class TestDetectionSchedulerService: @pytest.mark.asyncio async def test_execute_task_flow(self, async_session, test_brand, test_user): from app.models.detection_task import DetectionTask - from app.services.detection_scheduler import DetectionSchedulerService + from app.services.detection.detection_scheduler import DetectionSchedulerService task = DetectionTask( brand_id=test_brand.id, @@ -353,7 +353,7 @@ class TestDetectionSchedulerService: @pytest.mark.asyncio async def test_delete_task_not_found(self, async_session, test_user): - from app.services.detection_scheduler import DetectionSchedulerService + from app.services.detection.detection_scheduler import DetectionSchedulerService service = DetectionSchedulerService() result = await service.delete_task(uuid.uuid4(), test_user.id, async_session) @@ -361,7 +361,7 @@ class TestDetectionSchedulerService: @pytest.mark.asyncio async def test_update_task_not_found(self, async_session, test_user): - from app.services.detection_scheduler import DetectionSchedulerService, TaskNotFoundError + from app.services.detection.detection_scheduler import DetectionSchedulerService, TaskNotFoundError service = DetectionSchedulerService() with pytest.raises(TaskNotFoundError): @@ -369,7 +369,7 @@ class TestDetectionSchedulerService: @pytest.mark.asyncio async def test_get_tasks_empty(self, async_session, test_brand, test_user): - from app.services.detection_scheduler import DetectionSchedulerService + from app.services.detection.detection_scheduler import DetectionSchedulerService service = DetectionSchedulerService() tasks = await service.get_tasks(test_brand.id, test_user.id, async_session) diff --git a/backend/tests/test_services/test_geo_diagnosis.py b/backend/tests/test_services/test_geo_diagnosis.py index a513a4d..a3cb49d 100644 --- a/backend/tests/test_services/test_geo_diagnosis.py +++ b/backend/tests/test_services/test_geo_diagnosis.py @@ -4,7 +4,7 @@ GEO诊断服务单元测试 测试6大维度诊断逻辑、评分算法、推荐生成和服务类 """ import pytest -from app.services.geo_diagnosis import ( +from app.services.diagnosis.geo_diagnosis import ( GEODiagnosisService, GEODiagnosisInput, diagnose_content_extractability, diff --git a/backend/tests/test_services/test_imports.py b/backend/tests/test_services/test_imports.py new file mode 100644 index 0000000..f3ee6c5 --- /dev/null +++ b/backend/tests/test_services/test_imports.py @@ -0,0 +1,174 @@ +""" +Import verification tests for services reorganization. + +These tests verify that all import paths work correctly after +moving service files into subdirectories. +""" +import pytest + + +# ============================================================ +# New import path tests (should work after reorganization) +# ============================================================ + +class TestDiagnosisImports: + """Verify diagnosis subdirectory imports.""" + + def test_geo_diagnosis_import(self): + from app.services.diagnosis.geo_diagnosis import GEODiagnosisService + assert GEODiagnosisService is not None + + def test_geo_diagnosis_dataclasses(self): + from app.services.diagnosis.geo_diagnosis import ( + GEODiagnosisInput, + GEODiagnosisResult, + GEODimensionScore, + GEORecommendation, + DiagnosisItem, + ) + assert GEODiagnosisInput is not None + assert GEODiagnosisResult is not None + + def test_seo_diagnosis_import(self): + from app.services.diagnosis.seo_diagnosis import SEODiagnosisService + assert SEODiagnosisService is not None + + def test_seo_diagnosis_dataclasses(self): + from app.services.diagnosis.seo_diagnosis import ( + SEODiagnosisResult, + SEODimensionScore, + SEORecommendation, + DiagnosisStatus, + DimensionName, + ) + assert SEODiagnosisResult is not None + assert DiagnosisStatus is not None + + +class TestScoringImports: + """Verify scoring subdirectory imports.""" + + def test_scoring_service_import(self): + from app.services.scoring.scoring_service import ScoringService + assert ScoringService is not None + + def test_scoring_dataclasses(self): + from app.services.scoring.scoring_service import ( + ScoringResultV2, + DimensionScore, + ) + assert ScoringResultV2 is not None + assert DimensionScore is not None + + def test_get_health_level_reexport(self): + """scoring_service re-exports get_health_level from app.utils.health""" + from app.services.scoring.scoring_service import get_health_level + assert callable(get_health_level) + + +class TestAlertImports: + """Verify alert subdirectory imports.""" + + def test_alert_engine_import(self): + from app.services.alert.alert_engine import AlertEngine + assert AlertEngine is not None + + def test_alert_context_import(self): + from app.services.alert.alert_engine import AlertContext + assert AlertContext is not None + + +class TestCitationImports: + """Verify citation subdirectory imports.""" + + def test_citation_import(self): + from app.services.citation.citation import get_citations + assert callable(get_citations) + + def test_citation_stats_import(self): + from app.services.citation.citation import get_citation_stats + assert callable(get_citation_stats) + + def test_citation_pattern_import(self): + from app.services.citation.citation_pattern import CitationPatternEngine + assert CitationPatternEngine is not None + + def test_citation_pattern_dataclasses(self): + from app.services.citation.citation_pattern import ( + CitationPattern, + PatternAnalysisReport, + ) + assert CitationPattern is not None + assert PatternAnalysisReport is not None + + +class TestLLMImports: + """Verify llm subdirectory imports (new additions).""" + + def test_smart_router_import(self): + from app.services.llm.smart_router import SmartRouter + assert SmartRouter is not None + + def test_engine_selector_import(self): + from app.services.llm.engine_selector import EngineSelector + assert EngineSelector is not None + + def test_smart_router_dataclasses(self): + from app.services.llm.smart_router import ( + CostTier, + EngineCostProfile, + ENGINE_COST_PROFILES, + ) + assert CostTier is not None + assert ENGINE_COST_PROFILES is not None + + def test_llm_init_includes_new_exports(self): + """Verify llm/__init__.py includes SmartRouter and EngineSelector.""" + from app.services.llm import SmartRouter, EngineSelector + assert SmartRouter is not None + assert EngineSelector is not None + + +class TestDetectionImports: + """Verify detection subdirectory imports.""" + + def test_detection_scheduler_import(self): + from app.services.detection.detection_scheduler import DetectionSchedulerService + assert DetectionSchedulerService is not None + + def test_task_not_found_error_import(self): + from app.services.detection.detection_scheduler import TaskNotFoundError + assert TaskNotFoundError is not None + + +class TestAdvisorImports: + """Verify advisor subdirectory imports.""" + + def test_optimization_advisor_import(self): + from app.services.advisor.optimization_advisor import generate_suggestions + assert callable(generate_suggestions) + + def test_advisor_dataclasses(self): + from app.services.advisor.optimization_advisor import ( + SuggestionItem, + BrandAnalysisContext, + ) + assert SuggestionItem is not None + assert BrandAnalysisContext is not None + + +class TestAnalysisImports: + """Verify analysis subdirectory imports.""" + + def test_sentiment_service_import(self): + from app.services.analysis.sentiment_service import SentimentAnalysisService + assert SentimentAnalysisService is not None + + def test_sentiment_result_import(self): + from app.services.analysis.sentiment_service import SentimentResult + assert SentimentResult is not None + + def test_get_sentiment_service_import(self): + from app.services.analysis.sentiment_service import get_sentiment_service + assert callable(get_sentiment_service) + diff --git a/tests/test_knowledge_enhanced.py b/backend/tests/test_services/test_knowledge_enhanced.py similarity index 98% rename from tests/test_knowledge_enhanced.py rename to backend/tests/test_services/test_knowledge_enhanced.py index 0fc372e..27bd703 100644 --- a/tests/test_knowledge_enhanced.py +++ b/backend/tests/test_services/test_knowledge_enhanced.py @@ -292,14 +292,14 @@ class TestMarkdownParser: """测试解析带标题的Markdown""" parser = MarkdownParser() - content = b"""# 这是一个标题 + content = """# 这是一个标题 这是文档内容。 ## 子标题 更多内容。 -""" +""".encode("utf-8") doc = await parser.parse(content) @@ -312,8 +312,8 @@ class TestMarkdownParser: """测试解析不带标题的Markdown""" parser = MarkdownParser() - content = b"""这是文档内容,没有标题。 -""" + content = """这是文档内容,没有标题。 +""".encode("utf-8") doc = await parser.parse(content) @@ -329,10 +329,10 @@ class TestTextParser: """测试使用第一行作为标题""" parser = TextParser() - content = b"""这是第一行标题 + content = """这是第一行标题 这是第二行内容 这是第三行内容 -""" +""".encode("utf-8") doc = await parser.parse(content) diff --git a/tests/test_knowledge_graph.py b/backend/tests/test_services/test_knowledge_graph.py similarity index 100% rename from tests/test_knowledge_graph.py rename to backend/tests/test_services/test_knowledge_graph.py diff --git a/tests/test_llm_provider.py b/backend/tests/test_services/test_llm_provider.py similarity index 100% rename from tests/test_llm_provider.py rename to backend/tests/test_services/test_llm_provider.py diff --git a/tests/test_platform_rules.py b/backend/tests/test_services/test_platform_rules.py similarity index 100% rename from tests/test_platform_rules.py rename to backend/tests/test_services/test_platform_rules.py diff --git a/tests/test_rag_service.py b/backend/tests/test_services/test_rag_service.py similarity index 100% rename from tests/test_rag_service.py rename to backend/tests/test_services/test_rag_service.py diff --git a/backend/tests/test_services/test_scoring_service.py b/backend/tests/test_services/test_scoring_service.py index 2f2f8d5..09bd90a 100644 --- a/backend/tests/test_services/test_scoring_service.py +++ b/backend/tests/test_services/test_scoring_service.py @@ -1,5 +1,5 @@ import pytest -from app.services.scoring_service import ( +from app.services.scoring.scoring_service import ( ScoringService, calculate_mention_rate_score, calculate_sov_score, diff --git a/backend/tests/test_services/test_seo_diagnosis.py b/backend/tests/test_services/test_seo_diagnosis.py index c8ed525..08aed11 100644 --- a/backend/tests/test_services/test_seo_diagnosis.py +++ b/backend/tests/test_services/test_seo_diagnosis.py @@ -2,7 +2,7 @@ SEO诊断服务单元测试 """ import pytest -from app.services.seo_diagnosis import ( +from app.services.diagnosis.seo_diagnosis import ( SEODiagnosisService, SEODiagnosisResult, SEODimensionScore, diff --git a/backend/tests/test_services/test_smart_router_and_usage.py b/backend/tests/test_services/test_smart_router_and_usage.py index 1f8b994..961fa28 100644 --- a/backend/tests/test_services/test_smart_router_and_usage.py +++ b/backend/tests/test_services/test_smart_router_and_usage.py @@ -1,6 +1,6 @@ import pytest -from app.services.smart_router import ( +from app.services.llm.smart_router import ( ENGINE_COST_PROFILES, CostTier, EngineCostProfile, diff --git a/backend/tests/test_services/test_smart_router_key_integration.py b/backend/tests/test_services/test_smart_router_key_integration.py index 621abd5..05da280 100644 --- a/backend/tests/test_services/test_smart_router_key_integration.py +++ b/backend/tests/test_services/test_smart_router_key_integration.py @@ -1,8 +1,8 @@ import pytest from app.services.api_key_manager import APIKeyManager, KeySource -from app.services.smart_router import ENGINE_COST_PROFILES, CostTier, SmartRouter -from app.services.engine_selector import EngineSelector +from app.services.llm.smart_router import ENGINE_COST_PROFILES, CostTier, SmartRouter +from app.services.llm.engine_selector import EngineSelector class TestSmartRouterWithKeyManager: diff --git a/tests/test_topic_templates.py b/backend/tests/test_services/test_topic_templates.py similarity index 100% rename from tests/test_topic_templates.py rename to backend/tests/test_services/test_topic_templates.py diff --git a/backend/tests/test_utils/__init__.py b/backend/tests/test_utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/tests/test_utils/test_health.py b/backend/tests/test_utils/test_health.py new file mode 100644 index 0000000..de02fa0 --- /dev/null +++ b/backend/tests/test_utils/test_health.py @@ -0,0 +1,66 @@ +"""Tests for app.utils.health — get_health_level / get_health_level_label""" +import pytest + +from app.utils.health import get_health_level, get_health_level_label + + +# ── get_health_level ────────────────────────────────────────── + + +class TestGetHealthLevel: + """根据评分返回健康等级字符串""" + + def test_excellent_lower_bound(self): + assert get_health_level(80) == "excellent" + + def test_excellent_high(self): + assert get_health_level(100) == "excellent" + + def test_good_upper_bound(self): + assert get_health_level(79) == "good" + + def test_good_lower_bound(self): + assert get_health_level(60) == "good" + + def test_pass_upper_bound(self): + assert get_health_level(59) == "pass" + + def test_pass_lower_bound(self): + assert get_health_level(40) == "pass" + + def test_danger_upper_bound(self): + assert get_health_level(39) == "danger" + + def test_danger_zero(self): + assert get_health_level(0) == "danger" + + def test_negative_score(self): + assert get_health_level(-10) == "danger" + + def test_float_score(self): + assert get_health_level(80.5) == "excellent" + + +# ── get_health_level_label ──────────────────────────────────── + + +class TestGetHealthLevelLabel: + """根据等级返回中文标签""" + + @pytest.mark.parametrize( + "level, label", + [ + ("excellent", "优秀"), + ("good", "良好"), + ("pass", "及格"), + ("danger", "危险"), + ], + ) + def test_known_levels(self, level, label): + assert get_health_level_label(level) == label + + def test_unknown_level(self): + assert get_health_level_label("unknown") == "未知" + + def test_empty_string(self): + assert get_health_level_label("") == "未知" diff --git a/backend/tests/test_utils/test_json_extractor.py b/backend/tests/test_utils/test_json_extractor.py new file mode 100644 index 0000000..3fec83d --- /dev/null +++ b/backend/tests/test_utils/test_json_extractor.py @@ -0,0 +1,88 @@ +"""Tests for app.utils.json_extractor — extract_json""" +import json + +import pytest + +from app.utils.json_extractor import extract_json + + +class TestExtractJsonPlainObject: + """纯 JSON 对象""" + + def test_simple_object(self): + text = '{"key": "value"}' + result = extract_json(text) + assert json.loads(result) == {"key": "value"} + + def test_nested_object(self): + text = '{"a": {"b": [1, 2]}}' + result = extract_json(text) + assert json.loads(result) == {"a": {"b": [1, 2]}} + + +class TestExtractJsonMarkdownCodeBlock: + """JSON 包裹在 markdown 代码块中""" + + def test_json_code_block(self): + text = '```json\n{"key": "value"}\n```' + result = extract_json(text) + assert json.loads(result) == {"key": "value"} + + def test_plain_code_block(self): + text = '```\n{"key": "value"}\n```' + result = extract_json(text) + assert json.loads(result) == {"key": "value"} + + def test_code_block_with_surrounding_text(self): + text = 'Here is the result:\n```json\n{"key": "value"}\n```\nDone.' + result = extract_json(text) + assert json.loads(result) == {"key": "value"} + + +class TestExtractJsonArray: + """JSON 数组""" + + def test_simple_array(self): + text = "[1, 2, 3]" + result = extract_json(text) + assert json.loads(result) == [1, 2, 3] + + def test_array_in_code_block(self): + text = "```json\n[1, 2, 3]\n```" + result = extract_json(text) + assert json.loads(result) == [1, 2, 3] + + def test_array_with_surrounding_text(self): + text = "Result: [1, 2, 3] done" + result = extract_json(text) + assert json.loads(result) == [1, 2, 3] + + +class TestExtractJsonWithSurroundingText: + """JSON 被周围文本包裹""" + + def test_object_with_surrounding_text(self): + text = 'Here is the result: {"key": "value"} done' + result = extract_json(text) + assert json.loads(result) == {"key": "value"} + + def test_nested_with_surrounding_text(self): + text = 'Output: {"a": {"b": [1, 2]}} end' + result = extract_json(text) + assert json.loads(result) == {"a": {"b": [1, 2]}} + + +class TestExtractJsonInvalidInput: + """无效输入应抛出 ValueError""" + + def test_no_json_at_all(self): + with pytest.raises(ValueError, match="无法从响应中提取JSON"): + extract_json("This is just plain text with no JSON") + + def test_empty_string(self): + with pytest.raises(ValueError, match="无法从响应中提取JSON"): + extract_json("") + + def test_unclosed_brace(self): + with pytest.raises(ValueError, match="无法从响应中提取JSON"): + extract_json('{"key": "value"') diff --git a/backend/tests/test_workers/test_llm_adapter.py b/backend/tests/test_workers/test_llm_adapter.py index 4c0deab..4dee5af 100644 --- a/backend/tests/test_workers/test_llm_adapter.py +++ b/backend/tests/test_workers/test_llm_adapter.py @@ -1,38 +1,38 @@ +"""品牌引用LLM服务测试 - 迁移自 test_llm_adapter.py + +原 LLMAdapter (System 3) 已被 BrandCitationLLMService (System 1) 替代。 +本文件保留了所有原有的测试场景,使用新的服务接口。 +""" import pytest from unittest.mock import AsyncMock, patch, MagicMock -from app.workers.llm_adapter import LLMAdapter, LLMAdapterError, BRAND_CITATION_PROMPT + +from app.services.llm.brand_citation_service import BrandCitationLLMService, BRAND_CITATION_PROMPT +from app.services.llm.base import LLMError -class TestLLMAdapter: - """LLM适配器测试""" +class TestBrandCitationLLMService: + """品牌引用LLM服务测试(原 LLMAdapter 测试迁移)""" @pytest.fixture - def llm_adapter(self): - """创建LLM适配器实例""" - return LLMAdapter() + def service(self): + """创建BrandCitationLLMService实例""" + return BrandCitationLLMService() @pytest.mark.asyncio - async def test_llm_adapter_cited_brand(self, llm_adapter): + async def test_cited_brand(self, service): """测试检测到品牌引用""" - mock_response = { - "cited": True, - "position": 1, - "citation_text": "XXX是一款非常优秀的品牌产品", - "sentiment": "positive", - "confidence": 0.95 - } - - with patch("app.workers.llm_adapter.settings") as mock_settings: - mock_settings.ENABLE_LLM = True - with patch.object(llm_adapter, '_call_deepseek', new_callable=AsyncMock) as mock_call: - mock_call.return_value = mock_response - - result = await llm_adapter.query_brand_citation( + mock_provider = AsyncMock() + mock_provider.chat.return_value = MagicMock( + content='{"cited": true, "position": 1, "citation_text": "XXX是一款非常优秀的品牌产品", "sentiment": "positive", "confidence": 0.95}' + ) + with patch.object(service, '_get_provider', return_value=mock_provider): + with patch('app.services.llm.brand_citation_service.settings') as mock_settings: + mock_settings.ENABLE_LLM = True + result = await service.query_brand_citation( keyword="AI搜索", brand_name="XXX", brand_aliases=["品牌别名1", "品牌别名2"] ) - assert result.cited is True assert result.position == 1 assert result.citation_text == "XXX是一款非常优秀的品牌产品" @@ -40,135 +40,85 @@ class TestLLMAdapter: assert result.confidence == 0.95 @pytest.mark.asyncio - async def test_llm_adapter_not_cited(self, llm_adapter): + async def test_not_cited(self, service): """测试未检测到品牌引用""" - mock_response = { - "cited": False, - "position": None, - "citation_text": None, - "sentiment": "neutral", - "confidence": 0.90 - } - - with patch("app.workers.llm_adapter.settings") as mock_settings: - mock_settings.ENABLE_LLM = True - with patch.object(llm_adapter, '_call_deepseek', new_callable=AsyncMock) as mock_call: - mock_call.return_value = mock_response - - result = await llm_adapter.query_brand_citation( + mock_provider = AsyncMock() + mock_provider.chat.return_value = MagicMock( + content='{"cited": false, "position": null, "citation_text": null, "sentiment": "neutral", "confidence": 0.90}' + ) + with patch.object(service, '_get_provider', return_value=mock_provider): + with patch('app.services.llm.brand_citation_service.settings') as mock_settings: + mock_settings.ENABLE_LLM = True + result = await service.query_brand_citation( keyword="AI搜索", brand_name="YYY", brand_aliases=[] ) - assert result.cited is False assert result.position is None assert result.citation_text is None assert result.sentiment == "neutral" @pytest.mark.asyncio - async def test_llm_adapter_sentiment_positive(self, llm_adapter): + async def test_sentiment_positive(self, service): """测试正面情感""" - mock_response = { - "cited": True, - "position": 2, - "citation_text": "YYY品牌产品质量非常好,用户口碑极佳", - "sentiment": "positive", - "confidence": 0.92 - } - - with patch("app.workers.llm_adapter.settings") as mock_settings: - mock_settings.ENABLE_LLM = True - with patch.object(llm_adapter, '_call_deepseek', new_callable=AsyncMock) as mock_call: - mock_call.return_value = mock_response - - result = await llm_adapter.query_brand_citation( + mock_provider = AsyncMock() + mock_provider.chat.return_value = MagicMock( + content='{"cited": true, "position": 2, "citation_text": "YYY品牌产品质量非常好,用户口碑极佳", "sentiment": "positive", "confidence": 0.92}' + ) + with patch.object(service, '_get_provider', return_value=mock_provider): + with patch('app.services.llm.brand_citation_service.settings') as mock_settings: + mock_settings.ENABLE_LLM = True + result = await service.query_brand_citation( keyword="AI搜索", brand_name="YYY", brand_aliases=[] ) - assert result.sentiment == "positive" @pytest.mark.asyncio - async def test_llm_adapter_sentiment_negative(self, llm_adapter): + async def test_sentiment_negative(self, service): """测试负面情感""" - mock_response = { - "cited": True, - "position": 3, - "citation_text": "ZZZ品牌存在质量问题,遭到用户投诉", - "sentiment": "negative", - "confidence": 0.88 - } - - with patch("app.workers.llm_adapter.settings") as mock_settings: - mock_settings.ENABLE_LLM = True - with patch.object(llm_adapter, '_call_deepseek', new_callable=AsyncMock) as mock_call: - mock_call.return_value = mock_response - - result = await llm_adapter.query_brand_citation( + mock_provider = AsyncMock() + mock_provider.chat.return_value = MagicMock( + content='{"cited": true, "position": 3, "citation_text": "ZZZ品牌存在质量问题,遭到用户投诉", "sentiment": "negative", "confidence": 0.88}' + ) + with patch.object(service, '_get_provider', return_value=mock_provider): + with patch('app.services.llm.brand_citation_service.settings') as mock_settings: + mock_settings.ENABLE_LLM = True + result = await service.query_brand_citation( keyword="AI搜索", brand_name="ZZZ", brand_aliases=[] ) - assert result.sentiment == "negative" @pytest.mark.asyncio - async def test_llm_adapter_api_error_retry(self, llm_adapter): - """测试API错误重试""" - mock_success_response = { - "cited": True, - "position": 1, - "citation_text": "测试文本", - "sentiment": "neutral", - "confidence": 0.90 - } - - with patch("app.workers.llm_adapter.settings") as mock_settings: - mock_settings.ENABLE_LLM = True - with patch.object(llm_adapter, '_call_deepseek', new_callable=AsyncMock) as mock_call: - mock_call.side_effect = [ - Exception("API调用失败"), - Exception("API调用失败"), - mock_success_response - ] - - result = await llm_adapter.query_brand_citation( - keyword="AI搜索", - brand_name="测试品牌", - brand_aliases=[] - ) - - assert result.cited is True - assert mock_call.call_count == 3 - - @pytest.mark.asyncio - async def test_llm_adapter_parse_error(self, llm_adapter): + async def test_parse_error(self, service): """测试响应解析错误""" - with patch("app.workers.llm_adapter.settings") as mock_settings: - mock_settings.ENABLE_LLM = True - with patch.object(llm_adapter, '_call_deepseek', new_callable=AsyncMock) as mock_call: - mock_call.return_value = {"invalid": "response"} - - with pytest.raises(LLMAdapterError) as exc_info: - await llm_adapter.query_brand_citation( + mock_provider = AsyncMock() + mock_provider.chat.return_value = MagicMock( + content='{"invalid": "response"}' + ) + with patch.object(service, '_get_provider', return_value=mock_provider): + with patch('app.services.llm.brand_citation_service.settings') as mock_settings: + mock_settings.ENABLE_LLM = True + with pytest.raises(LLMError) as exc_info: + await service.query_brand_citation( keyword="AI搜索", brand_name="测试品牌", brand_aliases=[] ) - error_msg = str(exc_info.value) assert "响应缺少必需字段" in error_msg or "解析响应失败" in error_msg - def test_build_prompt(self, llm_adapter): + def test_build_prompt(self, service): """测试Prompt构建""" - prompt = llm_adapter._build_prompt( + prompt = service._build_prompt( keyword="AI搜索", brand_name="测试品牌", brand_aliases=["别名1", "别名2"] ) - assert "AI搜索" in prompt assert "测试品牌" in prompt assert "别名1" in prompt diff --git a/backend/tests/test_workers/test_llm_adapter_no_mock.py b/backend/tests/test_workers/test_llm_adapter_no_mock.py index b596f56..82cdd65 100644 --- a/backend/tests/test_workers/test_llm_adapter_no_mock.py +++ b/backend/tests/test_workers/test_llm_adapter_no_mock.py @@ -1,25 +1,30 @@ +"""验证BrandCitationLLMService不再返回Mock数据,而是抛出明确错误 + +迁移自 test_llm_adapter_no_mock.py,原测试验证 LLMAdapter 不返回 Mock 数据。 +现在验证 BrandCitationLLMService (System 1) 的相同行为。 +""" import pytest -from unittest.mock import AsyncMock, patch, PropertyMock +from unittest.mock import AsyncMock, patch, MagicMock -from app.workers.llm_adapter import LLMAdapter, LLMAdapterError +from app.services.llm.brand_citation_service import BrandCitationLLMService +from app.services.llm.base import LLMError -class TestLLMAdapterNoMock: - """验证LLMAdapter不再返回Mock数据,而是抛出明确错误""" +class TestBrandCitationLLMServiceNoMock: + """验证BrandCitationLLMService不返回Mock数据,而是抛出明确错误""" @pytest.fixture - def adapter(self): - return LLMAdapter() + def service(self): + return BrandCitationLLMService() @pytest.mark.asyncio - async def test_enable_llm_false_raises_error(self, adapter): - """ENABLE_LLM=False时必须抛出LLMAdapterError,而非返回Mock数据""" - with patch("app.workers.llm_adapter.settings") as mock_settings: + async def test_enable_llm_false_raises_error(self, service): + """ENABLE_LLM=False时必须抛出LLMError,而非返回Mock数据""" + with patch('app.services.llm.brand_citation_service.settings') as mock_settings: mock_settings.ENABLE_LLM = False - mock_settings.DEEPSEEK_API_KEY = "test-key" - with pytest.raises(LLMAdapterError) as exc_info: - await adapter.query_brand_citation( + with pytest.raises(LLMError) as exc_info: + await service.query_brand_citation( keyword="AI搜索", brand_name="测试品牌", brand_aliases=["别名1"], @@ -30,45 +35,18 @@ class TestLLMAdapterNoMock: assert "未启用" in error_msg @pytest.mark.asyncio - async def test_enable_llm_true_no_api_key_raises_error(self, adapter): - """ENABLE_LLM=True但无API Key时必须抛出LLMAdapterError""" - adapter.api_key = None + async def test_enable_llm_true_calls_provider(self, service): + """ENABLE_LLM=True时正常调用Provider""" + mock_provider = AsyncMock() + mock_provider.chat.return_value = MagicMock( + content='{"cited": true, "position": 1, "citation_text": "测试引用", "sentiment": "positive", "confidence": 0.95}' + ) - with patch("app.workers.llm_adapter.settings") as mock_settings: - mock_settings.ENABLE_LLM = True + with patch.object(service, '_get_provider', return_value=mock_provider): + with patch('app.services.llm.brand_citation_service.settings') as mock_settings: + mock_settings.ENABLE_LLM = True - with pytest.raises(LLMAdapterError) as exc_info: - await adapter.query_brand_citation( - keyword="AI搜索", - brand_name="测试品牌", - brand_aliases=[], - ) - - error_msg = str(exc_info.value) - assert "API Key" in error_msg or "DEEPSEEK_API_KEY" in error_msg - - @pytest.mark.asyncio - async def test_enable_llm_true_with_key_calls_api(self, adapter): - """ENABLE_LLM=True且有Key时正常调用API""" - mock_response = { - "cited": True, - "position": 1, - "citation_text": "测试引用", - "sentiment": "positive", - "confidence": 0.95, - } - - with patch("app.workers.llm_adapter.settings") as mock_settings: - mock_settings.ENABLE_LLM = True - mock_settings.OPENAI_API_KEY = None - mock_settings.DEEPSEEK_API_KEY = "sk-test-key" - - with patch.object( - adapter, "_call_deepseek", new_callable=AsyncMock - ) as mock_call: - mock_call.return_value = mock_response - - result = await adapter.query_brand_citation( + result = await service.query_brand_citation( keyword="AI搜索", brand_name="测试品牌", brand_aliases=[], @@ -79,43 +57,23 @@ class TestLLMAdapterNoMock: assert result.sentiment == "positive" def test_get_mock_result_method_removed(self): - """_get_mock_result方法必须已被删除""" - assert not hasattr(LLMAdapter, "_get_mock_result"), ( + """_get_mock_result方法必须已被删除(旧LLMAdapter遗留)""" + assert not hasattr(BrandCitationLLMService, "_get_mock_result"), ( "_get_mock_result方法仍然存在,必须删除" ) @pytest.mark.asyncio - async def test_error_message_user_friendly(self, adapter): + async def test_error_message_user_friendly(self, service): """错误信息必须对用户友好,包含配置指引""" - with patch("app.workers.llm_adapter.settings") as mock_settings: + with patch('app.services.llm.brand_citation_service.settings') as mock_settings: mock_settings.ENABLE_LLM = False - mock_settings.DEEPSEEK_API_KEY = "test-key" - with pytest.raises(LLMAdapterError) as exc_info: - await adapter.query_brand_citation( + with pytest.raises(LLMError) as exc_info: + await service.query_brand_citation( keyword="AI搜索", brand_name="测试品牌", brand_aliases=[], ) error_msg = str(exc_info.value) - assert "ENABLE_LLM=True" in error_msg - assert "DEEPSEEK_API_KEY" in error_msg - - @pytest.mark.asyncio - async def test_no_api_key_error_message_user_friendly(self, adapter): - """无API Key时错误信息必须包含配置指引""" - adapter.api_key = None - - with patch("app.workers.llm_adapter.settings") as mock_settings: - mock_settings.ENABLE_LLM = True - - with pytest.raises(LLMAdapterError) as exc_info: - await adapter.query_brand_citation( - keyword="AI搜索", - brand_name="测试品牌", - brand_aliases=[], - ) - - error_msg = str(exc_info.value) - assert "DEEPSEEK_API_KEY" in error_msg + assert "ENABLE_LLM" in error_msg diff --git a/docs/01-项目概览/architecture.md b/docs/01-项目概览/architecture.md index 25d0dc0..379a31c 100644 --- a/docs/01-项目概览/architecture.md +++ b/docs/01-项目概览/architecture.md @@ -70,6 +70,9 @@ │ ├───────────┬───────────┬───────────┬────────────┤ │ │ │ Citation │ Content │ DeAI │ GEO │ │ │ │ Detector │ Generator │ Agent │ Optimizer │ │ +│ ├───────────┼───────────┼───────────┼────────────┤ │ +│ │ Monitor │ Schema │Competitor│ Trend │ │ +│ │ Agent │ Advisor │ Analyzer │ Agent │ │ │ └───────────┴───────────┴───────────┴────────────┘ │ │ │ │ ┌─────────────────────────────────────────────────┐ │ @@ -89,11 +92,47 @@ | ContentGenerator | 内容生成 | 主题、规则库、品牌素材 | GEO优化内容 | | DeAIAgent | 去AI化 | AI生成内容 | 自然化内容 | | GEOOptimizer | GEO优化 | 原始内容、关键词策略 | 优化后内容 | +| MonitorAgent | 效果追踪 | 品牌ID、监测配置 | 监测记录、趋势数据 | +| SchemaAdvisor | Schema建议 | 品牌内容、行业类型 | Schema标记建议 | +| CompetitorAnalyzer | 竞品分析 | 品牌、竞品列表 | 竞品洞察报告 | +| TrendAgent | 趋势洞察 | 行业、关键词 | 趋势洞察数据 | **注意:** 当前项目中的`SEOOptimizer`实际执行的是GEO优化(内容结构化、实体优化),而非传统SEO优化(技术SEO)。建议在后续版本中明确区分: - **SEOOptimizer** → 技术SEO优化(网站技术层面) - **GEOOptimizer** → 内容实体优化(AI引用层面) +## GEO 业务闭环 + +``` +诊断 → 策略 → 方案 → 内容生成 → 效果追踪 + ↑ │ + └────────────────────────────────────┘ +``` + +**闭环流程说明:** + +| 阶段 | 描述 | 核心Agent/服务 | +|------|------|----------------| +| 诊断 | 品牌GEO现状评估,识别优化机会 | CitationDetector, GEOOptimizer | +| 策略 | 基于诊断结果制定GEO优化策略 | StrategyService, CompetitorAnalyzer | +| 方案 | 生成可执行的GEO优化方案 | StrategyService, SchemaAdvisor | +| 内容生成 | 按方案生成GEO优化内容 | ContentGenerator, DeAIAgent | +| 效果追踪 | 持续监测优化效果,驱动下一轮迭代 | MonitorAgent, TrendAgent | + +## 品牌评分体系 + +品牌评分V2采用5维度评分模型,全面衡量品牌在AI搜索引擎中的表现: + +| 维度 | 字段名 | 描述 | 评分范围 | +|------|--------|------|----------| +| 提及率 | mention_rate | 品牌在AI响应中的被提及频率 | 0-100 | +| 推荐排名 | recommendation_rank | 品牌在AI推荐中的排名位置 | 0-100 | +| 情感评分 | sentiment_score | AI对品牌的情感倾向评分 | 0-100 | +| 引用质量 | citation_quality | 品牌被引用内容的质量与权威性 | 0-100 | +| 竞争位置 | competitive_position | 品牌相对竞品的综合竞争位势 | 0-100 | + +**综合评分计算:** 综合品牌评分 = 加权平均(mention_rate, recommendation_rank, sentiment_score, citation_quality, competitive_position) + ## 内容生成Pipeline ``` @@ -118,4 +157,4 @@ API请求 → Prometheus指标 → Grafana可视化 ## 数据库设计 -核心表:users, organizations, brands, competitors, queries, citations, alerts, contents, knowledge_bases, knowledge_entities, knowledge_relations +核心表:users, organizations, brands, competitors, queries, citations, alerts, contents, knowledge_bases, knowledge_entities, knowledge_relations, geo_plans, geo_plan_actions, monitoring_records, content_baselines, schema_suggestions, competitor_insights, trend_insights diff --git a/docs/01-项目概览/changelog.md b/docs/01-项目概览/changelog.md index 4169335..b5adb21 100644 --- a/docs/01-项目概览/changelog.md +++ b/docs/01-项目概览/changelog.md @@ -1,5 +1,39 @@ # 变更日志 +## v2.1.0 (2026-05-31) + +### 新增功能 +- 4个新Agent: MonitorAgent, SchemaAdvisor, CompetitorAnalyzer, TrendAgent +- GEO方案自动生成系统 +- 竞品引导提示组件 +- 品牌评分V2体系 (mention_rate, recommendation_rank, sentiment_score, citation_quality, competitive_position) + +### 架构优化 +- BrandScoringDataService共享服务提取 +- N+1查询消除 +- Repository模式数据访问层完善 + +### Bug修复 +- 前端API参数不匹配修复 +- agentsApi死代码清理 +- CORS配置修正 +- 前端环境变量端口修正 + +### 新增API +- /api/v1/strategy +- /api/v1/monitoring +- /api/v1/competitor +- /api/v1/schema +- /api/v1/trends + +### 新增前端API客户端 +- monitoring.ts +- competitor-analysis.ts +- schema-advisor.ts +- trends.ts + +--- + ## v2.0.0 (当前版本) ### 新增功能 diff --git a/docs/01-项目概览/tech-stack.md b/docs/01-项目概览/tech-stack.md index d5cceda..c2158a9 100644 --- a/docs/01-项目概览/tech-stack.md +++ b/docs/01-项目概览/tech-stack.md @@ -52,13 +52,59 @@ geo/ │ │ ├── agent_framework/ # Agent框架 │ │ ├── models/ # 数据模型 │ │ ├── schemas/ # Pydantic模型 -│ │ ├── services/ # 业务逻辑 -│ │ ├── workers/ # 异步任务 -│ │ └── monitoring/ # 监控模块 +│ │ ├── services/ # 业务逻辑 +│ │ │ ├── ai_engine/ # AI引擎服务 +│ │ │ ├── llm/ # LLM服务 +│ │ │ ├── knowledge/ # 知识库服务 +│ │ │ ├── content/ # 内容服务 +│ │ │ ├── distribution/ # 分发服务 +│ │ │ ├── diagnosis/ # 诊断服务 +│ │ │ ├── citation/ # 引用服务 +│ │ │ ├── competitor/ # 竞品服务 +│ │ │ ├── monitoring/ # 监测服务 +│ │ │ ├── trend/ # 趋势服务 +│ │ │ ├── schema/ # Schema服务 +│ │ │ ├── scoring/ # 评分服务 +│ │ │ ├── strategy/ # 策略服务 +│ │ │ ├── alert/ # 告警服务 +│ │ │ ├── advisor/ # 顾问服务 +│ │ │ ├── analysis/ # 分析服务 +│ │ │ ├── detection/ # 检测服务 +│ │ │ └── analytics/ # 分析统计服务 +│ │ ├── repositories/ # 数据访问层 │ └── requirements.txt ├── frontend/ # Next.js 前端 │ ├── app/ # 页面 │ ├── components/ # 组件 -│ └── lib/ # 工具函数 +│ └── lib/ +│ ├── api/ # API客户端 +│ │ ├── analytics.ts +│ │ ├── api-keys.ts +│ │ ├── auth.ts +│ │ ├── brands.ts +│ │ ├── citations.ts +│ │ ├── clients.ts +│ │ ├── competitor-analysis.ts +│ │ ├── contents.ts +│ │ ├── dashboard.ts +│ │ ├── detection.ts +│ │ ├── diagnosis.ts +│ │ ├── distribution.ts +│ │ ├── health.ts +│ │ ├── image.ts +│ │ ├── knowledge.ts +│ │ ├── lifecycle.ts +│ │ ├── monitoring.ts +│ │ ├── onboarding.ts +│ │ ├── organization.ts +│ │ ├── platform-rules.ts +│ │ ├── queries.ts +│ │ ├── reports.ts +│ │ ├── schema-advisor.ts +│ │ ├── strategy.ts +│ │ ├── suggestions.ts +│ │ ├── trends.ts +│ │ └── usage.ts +│ └── ... # 其他工具函数 └── docs/ # 文档 ``` diff --git a/docs/02-模块说明/agent-framework.md b/docs/02-模块说明/agent-framework.md index 07fc8c3..f258e90 100644 --- a/docs/02-模块说明/agent-framework.md +++ b/docs/02-模块说明/agent-framework.md @@ -98,6 +98,65 @@ GEO平台的AI Agent框架是系统的核心智能层,采用模块化、可插 | 输入 | 原始内容、关键词策略 | | 输出 | 优化后的内容 | +### MonitorAgent (效果追踪Agent) + +| 属性 | 说明 | +|------|------| +| 文件 | `backend/app/agent_framework/agents/monitor_agent.py` | +| 职责 | 定期检测品牌在AI平台的引用变化,生成变化报告 | +| 输入 | 品牌名称、AI平台列表、检测周期 | +| 输出 | 引用变化报告(新增引用、丢失引用、变化趋势) | + +**报告类型**: +- `daily_report` - 每日引用变化快报 +- `weekly_report` - 周度引用趋势报告 +- `alert` - 异常变化告警(引用大幅下降等) + +### SchemaAdvisorAgent (Schema建议Agent) + +| 属性 | 说明 | +|------|------| +| 文件 | `backend/app/agent_framework/agents/schema_advisor_agent.py` | +| 职责 | 基于品牌诊断数据生成JSON-LD结构化数据建议 | +| 输入 | 品牌诊断数据、当前Schema配置、行业类型 | +| 输出 | JSON-LD结构化数据建议方案 | + +**建议类型**: +- `schema_add` - 新增Schema建议 +- `schema_modify` - 现有Schema修改建议 +- `schema_validate` - Schema合规性验证结果 + +### CompetitorAnalyzerAgent (竞品分析Agent) + +| 属性 | 说明 | +|------|------| +| 文件 | `backend/app/agent_framework/agents/competitor_analyzer_agent.py` | +| 职责 | 5维度分析竞品(引用量、引用质量、引用场景、内容策略、机会发现) | +| 输入 | 品牌名称、竞品列表、分析维度 | +| 输出 | 竞品分析报告 | + +**分析维度**: +- `citation_volume` - 引用量对比 +- `citation_quality` - 引用质量评估 +- `citation_scenario` - 引用场景分布 +- `content_strategy` - 内容策略分析 +- `opportunity_discovery` - 机会发现与建议 + +### TrendAgent (趋势洞察Agent) + +| 属性 | 说明 | +|------|------| +| 文件 | `backend/app/agent_framework/agents/trend_agent.py` | +| 职责 | 识别关键词和平台的引用趋势(上升/下降/平稳/热点) | +| 输入 | 关键词列表、AI平台列表、时间范围 | +| 输出 | 趋势分析报告 | + +**趋势类型**: +- `rising` - 上升趋势 +- `declining` - 下降趋势 +- `stable` - 平稳趋势 +- `hotspot` - 热点趋势 + ## 通信协议 基于数据库+Redis Queue的异步通信: diff --git a/docs/02-模块说明/agent-protocol.md b/docs/02-模块说明/agent-protocol.md index 924fa66..2960758 100644 --- a/docs/02-模块说明/agent-protocol.md +++ b/docs/02-模块说明/agent-protocol.md @@ -90,6 +90,9 @@ class AgentStatus(str, Enum): | RULE_CHECKER | 规则检查Agent | 内容合规审核 | | COMPETITOR_ANALYZER | 竞品分析Agent | 竞品数据收集分析 | | PERFORMANCE_TRACKER | 性能追踪Agent | 追踪内容表现 | +| SCHEMA_ADVISOR | Schema建议Agent | 生成JSON-LD结构化数据建议 | +| MONITOR_AGENT | 效果追踪Agent | 检测品牌引用变化,生成变化报告 | +| TREND_AGENT | 趋势洞察Agent | 识别关键词和平台的引用趋势 | ## 任务类型 diff --git a/docs/04-API文档/README.md b/docs/04-API文档/README.md index 6b7b682..d1c7e65 100644 --- a/docs/04-API文档/README.md +++ b/docs/04-API文档/README.md @@ -1,6 +1,6 @@ # API文档 -本目录包含GEO平台的API接口文档。 +本目录包含GEO平台的API接口文档,共33个API模块。 ## 文档内容 @@ -8,8 +8,33 @@ - [品牌API](./brands.md) - 品牌管理接口 - [内容API](./content.md) - 内容生成接口 - [知识库API](./knowledge.md) - 知识库接口 -- [诊断API](./diagnosis.md) - SEO/GEO诊断接口 ✅ 新增 +- [诊断API](./diagnosis.md) - SEO/GEO诊断接口 - [健康检查](./health.md) - 系统健康检查接口 +- [监测优化API](./analytics.md) - 监测数据分析接口 +- [告警通知API](./alerts.md) - 告警通知管理接口 +- [仪表盘API](./dashboard.md) - 仪表盘数据接口 +- [查询词API](./queries.md) - 查询词管理接口 +- [引用数据API](./citations.md) - 引用数据管理接口 +- [报告API](./reports.md) - 报告生成接口 +- [Agent管理API](./agents.md) - Agent管理接口 +- [生命周期API](./lifecycle.md) - 生命周期管理接口 +- [内容管理API](./contents.md) - 内容管理接口 +- [客户管理API](./clients.md) - 客户管理接口 +- [内容分发API](./distribution.md) - 内容分发接口 +- [优化建议API](./suggestions.md) - 优化建议接口 +- [引导流程API](./onboarding.md) - 引导流程接口 +- [平台规则API](./platform-rules.md) - 平台规则管理接口 +- [图片生成API](./image.md) - 图片生成接口 +- [组织管理API](./organization.md) - 组织管理接口 +- [检测任务API](./detection.md) - 检测任务接口 +- [GEO方案API](./strategy.md) - GEO方案生成接口 +- [AI引擎查询API](./ai-engines.md) - AI引擎查询接口 +- [效果追踪API](./monitoring.md) - 效果追踪监测接口 +- [竞品分析API](./competitor.md) - 竞品分析接口 +- [Schema建议API](./schema.md) - Schema标记建议接口 +- [趋势洞察API](./trends.md) - 趋势洞察接口 +- [API Key管理API](./api-keys.md) - API Key管理接口 +- [用量追踪API](./usage.md) - 用量追踪接口 ## 诊断模块API diff --git a/docs/05-部署运维/docker.md b/docs/05-部署运维/docker.md index fa19d3d..4750b1e 100644 --- a/docs/05-部署运维/docker.md +++ b/docs/05-部署运维/docker.md @@ -23,8 +23,12 @@ cp .env.example .env # 编辑.env配置必要参数 ``` +> **说明**:Docker Compose 默认读取项目根目录下的 `.env` 文件加载环境变量。所有配置项均可通过 `.env` 文件统一管理,无需在 `docker-compose.yml` 中硬编码。 + ### 3. 启动服务 +> **注意**:开发模式使用 `uvicorn --reload` 启动,支持热重载;生产环境必须移除 `--reload` 参数,使用 `uvicorn app.main:app --host 0.0.0.0 --port 8000` 启动。 + ```bash # 开发环境 docker-compose up -d @@ -56,10 +60,24 @@ docker-compose logs -f ## 数据持久化 - volumes: - - postgres_data:/var/lib/postgresql/data - - redis_data:/data - - ./uploads:/app/uploads +通过Volume挂载确保容器重建后数据不丢失: + +```yaml +volumes: + postgres_data:/var/lib/postgresql/data + redis_data:/data + ./uploads:/app/uploads + ./logs:/app/logs + ./backups:/app/backups +``` + +| 挂载路径 | 说明 | +|----------|------| +| `postgres_data` | PostgreSQL数据目录 | +| `redis_data` | Redis持久化数据 | +| `./uploads` | 用户上传文件 | +| `./logs` | 应用日志 | +| `./backups` | 数据库备份 | ## 健康检查 diff --git a/docs/05-部署运维/environment.md b/docs/05-部署运维/environment.md index 6cd9517..884764a 100644 --- a/docs/05-部署运维/environment.md +++ b/docs/05-部署运维/environment.md @@ -32,6 +32,8 @@ | `SECRET_KEY` | 应用密钥 | `your-secret-key` | | `CORS_ORIGINS` | 允许的跨域源 | `http://localhost:3000` | +> **注意**:生产环境必须配置具体域名,不能使用通配符`*` + ### 认证 | 变量 | 说明 | 示例 | @@ -60,6 +62,29 @@ | `QDRANT_HOST` | Qdrant地址 | `localhost` | | `QDRANT_PORT` | Qdrant端口 | `6333` | +### AI平台API密钥 + +| 变量 | 说明 | 示例 | +|------|------|------| +| `CHATGPT_API_KEY` | ChatGPT API密钥 | `sk-...` | +| `DEEPSEEK_API_KEY` | DeepSeek API密钥 | `sk-...` | +| `QWEN_API_KEY` | 通义千问API密钥 | `sk-...` | +| `KIMI_API_KEY` | Kimi API密钥 | `sk-...` | +| `DOUBAO_API_KEY` | 豆包API密钥 | `sk-...` | +| `GEMINI_API_KEY` | Gemini API密钥 | `AI...` | +| `PERPLEXITY_API_KEY` | Perplexity API密钥 | `pplx-...` | +| `YUANBAO_API_KEY` | 元宝API密钥 | `sk-...` | +| `WENXIN_API_KEY` | 文心一言API密钥 | `sk-...` | +| `ZHIPU_API_KEY` | 智谱AI API密钥 | `...` | +| `TONGYI_API_KEY` | 通义API密钥 | `sk-...` | + +### 功能开关 + +| 变量 | 说明 | 默认值 | +|------|------|------| +| `ENABLE_LLM` | 启用LLM增强功能 | `true` | +| `ENABLE_AGENT_FRAMEWORK` | 启用Agent框架 | `true` | + ## 开发环境示例 ```env diff --git a/docs/brainstorms/2026-05-31-geo-next-phase-core-flow-repair-requirements.md b/docs/brainstorms/2026-05-31-geo-next-phase-core-flow-repair-requirements.md new file mode 100644 index 0000000..4f4ec1c --- /dev/null +++ b/docs/brainstorms/2026-05-31-geo-next-phase-core-flow-repair-requirements.md @@ -0,0 +1,204 @@ +--- +date: "2026-05-31" +topic: "geo-next-phase-core-flow-repair" +--- + +## Summary + +GEO 平台下一阶段的核心目标是实施变现闭环(诊断修复→免费获客→付费转化→执行闭环→效果归因),同时用 API 契约测试驱动核心功能修复,辅以少量 E2E 烟雾测试验证关键用户路径。原有的全面质量保障计划(E2E 覆盖、性能基线、安全扫描)降级为上线后工作。 + +--- + +## Problem Frame + +GEO 平台尚未上线,存在三个致命问题阻断变现闭环: + +1. **诊断系统形同虚设** — `backend/app/api/diagnosis.py` 第 75-77 行用 `GEODiagnosisInput()` 空参调用,所有字段默认 False/0,永远返回 0 分。产品核心价值为零。 +2. **内容 Pipeline 只格式化不生成** — `backend/app/services/content/content_pipeline.py` 只有 4 个阶段(RuleValidator→SensitiveFilter→SEOOptimizer→HTMLGenerator),没有 AI 内容生成阶段,核心付费功能缺失。 +3. **支付是模拟的** — `backend/app/services/subscription.py` 中 `payment_method="模拟支付"`,没有真实支付网关和功能限制中间件,用户无付费动力。 + +此外,Onboarding 缺少付费墙触发点、分发没有实际发布集成、监控没有归因逻辑、邮件服务未接入业务系统——各模块存在但互不连通。 + +与此同时,当前的质量保障计划(11 个 IU)将大量投入放在 E2E 测试、性能基线、安全扫描上——这些对未上线且核心功能断裂的产品来说 ROI 不高。核心矛盾:**测试体系是为稳定产品服务的,但产品本身还不稳定。** + +--- + +## Key Decisions + +**变现闭环优先于全面测试体系** — 产品核心价值为零(诊断永远 0 分),此时投入 E2E 测试、性能基线、安全扫描无法产生实际价值。先修通变现闭环,让产品可用且有收入,再建设全面质量保障。 + +**测试角色从"质量保障"转变为"行为契约"** — 不再是"写测试防止回归",而是"用测试定义正确行为,驱动修复"。为变现闭环的每个新/修改 API 编写契约测试,定义期望行为,驱动实现修复。 + +**API 契约测试为主,E2E 烟雾测试为辅** — API 测试反馈快(秒级)、定位精确、维护成本低;E2E 只覆盖 1-2 个最关键路径作为安全网。原 QA 计划的 3 个完整 E2E 套件(U3/U4/U5)缩减为 1-2 个烟雾测试。 + +**获客优先于数据护城河** — 行业基准数据是长期壁垒,但在没有验证付费意愿之前投入风险过高。先用免费 GEO 健康分验证需求,再积累数据资产。 + +**半自动执行优先于全自动** — 品牌方不会信任 AI 完全自主发布内容。执行闭环采用"AI 生成→人工审核→确认发布"模式。 + +**中国 AI 平台生态为差异化核心** — 海外 SEO/GEO 工具无法覆盖文心、Kimi、通义、豆包等中国 AI 平台,这是天然护城河。 + +**性能基线、安全扫描、全面 E2E 推迟到上线后** — 这些投入在有用户之前无法产生实际价值。上线后根据真实数据和用户反馈决定优先级。 + +--- + +## Requirements + +### 变现闭环核心修复 + +R1. 修复诊断系统:实现自动数据采集(AI 平台查询+网站爬取+CitationRecord 分析),让 `GEODiagnosisService.diagnose()` 接受真实数据输入,产出非零有差异的评分 + +R2. 添加 AI 内容生成阶段:在 ContentPipeline 前端添加 Stage 0(AI Generator),基于诊断结果+RAG 知识库+LLM 生成初稿,后续阶段不变 + +R3. 接入真实支付:微信支付+支付宝双通道,支付 Webhook 处理订阅状态更新,功能限制中间件按套餐控制使用量 + +### 获客与转化 + +R4. 建设免费 GEO 健康分公开页面:输入品牌名→30 秒出报告(综合评分+3 核心维度+3 竞品对比),无需注册,结果缓存 24 小时 + +R5. 重设计 Onboarding 流程:第一步为"查看 GEO 健康分"而非填表,在查看详细建议/执行优化/查看归因报告时嵌入升级提示 + +R6. 执行闭环:内容分发集成微信(半自动)/知乎(API)/头条(API),发布流程"AI 生成→人工审核→确认发布" + +R7. 效果归因系统:追踪已发布内容的 AI 引用变化,2-4 周归因窗口,ROI 计算+A/B 对比报告 + +### API 契约测试 + +R8. 为诊断 API 编写契约测试:POST /api/v1/diagnosis 触发异步诊断,GET /api/v1/diagnosis/{id}/result 返回非零评分 + +R9. 为公开健康分 API 编写契约测试:GET /api/v1/public/health-score?brand=XXX 无需认证,返回综合评分+3 维度+竞品对比 + +R10. 为内容生成 API 编写契约测试:POST /api/v1/content/generate 基于诊断+RAG 生成内容 + +R11. 为支付 API 编写契约测试:创建订单→支付回调→订阅激活的完整链路 + +R12. 编写跨步骤集成测试:品牌创建→诊断→方案→内容生成→效果追踪的完整链路 + +### E2E 烟雾测试 + +R13. 编写 1 个获客路径 E2E 烟雾测试:访问健康分页面→输入品牌名→看到报告→点击注册 + +R14. 编写 1 个核心流程 E2E 烟雾测试:登录→创建品牌→触发诊断→查看诊断结果 + +### 测试基础设施(延续) + +R15. 完成后端测试目录统一验证(原 QA 计划 U1) + +R16. 建立共享 fixture 体系(原 QA 计划 U2):API 契约测试完成后迁移到共享 fixture + +--- + +## Key Flows + +- F1. 变现闭环主流程 + - **Trigger:** 品牌市场人员访问平台 + - **Actors:** 品牌市场人员, 平台系统 + - **Steps:** (1) 输入品牌名查看免费 GEO 健康分 (2) 看到震撼的评分和竞品对比 (3) 注册查看详细建议 (4) 升级付费版执行优化 (5) AI 生成优化内容 (6) 人工审核确认发布 (7) 2-4 周后查看效果归因报告 (8) 续费 + - **Outcome:** 用户完成从"不知道 GEO"到"付费执行优化"到"续费"的完整闭环 + +- F2. API 契约驱动修复流程 + - **Trigger:** 开始修复某个断裂环节 + - **Actors:** 开发者 + - **Steps:** (1) 编写该环节的 API 契约测试,定义期望的输入/输出 (2) 运行测试,确认失败(红色) (3) 修复后端实现 (4) 运行测试,确认通过(绿色) (5) 编写下一个环节的契约测试 (6) 重复直到全链路通过 + - **Outcome:** 核心流程每步都有契约保护,断裂点被精确定位和修复 + +--- + +## Acceptance Examples + +- AE1. **诊断系统修复** — Covers R1, R8 + - **Given:** 品牌已创建 + - **When:** 调用诊断 API 并传入品牌 ID + - **Then:** 返回非零评分,且不同品牌评分有差异;免费版返回 3 核心维度,付费版返回全部 6 维度 + +- AE2. **免费健康分获客** — Covers R4, R9 + - **Given:** 未注册用户访问健康分页面 + - **When:** 输入品牌名 + - **Then:** 30 秒内生成包含综合评分、3 维度、3 竞品对比的报告;24 小时内重复查询返回缓存 + +- AE3. **AI 内容生成** — Covers R2, R10 + - **Given:** 付费用户有诊断结果和知识库 + - **When:** 触发内容生成 + - **Then:** 基于诊断问题+RAG 上下文生成针对性优化内容,内容通过 Pipeline 校验 + +- AE4. **支付闭环** — Covers R3, R11 + - **Given:** 用户选择升级套餐 + - **When:** 完成支付 + - **Then:** 订阅状态更新为 active,功能限制中间件解锁对应功能 + +- AE5. **E2E 烟雾测试** — Covers R13, R14 + - **Given:** 前后端服务运行 + - **When:** 执行 Playwright E2E 烟雾测试 + - **Then:** 获客路径和核心流程路径通过 + +--- + +## Success Criteria + +- 诊断系统产出真实非零评分,不同品牌有差异 +- 免费 GEO 健康分页面可公开访问,30 秒内出报告 +- AI 内容生成基于诊断结果产出可用内容 +- 支付闭环可走通(创建订单→支付→激活→功能解锁) +- 核心流程 API 契约测试全部通过 +- 2 个 E2E 烟雾测试通过 + +--- + +## Scope Boundaries + +**Deferred for later:** + +- 全面 E2E 测试覆盖(原 QA 计划 U3/U4/U5 → 上线后逐步扩展) +- 性能基线测试(原 QA 计划 U7 → 上线后有真实负载后建立) +- CI 安全扫描(原 QA 计划 U8 → 上线后集成) +- Alembic 迁移验证(原 QA 计划 U8 → 上线后添加) +- 前端组件测试(原 QA 计划 U10 → 上线后扩展) +- 行业 GEO 基准数据积累 +- API 市场开放和第三方集成 +- 白标报告和专属客户成功经理 +- 多语言和国际 AI 平台支持 + +**Outside this product's identity:** + +- 传统 SEO 工具功能(关键词排名、外链分析等) +- 广告投放和营销自动化 +- 社交媒体管理 +- 基础设施级别的渗透测试 + +--- + +## Dependencies / Assumptions + +- 变现闭环实施计划已存在:`docs/plans/2026-05-31-003-feat-geo-monetization-closed-loop-plan.md` +- 诊断自动数据采集依赖现有 AI 平台适配器(`backend/app/services/ai_engine/`) +- AI 内容生成依赖 LLM API Key 可用 +- 支付集成需要微信支付和支付宝商户号 +- 知乎/头条 API 发布需要 OAuth 授权和 API Key +- API 契约测试使用 httpx AsyncClient + 内存 SQLite,不依赖外部服务 +- E2E 烟雾测试需要前后端服务同时运行 +- 假设品牌方看到 GEO 健康分后会意识到问题并产生付费意愿(未验证,需 MVP 验证) + +--- + +## Outstanding Questions + +**Deferred to Planning:** + +- 效果归因的具体技术实现方案(基于引用检测还是基于评分变化) +- 付费版定价是否需要调整(当前 ¥199/599/1999 是否匹配价值感知) +- 微信公众号 OAuth 授权流程的具体实现方案 +- 知乎和头条 API 的申请和审核周期 +- 诊断自动数据采集的网站爬取可能遇到反爬机制的降级方案 + +--- + +## Sources / Research + +- 变现闭环需求文档:`docs/brainstorms/2026-05-31-geo-platform-monetization-closed-loop-requirements.md` +- 变现闭环实施计划:`docs/plans/2026-05-31-003-feat-geo-monetization-closed-loop-plan.md` +- 质量保障需求文档:`docs/brainstorms/2026-05-31-geo-platform-next-phase-quality-requirements.md` +- 质量保障实施计划:`docs/plans/2026-05-31-002-test-quality-assurance-system-plan.md` +- 诊断系统根因:`backend/app/api/diagnosis.py` 第 75-77 行(空参调用) +- 内容 Pipeline:`backend/app/services/content/content_pipeline.py`(4 阶段,无 AI 生成) +- 订阅服务:`backend/app/services/subscription.py`(模拟支付) +- 现有后端测试:`backend/tests/` (~90 个文件) +- 现有 E2E 测试:`frontend/e2e/tests/` (7 个文件) diff --git a/docs/brainstorms/2026-05-31-geo-platform-acceptance-audit-requirements.md b/docs/brainstorms/2026-05-31-geo-platform-acceptance-audit-requirements.md new file mode 100644 index 0000000..8e78178 --- /dev/null +++ b/docs/brainstorms/2026-05-31-geo-platform-acceptance-audit-requirements.md @@ -0,0 +1,184 @@ +--- +date: "2026-05-31" +topic: "geo-platform-acceptance-audit" +--- + +## Summary + +GEO 平台全项目验收审计方案,定义功能完整性、安全性、性能、代码质量四个维度的验收标准与检查清单,包含已发现 42 项问题的分级修复计划,以及四类项目文档的更新规范,确保平台达到可交付质量水平。 + +## Problem Frame + +GEO 平台经过多轮迭代开发,已实现 33 个 API 路由模块、8 个 Agent、35 个数据模型、25+ 前端页面。但快速迭代带来了技术债务:4 个新 Agent 的前端 API 客户端尚未创建、部分 API 存在认证绕过风险、N+1 查询影响性能、API 文档覆盖率仅 15%。在交付前需要系统性的验收审计,识别并修复阻塞问题,同步更新文档使项目状态与代码一致。 + +--- + +## Key Decisions + +**审计维度覆盖全部四项而非选择性覆盖** — 功能完整性、安全性、性能、代码质量相互关联,部分安全漏洞隐藏在功能实现中,性能问题影响功能可用性,选择性覆盖会遗漏跨维度问题。 + +**Critical/High 问题作为验收阻塞项** — 20 个 Critical + High 问题必须在验收通过前修复,Medium/Low 问题可记录为技术债务在后续迭代处理。 + +**文档更新纳入验收流程** — 文档与代码不同步是当前最大运维风险(API 文档覆盖率仅 15%),文档更新不是附加任务而是验收的必要条件。 + +**自审计而非第三方审计** — 本方案定位为交付前自审计,确保基本质量门槛;渗透测试等深度安全审计作为后续独立项目规划。 + +--- + +## Requirements + +### 功能完整性 + +R1. 所有 33 个 API 路由模块必须有对应的前端 API 客户端模块,且接口签名与后端一致 + +R2. 8 个 Agent 均可通过 standalone 模式独立启动并执行其声明的 supported_tasks + +R3. 用户注册→登录→Onboarding→品牌创建→诊断→方案生成→内容生成→效果追踪的完整业务闭环可端到端走通 + +R4. 4 个新 Agent(MonitorAgent、SchemaAdvisor、CompetitorAnalyzer、TrendAgent)的前端 API 客户端模块必须创建,覆盖其全部 API 端点 + +R5. 前端所有页面的 API 调用参数(字段名、类型、顺序)与后端路由定义一致,无运行时 400/500 错误 + +R6. GEO 方案自动生成流程(诊断→策略→方案→行动)完整可用,竞品为可选输入 + +R7. 品牌评分(5 维度 V2 评分体系)在品牌详情页和仪表盘均可正确展示 + +### 安全性 + +R8. 所有业务 API 端点(除 `/health`、`/ready`、`/auth/login`、`/auth/register` 外)必须要求认证,不存在认证绕过路径 + +R9. analytics 路由的认证检查必须与业务路由一致,不能因中间件顺序或装饰器缺失而跳过 + +R10. API Key 通过加密存储(key_encryption.py),密钥不在日志、响应体、前端代码中明文暴露 + +R11. CORS 配置在生产环境必须限制 `allow_origins` 为具体域名,不能使用通配符 `*` + +R12. 用户输入必须经过 Pydantic Schema 校验,不存在未校验的请求参数直接进入数据库查询 + +R13. SQL 查询使用 SQLAlchemy ORM 参数化查询,不存在字符串拼接 SQL + +### 性能 + +R14. 品牌评分数据获取逻辑提取为共享服务(BrandScoringDataService),消除 dashboard.py 和 strategy.py 中的重复查询 + +R15. 消除 N+1 查询模式:品牌列表页、仪表盘、策略页等高频访问路径不得出现循环内单条查询 + +R16. AI 引擎批量查询使用 `queryBatch` 端点而非逐条查询 + +R17. Redis 缓存层对热点数据(品牌评分、仪表盘统计)生效,缓存命中率作为可观测指标 + +### 代码质量 + +R18. 前端 `agentsApi.enable`/`agentsApi.disable` 死代码必须移除或替换为实际可用的实现 + +R19. 前端 `onboardingApi.getCompetitorRecommendations` 参数签名必须与后端一致 + +R20. `models/__init__.py` 中标注为"重构后遗留"的模型导入必须验证其与数据库迁移的一致性 + +R21. `SEOOptimizer` 命名与实际功能(GEO 优化)不一致的问题必须在代码或文档中明确标注 + +R22. `content.py`(内容生产)与 `contents.py`(内容管理)的命名混淆必须在 API 文档中明确区分 + +R23. 前端 `lib/api/index.ts` 聚合导出必须覆盖所有 API 客户端模块,消除遗漏 + +--- + +## Key Flows + +- F1. 验收审计执行流程 + - **Trigger:** 审计方案文档确认后启动 + - **Actors:** 开发者、审计执行者 + - **Steps:** (1) 按维度执行检查清单 (2) 记录每项的通过/不通过状态 (3) Critical/High 不通过项进入修复流程 (4) 修复后回归验证 (5) 全部阻塞项通过后进入文档更新阶段 (6) 文档更新完成后签发验收报告 + - **Outcome:** 验收通过/不通过判定 + 问题修复记录 + 更新后的文档 + +- F2. 问题修复流程 + - **Trigger:** 审计检查发现 Critical 或 High 问题 + - **Actors:** 开发者 + - **Steps:** (1) 问题分级确认 (2) 按优先级排序(Critical > High > Medium > Low)(3) 修复实现 (4) 回归验证(确认修复未引入新问题)(5) 更新审计检查清单状态 + - **Outcome:** 问题状态从"不通过"变为"通过" + +- F3. 文档更新流程 + - **Trigger:** 代码修复完成后 + - **Actors:** 开发者 + - **Steps:** (1) 更新项目文档(README、架构图、API 文档)(2) 更新 AI 上下文文件(CLAUDE.md、AGENTS.md)(3) 更新设计文档(模块说明、流程图)(4) 更新部署文档(环境变量、Docker 配置)(5) 交叉验证文档与代码一致性 + - **Outcome:** 四类文档与代码实现完全一致 + +--- + +## Acceptance Examples + +- AE1. **认证绕过验证** — Covers R8, R9 + - **Given:** 未认证的 HTTP 请求 + - **When:** 请求 `/api/v1/analytics` 端点 + - **Then:** 返回 401 Unauthorized,不返回业务数据 + +- AE2. **新 Agent 前端可用性** — Covers R4 + - **Given:** 前端 API 客户端模块 `monitoring.ts`、`competitor-analysis.ts`、`schema-advisor.ts`、`trends.ts` 已创建 + - **When:** 前端页面调用这些模块的方法 + - **Then:** 请求正确到达后端对应端点,参数类型和顺序匹配,无 400/500 错误 + +- AE3. **N+1 查询消除** — Covers R14, R15 + - **Given:** 仪表盘页面加载 + - **When:** 后端处理 `/api/v1/dashboard` 请求 + - **Then:** 品牌评分数据通过共享服务一次查询获取,SQL 日志中不出现循环内单条查询模式 + +- AE4. **CORS 生产配置** — Covers R11 + - **Given:** 生产环境部署 + - **When:** CORS 中间件初始化 + - **Then:** `allow_origins` 不包含通配符 `*`,仅允许配置的域名 + +- AE5. **业务闭环端到端** — Covers R3 + - **Given:** 新注册用户完成 Onboarding + - **When:** 依次执行诊断→查看诊断报告→生成 GEO 方案→执行内容生成→查看效果追踪 + - **Then:** 每一步均可成功完成,数据在步骤间正确传递 + +--- + +## Success Criteria + +- 全部 6 个 Critical 问题修复并通过回归验证 +- 全部 14 个 High 问题修复并通过回归验证 +- API 文档覆盖率从 15% 提升至 80% 以上(至少覆盖 27/33 个路由模块) +- AI 上下文文件(CLAUDE.md / AGENTS.md)创建完成,准确反映项目当前架构 +- 四类文档更新完成且与代码实现一致 +- 端到端业务闭环无阻塞错误 + +--- + +## Scope Boundaries + +**Deferred for later:** + +- 自动化 E2E 测试框架搭建 +- 性能压测与负载测试方案 +- 第三方安全审计(渗透测试) +- CI/CD 流水线完善 +- 前端单元测试和集成测试覆盖 + +**Outside this product's identity:** + +- 基础设施级别的安全审计(网络层、K8s 配置) +- 用户体验走查和可用性测试 +- 多语言/国际化验证 + +--- + +## Dependencies / Assumptions + +- 后端服务可正常启动(PostgreSQL + Redis 可用) +- 前端开发服务器可正常启动 +- 至少一个 LLM Provider 的 API Key 可用(用于验证 Agent 功能) +- 当前审计基于代码静态分析 + 架构审查,不包含运行时动态分析 +- 数据库迁移状态与 models 定义一致(需在验收时验证) + +--- + +## Sources / Research + +- 审计扫描结果:Critical 6、High 14、Medium 14、Low 8 共 42 项问题 +- 项目架构文档:`docs/01-项目概览/architecture.md` +- Agent 框架协议:`docs/02-模块说明/agent-protocol.md` +- GEO 工作流分析:`docs/02-模块说明/geo-workflow-analysis.md` +- 新 Agent 实现计划:`docs/plans/2026-05-28-001-feat-geo-platform-new-agents-plan.md` +- 环境变量模板:`.env.example` +- Docker 部署配置:`docker-compose.yml` diff --git a/docs/brainstorms/2026-05-31-geo-platform-monetization-closed-loop-requirements.md b/docs/brainstorms/2026-05-31-geo-platform-monetization-closed-loop-requirements.md new file mode 100644 index 0000000..43dc50b --- /dev/null +++ b/docs/brainstorms/2026-05-31-geo-platform-monetization-closed-loop-requirements.md @@ -0,0 +1,127 @@ +--- +date: 2026-05-31 +topic: geo-platform-monetization-closed-loop +--- + +## Summary + +将现有GEO平台改造为以「免费GEO健康分」为获客入口、以「诊断→建议→执行→效果归因」为付费闭环的商业模式,同时聚焦中国AI平台生态构建差异化壁垒。 + +## Problem Frame + +品牌方市场团队目前完全没有关注过自己在AI搜索引擎中的表现。当ChatGPT、Kimi、DeepSeek等AI平台回答用户提问时,品牌是否被引用、引用是否正面、排名如何——这些直接影响品牌认知的数据,品牌方一无所知。现有的GEO平台已经具备了6维度诊断、内容生成、知识库等能力,但缺少三个关键环节:让目标客户意识到问题的获客机制、从诊断到效果验证的执行闭环、以及证明优化有效的效果归因。没有这三者,平台无法完成商业变现闭环。 + +## Key Decisions + +**获客优先于数据护城河。** 行业基准数据(方案C)是长期壁垒,但在没有验证付费意愿之前投入,风险过高。先用免费GEO健康分验证需求,再积累数据资产。 + +**半自动执行优先于全自动。** 品牌方不会信任AI完全自主发布内容。执行闭环采用"AI生成→人工审核→确认发布"模式,而非全自动发布。 + +**中国AI平台生态为差异化核心。** 海外SEO工具(Ahrefs/Semrush)和GEO工具(Profound/peecAI)无法覆盖文心、Kimi、通义、豆包等中国AI平台,这是天然护城河。 + +**免费诊断的"震撼值"驱动转化。** 免费GEO健康分必须足够震撼(竞品对比、行业排名、具体问题),才能让不知道GEO的品牌方产生紧迫感和付费意愿。 + +**免费诊断采用频率+深度双限制控制成本。** 免费版限制诊断频率(每个品牌每天1次,详细报告每7天1次)和诊断深度(免费版只看3个核心维度和3个竞品,付费版看全部6维度和更多竞品),在保持震撼值的同时控制API调用成本。 + +**三平台同步推进内容分发集成。** 微信公众号、知乎、头条三个平台同步推进发布集成,微信公众号采用半自动模式(生成内容+复制粘贴引导),知乎和头条利用较开放的API实现更自动化的发布。 + +## Actors + +- A1. **品牌市场人员** — 核心用户,需要了解品牌在AI搜索中的表现并采取行动 +- A2. **品牌决策者** — 查看GEO健康分报告后决定是否付费的管理层 +- A3. **平台系统** — 自动执行诊断、监控、归因等后台任务 + +## Requirements + +### 获客层:GEO健康分驱动增长 + +- R1. 提供无需注册的即时GEO健康分诊断,输入品牌名即可在30秒内生成可视化报告 +- R2. 免费报告包含:综合GEO评分、3个核心维度得分、与3个竞品的对比、最严重的3个问题概要;频率限制为每个品牌每天1次,详细报告每7天1次 +- R2b. 付费版解锁全部6维度诊断、更多竞品对比、无频率限制 +- R3. 详细修复建议和执行方案需注册后查看,注册免费 +- R4. 执行修复建议(内容生成、知识库搭建等)需升级付费版 +- R5. 诊断报告支持生成可分享链接和PDF,便于品牌市场人员向决策者汇报 + +### 转化层:从认知到付费 + +- R6. 重设计Onboarding流程,第一步为"查看你的GEO健康分"而非填表 +- R7. 注册后自动触发GEO变化周报邮件订阅,作为持续触达和转化手段 +- R8. 付费版升级触发点嵌入在以下场景:查看详细建议时、执行优化操作时、查看效果归因报告时 +- R9. 提供GEO评分提升承诺和效果保障机制,降低付费决策门槛 + +### 执行闭环层:诊断→建议→执行→效果归因 + +- R10. 打通内容分发到实际发布的链路,同步支持微信公众号(半自动:生成内容+复制粘贴引导)、知乎(API发布)、头条(API发布)三个平台 +- R11. 发布流程采用"AI生成→人工审核→确认发布"的半自动模式 +- R12. 建设效果归因系统,追踪已发布内容是否被AI搜索引擎引用,计算GEO ROI +- R13. 提供优化前后的A/B对比报告,量化GEO优化的实际效果 +- R14. 设定合理的归因时间窗口(建议2-4周),并在此期间持续监控变化趋势 + +### 差异化层:中国AI平台生态壁垒 + +- R15. 持续覆盖中国主流AI平台(文心、Kimi、通义、豆包、元宝、清言、天工等),确保海外工具无法替代 +- R16. 针对中国AI平台的引用偏好建立规则库,持续更新各平台的引用模式变化 +- R17. 知识库针对GEO场景深度优化,自动从品牌官网、产品文档中提取AI搜索引擎偏好的信息结构 + +## Key Flows + +- F1. 免费获客转化流 + - **Trigger:** 品牌市场人员访问平台首页 + - **Actors:** A1, A3 + - **Steps:** 输入品牌名 → 系统即时生成GEO健康分报告 → 展示综合评分和竞品对比 → 提示"查看详细修复建议"需注册 → 注册后查看建议概要 → 执行建议需升级付费版 + - **Outcome:** 用户完成从"不知道GEO"到"付费执行优化"的转化 + +- F2. 执行闭环流 + - **Trigger:** 用户查看GEO优化建议并决定执行 + - **Actors:** A1, A3 + - **Steps:** 选择优化建议 → AI生成优化内容 → 人工审核确认 → 确认发布到目标平台 → 系统开始追踪该内容的AI引用情况 → 2-4周后生成效果归因报告 + - **Outcome:** 用户看到GEO优化的量化效果,形成续费动力 + +- F3. 效果归因流 + - **Trigger:** 内容发布后系统自动启动追踪 + - **Actors:** A3 + - **Steps:** 记录发布时的基线数据 → 定期检查AI搜索引擎是否引用该内容 → 计算引用变化和GEO评分变化 → 生成归因报告(优化前vs优化后) → 通知用户查看效果 + - **Outcome:** 用户获得GEO ROI的量化证明 + +## Acceptance Examples + +- AE1. **Covers R1, R2, R5.** Given 一个未注册用户访问平台, When 输入品牌名"某某科技", Then 30秒内生成包含综合评分、6维度得分、竞品对比和可分享链接的GEO健康分报告 +- AE2. **Covers R3, R4.** Given 一个已注册免费用户查看GEO健康分报告, When 点击"查看详细修复建议", Then 显示建议概要;When 点击"执行优化", Then 提示需升级付费版 +- AE3. **Covers R10, R11, R12.** Given 一个付费用户选择执行某条GEO优化建议, When AI生成优化内容后, Then 进入人工审核环节;When 用户确认发布, Then 内容发布到目标平台且系统开始追踪AI引用变化 +- AE4. **Covers R13, R14.** Given 一条优化内容已发布3周, When 系统检测到该内容被Kimi引用, Then 生成归因报告显示优化前后的GEO评分变化和引用数据对比 + +## Success Criteria + +- 免费GEO健康分报告的完成率(输入品牌名到查看报告)> 60% +- 从免费报告到注册的转化率 > 15% +- 从注册到付费的转化率 > 5% +- 付费用户中能看到效果归因报告的比例 > 70% +- 付费用户月续费率 > 85% + +## Scope Boundaries + +**Deferred for later:** +- 行业GEO基准数据积累和行业白皮书发布 +- API市场开放和第三方集成 +- 白标报告和专属客户成功经理功能 +- 多语言和国际AI平台支持 + +**Outside this product's identity:** +- 传统SEO工具功能(关键词排名、外链分析等) +- 广告投放和营销自动化 +- 社交媒体管理 + +## Dependencies / Assumptions + +- 假设品牌方看到GEO健康分后会意识到问题并产生付费意愿(未验证,需通过MVP验证) +- 假设GEO优化效果可在2-4周内被观测到(需验证,不同AI平台更新周期不同) +- 假设中国主流内容平台(微信公众号、知乎、头条)的API开放程度足以支持半自动发布(需验证) +- 依赖现有6维度GEO诊断系统的准确性和稳定性 +- 依赖现有8个Agent框架的可扩展性 + +## Outstanding Questions + +**Deferred to Planning:** +- 效果归因的具体技术实现方案(基于引用检测还是基于评分变化) +- 付费版定价是否需要调整(当前¥199/599/1999是否匹配价值感知) +- GEO变化周报的邮件发送基础设施选型 diff --git a/docs/brainstorms/2026-05-31-geo-platform-next-phase-quality-requirements.md b/docs/brainstorms/2026-05-31-geo-platform-next-phase-quality-requirements.md new file mode 100644 index 0000000..0826d79 --- /dev/null +++ b/docs/brainstorms/2026-05-31-geo-platform-next-phase-quality-requirements.md @@ -0,0 +1,167 @@ +--- +date: "2026-05-31" +topic: "geo-platform-next-phase-quality" +--- + +## Summary + +GEO 平台下一阶段质量保障体系,采用双轨并行策略:轨道一统一测试基础设施(合并分裂的测试目录、完善 CI 流水线、建立共享 fixture 体系),轨道二同步推进核心业务 E2E 测试和四项专项测试(性能基准、安全扫描、数据库迁移验证、Agent 端到端测试),确保核心业务流程尽早获得回归保护。 + +## Problem Frame + +GEO 平台经过多轮迭代已具备 33 个 API、8 个 Agent、25+ 前端页面,但测试覆盖严重不足:后端两套测试目录分裂导致 17 个测试被 CI 忽略,前端仅 13 个单元测试且零组件测试,E2E 仅覆盖登录流程,核心业务(品牌创建→诊断→方案→内容→监控)完全无回归保护。CI 流水线不运行 E2E 测试,无安全扫描,无性能基线。在用户开始使用前,必须建立质量保障体系防止功能回归。 + +--- + +## Key Decisions + +**双轨并行而非分层递进** — 基础设施统一和 E2E 测试编写同步推进,因为两者几乎不涉及相同文件,并行没有实质冲突。E2E 测试初期用独立 setup,后续迁移到共享 fixture 是低成本重构。如果等基础设施全部就绪再写 E2E,核心流程的回归保护会延迟数周。 + +**E2E 测试仅覆盖 Chromium** — 先在一个浏览器上稳定运行,跨浏览器扩展作为后续迭代。Playwright 已配置 3 浏览器但当前 E2E 用例太少,扩展浏览器覆盖的 ROI 不高。 + +**性能测试先建基线再设阈值** — 在没有历史数据时设定 SLA 容易过严或过松。第一轮只采集数据建立基线,第二轮根据基线设定阈值和告警。 + +**安全扫描集成到 CI 而非独立流程** — bandit 和 npm audit 作为 CI 步骤运行,在 PR 级别就暴露问题,而不是等到发布前才扫描。 + +--- + +## Requirements + +### 轨道一:测试基础设施 + +R1. 将 `geo/tests/` 下的 17 个测试迁移至 `backend/tests/` 对应子目录,合并两套 conftest.py 的 fixture,删除旧目录 + +R2. CI 中 `pytest tests/` 命令能发现并运行全部后端测试(迁移后应 ≥77 个) + +R3. 将 `pytest-cov` 正式加入 `backend/requirements.txt`,不再依赖 CI 中临时安装 + +R4. 建立共享 fixture 体系:数据库会话(含自动 rollback)、认证 mock(JWT token + auth headers)、测试用户创建、httpx AsyncClient + +R5. 前端 vitest 覆盖范围扩展至 `lib/api/` 全部 27 个模块和关键页面组件 + +R6. 更新 `docs/03-开发指南/testing.md` 使其与实际目录结构、CI 配置、fixture 体系一致 + +### 轨道二:核心业务 E2E 测试 + +R7. 编写用户注册→登录→Onboarding→品牌创建的完整 E2E 测试 + +R8. 编写诊断→查看诊断报告→生成 GEO 方案的 E2E 测试 + +R9. 编写内容生成→查看内容→效果追踪的 E2E 测试 + +R10. E2E 测试在 CI 中运行(至少 Chromium),使用 PostgreSQL 和 Redis service container + +R11. E2E 测试失败时自动截图和录制视频,便于排查 + +### 专项测试 + +R12. 为 5-10 个高频 API 端点建立性能基线(p50/p95/p99 响应时间),首轮只采集不设阈值 + +R13. CI 中集成 bandit(Python 安全扫描)和 npm audit(Node.js 依赖安全检查),PR 级别阻断高危漏洞 + +R14. CI 中添加 Alembic 迁移验证步骤:`alembic upgrade head` 在空数据库上成功执行 + +R15. 编写 Agent 框架端到端测试:任务创建→分发→执行→结果查询的完整链路(至少覆盖 CitationDetector 和 ContentGenerator) + +### 前端组件测试 + +R16. 为 5 个关键页面组件编写 React Testing Library 测试:Dashboard、品牌详情、诊断页、策略页、内容编辑器 + +R17. 为 4 个新 Agent 的前端 API 客户端模块编写单元测试:monitoring.ts、competitor-analysis.ts、schema-advisor.ts、trends.ts + +--- + +## Key Flows + +- F1. 双轨并行执行流程 + - **Trigger:** 需求文档确认后启动 + - **Actors:** 开发者 + - **Steps:** (1) 轨道一:统一测试目录→合并 fixture→修复 CI→扩展 vitest 覆盖 (2) 轨道二:编写核心 E2E→集成 CI E2E 步骤→编写专项测试 (3) 两条轨道完成后汇合:E2E 测试迁移到共享 fixture (4) 更新文档 + - **Outcome:** 完整的质量保障体系,CI 全量运行,核心流程有回归保护 + +- F2. E2E 测试编写流程 + - **Trigger:** 开始编写某个业务流程的 E2E 测试 + - **Actors:** 开发者 + - **Steps:** (1) 定义用户旅程步骤 (2) 编写 Playwright 测试用独立 setup (3) 本地验证通过 (4) 提交并在 CI 中验证 (5) 后续迁移到共享 fixture + - **Outcome:** 可在 CI 中稳定运行的 E2E 测试 + +--- + +## Acceptance Examples + +- AE1. **测试目录统一** — Covers R1, R2 + - **Given:** `geo/tests/` 下有 17 个测试文件 + - **When:** 迁移完成并运行 `cd backend && pytest tests/` + - **Then:** 全部测试被发现且通过,`geo/tests/` 目录已删除 + +- AE2. **核心业务 E2E** — Covers R7, R8, R9 + - **Given:** 后端服务 + PostgreSQL + Redis 运行 + - **When:** 执行 Playwright E2E 测试 + - **Then:** 注册→登录→Onboarding→品牌创建→诊断→方案→内容生成→效果追踪 全流程通过 + +- AE3. **CI 安全扫描** — Covers R13 + - **Given:** PR 中引入了有已知漏洞的依赖 + - **When:** CI 运行 PR Check + - **Then:** bandit 或 npm audit 报告高危漏洞,PR 被标记为检查失败 + +- AE4. **性能基线** — Covers R12 + - **Given:** 性能测试首次运行 + - **When:** 对 `/api/v1/dashboard/stats` 发送 100 次请求 + - **Then:** 采集 p50/p95/p99 响应时间并保存为基线数据,不设阈值不阻断 + +- AE5. **迁移验证** — Covers R14 + - **Given:** 新增了 Alembic 迁移脚本 + - **When:** CI 在空数据库上执行 `alembic upgrade head` + - **Then:** 迁移成功执行,无报错 + +--- + +## Success Criteria + +- 后端全部测试(≥77 个)在 CI 中通过,无目录分裂 +- 核心业务 E2E 测试(≥3 条用户旅程)在 CI 中稳定运行 +- 前端 vitest 覆盖率从当前极低水平提升至 lib/api/ 模块 80%+ +- CI 流水线包含:lint + 单元测试 + E2E 测试 + 安全扫描 + 迁移验证 +- 5-10 个高频 API 端点有性能基线数据 +- Agent 框架端到端测试覆盖至少 2 个 Agent +- testing.md 与实际项目结构一致 + +--- + +## Scope Boundaries + +**Deferred for later:** + +- 跨浏览器 E2E 测试(Firefox / WebKit) +- 覆盖率报告上传第三方服务(Codecov / Coveralls) +- 测试数据工厂(factory-boy / faker)— 当前用 fixture 足够 +- PR 评论中显示覆盖率变化 +- 生产环境监控告警集成 + +**Outside this product's identity:** + +- 新业务功能开发 +- UI/UX 改进和设计优化 +- 基础设施级别的渗透测试 +- 移动端适配测试 + +--- + +## Dependencies / Assumptions + +- CI 环境(GitHub Actions)支持 PostgreSQL 和 Redis service container +- E2E 测试需要后端服务可启动(所有依赖可用) +- Agent E2E 测试需要至少一个 LLM Provider 的 API Key 可用(或使用 mock) +- 性能基线数据需要在相对稳定的环境下采集,避免 CI 共享 runner 的噪声 +- 当前 `geo/tests/conftest.py` 中的 fixture(async_client、auth_token 等)可正确迁移 + +--- + +## Sources / Research + +- 现有测试配置:`backend/pyproject.toml`、`frontend/vitest.config.ts`、`frontend/playwright.config.ts` +- CI 配置:`.github/workflows/ci.yml`、`.github/workflows/pr-check.yml` +- 测试策略文档:`docs/03-开发指南/testing.md` +- 现有 E2E 测试:`frontend/e2e/tests/`(7 个文件) +- 现有后端测试:`backend/tests/`(~60 个)+ `geo/tests/`(17 个) +- 现有前端测试:`frontend/__tests__/`(13 个) diff --git a/docs/plans/2026-05-28-001-feat-geo-platform-new-agents-plan.md b/docs/plans/2026-05-28-001-feat-geo-platform-new-agents-plan.md new file mode 100644 index 0000000..7ec47f4 --- /dev/null +++ b/docs/plans/2026-05-28-001-feat-geo-platform-new-agents-plan.md @@ -0,0 +1,384 @@ +# GEO 平台新增 Agent 实现计划 + +**Status:** Draft +**Type:** feat +**Created:** 2026-05-28 +**Updated:** 2026-05-28 + +## Summary + +为 GEO 平台新增 4 个专业 Agent,构建完整的监测-分析-优化闭环: + +1. **MonitorAgent** — 效果追踪:定期检测内容发布后的引用变化 +2. **SchemaAdvisor** — Schema优化:根据诊断结果推荐具体的 Schema 标记方案 +3. **CompetitorAnalyzer** — 竞品分析:深度分析竞品内容策略、引用模式 +4. **TrendAgent** — 趋势洞察:分析行业 AI 搜索趋势、热点话题 + +--- + +## Problem Frame + +当前 GEO 平台已有 4 个 Agent(CitationDetector、ContentGenerator、DeAIAgent、GEOOptimizer),覆盖检测→生成→去AI化→优化链路。但缺少: + +- **效果追踪**:内容发布后无法自动追踪引用变化 +- **Schema 优化**:诊断发现 Schema 问题但没有 Agent 提供具体方案 +- **竞品深度分析**:仅做引用检测,缺乏策略层面的竞品分析 +- **趋势洞察**:无法感知行业 AI 搜索趋势和热点话题 + +这 4 个 Agent 的缺失导致 GEO 平台只能发现问题,无法主动建议行动方向。 + +--- + +## Scope Boundaries + +### In Scope +- 4 个新 Agent 的完整实现 +- Agent 注册和调度集成 +- 相应的数据模型和 API +- 前端展示(可选,取决于 API 完备性) + +### Out of Scope +- 不实现前端页面(只提供 API,供现有页面调用) +- 不实现自动执行闭环(内容发布后的自动追踪需要分发系统配合) +- TrendAgent 的第三方数据源集成(仅使用现有 AI 引擎返回的引用数据) + +--- + +## Key Technical Decisions + +### KTD-1: MonitorAgent 触发方式 + +**Decision:** 采用「定时任务 + 手动触发」双模式 + +**Rationale:** 自动追踪需要内容分发系统配合,目前分发系统仅支持手动发布。定时任务模式可覆盖已有内容的手动追踪需求。手动触发作为 API 入口,可被前端页面直接调用。 + +**Alternatives considered:** +- 仅定时任务:无法满足用户即时查询需求 +- 仅手动触发:无法实现自动化效果追踪 + +### KTD-2: SchemaAdvisor 生成策略 + +**Decision:** 采用「规则模板 + LLM 增强」混合模式 + +**Rationale:** Schema 标记有明确的规范(JSON-LD 格式),规则模板可覆盖 80% 常见场景。LLM 用于生成描述性内容(如 FAQ 的 questions/answers)和处理边界情况。 + +**Alternatives considered:** +- 纯规则:无法处理动态内容场景 +- 纯 LLM:Schema 格式要求严格,纯 LLM 生成可能不合规 + +### KTD-3: CompetitorAnalyzer 数据来源 + +**Decision:** 复用现有 CitationDetector 的检测数据 + +**Rationale:** 竞品分析依赖品牌和竞品的引用数据,现有 CitationDetector 已实现全量检测能力。新 Agent 只需在检测结果基础上做策略分析,避免重复建设。 + +**Alternatives considered:** +- 独立检测通道:增加 API 调用成本和数据冗余 +- 仅分析历史数据:无法获取最新竞品动态 + +### KTD-4: TrendAgent 趋势识别 + +**Decision:** 基于「品牌领域关键词 + AI 引擎引用模式」识别趋势 + +**Rationale:** TrendAgent 的定位是「GEO 趋势洞察」而非「通用舆情分析」。通过分析品牌核心关键词在 AI 引擎中的引用模式变化,可以识别用户关注度趋势,无需引入外部数据源。 + +**Alternatives considered:** +- 接入第三方舆情 API:成本高、数据不精准 +- 基于社交媒体趋势:与 GEO 关联度低 + +--- + +## High-Level Technical Design + +### 4 Agent 在现有架构中的位置 + +``` +┌──────────────────────────────────────────────────────────────────┐ +│ Agent Framework (Dispatcher) │ +└──────────────────────────────────────────────────────────────────┘ + │ + ┌───────────┬───────────┼───────────┬───────────┐ + ▼ ▼ ▼ ▼ ▼ +┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ +│Citation │ │Monitor │ │Schema │ │Competitor│ │Trend │ +│Detector │ │Agent 🆕 │ │Advisor 🆕│ │Analyzer 🆕│ │Agent 🆕 │ +│ │ │ │ │ │ │ │ │ │ +│引用检测 │ │效果追踪 │ │Schema │ │竞品策略 │ │趋势洞察 ││ +└────┬─────┘ └────┬─────┘ └────┬─────┘ └────┬─────┘ └────┬─────┘ + │ │ │ │ │ + ▼ ▼ ▼ ▼ ▼ + CitationRecord MonitoringRecord SchemaSuggestion CompetitorInsight TrendInsight + │ │ │ │ │ + ▼ ▼ ▼ ▼ ▼ + ScoringService GeoPlanGenerator DiagnosisReport StrategyPage Dashboard +``` + +### MonitorAgent 数据流 + +``` +内容发布 → 记录监测任务 → 定时执行检测 → 对比基线数据 → 生成变化报告 → 更新评分 + │ + ┌───────┴───────┐ + │ 变化类型判断 │ + ├───────────────┤ + │ + 引用量增加 │ → 正面反馈 + │ - 引用量下降 │ → 告警 + 建议 + │ = 无变化 │ → 持续观察 + └───────────────┘ +``` + +### SchemaAdvisor 建议生成 + +``` +诊断数据 → 识别 Schema 缺失维度 → 匹配模板 → 生成 JSON-LD → LLM 填充内容 → 验证格式 + │ + ┌───────┴───────┐ + │ 格式验证通过 │ → 返回建议 + │ 格式验证失败 │ → 回退到规则 + └───────────────┘ +``` + +--- + +## Implementation Units + +### U1. MonitorAgent — 效果追踪 Agent + +**Goal:** 实现内容发布后引用变化的自动追踪和对比分析 + +**Requirements:** R1(效果追踪) + +**Files:** +- `app/agent_framework/agents/monitor_agent.py` (新增) +- `app/models/monitoring.py` (新增) +- `app/schemas/monitoring.py` (新增) +- `app/services/monitoring/monitor_service.py` (新增) +- `app/api/monitoring.py` (新增) +- `tests/test_monitor_agent.py` (新增) + +**Approach:** +1. 数据模型:`MonitoringRecord` 记录监测任务,`ContentBaseline` 存储发布时的基线数据 +2. Agent 逻辑:接收监测任务 → 提取关键词和平台 → 调用 CitationService 检测 → 对比基线 → 生成变化报告 +3. 变化判断:基于引用量、情感倾向、引用排名三个维度计算变化幅度 +4. API:提供创建监测任务、查询监测历史、获取变化报告的端点 + +**Test scenarios:** +- Happy path: 内容发布后创建监测任务,定时检测对比基线,生成正面/负面/无变化报告 +- Edge case: 基线数据为空时,使用首次检测作为基线 +- Error path: 检测服务不可用时,记录失败状态并重试 +- Integration: 监测任务触发时,通过 Dispatcher 正确分发给 MonitorAgent + +--- + +### U2. SchemaAdvisor — Schema 优化建议 Agent + +**Goal:** 根据诊断结果生成具体的 Schema 标记方案 + +**Requirements:** R2(Schema 优化) + +**Files:** +- `app/agent_framework/agents/schema_advisor.py` (新增) +- `app/models/schema_suggestion.py` (新增) +- `app/schemas/schema_suggestion.py` (新增) +- `app/services/schema/schema_advisor_service.py` (新增) +- `app/api/schema_advisor.py` (新增) +- `tests/test_schema_advisor.py` (新增) + +**Approach:** +1. 诊断数据输入:接收 GEO 诊断结果,识别 Schema 缺失维度 +2. 模板匹配:5 种 Schema 类型 × 3 种场景 = 15 个预定义模板 + - Organization: 品牌主页 + - Product: 产品详情页 + - FAQPage: 常见问题页 + - Article: 博客/新闻页 + - LocalBusiness: 本地商户页 +3. LLM 增强:使用 LLM 生成 FAQ 的 questions/answers、Product 描述等自然语言内容 +4. 格式验证:使用 JSON-LD 解析器验证生成结果的语法正确性 +5. 优先级排序:基于诊断得分和实施难度排序建议 + +**Test scenarios:** +- Happy path: 诊断发现 FAQPage 缺失,生成完整的 FAQPage JSON-LD +- Edge case: LLM 生成内容格式错误,回退到规则模板 +- Error path: 模板不存在时,返回"暂不支持该场景"的友好提示 +- Format validation: 生成的 JSON-LD 通过 schema.org 验证 + +--- + +### U3. CompetitorAnalyzer — 竞品策略分析 Agent + +**Goal:** 在引用检测数据基础上,深度分析竞品的内容策略和引用模式 + +**Requirements:** R3(竞品分析) + +**Files:** +- `app/agent_framework/agents/competitor_analyzer.py` (新增) +- `app/models/competitor_insight.py` (新增) +- `app/schemas/competitor_insight.py` (新增) +- `app/services/competitor/competitor_analyzer_service.py` (新增) +- `app/api/competitor_analysis.py` (新增) +- `tests/test_competitor_analyzer.py` (新增) + +**Approach:** +1. 数据聚合:聚合品牌和所有竞品的引用检测数据 +2. 维度分析: + - 引用量对比:品牌 vs 竞品在各平台的引用次数 + - 引用质量对比:正面/中性/负面引用比例 + - 引用场景分析:竞品在哪些查询词下被引用 + - 内容策略推断:竞品引用内容的类型(产品介绍/评测/新闻等) +3. Gap 识别:对比发现品牌在哪些维度落后于竞品 +4. 机会发现:竞品未被引用的场景,即品牌的潜在机会 +5. 策略建议:基于分析结果生成"缩小差距"和"差异化竞争"两类建议 + +**Test scenarios:** +- Happy path: 品牌 vs 3个竞品,生成完整的竞品分析报告 +- Edge case: 竞品数据不足(少于5次引用),标记为"数据不足"并说明 +- Error path: 竞品数据获取失败,生成部分报告并提示缺失数据 +- Content strategy: 识别竞品的内容类型分布,给出差异化建议 + +--- + +### U4. TrendAgent — 趋势洞察 Agent + +**Goal:** 分析品牌领域关键词在 AI 引擎中的引用模式变化,识别趋势 + +**Requirements:** R4(趋势洞察) + +**Files:** +- `app/agent_framework/agents/trend_agent.py` (新增) +- `app/models/trend_insight.py` (新增) +- `app/schemas/trend_insight.py` (新增) +- `app/services/trend/trend_analyzer_service.py` (新增) +- `app/api/trends.py` (新增) +- `tests/test_trend_agent.py` (新增) + +**Approach:** +1. 时间序列分析:聚合品牌关键词在过去 N 天的引用数据(按天/周粒度) +2. 趋势识别算法: + - 上升趋势:引用量环比增长 > 20% + - 下降趋势:引用量环比下降 > 20% + - 平稳趋势:变化率在 ±20% 以内 +3. 热点发现:引用量突增的关键词/话题 +4. 平台差异:同一趋势在不同 AI 引擎的差异表现 +5. 原因推断:结合情感分析推断趋势变化的可能原因(正面事件/负面事件/行业热点) + +**Test scenarios:** +- Happy path: 分析过去30天数据,识别3个上升趋势、2个下降趋势、5个热点话题 +- Edge case: 数据不足30天,使用可用数据进行趋势分析并注明 +- Error path: 引用数据获取失败,返回"趋势分析需要至少7天数据"提示 +- Platform comparison: 识别同一趋势在文心 vs Kimi 的差异 + +--- + +### U5. Agent 注册和调度集成 + +**Goal:** 将 4 个新 Agent 注册到 Agent Framework,支持 Dispatcher 统一调度 + +**Requirements:** 所有 R1-R4 + +**Dependencies:** U1, U2, U3, U4 + +**Files:** +- `app/agent_framework/registry.py` (修改) +- `app/agent_framework/protocol.py` (修改) + +**Approach:** +1. 在 `protocol.py` 中定义 4 种新任务类型: + - `monitor_track` — 效果追踪 + - `schema_advise` — Schema 建议 + - `competitor_analyze` — 竞品分析 + - `trend_insight` — 趋势洞察 +2. 在 `registry.py` 中注册 4 个 Agent 类 +3. 验证 Dispatcher 可正确分发任务到对应 Agent + +**Test scenarios:** +- Happy path: Dispatcher 分发 `monitor_track` 任务到 MonitorAgent +- Happy path: Dispatcher 分发 `schema_advise` 任务到 SchemaAdvisor +- Edge case: 未知任务类型返回"Agent not found"错误 + +--- + +### U6. 数据库迁移 + +**Goal:** 创建新 Agent 所需的数据库表 + +**Dependencies:** U1, U2, U3, U4 + +**Files:** +- `alembic/versions/xxxx_add_new_agent_tables.py` (新增) + +**Approach:** +1. 生成 Alembic 迁移脚本 +2. 创建 5 张新表: + - `monitoring_records` — 监测记录 + - `content_baselines` — 内容基线 + - `schema_suggestions` — Schema 建议 + - `competitor_insights` — 竞品洞察 + - `trend_insights` — 趋势洞察 +3. 执行迁移并验证 + +**Test scenarios:** +- Happy path: 迁移脚本成功执行,5张表创建完成 +- Rollback: 执行 downgrade 后,表被正确删除 + +--- + +## Open Questions + +| # | Question | Resolution | +|---|----------|------------| +| 1 | MonitorAgent 的定时检测间隔是多少? | 建议默认 24 小时,支持用户配置(1h/6h/24h/7d) | +| 2 | SchemaAdvisor 是否需要验证生成的 Schema 在目标网站上可正常解析? | 否,验证 JSON-LD 语法即可,网站解析验证超出当前范围 | +| 3 | TrendAgent 历史数据保留多久? | 建议保留 90 天,超出后归档或删除 | +| 4 | 4 个新 Agent 是否需要独立启动入口(类似 standalone.py)? | 否,通过 API 触发即可,无独立常驻需求 | + +--- + +## System-Wide Impact + +- **数据库**:新增 5 张表,需要 Alembic 迁移 +- **API 层**:新增 4 个路由模块(monitoring.py, schema_advisor.py, competitor_analysis.py, trends.py) +- **Agent Framework**:新增 4 种任务类型注册 +- **前端**:可选,取决于是否需要页面展示(当前计划仅提供 API) + +--- + +## Deferred to Follow-Up Work + +1. **前端监测页面**:MonitorAgent 的效果追踪报告需要页面展示 +2. **自动执行闭环**:内容发布后自动创建监测任务(依赖分发系统) +3. **告警集成**:监测到负面变化时自动触发告警通知 +4. **定时调度**:TrendAgent 的周期性趋势报告(依赖 APScheduler 集成) + +--- + +## Verification + +### 编译验证 +```bash +cd geo/backend && python -m py_compile app/agent_framework/agents/monitor_agent.py app/agent_framework/agents/schema_advisor.py app/agent_framework/agents/competitor_analyzer.py app/agent_framework/agents/trend_agent.py +``` + +### API 验证 +```bash +# MonitorAgent +curl -X POST http://localhost:8000/api/v1/monitoring/tasks -H "Authorization: Bearer $TOKEN" -d '{"content_id": "xxx", "brand_id": "yyy"}' + +# SchemaAdvisor +curl -X POST http://localhost:8000/api/v1/schema/advise -H "Authorization: Bearer $TOKEN" -d '{"brand_id": "yyy"}' + +# CompetitorAnalyzer +curl -X POST http://localhost:8000/api/v1/competitor/analyze -H "Authorization: Bearer $TOKEN" -d '{"brand_id": "yyy"}' + +# TrendAgent +curl -X POST http://localhost:8000/api/v1/trends/insight -H "Authorization: Bearer $TOKEN" -d '{"brand_id": "yyy", "days": 30}' +``` + +### Agent 调度验证 +```bash +# 通过 Agent Framework 调度 +curl -X POST http://localhost:8000/api/v1/agents/tasks -H "Authorization: Bearer $TOKEN" -d '{ + "agent_type": "monitor", + "task_type": "monitor_track", + "params": {"content_id": "xxx", "brand_id": "yyy"} +}' +``` diff --git a/docs/plans/2026-05-31-002-test-quality-assurance-system-plan.md b/docs/plans/2026-05-31-002-test-quality-assurance-system-plan.md new file mode 100644 index 0000000..e004d52 --- /dev/null +++ b/docs/plans/2026-05-31-002-test-quality-assurance-system-plan.md @@ -0,0 +1,332 @@ +--- +title: "test: GEO Platform Quality Assurance System" +type: test +status: active +date: "2026-05-31" +origin: docs/brainstorms/2026-05-31-geo-platform-next-phase-quality-requirements.md +--- + +## Summary + +GEO 平台质量保障体系建设,采用双轨并行策略:轨道一统一测试基础设施(合并分裂目录、完善 CI、建立共享 fixture),轨道二同步推进核心业务 E2E 测试和四项专项测试(性能基线、安全扫描、迁移验证、Agent E2E),确保核心业务流程尽早获得回归保护。 + +## Problem Frame + +GEO 平台经过多轮迭代已具备 33 个 API、8 个 Agent、25+ 前端页面,但测试覆盖严重不足:后端两套测试目录分裂导致 17 个测试被 CI 忽略,前端仅 13 个单元测试且零组件测试,E2E 仅覆盖登录流程,核心业务完全无回归保护。CI 不运行 E2E 测试,无安全扫描,无性能基线。在用户开始使用前,必须建立质量保障体系防止功能回归。 + +--- + +## Requirements + +### 测试基础设施 + +R1. 将 `geo/tests/` 下的 17 个测试迁移至 `backend/tests/` 对应子目录,合并两套 conftest.py 的 fixture,删除旧目录 + +R2. CI 中 `pytest tests/` 命令能发现并运行全部后端测试 + +R3. 将 `pytest-cov` 正式加入 `backend/requirements.txt` + +R4. 建立共享 fixture 体系:数据库会话(含自动 rollback)、认证 mock(JWT token + auth headers)、测试用户创建、httpx AsyncClient + +R5. 前端 vitest 覆盖范围扩展至 `lib/api/` 全部 27 个模块和关键页面组件 + +R6. 更新 `docs/03-开发指南/testing.md` 使其与实际目录结构、CI 配置、fixture 体系一致 + +### 核心业务 E2E 测试 + +R7. 编写用户注册→登录→Onboarding→品牌创建的完整 E2E 测试 + +R8. 编写诊断→查看诊断报告→生成 GEO 方案的 E2E 测试 + +R9. 编写内容生成→查看内容→效果追踪的 E2E 测试 + +R10. E2E 测试在 CI 中运行(至少 Chromium),使用 PostgreSQL 和 Redis service container + +R11. E2E 测试失败时自动截图和录制视频 + +### 专项测试 + +R12. 为 5-10 个高频 API 端点建立性能基线(p50/p95/p99 响应时间),首轮只采集不设阈值 + +R13. CI 中集成 bandit(Python 安全扫描)和 npm audit(Node.js 依赖安全检查),PR 级别阻断高危漏洞 + +R14. CI 中添加 Alembic 迁移验证步骤 + +R15. 编写 Agent 框架端到端测试:任务创建→分发→执行→结果查询的完整链路(至少覆盖 CitationDetector 和 ContentGenerator) + +### 前端组件测试 + +R16. 为 5 个关键页面组件编写 React Testing Library 测试 + +R17. 为 4 个新 Agent 的前端 API 客户端模块编写单元测试 + +--- + +## Key Technical Decisions + +KTD1. **双轨并行而非分层递进** — 基础设施统一和 E2E 测试编写同步推进,两者几乎不涉及相同文件,并行没有实质冲突。E2E 测试初期用独立 setup,后续迁移到共享 fixture 是低成本重构。 + +KTD2. **E2E 测试仅覆盖 Chromium** — 先在一个浏览器上稳定运行,跨浏览器扩展作为后续迭代。Playwright 已配置 3 浏览器但当前 E2E 用例太少,扩展浏览器覆盖的 ROI 不高。 + +KTD3. **性能测试先建基线再设阈值** — 在没有历史数据时设定 SLA 容易过严或过松。第一轮只采集数据建立基线,第二轮根据基线设定阈值和告警。 + +KTD4. **Agent E2E 测试优先使用真实 LLM 调用** — 可使用 LLM API Key 进行真实调用测试,确保 Agent 在实际 LLM 响应下的行为正确。CI 中通过环境变量注入 API Key,无 Key 时降级为 mock 模式。 + +KTD5. **安全扫描集成到 CI 而非独立流程** — bandit 和 npm audit 作为 CI 步骤运行,在 PR 级别就暴露问题。 + +--- + +## Implementation Units + +### U1. 统一后端测试目录 + +- **Goal:** 消除测试目录分裂,确保 CI 能发现并运行全部后端测试 +- **Requirements:** R1, R2, R3 +- **Dependencies:** none +- **Files:** + - `geo/tests/` (17 files to migrate) + - `backend/tests/conftest.py` (merge fixtures) + - `geo/tests/conftest.py` (merge then delete) + - `backend/requirements.txt` (add pytest-cov) + - `.github/workflows/ci.yml` (verify pytest discovers all tests) +- **Approach:** 逐个迁移 `geo/tests/` 下的测试文件到 `backend/tests/` 对应子目录(test_api/, test_services/ 等)。合并两套 conftest.py:保留 `backend/tests/conftest.py` 的内存 SQLite 引擎,从 `geo/tests/conftest.py` 迁入 async_client、auth_token、auth_headers、override_get_current_user 等 fixture。删除 `geo/tests/` 目录。将 pytest-cov 加入 requirements.txt。 +- **Patterns to follow:** `backend/tests/conftest.py` 现有 fixture 模式(async_engine, async_session, test_user) +- **Test scenarios:** + - `cd backend && pytest tests/` 发现并运行全部测试(≥77 个) + - 迁移后的测试功能与迁移前一致 + - CI 中 pytest 命令无需修改即可运行全量测试 +- **Verification:** `cd backend && pytest tests/ --tb=short` 全部通过,`geo/tests/` 目录已不存在 + +### U2. 建立共享 fixture 体系 + +- **Goal:** 为后端集成测试和 E2E 测试提供可复用的测试基础设施 +- **Requirements:** R4 +- **Dependencies:** U1 +- **Files:** + - `backend/tests/conftest.py` (enhance with shared fixtures) + - `backend/tests/fixtures/` (new directory for modular fixtures) + - `backend/tests/fixtures/auth.py` (authentication fixtures) + - `backend/tests/fixtures/database.py` (database fixtures) + - `backend/tests/fixtures/client.py` (httpx client fixtures) + - `backend/tests/fixtures/brands.py` (brand and competitor test data) +- **Approach:** 在 `backend/tests/fixtures/` 下按领域拆分 fixture 模块。auth.py 提供 override_get_current_user、auth_token、auth_headers。database.py 提供 async_engine、async_session(含自动 rollback)。client.py 提供 async_client(httpx AsyncClient with app)。brands.py 提供预创建的测试品牌和竞品数据。conftest.py 通过 pytest plugin 机制自动加载 fixtures/ 下所有模块。 +- **Patterns to follow:** `geo/tests/conftest.py` 中已有的 fixture 实现(async_client, auth_token, override_get_current_user) +- **Test scenarios:** + - 使用 auth_headers fixture 的测试能成功调用需要认证的 API + - 使用 async_client fixture 的测试能发送 HTTP 请求到 FastAPI 应用 + - 数据库 fixture 在测试结束后自动 rollback,不污染数据库 + - 多个测试并行运行时 fixture 互不干扰 +- **Verification:** 使用新 fixture 编写 2-3 个示例测试并全部通过 + +### U3. 核心业务 E2E — 用户注册到品牌创建 + +- **Goal:** 覆盖用户从注册到品牌创建的完整 Onboarding 流程 +- **Requirements:** R7, R11 +- **Dependencies:** none (独立 setup,不依赖 U2) +- **Files:** + - `frontend/e2e/tests/onboarding.spec.ts` (new) + - `frontend/e2e/fixtures/auth.ts` (new, independent setup) + - `frontend/playwright.config.ts` (verify screenshot/video config) +- **Approach:** 编写 Playwright 测试覆盖:注册新用户→登录→进入 Onboarding→填写品牌名称→添加竞品→选择平台→查看健康报告→完成引导。使用独立 setup(直接调用 API 创建测试数据),后续可迁移到共享 fixture。确保 playwright.config.ts 中 screenshot:on-failure 和 video:retain-on-failure 已配置。 +- **Patterns to follow:** `frontend/e2e/tests/login.spec.ts` 现有 E2E 测试模式 +- **Test scenarios:** + - Covers AE1 (from origin): 注册→登录→Onboarding→品牌创建全流程通过 + - 注册时输入无效邮箱显示错误提示 + - 品牌名称为空时无法提交 + - 不添加竞品也能完成 Onboarding + - 已注册用户登录后直接跳转 Dashboard +- **Verification:** `cd frontend && npx playwright test onboarding` 通过 + +### U4. 核心业务 E2E — 诊断到 GEO 方案 + +- **Goal:** 覆盖诊断→方案生成的核心业务闭环 +- **Requirements:** R8, R11 +- **Dependencies:** U3 (需要已登录用户和品牌) +- **Files:** + - `frontend/e2e/tests/diagnosis-strategy.spec.ts` (new) +- **Approach:** 在 U3 创建的品牌基础上,编写:进入诊断页面→触发诊断→查看诊断报告→点击"基于诊断制定 GEO 方案"→查看方案详情。诊断和方案生成可能需要较长时间,使用适当的等待策略。 +- **Patterns to follow:** `frontend/e2e/tests/login.spec.ts` 现有 E2E 测试模式 +- **Test scenarios:** + - Covers AE2 (from origin): 诊断→报告→方案生成全流程通过 + - 诊断页面正确显示品牌评分 + - 方案生成后显示行动项列表 + - 方案行动项状态可更新 +- **Verification:** `cd frontend && npx playwright test diagnosis-strategy` 通过 + +### U5. 核心业务 E2E — 内容生成到效果追踪 + +- **Goal:** 覆盖内容生成→查看→效果追踪的内容工坊流程 +- **Requirements:** R9, R11 +- **Dependencies:** U3 (需要已登录用户和品牌) +- **Files:** + - `frontend/e2e/tests/content-monitoring.spec.ts` (new) +- **Approach:** 编写:进入内容工坊→输入关键词→生成内容→查看生成结果→进入效果追踪页面→查看监测数据。内容生成可能需要 LLM 调用,使用较长超时或 mock。 +- **Patterns to follow:** `frontend/e2e/tests/login.spec.ts` 现有 E2E 测试模式 +- **Test scenarios:** + - Covers AE3 (from origin): 内容生成→查看→效果追踪全流程通过 + - 内容生成页面正确显示生成状态 + - 生成完成后内容可查看 + - 效果追踪页面显示监测数据 +- **Verification:** `cd frontend && npx playwright test content-monitoring` 通过 + +### U6. CI 集成 E2E 测试 + +- **Goal:** 将 E2E 测试纳入 CI 流水线 +- **Requirements:** R10 +- **Dependencies:** U3, U4, U5 +- **Files:** + - `.github/workflows/ci.yml` (add E2E step) + - `.github/workflows/pr-check.yml` (add E2E step) +- **Approach:** 在 CI 中添加 E2E 测试步骤:启动 PostgreSQL 和 Redis service container → 启动后端服务 → 启动前端服务 → 运行 Playwright 测试。仅运行 Chromium 项目以控制时间。E2E 步骤放在单元测试之后,允许失败但不阻塞(初期),稳定后改为阻塞。 +- **Patterns to follow:** 现有 CI 中 PostgreSQL 和 Redis service container 配置 +- **Test scenarios:** + - PR 提交时 CI 自动运行 E2E 测试 + - E2E 测试失败时 CI 标记为失败 + - 截图和视频作为 artifact 上传 +- **Verification:** 提交一个测试 PR,观察 CI 中 E2E 步骤是否正确运行 + +### U7. 性能基线测试 + +- **Goal:** 为高频 API 端点建立性能基线数据 +- **Requirements:** R12 +- **Dependencies:** U2 +- **Files:** + - `backend/tests/performance/` (new directory) + - `backend/tests/performance/__init__.py` (new) + - `backend/tests/performance/test_api_baseline.py` (new) + - `backend/tests/performance/conftest.py` (new) +- **Approach:** 使用 httpx AsyncClient 对 5-10 个高频端点(dashboard/stats, brands, queries, citations, content/generate, strategy/generate, monitoring/brand/{id}, analytics/overview, diagnosis, ai-engines/query)发送多次请求,采集 p50/p95/p99 响应时间,输出为 JSON 基线文件。首轮只采集不设阈值。使用 pytest-benchmark 或自定义计时逻辑。 +- **Patterns to follow:** `backend/tests/` 现有测试结构 +- **Test scenarios:** + - Covers AE4 (from origin): 对 dashboard/stats 发送 100 次请求采集性能数据 + - 基线数据以 JSON 格式保存 + - 性能测试可在本地和 CI 中运行 + - 首轮不设阈值不阻断 +- **Verification:** `cd backend && pytest tests/performance/ --tb=short` 运行并输出基线数据 + +### U8. CI 安全扫描和迁移验证 + +- **Goal:** 在 CI 中集成安全扫描和数据库迁移验证 +- **Requirements:** R13, R14 +- **Dependencies:** none +- **Files:** + - `.github/workflows/ci.yml` (add security scan and migration steps) + - `.github/workflows/pr-check.yml` (add security scan and migration steps) +- **Approach:** 在 CI 中添加:1) `pip install bandit && bandit -r backend/app/ -ll` 扫描 Python 代码安全问题;2) `cd frontend && npm audit --audit-level=high` 检查 Node.js 依赖漏洞;3) `cd backend && alembic upgrade head` 在空数据库上验证迁移。安全扫描发现高危问题时 CI 失败。 +- **Patterns to follow:** 现有 CI 步骤结构 +- **Test scenarios:** + - Covers AE3 (from origin): PR 中引入有漏洞依赖时 CI 报告失败 + - Covers AE5 (from origin): 新增迁移脚本时 CI 验证迁移可执行 + - bandit 扫描 Python 代码中的安全问题 + - npm audit 检查前端依赖漏洞 + - alembic upgrade head 在空数据库上成功执行 +- **Verification:** 提交测试 PR,观察 CI 中安全扫描和迁移验证步骤 + +### U9. Agent 框架 E2E 测试 + +- **Goal:** 测试 Agent 框架的完整任务调度链路 +- **Requirements:** R15 +- **Dependencies:** U2 +- **Files:** + - `backend/tests/test_agent_framework/` (new or extend existing) + - `backend/tests/test_agent_framework/test_e2e_citation.py` (new) + - `backend/tests/test_agent_framework/test_e2e_content.py` (new) +- **Approach:** 测试完整链路:创建 Agent 任务→Dispatcher 分发→Agent 执行→结果写入数据库→查询结果。至少覆盖 CitationDetector(citation_detect 任务)和 ContentGenerator(content_generate 任务)。优先使用真实 LLM API Key 调用,CI 中通过环境变量注入;无 Key 时降级为 mock 模式使用预定义响应。 +- **Patterns to follow:** `backend/tests/test_agent_framework/` 现有 Agent 测试模式 +- **Test scenarios:** + - CitationDetector: 创建 citation_detect 任务→执行→结果包含引用数据 + - ContentGenerator: 创建 content_generate 任务→执行→结果包含生成内容 + - 任务状态从 pending→running→completed 正确流转 + - 任务失败时状态变为 failed 并记录错误信息 + - 无 Redis 时 Agent 以 standalone 模式运行 +- **Verification:** `cd backend && pytest tests/test_agent_framework/test_e2e_*.py --tb=short` 通过 + +### U10. 前端组件和 API 客户端测试 + +- **Goal:** 扩展前端测试覆盖至关键组件和新 API 客户端模块 +- **Requirements:** R16, R17 +- **Dependencies:** none +- **Files:** + - `frontend/__tests__/components/` (new directory) + - `frontend/__tests__/components/dashboard.test.tsx` (new) + - `frontend/__tests__/components/brand-detail.test.tsx` (new) + - `frontend/__tests__/components/diagnosis.test.tsx` (new) + - `frontend/__tests__/components/strategy.test.tsx` (new) + - `frontend/__tests__/components/content-editor.test.tsx` (new) + - `frontend/__tests__/api/monitoring.test.ts` (new) + - `frontend/__tests__/api/competitor-analysis.test.ts` (new) + - `frontend/__tests__/api/schema-advisor.test.ts` (new) + - `frontend/__tests__/api/trends.test.ts` (new) + - `frontend/vitest.config.ts` (extend coverage include) +- **Approach:** 使用 React Testing Library 为 5 个关键页面组件编写渲染和交互测试。使用 vitest 为 4 个新 API 客户端模块编写单元测试(mock fetchWithAuth,验证正确的 URL 和参数传递)。扩展 vitest.config.ts 的 coverage.include 配置。 +- **Patterns to follow:** `frontend/__tests__/` 现有测试模式(hooks, stores, lib) +- **Test scenarios:** + - Dashboard 组件正确渲染评分和平台数据 + - 品牌详情页正确显示 GEO 评分维度 + - 诊断页触发诊断后显示结果 + - 策略页显示 GEO 方案和行动项 + - 内容编辑器正确渲染和提交 + - monitoring.ts API 客户端调用正确的端点 + - competitor-analysis.ts API 客户端传递正确的参数 + - schema-advisor.ts API 客户端处理响应 + - trends.ts API 客户端处理查询参数 +- **Verification:** `cd frontend && npm run test:ci` 通过,覆盖率报告显示 lib/api/ 模块覆盖 + +### U11. 更新测试策略文档 + +- **Goal:** 使文档与实际项目结构和 CI 配置一致 +- **Requirements:** R6 +- **Dependencies:** U1, U2, U6, U8 +- **Files:** + - `docs/03-开发指南/testing.md` (update) +- **Approach:** 更新 testing.md 中的目录结构描述(test_api/, test_services/ 等按领域分类),更新 fixture 体系说明(fixtures/ 模块化 fixture),更新 CI 配置示例(包含 E2E 步骤、安全扫描、迁移验证),添加 E2E 测试编写指南。 +- **Patterns to follow:** 现有 testing.md 文档风格 +- **Test scenarios:** + - Test expectation: none — documentation update +- **Verification:** 文档内容与实际项目结构一致 + +--- + +## Scope Boundaries + +**Deferred for later:** + +- 跨浏览器 E2E 测试(Firefox / WebKit) +- 覆盖率报告上传第三方服务(Codecov / Coveralls) +- 测试数据工厂(factory-boy / faker) +- PR 评论中显示覆盖率变化 +- 生产环境监控告警集成 + +**Outside this product's identity:** + +- 新业务功能开发 +- UI/UX 改进和设计优化 +- 基础设施级别的渗透测试 +- 移动端适配测试 + +### Deferred to Follow-Up Work + +- E2E 测试迁移到共享 fixture(U3/U4/U5 完成后,将独立 setup 替换为 U2 的共享 fixture) +- 性能基线设定阈值和告警(U7 采集基线后,根据数据设定合理阈值) +- 跨浏览器 E2E 扩展(稳定 Chromium 后再扩展) + +--- + +## Risks & Dependencies + +- **E2E 测试稳定性** — 依赖后端服务启动和数据库初始化,CI 环境可能比本地更不稳定。缓解:E2E 初期允许失败不阻塞,稳定后再改为阻塞。 +- **LLM 依赖** — 内容生成和诊断 E2E 测试可能需要 LLM 调用。缓解:使用 mock 或较长超时,Agent E2E 使用 mock LLM。 +- **测试目录迁移风险** — 旧测试可能依赖 `geo/tests/conftest.py` 的特定 fixture,迁移后可能需要调整 import 路径。缓解:逐个迁移并验证。 +- **CI 时间增长** — 添加 E2E、安全扫描、迁移验证会延长 CI 运行时间。缓解:E2E 仅 Chromium,安全扫描仅高危阻断,迁移验证快速执行。 + +--- + +## Sources & Research + +- 现有测试配置:`backend/pyproject.toml`, `frontend/vitest.config.ts`, `frontend/playwright.config.ts` +- CI 配置:`.github/workflows/ci.yml`, `.github/workflows/pr-check.yml` +- 测试策略文档:`docs/03-开发指南/testing.md` +- 现有 E2E 测试:`frontend/e2e/tests/` (7 files) +- 现有后端测试:`backend/tests/` (~60 files) + `geo/tests/` (17 files) +- 现有前端测试:`frontend/__tests__/` (13 files) +- Origin document: `docs/brainstorms/2026-05-31-geo-platform-next-phase-quality-requirements.md` diff --git a/docs/plans/2026-05-31-003-feat-geo-monetization-closed-loop-plan.md b/docs/plans/2026-05-31-003-feat-geo-monetization-closed-loop-plan.md new file mode 100644 index 0000000..0fac9d0 --- /dev/null +++ b/docs/plans/2026-05-31-003-feat-geo-monetization-closed-loop-plan.md @@ -0,0 +1,558 @@ +--- +title: "feat: GEO Platform Monetization Closed Loop" +type: feat +status: active +date: "2026-05-31" +origin: docs/brainstorms/2026-05-31-geo-next-phase-core-flow-repair-requirements.md +secondary-origin: docs/brainstorms/2026-05-31-geo-platform-monetization-closed-loop-requirements.md +--- + +## Summary + +修复 GEO 诊断系统(当前空输入=0 分)、建设免费 GEO 健康分获客入口、接入真实支付与功能限制、添加 AI 内容生成能力、打通内容分发与效果归因闭环,实现从获客到续费的完整商业变现链路。同时用 API 契约测试驱动核心功能修复,辅以少量 E2E 烟雾测试验证关键用户路径。 + +## Problem Frame + +GEO 平台已具备 8 个 Agent、6 维度诊断框架、内容 Pipeline、知识库 RAG 等模块,但三个致命问题阻断了变现闭环:(1) 诊断系统使用空输入调用,永远返回 0 分,产品核心价值为零;(2) 内容 Pipeline 只做格式化不生成内容,核心付费功能缺失;(3) 支付是模拟的,没有功能限制,用户无付费动力。此外,Onboarding 缺少付费墙触发点、分发没有实际发布集成、监控没有归因逻辑、邮件服务未接入业务系统——各模块存在但互不连通。 + +## Requirements + +Origin: `docs/brainstorms/2026-05-31-geo-next-phase-core-flow-repair-requirements.md` + +- R1. 修复诊断系统:实现自动数据采集,让诊断产出非零有差异的评分 → U1 +- R2. 添加 AI 内容生成阶段:在 ContentPipeline 前端添加 Stage 0 → U5 +- R3. 接入真实支付:微信支付+支付宝+功能限制中间件 → U4 +- R4. 建设免费 GEO 健康分公开页面 → U2 +- R5. 重设计 Onboarding 流程,嵌入付费墙触发点 → U3 +- R6. 执行闭环:内容分发集成微信/知乎/头条 → U5 +- R7. 效果归因系统 → U6 +- R8. 为诊断 API 编写契约测试 → U1 +- R9. 为公开健康分 API 编写契约测试 → U2 +- R10. 为内容生成 API 编写契约测试 → U5 +- R11. 为支付 API 编写契约测试 → U4 +- R12. 编写跨步骤集成测试 → U8 +- R13. 编写获客路径 E2E 烟雾测试 → U9 +- R14. 编写核心流程 E2E 烟雾测试 → U9 +- R15. 完成后端测试目录统一验证 → U8 +- R16. 建立共享 fixture 体系 → U8 + +## Key Technical Decisions + +KTD-1. **诊断自动数据采集采用"AI 平台查询+CitationRecord 分析"双通道方案(V1 跳过网站爬取)。** 通过调用现有 AI 平台适配器查询品牌关键词分析引用情况,通过分析已有 CitationRecord 数据聚合引用指标。网站爬取通道留待 V2 迭代,降低首版复杂度。AI 平台查询填充"引用就绪度"维度,CitationRecord 分析填充"E-E-A-T/主题权威"维度,品牌官网 URL 解析填充"内容可提取性/Schema 标记"维度(基于简单 HTTP GET + HTML 解析,非完整爬取)。 + +KTD-2. **免费 GEO 健康分页面为独立公开页面,不依赖现有 Dashboard。** 新建 `/health-score` 路由,无需认证即可访问。结果缓存 24 小时(按品牌名+竞品哈希),避免重复查询消耗 API 额度。 + +KTD-3. **支付集成采用微信支付+支付宝双通道。** 使用官方 SDK,后端实现支付回调 Webhook 处理订阅状态更新。 + +KTD-4. **AI 内容生成集成到现有 ContentPipeline 前端。** 在 Stage 1(RuleValidator) 之前添加 Stage 0(AI Generator),基于诊断结果+RAG 知识库+LLM 生成初稿,后续阶段不变。 + +KTD-5. **效果归因采用"内容发布时间窗口+引用变化关联"方案。** 发布内容时记录时间戳和基线引用数据,2-4 周内定期检查引用变化。 + +KTD-6. **微信公众号采用半自动模式。** 生成内容+格式化+一键复制,用户手动粘贴到公众号编辑器。知乎和头条利用 API 实现自动发布。 + +KTD-7. **API 契约测试驱动修复,E2E 仅做烟雾测试。** 每个核心 API 先写契约测试定义期望行为,再修复实现。E2E 只覆盖 2 个最关键路径(获客+核心流程),不追求全面覆盖。 + +## High-Level Technical Design + +```mermaid +flowchart TB + subgraph Phase1["Phase 1: 核心价值修复 + 契约测试"] + A[品牌名输入] --> B[自动数据采集] + B --> B1[AI平台查询] + B --> B2[CitationRecord分析] + B1 & B2 --> C[GEO诊断引擎] + C --> D[6维度评分] + D --> T1[诊断API契约测试] + end + + subgraph Phase2["Phase 2: 获客入口 + 契约测试"] + D --> E[免费健康分页面] + E --> F{查看详细建议?} + F -->|免费概要| G[注册] + F -->|完整分析| G + G --> H{执行优化?} + H -->|升级付费| I[订阅支付] + E --> T2[健康分API契约测试] + end + + subgraph Phase3["Phase 3: 执行闭环 + 契约测试"] + I --> J[AI内容生成] + J --> K[人工审核] + K --> L[发布到平台] + L --> M[效果归因追踪] + M --> N[ROI报告] + N --> O[续费/升级] + J --> T3[内容生成API契约测试] + I --> T4[支付API契约测试] + end + + subgraph Phase4["Phase 4: 集成验证"] + T1 & T2 & T3 & T4 --> T5[跨步骤集成测试] + T5 --> T6[E2E烟雾测试] + end +``` + +## Implementation Units + +### U1. GEO 诊断自动数据采集与修复 + +**Goal:** 修复当前诊断系统(空输入=0 分),实现自动数据采集填充 GEODiagnosisInput,让诊断产出真实有价值的分数。 + +**Requirements:** R1, R8 + +**Dependencies:** 无 + +**Files:** +- `backend/app/services/diagnosis/geo_diagnosis.py` — 修改诊断服务,添加自动数据采集 +- `backend/app/services/diagnosis/data_collector.py` — 新建自动数据采集服务 +- `backend/app/api/diagnosis.py` — 修改 API 端点,支持自动采集+异步诊断 +- `backend/app/schemas/diagnosis.py` — 新建诊断请求/响应 schema +- `backend/app/models/diagnosis_record.py` — 新建诊断历史记录模型 +- `backend/tests/test_api/test_diagnosis_contract.py` — 新建诊断 API 契约测试 +- `backend/tests/test_services/test_data_collector.py` — 新建数据采集服务测试 + +**Approach:** +1. 新建 `DataCollectorService`,两个采集通道(V1): + - AI 平台查询通道:复用现有 `ai_engine` 适配器,查询品牌相关关键词,分析引用情况填充"引用就绪度"维度 + - 历史数据通道:从 CitationRecord 聚合已有引用数据,填充"E-E-A-T/主题权威"维度 + - 品牌官网解析:简单 HTTP GET + HTML 解析(检查 Schema 标记、标题层级等),填充"内容可提取性/Schema 标记"维度 +2. 修改 `GEODiagnosisService.diagnose()` 接受 `DataCollectorOutput` 作为输入 +3. 修改诊断 API 端点:`POST /api/v1/diagnosis/geo/{brand_id}` 触发异步诊断,`GET /api/v1/diagnosis/geo/{brand_id}/result` 轮询结果 +4. 诊断结果持久化到 `diagnosis_records` 表,支持历史对比 +5. 免费版返回综合分+3 核心维度,付费版返回全部 6 维度 +6. **先写契约测试**:定义诊断 API 的期望行为(非零评分、维度结构、免费/付费差异),再修复实现 + +**Execution note:** 先写诊断 API 契约测试(红色),再实现 DataCollectorService 和修改诊断服务(绿色)。 + +**Patterns to follow:** +- 复用 `backend/app/services/ai_engine/` 下的平台适配器 +- 复用 `backend/app/workers/citation_engine.py` 的查询执行模式 +- 异步任务模式参考 `backend/app/services/detection/detection_scheduler.py` +- 契约测试参考 `backend/tests/conftest.py` 中的 async_client fixture + +**Test scenarios:** +- 输入有效品牌名,自动采集数据并返回非零诊断分数 +- 输入不存在的品牌名,返回低分但非零(基于 AI 平台查询结果) +- 免费用户请求诊断,只返回 3 个核心维度 +- 付费用户请求诊断,返回全部 6 维度 +- 同一品牌 24 小时内重复请求,返回缓存结果 +- 诊断结果持久化,第二次诊断可展示历史对比 +- 契约测试:POST /api/v1/diagnosis/geo/{brand_id} 返回 202 + task_id +- 契约测试:GET /api/v1/diagnosis/geo/{brand_id}/result 返回非零 overall_score + +**Verification:** 调用 `POST /api/v1/diagnosis/geo/{brand_id}` 返回真实非零评分,且不同品牌评分有差异。契约测试全部通过。 + +--- + +### U2. 免费 GEO 健康分公开页面 + +**Goal:** 建设无需注册即可访问的 GEO 健康分页面,作为获客和市场教育的核心入口。 + +**Requirements:** R4, R9 + +**Dependencies:** U1 + +**Files:** +- `frontend/app/(public)/health-score/page.tsx` — 新建公开健康分页面 +- `frontend/app/(public)/health-score/components/ScoreCard.tsx` — 评分卡片组件 +- `frontend/app/(public)/health-score/components/CompetitorComparison.tsx` — 竞品对比组件 +- `frontend/app/(public)/health-score/components/ShareButton.tsx` — 分享按钮组件 +- `frontend/lib/api/health-score.ts` — 健康分 API 客户端 +- `backend/app/api/health_score.py` — 新建公开健康分 API(无需认证) +- `backend/app/schemas/health_score.py` — 健康分响应 schema +- `backend/tests/test_api/test_health_score_contract.py` — 新建健康分 API 契约测试 + +**Approach:** +1. 新建公开路由 `/health-score`,无需登录即可访问 +2. 页面核心交互:输入品牌名 → 30 秒内展示 GEO 健康分报告 +3. 报告内容:综合评分(大数字+健康等级色标)、3 核心维度雷达图、3 竞品对比柱状图、最严重 3 个问题概要 +4. "查看详细修复建议"按钮 → 触发注册弹窗 +5. 注册后自动关联本次诊断结果到用户账户 +6. 支持生成可分享链接(含品牌名参数)和 PDF 下载 +7. 后端:`GET /api/v1/public/health-score?brand=XXX` 无需认证,返回免费版诊断结果 +8. 结果缓存 24 小时(Redis),避免重复查询 + +**Patterns to follow:** +- 前端页面参考 `frontend/app/(dashboard)/onboarding/Step4HealthReport.tsx` 的报告展示模式 +- API 参考 `backend/app/api/diagnosis.py` 但无需认证 +- 分享功能参考 `backend/app/api/reports.py` 的 PDF 生成 + +**Test scenarios:** +- 未登录用户输入品牌名,30 秒内看到健康分报告 +- 报告包含综合评分、3 维度、3 竞品对比 +- 点击"查看详细建议"弹出注册弹窗 +- 注册后诊断结果自动关联到账户 +- 生成可分享链接,他人打开看到相同报告 +- 24 小时内重复查询同一品牌返回缓存结果 +- 契约测试:GET /api/v1/public/health-score?brand=XXX 返回 200 + 非零评分 +- 契约测试:无品牌名参数返回 422 + +**Verification:** 未登录状态下访问 `/health-score`,输入品牌名后看到完整报告且可分享。契约测试通过。 + +--- + +### U3. Onboarding 重设计与转化层 + +**Goal:** 重设计 Onboarding 流程以"查看 GEO 健康分"为第一步,嵌入付费墙触发点。 + +**Requirements:** R5 + +**Dependencies:** U1, U2 + +**Files:** +- `frontend/app/(dashboard)/onboarding/page.tsx` — 重写 Onboarding 主页面 +- `frontend/app/(dashboard)/onboarding/Step0HealthScore.tsx` — 新建第一步:健康分 +- `frontend/app/(dashboard)/onboarding/Step1BrandName.tsx` — 修改为简化版 +- `frontend/app/(dashboard)/onboarding/Step4HealthReport.tsx` — 修改为含升级提示 +- `frontend/app/(dashboard)/onboarding/Step5ActionSuggestions.tsx` — 修改为含执行按钮+付费墙 +- `frontend/components/subscription/UpgradePrompt.tsx` — 新建升级提示组件 + +**Approach:** +1. 新流程:Step0(输入品牌名看健康分) → Step1(注册/登录) → Step2(补充竞品信息) → Step3(查看详细报告+升级提示) → Step4(行动建议+执行按钮) +2. Step0 直接复用 U2 的健康分页面组件,嵌入 Onboarding 流程 +3. 在以下场景嵌入升级提示: + - 查看详细 6 维度分析时 → "升级 Pro 版查看完整分析" + - 点击执行优化建议时 → "升级 Starter 版开始优化" + - 查看效果归因报告时 → "升级 Pro 版追踪优化效果" +4. 升级提示组件 `UpgradePrompt` 统一管理,显示当前套餐限制和升级收益 + +**Patterns to follow:** +- 现有 Onboarding 步骤组件模式(Step1-5 的 props 接口) +- `useOnboardingData` hook 的数据管理模式 + +**Test scenarios:** +- 新用户进入 Onboarding,第一步看到健康分输入框 +- 输入品牌名后看到免费健康分报告 +- 注册后继续 Onboarding,看到详细报告含升级提示 +- 免费用户点击"执行优化"时看到升级弹窗 +- 已完成 Onboarding 的用户不再看到引导 + +**Verification:** 新用户从健康分开始 Onboarding,在关键节点看到升级提示。 + +--- + +### U4. 真实支付集成与功能限制 + +**Goal:** 接入微信支付+支付宝,实现功能限制中间件,让付费真正生效。 + +**Requirements:** R3, R11 + +**Dependencies:** U3 + +**Files:** +- `backend/app/services/payment/wechat_pay.py` — 新建微信支付服务 +- `backend/app/services/payment/alipay.py` — 新建支付宝服务 +- `backend/app/services/payment/base.py` — 新建支付基类 +- `backend/app/api/payments.py` — 新建支付 API 端点 +- `backend/app/middleware/subscription_enforcement.py` — 新建功能限制中间件 +- `backend/app/services/subscription.py` — 修改添加真实支付逻辑 +- `backend/app/api/subscriptions.py` — 修改添加支付 Webhook +- `backend/app/models/subscription.py` — 修改添加支付相关字段 +- `backend/tests/test_api/test_payment_contract.py` — 新建支付 API 契约测试 +- `backend/tests/test_services/test_payment.py` — 新建支付服务测试 + +**Approach:** +1. 支付基类定义统一接口:`create_order`, `verify_callback`, `refund` +2. 微信支付:使用官方 SDK,实现 JSAPI 支付(H5 页面)和 Native 支付(扫码) +3. 支付宝:使用官方 SDK,实现手机网站支付 +4. 支付 Webhook:`POST /api/v1/payments/callback/wechat` 和 `/alipay`,验证签名后更新订阅状态 +5. 功能限制中间件:在 API 路由层检查用户套餐,超限返回 403+升级提示 +6. 限制维度:查询数/品牌数/诊断频率/内容生成数/知识库数/数据保留期 +7. **先写契约测试**:定义支付 API 的期望行为(创建订单、回调处理、功能限制) + +**Execution note:** 先写支付 API 契约测试(红色),再实现支付服务(绿色)。 + +**Patterns to follow:** +- 支付回调参考 `backend/app/api/auth.py` 的 Webhook 处理模式 +- 功能限制参考 `backend/app/middleware/rate_limit.py` 的中间件模式 +- 订阅逻辑复用 `backend/app/services/subscription.py` 的 PLANS 配置 + +**Test scenarios:** +- 创建支付订单,返回支付链接/二维码 +- 支付成功回调,订阅状态更新为 active +- 免费用户超出查询限制,返回 403+升级提示 +- Pro 用户在限制内正常使用 +- 支付失败回调,订阅状态不变 +- 订阅到期,自动降级为免费版 +- 契约测试:POST /api/v1/payments/orders 返回 201 + order_id + pay_url +- 契约测试:POST /api/v1/payments/callback/wechat 验证签名后更新订阅 + +**Verification:** 完整支付流程:创建订单→支付→回调→订阅激活→功能解锁。契约测试通过。 + +--- + +### U5. AI 内容生成与分发集成 + +**Goal:** 在 ContentPipeline 前端添加 AI 生成阶段,集成知乎/头条 API 发布和微信半自动发布。 + +**Requirements:** R2, R6, R10 + +**Dependencies:** U1, U4 + +**Files:** +- `backend/app/services/content/ai_generator.py` — 新建 AI 内容生成服务 +- `backend/app/services/content/content_pipeline.py` — 修改添加 Stage 0 +- `backend/app/services/distribution/publishers/zhihu_publisher.py` — 新建知乎发布器 +- `backend/app/services/distribution/publishers/toutiao_publisher.py` — 新建头条发布器 +- `backend/app/services/distribution/publishers/wechat_publisher.py` — 新建微信半自动发布器 +- `backend/app/services/distribution/publish_engine.py` — 新建发布引擎 +- `backend/app/api/distribution.py` — 修改添加发布 API +- `backend/tests/test_api/test_content_contract.py` — 新建内容生成 API 契约测试 +- `backend/tests/test_services/test_ai_generator.py` — 新建 AI 生成测试 +- `backend/tests/test_services/test_publishers.py` — 新建发布器测试 + +**Approach:** +1. AI 内容生成服务: + - 输入:诊断结果(问题维度)+ RAG 知识库上下文 + 目标关键词 + 目标平台 + - 调用 LLM 生成初稿,基于诊断结果针对性优化 + - 输出传入现有 Pipeline(RuleValidator → SensitiveFilter → SEOOptimizer → HTMLGenerator) +2. 发布引擎: + - 知乎:通过知乎创作中心 API 发布文章(需 OAuth 授权) + - 头条:通过头条号 API 发布文章(需 API Key) + - 微信:生成格式化内容+复制引导,不直接调用 API +3. 发布流程:AI 生成 → 人工预览/编辑 → 确认发布 → 调用对应 Publisher → 记录发布状态 +4. 发布记录关联到 Content 模型,用于后续归因 + +**Patterns to follow:** +- AI 生成参考 `backend/app/services/content/content_generation_service.py` 的 LLM 调用模式 +- RAG 集成参考 `backend/app/services/knowledge/rag_service.py` +- 平台规则复用 `backend/app/services/distribution/platform_rules.py` +- Publisher 模式参考 `backend/app/services/ai_engine/` 的适配器模式 + +**Test scenarios:** +- 基于诊断结果+知识库上下文生成针对性优化内容 +- 生成内容通过 Pipeline 校验(RuleValidator + SensitiveFilter) +- 知乎 API 发布成功,返回文章 URL +- 头条 API 发布成功,返回文章 URL +- 微信半自动模式生成格式化内容+复制引导 +- 发布记录保存到 Content 模型,状态为 published +- 契约测试:POST /api/v1/content/generate 返回 201 + content_id + 生成内容 + +**Verification:** 从诊断建议到内容生成到发布的完整流程可走通。契约测试通过。 + +--- + +### U6. 效果归因系统与 ROI 报告 + +**Goal:** 建设效果归因系统,追踪已发布内容的 AI 引用变化,生成 ROI 报告证明付费价值。 + +**Requirements:** R7 + +**Dependencies:** U5 + +**Files:** +- `backend/app/services/attribution/attribution_engine.py` — 新建归因引擎 +- `backend/app/services/attribution/roi_calculator.py` — 新建 ROI 计算器 +- `backend/app/api/attribution.py` — 新建归因 API +- `backend/app/models/attribution_record.py` — 新建归因记录模型 +- `backend/app/services/monitoring/monitor_service.py` — 修改集成归因 +- `frontend/app/(dashboard)/dashboard/roi/page.tsx` — 新建 ROI 报告页面 +- `frontend/lib/api/attribution.ts` — 新建归因 API 客户端 +- `backend/tests/test_services/test_attribution.py` — 新建归因测试 + +**Approach:** +1. 归因引擎核心逻辑: + - 内容发布时记录:发布时间、目标关键词、发布前基线引用数据 + - 定期(每 3 天)检查目标关键词的引用变化 + - 归因判定:引用率上升 + 新引用内容与已发布内容语义相关 → 归因为该内容贡献 + - 归因窗口:2-4 周,窗口内持续追踪 +2. ROI 计算器:输入订阅费用+归因引用提升+行业平均引用价值,输出 GEO ROI 百分比 +3. A/B 对比报告:优化前诊断分数 vs 当前诊断分数,按维度对比 +4. 效果保障:如果 2-4 周内无任何引用提升,提供额外优化建议或延长服务 + +**Patterns to follow:** +- 监控模式参考 `backend/app/services/monitoring/monitor_service.py` +- 定期任务参考 `backend/app/services/detection/detection_scheduler.py` +- 报告生成参考 `backend/app/api/reports.py` + +**Test scenarios:** +- 发布内容后自动创建归因追踪记录 +- 3 天后检查引用变化,记录 delta +- 引用率上升且语义相关,归因为该内容贡献 +- 归因窗口结束后生成 ROI 报告 +- 优化前后 A/B 对比报告展示各维度变化 +- 无引用提升时提供额外优化建议 + +**Verification:** 发布内容→2-4 周后→归因报告展示引用提升和 ROI。 + +--- + +### U7. 邮件集成与 Dashboard 变现 UI + +**Goal:** 将 EmailService 接入业务系统(周报、续费提醒),Dashboard 添加订阅状态、用量进度、ROI 面板。 + +**Requirements:** R7 (邮件部分) + +**Dependencies:** U4, U6 + +**Files:** +- `backend/app/services/email_service.py` — 修改添加新模板 +- `backend/app/services/email/email_scheduler.py` — 新建邮件调度器 +- `backend/app/templates/geo_weekly_report.html` — 新建周报模板 +- `backend/app/templates/renewal_reminder.html` — 新建续费提醒模板 +- `backend/app/templates/trial_expiring.html` — 新建试用期到期模板 +- `frontend/app/(dashboard)/dashboard/page.tsx` — 修改添加变现 UI +- `frontend/components/subscription/SubscriptionStatus.tsx` — 新建订阅状态组件 +- `frontend/components/subscription/UsageProgress.tsx` — 新建用量进度组件 +- `frontend/components/dashboard/ROICard.tsx` — 新建 ROI 卡片组件 + +**Approach:** +1. 邮件模板新增:GEO 变化周报、续费提醒(到期前 7 天/3 天/1 天)、试用期到期提醒、欢迎邮件 +2. 邮件调度器:APScheduler 定时任务,每周一发送周报,每日检查续费提醒 +3. Dashboard 变现 UI:顶部订阅状态栏、用量进度条、ROI 卡片、功能锁定 UI +4. 配置真实 SMTP(支持 SendGrid/阿里云邮件推送) + +**Patterns to follow:** +- 邮件模板参考 `backend/app/services/email_service.py` 现有模板格式 +- Dashboard 组件参考 `frontend/app/(dashboard)/dashboard/page.tsx` 的 KPI 卡片模式 +- 订阅状态参考 `backend/app/services/subscription.py` 的 PLANS 配置 + +**Test scenarios:** +- 每周一自动发送 GEO 变化周报给注册用户 +- 订阅到期前 7 天发送续费提醒 +- 试用期到期前 3 天发送提醒 +- Dashboard 显示当前套餐和用量进度 +- 超出用量时显示升级提示 +- ROI 卡片展示引用率提升和归因内容 + +**Verification:** 注册用户收到周报邮件,Dashboard 显示订阅状态和 ROI 数据。 + +--- + +### U8. API 契约集成测试与测试基础设施 + +**Goal:** 编写跨步骤集成测试验证完整链路,完成后端测试基础设施统一。 + +**Requirements:** R12, R15, R16 + +**Dependencies:** U1, U2, U4, U5 + +**Files:** +- `backend/tests/test_integration/test_monetization_flow.py` — 新建变现闭环集成测试 +- `backend/tests/conftest.py` — 增强 fixture 体系 +- `backend/tests/fixtures/auth.py` — 新建认证 fixture 模块 +- `backend/tests/fixtures/database.py` — 新建数据库 fixture 模块 +- `backend/tests/fixtures/client.py` — 新建 httpx 客户端 fixture 模块 +- `backend/tests/fixtures/brands.py` — 新建品牌测试数据 fixture 模块 + +**Approach:** +1. 编写跨步骤集成测试:品牌创建→诊断→方案→内容生成→效果追踪的完整链路 +2. 验证 U1-U5 的 API 契约测试在集成场景下仍然通过 +3. 完成后端测试目录统一验证(原 QA 计划 U1) +4. 建立共享 fixture 体系:按领域拆分到 `backend/tests/fixtures/` 下 + - auth.py:override_get_current_user、auth_token、auth_headers + - database.py:async_engine、async_session(含自动 rollback) + - client.py:async_client(httpx AsyncClient with app) + - brands.py:预创建的测试品牌和竞品数据 +5. conftest.py 通过 pytest plugin 机制自动加载 fixtures/ 下所有模块 + +**Patterns to follow:** +- `backend/tests/conftest.py` 现有 fixture 模式 +- `backend/tests/test_integration/test_business_flow.py` 现有集成测试模式 + +**Test scenarios:** +- 品牌创建→诊断→方案→内容生成→效果追踪全链路通过 +- 数据在步骤间正确传递(品牌 ID、诊断结果 ID、内容 ID) +- 共享 fixture 可被其他测试复用 +- 数据库 fixture 在测试结束后自动 rollback +- 多个测试并行运行时 fixture 互不干扰 + +**Verification:** `cd backend && pytest tests/test_integration/test_monetization_flow.py --tb=short` 通过。`cd backend && pytest tests/ -q` 全部测试被发现且无新增失败。 + +--- + +### U9. E2E 烟雾测试 + +**Goal:** 编写 2 个关键路径的 Playwright E2E 烟雾测试,验证最核心的用户路径端到端可用。 + +**Requirements:** R13, R14 + +**Dependencies:** U2, U3 + +**Files:** +- `frontend/e2e/tests/health-score-smoke.spec.ts` — 新建获客路径烟雾测试 +- `frontend/e2e/tests/core-flow-smoke.spec.ts` — 新建核心流程烟雾测试 + +**Approach:** +1. 获客路径烟雾测试:访问健康分页面→输入品牌名→看到报告→点击注册 +2. 核心流程烟雾测试:登录→创建品牌→触发诊断→查看诊断结果 +3. 使用独立 setup(直接调用 API 创建测试数据),不依赖共享 fixture +4. 确保截图和视频录制配置正确(screenshot:on-failure, video:retain-on-failure) + +**Patterns to follow:** +- `frontend/e2e/tests/login.spec.ts` 现有 E2E 测试模式 +- `frontend/e2e/pages/login.page.ts` Page Object 模式 + +**Test scenarios:** +- 获客路径:未登录→访问 /health-score→输入品牌名→30 秒内看到报告→点击注册→注册弹窗出现 +- 核心流程:登录→创建品牌→品牌出现在 Dashboard→触发诊断→诊断结果非零 +- 测试失败时自动截图保存 + +**Verification:** `cd frontend && npx playwright test health-score-smoke core-flow-smoke` 通过。 + +## Scope Boundaries + +**Deferred for later:** + +- 网站爬取通道(诊断数据采集 V2) +- 全面 E2E 测试覆盖(原 QA 计划 U3/U4/U5 → 上线后逐步扩展) +- 性能基线测试(原 QA 计划 U7 → 上线后有真实负载后建立) +- CI 安全扫描(原 QA 计划 U8 → 上线后集成) +- Alembic 迁移验证(原 QA 计划 U8 → 上线后添加) +- 前端组件测试(原 QA 计划 U10 → 上线后扩展) +- 行业 GEO 基准数据积累 +- API 市场开放和第三方集成 +- 白标报告和专属客户成功经理 +- 多语言和国际 AI 平台支持 +- 知识库自动构建(从品牌官网自动提取 GEO 相关信息结构) +- 批量内容生成(一次生成多平台内容) +- 内容日历和排期执行系统 + +**Outside this product's identity:** + +- 传统 SEO 工具功能(关键词排名、外链分析等) +- 广告投放和营销自动化 +- 社交媒体管理 +- 基础设施级别的渗透测试 + +**Deferred to follow-up work:** + +- E2E 测试迁移到共享 fixture(U9 完成后,将独立 setup 替换为 U8 的共享 fixture) +- 性能基线设定阈值和告警(上线后根据数据设定合理阈值) +- 跨浏览器 E2E 扩展(稳定 Chromium 后再扩展) +- 知乎/头条 API OAuth 授权完整流程 +- 批量内容生成 + +## Open Questions + +- 微信公众号 OAuth 授权流程的具体实现方案需在 U5 实施时确认 +- 知乎和头条 API 的申请和审核周期可能影响 U5 的交付时间 +- 诊断自动数据采集的品牌官网解析可能遇到反爬机制,需准备降级方案 +- 真实 SMTP 配置需要用户提供邮件服务商账号信息 +- 效果归因的具体技术实现方案(基于引用检测还是基于评分变化) +- 付费版定价是否需要调整(当前 ¥199/599/1999 是否匹配价值感知) + +## Risks & Dependencies + +| Risk | Impact | Mitigation | +|------|--------|------------| +| 诊断数据采集不准确 | 用户不信任健康分,获客失效 | 先用 AI 平台查询数据(已有适配器),品牌官网解析作为增强 | +| 微信/知乎/头条 API 审核不通过 | 分发集成延期 | 微信降级为半自动,知乎/头条备选手动发布 | +| 支付 SDK 集成复杂度超预期 | 付费闭环延期 | 先实现支付宝(文档更友好),微信支付后续迭代 | +| GEO 效果归因窗口内无可见提升 | 用户认为产品无效 | 提供效果保障机制,延长归因窗口或提供额外优化 | +| 免费诊断 API 调用成本过高 | 运营成本失控 | 24 小时缓存+频率限制+深度限制 | +| LLM API Key 不可用影响内容生成 | 核心付费功能不可用 | 支持多 LLM Provider 降级,至少保证一个可用 | + +## Sources & Research + +- Origin requirements: `docs/brainstorms/2026-05-31-geo-next-phase-core-flow-repair-requirements.md` +- Secondary origin: `docs/brainstorms/2026-05-31-geo-platform-monetization-closed-loop-requirements.md` +- QA plan (downscoped): `docs/plans/2026-05-31-002-test-quality-assurance-system-plan.md` +- Existing diagnosis system: `backend/app/services/diagnosis/geo_diagnosis.py` +- Diagnosis API root cause: `backend/app/api/diagnosis.py` line 75 (`GEODiagnosisInput()`) +- Existing content pipeline: `backend/app/services/content/content_pipeline.py` +- Existing distribution rules: `backend/app/services/distribution/platform_rules.py` +- Existing monitoring: `backend/app/services/monitoring/monitor_service.py` +- Existing subscription: `backend/app/services/subscription.py` +- Existing email: `backend/app/services/email_service.py` +- Existing E2E tests: `frontend/e2e/tests/` diff --git a/frontend/__tests__/hooks/use-compare-data.test.ts b/frontend/__tests__/hooks/use-compare-data.test.ts new file mode 100644 index 0000000..e7168b5 --- /dev/null +++ b/frontend/__tests__/hooks/use-compare-data.test.ts @@ -0,0 +1,192 @@ +/** + * useCompareData Hook 单元测试 + */ + +import { describe, it, expect, vi, beforeEach } from "vitest"; +import { renderHook, waitFor, act } from "@testing-library/react"; +import { SWRConfig } from "swr"; +import type { ReactNode } from "react"; + +vi.mock("@/lib/api/client", () => ({ + fetchWithAuth: vi.fn(), +})); + +import { fetchWithAuth } from "@/lib/api/client"; +const mockFetchWithAuth = vi.mocked(fetchWithAuth); + +vi.mock("next-auth/react", () => ({ + useSession: vi.fn(() => ({ + data: { accessToken: "test-token" }, + status: "authenticated", + })), +})); + +import { useCompareData } from "@/lib/hooks/use-compare-data"; +import type { BrandListResponse, CompareResponse } from "@/types/brand"; + +const noRetryOptions = { shouldRetryOnError: false }; + +/** 使用独立 SWR 缓存的 wrapper,避免测试间缓存冲突 */ +function createWrapper() { + return ({ children }: { children: ReactNode }) => + SWRConfig({ + value: { provider: () => new Map() }, + children, + }); +} + +describe("useCompareData", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + const mockBrandsResponse: BrandListResponse = { + items: [ + { + id: "brand-1", + name: "测试品牌", + aliases: [], + platforms: ["wenxin"], + frequency: "weekly", + status: "active", + score: 80, + last_queried_at: null, + next_query_at: null, + created_at: "2024-01-01", + }, + ], + total: 1, + }; + + const mockCompareResponse: CompareResponse = { + brand_id: "brand-1", + brand_name: "测试品牌", + items: [ + { + entity_id: "brand-1", + entity_name: "测试品牌", + entity_type: "brand", + mention_rate_score: 70, + sov_score: 65, + quality_score: 80, + overall_score: 75, + citation_count: 100, + dimensions: [], + overall_trend: "up", + overall_trend_value: 5, + }, + ], + radar_data: [], + }; + + it("应返回品牌列表数据", async () => { + mockFetchWithAuth.mockImplementation((url: string) => { + if (url.includes("/api/v1/brands")) return Promise.resolve(mockBrandsResponse); + return Promise.resolve({}); + }); + + const { result } = renderHook( + () => useCompareData({ swrOptions: noRetryOptions }), + { wrapper: createWrapper() } + ); + + await waitFor(() => { + expect(result.current.brands).toEqual(mockBrandsResponse.items); + }); + }); + + it("selectedBrandId 为空时应暂停对比数据请求", () => { + mockFetchWithAuth.mockResolvedValue(mockBrandsResponse); + + const { result } = renderHook( + () => useCompareData({ swrOptions: noRetryOptions }), + { wrapper: createWrapper() } + ); + + expect(result.current.compareData).toBeUndefined(); + }); + + it("selectedBrandId 存在时应获取对比数据", async () => { + mockFetchWithAuth.mockImplementation((url: string) => { + if (url.includes("/compare")) return Promise.resolve(mockCompareResponse); + return Promise.resolve(mockBrandsResponse); + }); + + const { result } = renderHook( + () => useCompareData({ initialBrandId: "brand-1", swrOptions: noRetryOptions }), + { wrapper: createWrapper() } + ); + + await waitFor(() => { + expect(result.current.compareData).toEqual(mockCompareResponse); + }); + }); + + it("应正确暴露 loading 状态", async () => { + mockFetchWithAuth.mockImplementation((url: string) => { + if (url.includes("/compare")) return Promise.resolve(mockCompareResponse); + return Promise.resolve(mockBrandsResponse); + }); + + const { result } = renderHook( + () => useCompareData({ initialBrandId: "brand-1", swrOptions: noRetryOptions }), + { wrapper: createWrapper() } + ); + + await waitFor(() => { + expect(result.current.isLoading).toBe(false); + }); + }); + + it("应正确暴露 error 状态", async () => { + mockFetchWithAuth.mockRejectedValue(new Error("网络错误")); + + const { result } = renderHook( + () => useCompareData({ swrOptions: noRetryOptions }), + { wrapper: createWrapper() } + ); + + await waitFor(() => { + expect(result.current.error).toBeDefined(); + expect(result.current.error?.message).toBe("网络错误"); + }); + }); + + it("setSelectedBrandId 应更新选中品牌", async () => { + mockFetchWithAuth.mockImplementation((url: string) => { + if (url.includes("/compare")) return Promise.resolve(mockCompareResponse); + return Promise.resolve(mockBrandsResponse); + }); + + const { result } = renderHook( + () => useCompareData({ swrOptions: noRetryOptions }), + { wrapper: createWrapper() } + ); + + await waitFor(() => { + expect(result.current.brands.length).toBeGreaterThan(0); + }); + + act(() => { + result.current.setSelectedBrandId("brand-1"); + }); + + expect(result.current.selectedBrandId).toBe("brand-1"); + }); + + it("应提供 refresh 方法刷新数据", async () => { + mockFetchWithAuth.mockResolvedValue(mockBrandsResponse); + + const { result } = renderHook( + () => useCompareData({ swrOptions: noRetryOptions }), + { wrapper: createWrapper() } + ); + + await waitFor(() => { + expect(result.current.isLoading).toBe(false); + }); + + expect(typeof result.current.refreshBrands).toBe("function"); + expect(typeof result.current.refreshCompare).toBe("function"); + }); +}); diff --git a/frontend/__tests__/hooks/use-content-data.test.ts b/frontend/__tests__/hooks/use-content-data.test.ts new file mode 100644 index 0000000..1ae6b1c --- /dev/null +++ b/frontend/__tests__/hooks/use-content-data.test.ts @@ -0,0 +1,141 @@ +/** + * useContentData Hook 单元测试 + */ + +import { describe, it, expect, vi, beforeEach } from "vitest"; +import { renderHook, waitFor } from "@testing-library/react"; +import { SWRConfig } from "swr"; +import type { ReactNode } from "react"; + +vi.mock("@/lib/api/client", () => ({ + fetchWithAuth: vi.fn(), +})); + +import { fetchWithAuth } from "@/lib/api/client"; +const mockFetchWithAuth = vi.mocked(fetchWithAuth); + +import { useContentData } from "@/lib/hooks/use-content-data"; +import type { Content } from "@/lib/api/contents"; +import type { KnowledgeBase } from "@/lib/api/knowledge"; + +const noRetryOptions = { shouldRetryOnError: false }; + +function createWrapper() { + return ({ children }: { children: ReactNode }) => + SWRConfig({ value: { provider: () => new Map() }, children }); +} + +describe("useContentData", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + const mockContents: Content[] = [ + { + id: "content-1", + title: "测试内容", + body: "内容正文", + content_type: "article", + status: "draft", + author_id: "user-1", + tags: ["test"], + created_at: "2024-01-01", + updated_at: "2024-01-01", + }, + ]; + + const mockKnowledgeBases: KnowledgeBase[] = [ + { + id: "kb-1", + name: "测试知识库", + type: "enterprise", + document_count: 5, + status: "ready", + created_at: "2024-01-01", + }, + ]; + + it("应同时获取内容列表和知识库列表", async () => { + mockFetchWithAuth.mockImplementation((url: string) => { + if (url.includes("/api/v1/contents/")) return Promise.resolve(mockContents); + if (url.includes("/api/v1/knowledge/bases")) return Promise.resolve(mockKnowledgeBases); + return Promise.resolve([]); + }); + + const { result } = renderHook( + () => useContentData({ swrOptions: noRetryOptions }), + { wrapper: createWrapper() } + ); + + await waitFor(() => { + expect(result.current.isLoading).toBe(false); + }); + + expect(result.current.contents).toEqual(mockContents); + expect(result.current.knowledgeBases).toEqual(mockKnowledgeBases); + }); + + it("应正确暴露 isLoading 状态", async () => { + mockFetchWithAuth.mockImplementation((url: string) => { + if (url.includes("/api/v1/contents/")) return Promise.resolve(mockContents); + if (url.includes("/api/v1/knowledge/bases")) return Promise.resolve(mockKnowledgeBases); + return Promise.resolve([]); + }); + + const { result } = renderHook( + () => useContentData({ swrOptions: noRetryOptions }), + { wrapper: createWrapper() } + ); + + await waitFor(() => { + expect(result.current.isLoading).toBe(false); + }); + }); + + it("应正确暴露 error 状态", async () => { + mockFetchWithAuth.mockRejectedValue(new Error("数据加载失败")); + + const { result } = renderHook( + () => useContentData({ swrOptions: noRetryOptions }), + { wrapper: createWrapper() } + ); + + await waitFor(() => { + expect(result.current.error).toBeDefined(); + expect(result.current.error?.message).toBe("数据加载失败"); + }); + }); + + it("应提供 refresh 方法", async () => { + mockFetchWithAuth.mockImplementation((url: string) => { + if (url.includes("/api/v1/contents/")) return Promise.resolve(mockContents); + if (url.includes("/api/v1/knowledge/bases")) return Promise.resolve(mockKnowledgeBases); + return Promise.resolve([]); + }); + + const { result } = renderHook( + () => useContentData({ swrOptions: noRetryOptions }), + { wrapper: createWrapper() } + ); + + await waitFor(() => { + expect(result.current.isLoading).toBe(false); + }); + + expect(typeof result.current.refreshContents).toBe("function"); + expect(typeof result.current.refreshKnowledgeBases).toBe("function"); + }); + + it("数据加载失败时 error 应包含错误信息", async () => { + mockFetchWithAuth.mockRejectedValue(new Error("网络异常")); + + const { result } = renderHook( + () => useContentData({ swrOptions: noRetryOptions }), + { wrapper: createWrapper() } + ); + + await waitFor(() => { + expect(result.current.error?.message).toBe("网络异常"); + }); + }); +}); diff --git a/frontend/__tests__/hooks/use-onboarding-data.test.ts b/frontend/__tests__/hooks/use-onboarding-data.test.ts new file mode 100644 index 0000000..138f146 --- /dev/null +++ b/frontend/__tests__/hooks/use-onboarding-data.test.ts @@ -0,0 +1,174 @@ +/** + * useOnboardingData Hook 单元测试 + */ + +import { describe, it, expect, vi, beforeEach } from "vitest"; +import { renderHook, waitFor, act } from "@testing-library/react"; +import { SWRConfig } from "swr"; +import type { ReactNode } from "react"; + +vi.mock("@/lib/api/client", () => ({ + fetchWithAuth: vi.fn(), +})); + +import { fetchWithAuth } from "@/lib/api/client"; +const mockFetchWithAuth = vi.mocked(fetchWithAuth); + +vi.mock("next-auth/react", () => ({ + useSession: vi.fn(() => ({ + data: { accessToken: "test-token" }, + status: "authenticated", + })), +})); + +import { useOnboardingData } from "@/lib/hooks/use-onboarding-data"; + +const noRetryOptions = { shouldRetryOnError: false }; + +function createWrapper() { + return ({ children }: { children: ReactNode }) => + SWRConfig({ value: { provider: () => new Map() }, children }); +} + +describe("useOnboardingData", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("应检查引导状态", async () => { + mockFetchWithAuth.mockResolvedValue({ completed: false }); + + const { result } = renderHook( + () => useOnboardingData({ swrOptions: noRetryOptions }), + { wrapper: createWrapper() } + ); + + await waitFor(() => { + expect(result.current.onboardingStatus).toEqual({ completed: false }); + }); + }); + + it("引导已完成时应标记 isCompleted", async () => { + mockFetchWithAuth.mockResolvedValue({ completed: true }); + + const { result } = renderHook( + () => useOnboardingData({ swrOptions: noRetryOptions }), + { wrapper: createWrapper() } + ); + + await waitFor(() => { + expect(result.current.onboardingStatus).toEqual({ completed: true }); + expect(result.current.isCompleted).toBe(true); + }); + }); + + it("应正确暴露 isLoading 状态", async () => { + mockFetchWithAuth.mockResolvedValue({ completed: false }); + + const { result } = renderHook( + () => useOnboardingData({ swrOptions: noRetryOptions }), + { wrapper: createWrapper() } + ); + + await waitFor(() => { + expect(result.current.isLoading).toBe(false); + }); + }); + + it("应正确暴露 error 状态", async () => { + mockFetchWithAuth.mockRejectedValue(new Error("检查引导状态失败")); + + const { result } = renderHook( + () => useOnboardingData({ swrOptions: noRetryOptions }), + { wrapper: createWrapper() } + ); + + await waitFor(() => { + expect(result.current.error).toBeDefined(); + expect(result.current.error?.message).toBe("检查引导状态失败"); + }); + }); + + it("应提供 createBrand mutation 方法", () => { + mockFetchWithAuth.mockResolvedValue({ completed: false }); + + const { result } = renderHook( + () => useOnboardingData({ swrOptions: noRetryOptions }), + { wrapper: createWrapper() } + ); + + expect(typeof result.current.createBrand).toBe("function"); + }); + + it("createBrand 成功应返回 brand_id", async () => { + mockFetchWithAuth.mockImplementation((url: string, options?: RequestInit) => { + if (options?.method === "POST") return Promise.resolve({ brand_id: "new-brand-1" }); + return Promise.resolve({ completed: false }); + }); + + const { result } = renderHook( + () => useOnboardingData({ swrOptions: noRetryOptions }), + { wrapper: createWrapper() } + ); + + await waitFor(() => { + expect(result.current.isLoading).toBe(false); + }); + + let brandId: string | null = null; + await act(async () => { + brandId = await result.current.createBrand({ + name: "新品牌", + competitors: [], + platforms: ["wenxin"], + frequency: "weekly", + }); + }); + + expect(brandId).toBe("new-brand-1"); + }); + + it("createBrand 失败应返回 null 并设置 mutationError", async () => { + mockFetchWithAuth.mockImplementation((url: string, options?: RequestInit) => { + if (options?.method === "POST") return Promise.reject(new Error("创建品牌失败")); + return Promise.resolve({ completed: false }); + }); + + const { result } = renderHook( + () => useOnboardingData({ swrOptions: noRetryOptions }), + { wrapper: createWrapper() } + ); + + await waitFor(() => { + expect(result.current.isLoading).toBe(false); + }); + + let brandId: string | null = null; + await act(async () => { + brandId = await result.current.createBrand({ + name: "新品牌", + competitors: [], + platforms: ["wenxin"], + frequency: "weekly", + }); + }); + + expect(brandId).toBeNull(); + expect(result.current.mutationError).toBeInstanceOf(Error); + }); + + it("应提供 refresh 方法", async () => { + mockFetchWithAuth.mockResolvedValue({ completed: false }); + + const { result } = renderHook( + () => useOnboardingData({ swrOptions: noRetryOptions }), + { wrapper: createWrapper() } + ); + + await waitFor(() => { + expect(result.current.isLoading).toBe(false); + }); + + expect(typeof result.current.refresh).toBe("function"); + }); +}); diff --git a/frontend/__tests__/hooks/use-suggestions-data.test.ts b/frontend/__tests__/hooks/use-suggestions-data.test.ts new file mode 100644 index 0000000..b589d26 --- /dev/null +++ b/frontend/__tests__/hooks/use-suggestions-data.test.ts @@ -0,0 +1,167 @@ +/** + * useSuggestionsData Hook 单元测试 + */ + +import { describe, it, expect, vi, beforeEach } from "vitest"; +import { renderHook, waitFor } from "@testing-library/react"; +import { SWRConfig } from "swr"; +import type { ReactNode } from "react"; + +vi.mock("@/lib/api/client", () => ({ + fetchWithAuth: vi.fn(), +})); + +import { fetchWithAuth } from "@/lib/api/client"; +const mockFetchWithAuth = vi.mocked(fetchWithAuth); + +vi.mock("next-auth/react", () => ({ + useSession: vi.fn(() => ({ + data: { accessToken: "test-token" }, + status: "authenticated", + })), +})); + +import { useSuggestionsData } from "@/lib/hooks/use-suggestions-data"; +import type { SuggestionListResponse } from "@/types/suggestion"; + +const noRetryOptions = { shouldRetryOnError: false }; + +function createWrapper() { + return ({ children }: { children: ReactNode }) => + SWRConfig({ value: { provider: () => new Map() }, children }); +} + +describe("useSuggestionsData", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + const mockSuggestionsResponse: SuggestionListResponse = { + suggestions: [ + { + id: "sug-1", + brand_id: "brand-1", + type: "content_optimization", + priority: "high", + title: "优化内容", + description: "建议优化内容", + action: null, + expected_impact: null, + difficulty: "easy", + status: "pending", + generated_at: "2024-01-01", + updated_at: "2024-01-01", + batch_id: "batch-1", + source: "rule", + }, + ], + total: 1, + }; + + it("brandId 为空时应暂停建议请求", () => { + mockFetchWithAuth.mockResolvedValue({}); + + const { result } = renderHook(() => useSuggestionsData(), { wrapper: createWrapper() }); + + expect(result.current.suggestions).toBeUndefined(); + }); + + it("brandId 存在时应获取建议列表", async () => { + mockFetchWithAuth.mockResolvedValue(mockSuggestionsResponse); + + const { result } = renderHook( + () => useSuggestionsData({ brandId: "brand-1", swrOptions: noRetryOptions }), + { wrapper: createWrapper() } + ); + + await waitFor(() => { + expect(result.current.suggestions).toEqual(mockSuggestionsResponse.suggestions); + }); + }); + + it("应支持筛选参数", async () => { + mockFetchWithAuth.mockResolvedValue(mockSuggestionsResponse); + + const { result } = renderHook( + () => useSuggestionsData({ + brandId: "brand-1", + filters: { type: "content_optimization", priority: "high", status: "pending" }, + swrOptions: noRetryOptions, + }), + { wrapper: createWrapper() } + ); + + await waitFor(() => { + expect(result.current.suggestions).toEqual(mockSuggestionsResponse.suggestions); + }); + + const suggestionCall = mockFetchWithAuth.mock.calls.find( + (call) => typeof call[0] === "string" && call[0].includes("/suggestions") + ); + expect(suggestionCall).toBeDefined(); + const calledUrl = suggestionCall![0] as string; + expect(calledUrl).toContain("type=content_optimization"); + expect(calledUrl).toContain("priority=high"); + expect(calledUrl).toContain("status=pending"); + }); + + it("应正确暴露 isLoading 状态", async () => { + mockFetchWithAuth.mockResolvedValue(mockSuggestionsResponse); + + const { result } = renderHook( + () => useSuggestionsData({ brandId: "brand-1", swrOptions: noRetryOptions }), + { wrapper: createWrapper() } + ); + + await waitFor(() => { + expect(result.current.isLoading).toBe(false); + }); + }); + + it("应正确暴露 error 状态", async () => { + mockFetchWithAuth.mockRejectedValue(new Error("加载建议失败")); + + const { result } = renderHook( + () => useSuggestionsData({ brandId: "brand-1", swrOptions: noRetryOptions }), + { wrapper: createWrapper() } + ); + + await waitFor(() => { + expect(result.current.error).toBeDefined(); + expect(result.current.error?.message).toBe("加载建议失败"); + }); + }); + + it("应提供 regenerate 方法", () => { + mockFetchWithAuth.mockResolvedValue(mockSuggestionsResponse); + + const { result } = renderHook( + () => useSuggestionsData({ brandId: "brand-1", swrOptions: noRetryOptions }), + { wrapper: createWrapper() } + ); + + expect(typeof result.current.regenerate).toBe("function"); + }); + + it("应提供 updateStatus 方法", () => { + mockFetchWithAuth.mockResolvedValue(mockSuggestionsResponse); + + const { result } = renderHook( + () => useSuggestionsData({ brandId: "brand-1", swrOptions: noRetryOptions }), + { wrapper: createWrapper() } + ); + + expect(typeof result.current.updateStatus).toBe("function"); + }); + + it("应提供 refresh 方法", () => { + mockFetchWithAuth.mockResolvedValue(mockSuggestionsResponse); + + const { result } = renderHook( + () => useSuggestionsData({ brandId: "brand-1", swrOptions: noRetryOptions }), + { wrapper: createWrapper() } + ); + + expect(typeof result.current.refresh).toBe("function"); + }); +}); diff --git a/frontend/__tests__/lib/action-suggestions.test.ts b/frontend/__tests__/lib/action-suggestions.test.ts new file mode 100644 index 0000000..04de5c3 --- /dev/null +++ b/frontend/__tests__/lib/action-suggestions.test.ts @@ -0,0 +1,157 @@ +/** + * 行动建议系统统一测试 + * 验证 ActionSuggestion 和 NextAction 使用统一的 ActionItem 类型 + */ + +import { describe, it, expect } from "vitest"; +import { generateActionSuggestions } from "@/lib/dashboard-health"; +import { generateNextActions } from "@/lib/next-action"; +import type { ActionItem, ActionPriority } from "@/types/suggestion"; +import type { ActionSuggestion } from "@/types/dashboard-health"; +import type { NextAction } from "@/types/next-action"; + +describe("ActionItem 统一类型", () => { + it("ActionItem 类型应包含所有必要字段", async () => { + const suggestionModule = await import("@/types/suggestion"); + // 验证类型可以被导入(编译时检查) + expect(suggestionModule).toBeDefined(); + }); + + it("ActionPriority 应为 'primary' | 'secondary' | 'optional'", async () => { + const suggestionModule = await import("@/types/suggestion"); + // ActionPriority 是类型,运行时不可直接检查,但我们可以验证模块导出 + expect(suggestionModule).toBeDefined(); + }); +}); + +describe("ActionSuggestion 兼容 ActionItem", () => { + it("generateActionSuggestions 返回的对象应兼容 ActionItem 结构", () => { + const stats = { + platformScores: [ + { + platform: "kimi", + score: 30, + competitor_score: 60, + competitor_name: "竞品A", + }, + ], + overallScore: 30, + hasQueries: true, + }; + + const suggestions = generateActionSuggestions(stats); + + expect(suggestions.length).toBeGreaterThan(0); + + for (const suggestion of suggestions) { + // ActionItem 必需字段 + expect(suggestion).toHaveProperty("id"); + expect(suggestion).toHaveProperty("title"); + expect(suggestion).toHaveProperty("description"); + expect(suggestion).toHaveProperty("icon"); + expect(suggestion).toHaveProperty("href"); + + // type 字段应兼容 ActionPriority + const validPriorities: ActionPriority[] = [ + "primary", + "secondary", + "optional", + ]; + expect(validPriorities).toContain(suggestion.type); + } + }); +}); + +describe("NextAction 兼容 ActionItem", () => { + it("generateNextActions 返回的对象应兼容 ActionItem 结构", () => { + const context = { + hasData: true, + hasBrands: true, + brandCount: 1, + overallScore: 50, + scoreChange: -5, + competitorCount: 2, + hasQueryHistory: true, + currentPage: "dashboard" as const, + }; + + const actions = generateNextActions(context); + + expect(actions.length).toBeGreaterThan(0); + + for (const action of actions) { + // ActionItem 必需字段 + expect(action).toHaveProperty("id"); + expect(action).toHaveProperty("title"); + expect(action).toHaveProperty("description"); + expect(action).toHaveProperty("icon"); + + // NextAction 应有 href 字段(映射自 actionUrl) + expect(action).toHaveProperty("href"); + + // priority 字段应兼容 ActionPriority + const validPriorities: ActionPriority[] = [ + "primary", + "secondary", + "optional", + ]; + expect(validPriorities).toContain(action.priority); + } + }); + + it("NextAction 的 href 应与 actionUrl 一致", () => { + const context = { + hasData: false, + hasBrands: false, + brandCount: 0, + overallScore: 0, + scoreChange: 0, + competitorCount: 0, + hasQueryHistory: false, + currentPage: "dashboard" as const, + }; + + const actions = generateNextActions(context); + + for (const action of actions) { + // href 应该存在且非空 + expect(action.href).toBeTruthy(); + expect(typeof action.href).toBe("string"); + } + }); +}); + +describe("两个系统生成结果的一致性", () => { + it("ActionSuggestion 和 NextAction 都能生成有效的行动建议", () => { + // dashboard-health 的 generateActionSuggestions + const dashboardSuggestions = generateActionSuggestions({ + platformScores: [ + { platform: "kimi", score: 30, competitor_score: 60 }, + ], + overallScore: 30, + hasQueries: true, + }); + + // next-action 的 generateNextActions + const nextActions = generateNextActions({ + hasData: true, + hasBrands: true, + brandCount: 1, + overallScore: 30, + scoreChange: -5, + competitorCount: 1, + hasQueryHistory: true, + currentPage: "dashboard", + }); + + // 两个系统都应该能生成建议 + expect(dashboardSuggestions.length).toBeGreaterThan(0); + expect(nextActions.length).toBeGreaterThan(0); + + // 所有建议都应有唯一 id + const dashboardIds = dashboardSuggestions.map((s) => s.id); + const nextActionIds = nextActions.map((a) => a.id); + expect(new Set(dashboardIds).size).toBe(dashboardIds.length); + expect(new Set(nextActionIds).size).toBe(nextActionIds.length); + }); +}); diff --git a/frontend/__tests__/lib/platforms.test.ts b/frontend/__tests__/lib/platforms.test.ts new file mode 100644 index 0000000..f97db4c --- /dev/null +++ b/frontend/__tests__/lib/platforms.test.ts @@ -0,0 +1,71 @@ +/** + * 平台映射统一测试 + * 验证 platforms.ts 作为唯一真实来源 + */ + +import { describe, it, expect } from "vitest"; +import { PLATFORM_MAP, PLATFORMS } from "@/lib/platforms"; + +describe("PLATFORM_MAP 统一平台映射", () => { + const EXPECTED_PLATFORMS = [ + "wenxin", + "kimi", + "tongyi", + "baidu_ai", + "yuanbao", + "qingyan", + "doubao", + "tiangong", + "xinghuo", + ] as const; + + it("PLATFORM_MAP 应包含全部9个平台", () => { + const keys = Object.keys(PLATFORM_MAP); + expect(keys).toHaveLength(9); + }); + + it("PLATFORM_MAP 应包含所有预期的平台键", () => { + for (const platform of EXPECTED_PLATFORMS) { + expect(PLATFORM_MAP).toHaveProperty(platform); + } + }); + + it("PLATFORM_MAP 必须包含 baidu_ai 和 yuanbao(dashboard-health.ts 的 PLATFORM_LABELS 缺失的)", () => { + expect(PLATFORM_MAP).toHaveProperty("baidu_ai"); + expect(PLATFORM_MAP).toHaveProperty("yuanbao"); + expect(PLATFORM_MAP["baidu_ai"]).toBe("百度AI搜索"); + expect(PLATFORM_MAP["yuanbao"]).toBe("腾讯元宝"); + }); + + it("每个平台的 label 应为非空字符串", () => { + for (const [key, label] of Object.entries(PLATFORM_MAP)) { + expect(label).toBeTruthy(); + expect(typeof label).toBe("string"); + expect(label.length).toBeGreaterThan(0); + } + }); + + it("PLATFORMS 数组应与 PLATFORM_MAP 键一致", () => { + const platformKeys = PLATFORMS.map((p) => p.key); + const mapKeys = Object.keys(PLATFORM_MAP); + expect(platformKeys.sort()).toEqual(mapKeys.sort()); + }); + + it("PLATFORMS 数组中每个条目的 label 应与 PLATFORM_MAP 一致", () => { + for (const platform of PLATFORMS) { + expect(platform.label).toBe(PLATFORM_MAP[platform.key]); + } + }); +}); + +describe("dashboard-health.ts 不再定义自己的 PLATFORM_LABELS", () => { + it("dashboard-health.ts 不应导出 PLATFORM_LABELS", async () => { + const dashboardHealthModule = await import("@/types/dashboard-health"); + expect("PLATFORM_LABELS" in dashboardHealthModule).toBe(false); + }); + + it("dashboard-health.ts 不应导出 PLATFORM_ICONS", async () => { + const dashboardHealthModule = await import("@/types/dashboard-health"); + expect("PLATFORM_ICONS" in dashboardHealthModule).toBe(false); + }); +}); diff --git a/frontend/__tests__/types/health-level.test.ts b/frontend/__tests__/types/health-level.test.ts new file mode 100644 index 0000000..d2fabed --- /dev/null +++ b/frontend/__tests__/types/health-level.test.ts @@ -0,0 +1,133 @@ +/** + * HealthLevel 类型和配置的统一测试 + * 验证 dashboard-health.ts 作为唯一真实来源 + */ + +import { describe, it, expect } from "vitest"; +import { + HealthLevel, + HEALTH_LEVEL_CONFIG, +} from "@/types/dashboard-health"; +import { getHealthLevel } from "@/lib/dashboard-health"; + +describe("HealthLevel 类型统一", () => { + it("HealthLevel 应包含正确的4个值:excellent, good, pass, danger", () => { + const expectedLevels: HealthLevel[] = [ + "excellent", + "good", + "pass", + "danger", + ]; + // 验证类型推断正确 - 如果 HealthLevel 包含 "fair" 则编译失败 + const actualLevels = Object.keys(HEALTH_LEVEL_CONFIG) as HealthLevel[]; + expect(actualLevels.sort()).toEqual(expectedLevels.sort()); + }); + + it("HEALTH_LEVEL_CONFIG 应包含所有4个等级的配置", () => { + const keys = Object.keys(HEALTH_LEVEL_CONFIG); + expect(keys).toContain("excellent"); + expect(keys).toContain("good"); + expect(keys).toContain("pass"); + expect(keys).toContain("danger"); + expect(keys).toHaveLength(4); + }); + + it("HEALTH_LEVEL_CONFIG 不应包含 'fair' 键", () => { + const keys = Object.keys(HEALTH_LEVEL_CONFIG); + expect(keys).not.toContain("fair"); + }); + + it("每个 HEALTH_LEVEL_CONFIG 条目应包含完整的配置字段", () => { + for (const [key, config] of Object.entries(HEALTH_LEVEL_CONFIG)) { + expect(config).toHaveProperty("level", key); + expect(config).toHaveProperty("label"); + expect(config).toHaveProperty("icon"); + expect(config).toHaveProperty("color"); + expect(config).toHaveProperty("minScore"); + expect(config).toHaveProperty("maxScore"); + expect(config.color).toHaveProperty("bg"); + expect(config.color).toHaveProperty("text"); + expect(config.color).toHaveProperty("border"); + expect(config.label.length).toBeGreaterThan(0); + } + }); +}); + +describe("getHealthLevel 函数", () => { + it("评分 85 应返回 'excellent'", () => { + expect(getHealthLevel(85)).toBe("excellent"); + }); + + it("评分 65 应返回 'good'", () => { + expect(getHealthLevel(65)).toBe("good"); + }); + + it("评分 45 应返回 'pass'", () => { + expect(getHealthLevel(45)).toBe("pass"); + }); + + it("评分 20 应返回 'danger'", () => { + expect(getHealthLevel(20)).toBe("danger"); + }); + + it("getHealthLevel 不应返回 'fair'", () => { + // 遍历所有可能的分数,确保永远不会返回 "fair" + for (let score = 0; score <= 100; score++) { + const result = getHealthLevel(score); + expect(result).not.toBe("fair"); + } + }); + + it("边界值测试:80 为 excellent 下界", () => { + expect(getHealthLevel(80)).toBe("excellent"); + }); + + it("边界值测试:79 为 good 上界", () => { + expect(getHealthLevel(79)).toBe("good"); + }); + + it("边界值测试:60 为 good 下界", () => { + expect(getHealthLevel(60)).toBe("good"); + }); + + it("边界值测试:59 为 pass 上界", () => { + expect(getHealthLevel(59)).toBe("pass"); + }); + + it("边界值测试:40 为 pass 下界", () => { + expect(getHealthLevel(40)).toBe("pass"); + }); + + it("边界值测试:39 为 danger 上界", () => { + expect(getHealthLevel(39)).toBe("danger"); + }); + + it("边界值测试:0 为 danger", () => { + expect(getHealthLevel(0)).toBe("danger"); + }); + + it("边界值测试:100 为 excellent", () => { + expect(getHealthLevel(100)).toBe("excellent"); + }); +}); + +describe("onboarding.ts 不再定义自己的 HealthLevel", () => { + it("onboarding.ts 应从 dashboard-health 导入 HealthLevel 而非自己定义", async () => { + // 读取 onboarding.ts 源码,验证不再有独立的 HealthLevel 定义 + const onboardingModule = await import("@/types/onboarding"); + // onboarding 导出的 HealthLevel 应该与 dashboard-health 的完全一致 + const dashboardModule = await import("@/types/dashboard-health"); + + // 验证 onboarding 导出的 getHealthLevel 返回 "pass" 而非 "fair" + if ("getHealthLevel" in onboardingModule) { + const onboardingGetHealthLevel = onboardingModule.getHealthLevel as ( + score: number, + ) => string; + expect(onboardingGetHealthLevel(45)).toBe("pass"); + expect(onboardingGetHealthLevel(45)).not.toBe("fair"); + } + + // 验证 onboarding 导出的 HEALTH_LEVELS 是 dashboard-health 的 HEALTH_LEVEL_CONFIG 的别名 + expect(onboardingModule.HEALTH_LEVELS).toBe(dashboardModule.HEALTH_LEVEL_CONFIG); + }); +}); diff --git a/frontend/app/(dashboard)/brands/[id]/page.tsx b/frontend/app/(dashboard)/brands/[id]/page.tsx index 079a5c5..cc36752 100644 --- a/frontend/app/(dashboard)/brands/[id]/page.tsx +++ b/frontend/app/(dashboard)/brands/[id]/page.tsx @@ -16,7 +16,13 @@ import type { CreateCompetitorRequest, CompetitorRecommendationItem, CompetitorRecommendationResponse, + BrandScoreV2Response, } from "@/types/brand"; +import { + HEALTH_LEVEL_CONFIG, + DIMENSION_COLORS, +} from "@/types/dashboard-health"; +import { getHealthLevelClassName } from "@/lib/dashboard-health"; import { ArrowLeft, Star, @@ -75,6 +81,9 @@ export default function BrandDetailPage() { const [loadingRecommendations, setLoadingRecommendations] = useState(false); const [showRecommendations, setShowRecommendations] = useState(false); + const [scoreV2, setScoreV2] = useState(null); + const [scoreLoading, setScoreLoading] = useState(false); + const fetchBrandDetail = useCallback(async () => { const token = session?.accessToken; if (!token || !brandId) return; @@ -94,6 +103,23 @@ export default function BrandDetailPage() { fetchBrandDetail(); }, [fetchBrandDetail]); + useEffect(() => { + const fetchScore = async () => { + const token = session?.accessToken; + if (!token || !brandId) return; + try { + setScoreLoading(true); + const data = await api.brands.getScore(token, brandId); + setScoreV2(data as BrandScoreV2Response); + } catch { + // 评分获取失败不阻塞页面 + } finally { + setScoreLoading(false); + } + }; + fetchScore(); + }, [session?.accessToken, brandId]); + const handleQueryNow = async () => { if (!session?.accessToken || !brand) return; @@ -435,6 +461,66 @@ export default function BrandDetailPage() { )} + {/* GEO 评分详情 */} + + + GEO 评分详情 + {scoreV2 && ( + + {HEALTH_LEVEL_CONFIG[scoreV2.health_level].label} + + )} + + + {scoreLoading ? ( +
+ +
+ ) : scoreV2 ? ( +
+
+ + {scoreV2.overall_score} + + + {HEALTH_LEVEL_CONFIG[scoreV2.health_level].label} + +
+
+ {[ + scoreV2.mention_rate, + scoreV2.recommendation_rank, + scoreV2.sentiment_score, + scoreV2.citation_quality, + scoreV2.competitive_position, + ].map((dim) => ( +
+
+ {dim.name} + + {dim.score}/{dim.max_score} | {dim.percentage}% + +
+
+
+
+
+ ))} +
+
+ ) : ( +
+

暂无评分数据,请先执行查询

+
+ )} + + + {/* 监控平台 */} diff --git a/frontend/app/(dashboard)/compare/page.tsx b/frontend/app/(dashboard)/compare/page.tsx index 435fa9c..c086227 100644 --- a/frontend/app/(dashboard)/compare/page.tsx +++ b/frontend/app/(dashboard)/compare/page.tsx @@ -1,12 +1,10 @@ "use client"; -import { useEffect, useState } from "react"; -import { useSession } from "next-auth/react"; +import { useEffect } from "react"; import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card"; import { Skeleton } from "@/components/ui/skeleton"; import { ScoreCard } from "@/components/dashboard/ScoreCard"; import { CompetitorRadarChart } from "@/components/charts/CompetitorRadarChart"; -import { api } from "@/lib/api"; import { Star, TrendingUp, TrendingDown, Minus, BarChart3 } from "lucide-react"; import { Select, @@ -23,14 +21,10 @@ import { TableHeader, TableRow, } from "@/components/ui/table"; -import { - BrandListItem, - BrandListResponse, - CompareResponse, - CompareItem, -} from "@/types/brand"; +import { CompareItem } from "@/types/brand"; import { NextActionCard } from "@/components/dashboard/NextActionCard"; import type { ActionContext } from "@/types/next-action"; +import { useCompareData } from "@/lib/hooks/use-compare-data"; // 趋势图标组件 function TrendIcon({ trend, value }: { trend: string; value: number }) { @@ -75,61 +69,23 @@ const METRICS = [ ] as const; export default function ComparePage() { - const { data: session } = useSession(); - const [brands, setBrands] = useState([]); - const [selectedBrandId, setSelectedBrandId] = useState(""); - const [compareData, setCompareData] = useState(null); - const [loading, setLoading] = useState(true); - const [error, setError] = useState(null); + const { + brands, + selectedBrandId, + setSelectedBrandId, + compareData, + isLoading: loading, + error: apiError, + } = useCompareData(); - // 加载品牌列表 + // 品牌列表加载完成后自动选择第一个品牌 useEffect(() => { - const token = session?.accessToken; - if (!token) return; - - async function loadBrands() { - if (!token) return; - try { - setLoading(true); - const response = (await api.brands.list(token)) as BrandListResponse; - const brandItems = response.items || []; - setBrands(brandItems); - if (brandItems.length > 0 && !selectedBrandId) { - setSelectedBrandId(brandItems[0].id); - } - } catch (err) { - setError(err instanceof Error ? err.message : "加载品牌列表失败"); - } finally { - setLoading(false); - } + if (brands.length > 0 && !selectedBrandId) { + setSelectedBrandId(brands[0].id); } + }, [brands, selectedBrandId, setSelectedBrandId]); - loadBrands(); - }, [session?.accessToken, selectedBrandId]); - - // 加载对比数据 - useEffect(() => { - const token = session?.accessToken; - if (!token || !selectedBrandId) return; - - async function loadCompareData() { - if (!token || !selectedBrandId) return; - try { - setLoading(true); - const data = (await api.brands.getCompare( - token, - selectedBrandId, - )) as CompareResponse; - setCompareData(data); - } catch (err) { - setError(err instanceof Error ? err.message : "加载对比数据失败"); - } finally { - setLoading(false); - } - } - - loadCompareData(); - }, [session?.accessToken, selectedBrandId]); + const error = apiError?.message ?? null; if (loading) { return ( diff --git a/frontend/app/(dashboard)/dashboard/ai-engines/page.tsx b/frontend/app/(dashboard)/dashboard/ai-engines/page.tsx index 6fda963..8cf1c97 100644 --- a/frontend/app/(dashboard)/dashboard/ai-engines/page.tsx +++ b/frontend/app/(dashboard)/dashboard/ai-engines/page.tsx @@ -413,8 +413,8 @@ export default function AIEnginesPage() { const queryMutation = useApiMutation("/api/v1/ai-engines/query"); + brand_name: string; + }>("/api/v1/ai-engines/query-batch"); const brands = brandsData?.items ?? []; @@ -441,7 +441,7 @@ export default function AIEnginesPage() { const result = await queryMutation.trigger({ engines: selectedEngines, query: queryText.trim(), - brand_id: selectedBrandId, + brand_name: brandName, }); if (result) { setQueryResults(result); diff --git a/frontend/app/(dashboard)/dashboard/content/page.tsx b/frontend/app/(dashboard)/dashboard/content/page.tsx index 016433a..6257bb0 100644 --- a/frontend/app/(dashboard)/dashboard/content/page.tsx +++ b/frontend/app/(dashboard)/dashboard/content/page.tsx @@ -1,6 +1,6 @@ "use client"; -import { useState, useCallback, useEffect } from "react"; +import { useState, useCallback } from "react"; import { useRouter } from "next/navigation"; import { Card, CardContent } from "@/components/ui/card"; import { Button } from "@/components/ui/button"; @@ -38,12 +38,11 @@ import { RefreshCw, } from "lucide-react"; import { - contentsApi, contentGenerationApi, - knowledgeApi, + contentsApi, type Content, - type KnowledgeBase, } from "@/lib/api"; +import { useContentData } from "@/lib/hooks/use-content-data"; // ─── Types ─────────────────────────────────────────────────────────────────── @@ -250,11 +249,18 @@ function EmptyState({ onGenerate }: { onGenerate: () => void }) { export default function ContentPage() { const router = useRouter(); - // Data states - const [contents, setContents] = useState([]); - const [knowledgeBases, setKnowledgeBases] = useState([]); - const [pageLoading, setPageLoading] = useState(true); - const [pageError, setPageError] = useState(null); + // 使用 useContentData hook 获取数据 + const { + contents: hookContents, + knowledgeBases: hookKnowledgeBases, + isLoading: pageLoading, + error: pageApiError, + refreshContents, + } = useContentData(); + + const contents = hookContents ?? []; + const knowledgeBases = hookKnowledgeBases ?? []; + const pageError = pageApiError?.message ?? null; // Dialog states const [dialogOpen, setDialogOpen] = useState(false); @@ -275,28 +281,6 @@ export default function ContentPage() { pipelineStepsTemplate.map((s) => ({ ...s })) ); - // Fetch content list and knowledge bases on mount - useEffect(() => { - async function fetchPageData() { - try { - setPageLoading(true); - setPageError(null); - const [contentList, kbList] = await Promise.all([ - contentsApi.list(), - knowledgeApi.listBases(undefined, "enterprise"), - ]); - setContents(contentList ?? []); - setKnowledgeBases(kbList ?? []); - } catch (err) { - console.error("Content page fetch error:", err); - setPageError(err instanceof Error ? err.message : "数据加载失败"); - } finally { - setPageLoading(false); - } - } - fetchPageData(); - }, []); - const resetForm = useCallback(() => { setKeyword(""); setPlatform(""); @@ -391,7 +375,7 @@ export default function ContentPage() { content_type: "article", tags: [platform, keyword.trim()], }); - setContents((prev) => [saved, ...prev]); + refreshContents(); } catch { // Ignore save error, content was still generated } diff --git a/frontend/app/(dashboard)/dashboard/diagnosis/page.tsx b/frontend/app/(dashboard)/dashboard/diagnosis/page.tsx index f9d93f7..1f82aa8 100644 --- a/frontend/app/(dashboard)/dashboard/diagnosis/page.tsx +++ b/frontend/app/(dashboard)/dashboard/diagnosis/page.tsx @@ -1,6 +1,7 @@ "use client"; import { useState, useMemo } from "react"; +import { useRouter } from "next/navigation"; import { Card, CardContent, CardHeader, CardTitle, CardDescription } from "@/components/ui/card"; import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs"; import { Badge } from "@/components/ui/badge"; @@ -18,6 +19,7 @@ import { Zap, ArrowRight, RefreshCw, + Rocket, } from "lucide-react"; import { useApi } from "@/lib/hooks/use-api"; import { diagnosisApi } from "@/lib/api/diagnosis"; @@ -189,6 +191,7 @@ function getPriorityColor(priority: string) { } export default function DiagnosisPage() { + const router = useRouter(); const [activeTab, setActiveTab] = useState("combined"); const { data: brandsData, isLoading: brandsLoading, refresh: refreshBrands } = useApi( @@ -422,6 +425,13 @@ export default function DiagnosisPage() {
); })} +
)} diff --git a/frontend/app/(dashboard)/dashboard/monitoring/page.tsx b/frontend/app/(dashboard)/dashboard/monitoring/page.tsx index 7b1f585..d3cd222 100644 --- a/frontend/app/(dashboard)/dashboard/monitoring/page.tsx +++ b/frontend/app/(dashboard)/dashboard/monitoring/page.tsx @@ -47,7 +47,7 @@ interface AlertsResponse { } interface UnreadCountResponse { - count: number; + unread_count: number; } /* ─── Stat Card Component ────────────────────────────────────────────────────*/ @@ -322,7 +322,7 @@ export default function MonitoringPage() { }; const alerts = alertsData?.items ?? []; - const unreadCount = unreadData?.count ?? 0; + const unreadCount = unreadData?.unread_count ?? 0; const filteredAlerts = alerts.filter((alert) => { if (filterType !== "all" && alert.type !== filterType) return false; diff --git a/frontend/app/(dashboard)/dashboard/page.tsx b/frontend/app/(dashboard)/dashboard/page.tsx index 6b17839..b4e3c03 100644 --- a/frontend/app/(dashboard)/dashboard/page.tsx +++ b/frontend/app/(dashboard)/dashboard/page.tsx @@ -2,10 +2,7 @@ import { useRouter } from "next/navigation"; import Link from "next/link"; -import { - MetricCard, - StageProgress, -} from "@/components/business"; +import { MetricCard, StageProgress } from "@/components/business"; import { Button } from "@/components/ui/button"; import { Badge } from "@/components/ui/badge"; import { @@ -18,10 +15,18 @@ import { Plus, ArrowRight, Zap, + Lock, } from "lucide-react"; import { type GeoProject, type LifecycleStats } from "@/lib/api"; import { useApi } from "@/lib/hooks/use-api"; -import { LoadingState, ErrorState, EmptyState } from "@/components/ui/api-states"; +import { + LoadingState, + ErrorState, + EmptyState, +} from "@/components/ui/api-states"; +import { SubscriptionStatus } from "@/components/subscription/SubscriptionStatus"; +import { UsageProgress } from "@/components/subscription/UsageProgress"; +import { ROICard } from "@/components/dashboard/ROICard"; /* ─── Helpers ─────────────────────────────────────────────────────────────────*/ @@ -105,11 +110,14 @@ export default function DashboardPage() { // "用户未关联组织" 类错误视为空状态 const isOrgError = (err: Error | undefined) => - err?.message.includes("未关联组织") || err?.message.includes("No organization"); + err?.message.includes("未关联组织") || + err?.message.includes("No organization"); const hasOrgError = isOrgError(projectsError) || isOrgError(statsError); const error = - !hasOrgError && (projectsError || statsError) ? projectsError || statsError : undefined; + !hasOrgError && (projectsError || statsError) + ? projectsError || statsError + : undefined; const safeProjects: GeoProject[] = hasOrgError ? [] : (projects ?? []); const safeStats: LifecycleStats | null = hasOrgError ? null : (stats ?? null); @@ -125,7 +133,9 @@ export default function DashboardPage() {
-
+
+ +
@@ -139,9 +149,9 @@ export default function DashboardPage() { return (
-

品牌健康中心

-
- +

品牌健康中心

+
+
); } @@ -152,7 +162,9 @@ export default function DashboardPage() {

品牌健康中心

-

GEO和SEO是AI营销时代的共生体

+

+ GEO和SEO是AI营销时代的共生体 +

} @@ -176,12 +188,52 @@ export default function DashboardPage() { const stages = buildStages(project.current_stage); const recommendation = getRecommendation(project.current_stage); - const citationRate = safeStats?.avg_ai_citation_rate != null - ? `${(safeStats.avg_ai_citation_rate * 100).toFixed(1)}%` - : "—"; + const citationRate = + safeStats?.avg_ai_citation_rate != null + ? `${(safeStats.avg_ai_citation_rate * 100).toFixed(1)}%` + : "—"; + + const userPlan = "free"; + const planExpiresAt = undefined; + const usageData = { + queries: { current: 2, limit: 3 }, + brands: { current: 1, limit: 1 }, + alerts: { current: 0, limit: 0 }, + }; + const roiData = { + roiPercentage: 0, + valueGenerated: 0, + subscriptionCost: 0, + }; + const isFreePlan = userPlan === "free"; return (
+ {/* Subscription Status Bar */} +
+ +
+ + + +
+
+ {/* 1. Page Title */}
@@ -230,6 +282,49 @@ export default function DashboardPage() { />
+ {/* ROI Card + Feature Lock */} +
+ + {isFreePlan && ( +
+

+ 解锁更多功能 +

+
+ {[ + { label: "多品牌监控", desc: "同时监控3个以上品牌" }, + { label: "无限告警", desc: "不限制告警通知数量" }, + { label: "竞品对比", desc: "完整的竞品雷达图分析" }, + { label: "AI优化建议", desc: "基于DeepSeek的个性化建议" }, + ].map((feature) => ( +
+ +
+

+ {feature.label} +

+

{feature.desc}

+
+
+ ))} +
+ + + +
+ )} +
+ {/* 3. Stage Progress */}
@@ -308,8 +403,12 @@ export default function DashboardPage() {
-

功能开发中

-

Agent状态监控即将上线

+

+ 功能开发中 +

+

+ Agent状态监控即将上线 +

diff --git a/frontend/app/(dashboard)/dashboard/roi/page.tsx b/frontend/app/(dashboard)/dashboard/roi/page.tsx new file mode 100644 index 0000000..5e6f735 --- /dev/null +++ b/frontend/app/(dashboard)/dashboard/roi/page.tsx @@ -0,0 +1,342 @@ +"use client"; + +import * as React from "react"; +import { useApi } from "@/lib/hooks/use-api"; +import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card"; +import { Badge } from "@/components/ui/badge"; +import { LoadingState, ErrorState, EmptyState } from "@/components/ui/api-states"; +import { + Table, + TableBody, + TableCell, + TableHead, + TableHeader, + TableRow, +} from "@/components/ui/table"; +import { + TrendingUp, + TrendingDown, + DollarSign, + Target, + ArrowUpRight, + ArrowDownRight, + BarChart3, +} from "lucide-react"; +import { cn } from "@/lib/utils"; +import type { ROIReport, ABComparisonResponse, BrandAttributionSummary } from "@/lib/api/attribution"; + +export default function ROIPage() { + const [selectedBrandId, setSelectedBrandId] = React.useState(null); + + const { + data: roiData, + isLoading: roiLoading, + error: roiError, + refresh: roiRefresh, + } = useApi( + selectedBrandId ? `/api/v1/attribution/roi/${selectedBrandId}` : null + ); + + const { + data: abData, + isLoading: abLoading, + } = useApi( + selectedBrandId ? `/api/v1/attribution/ab-comparison/${selectedBrandId}` : null + ); + + const { + data: summaryData, + isLoading: summaryLoading, + } = useApi( + selectedBrandId ? `/api/v1/attribution/brand/${selectedBrandId}` : null + ); + + if (!selectedBrandId) { + return ( +
+
+

效果归因与ROI报告

+

+ 追踪内容发布效果,计算投资回报率 +

+
+ } + message="请选择品牌" + description="在品牌管理页面选择一个品牌后查看ROI报告" + /> +
+ ); + } + + if (roiLoading || abLoading || summaryLoading) { + return ( +
+
+

效果归因与ROI报告

+

+ 追踪内容发布效果,计算投资回报率 +

+
+
+ {Array.from({ length: 4 }).map((_, i) => ( + + +
+
+
+
+ + + ))} +
+
+ ); + } + + if (roiError) { + return ( +
+
+

效果归因与ROI报告

+

+ 追踪内容发布效果,计算投资回报率 +

+
+ +
+ ); + } + + const roi = roiData; + const ab = abData; + + return ( +
+
+

效果归因与ROI报告

+

+ 追踪内容发布效果,计算投资回报率 +

+
+ + {roi && ( + <> +
+ + +
+
+

ROI

+

= 0 ? "text-emerald-600" : "text-red-600" + )}> + {roi.roi_percentage >= 0 ? "+" : ""}{roi.roi_percentage}% +

+
+
= 0 ? "bg-emerald-50" : "bg-red-50" + )}> + {roi.roi_percentage >= 0 + ? + : + } +
+
+
+
+ + + +
+
+

价值产出

+

+ ¥{roi.value_generated.toLocaleString()} +

+
+
+ +
+
+
+
+ + + +
+
+

订阅成本

+

+ ¥{roi.subscription_cost.toLocaleString()} +

+

+ 当前套餐: {roi.current_plan} +

+
+
+ +
+
+
+
+ + + +
+
+

分数提升

+

= 0 ? "text-emerald-600" : "text-red-600" + )}> + {roi.total_score_delta >= 0 ? "+" : ""}{roi.total_score_delta} +

+

+ 盈亏平衡: {roi.break_even_delta} +

+
+
+ +
+
+
+
+
+ + {ab && ab.dimensions.length > 0 && ( + + + + + A/B 对比分析 + + + +
+
+ 整体变化: + = 0 ? "text-emerald-600" : "text-red-600" + )}> + {ab.overall_delta >= 0 ? "+" : ""}{ab.overall_delta} + + + ({ab.overall_before} → {ab.overall_after}) + +
+
+ + + + 维度 + 发布前 + 发布后 + 变化 + 状态 + + + + {ab.dimensions.map((dim) => ( + + {dim.name} + {dim.before} + {dim.after} + 0 ? "text-emerald-600" : dim.delta < 0 ? "text-red-600" : "" + )}> + {dim.delta > 0 ? "+" : ""}{dim.delta} + + + {dim.improved ? ( + + + 提升 + + ) : ( + + + 下降 + + )} + + + ))} + +
+
+
+ )} + + + + 归因追踪记录 + + + {roi.tracking_records.length === 0 ? ( + } + message="暂无追踪记录" + description="发布内容后将自动开始归因追踪" + /> + ) : ( + + + + 状态 + 基准分数 + 当前分数 + 分数变化 + ROI + 创建时间 + + + + {roi.tracking_records.map((record) => ( + + + + {record.status === "tracking" ? "追踪中" : + record.status === "completed" ? "已完成" : "已过期"} + + + {record.baseline_score} + {record.current_score ?? "-"} + 0 ? "text-emerald-600" : + (record.score_delta ?? 0) < 0 ? "text-red-600" : "" + )}> + {record.score_delta != null + ? `${record.score_delta > 0 ? "+" : ""}${record.score_delta}` + : "-"} + + + {record.roi_percentage != null + ? `${record.roi_percentage > 0 ? "+" : ""}${record.roi_percentage}%` + : "-"} + + + {new Date(record.created_at).toLocaleDateString("zh-CN")} + + + ))} + +
+ )} +
+
+ + )} +
+ ); +} diff --git a/frontend/app/(dashboard)/dashboard/strategy/page.tsx b/frontend/app/(dashboard)/dashboard/strategy/page.tsx index 869dac3..da1b41c 100644 --- a/frontend/app/(dashboard)/dashboard/strategy/page.tsx +++ b/frontend/app/(dashboard)/dashboard/strategy/page.tsx @@ -1,22 +1,513 @@ "use client"; -import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card"; +import { useState, useEffect, useCallback } from "react"; +import { useSession } from "next-auth/react"; +import { Card, CardContent, CardHeader, CardTitle, CardDescription } from "@/components/ui/card"; +import { Button } from "@/components/ui/button"; +import { Badge } from "@/components/ui/badge"; +import { Skeleton } from "@/components/ui/skeleton"; +import { useApi } from "@/lib/hooks/use-api"; +import { strategyApi } from "@/lib/api/strategy"; +import type { GeoPlan, GeoPlanAction, GeoPlanListResponse } from "@/types/strategy"; +import { + Sparkles, + Loader2, + FileText, + Edit3, + Search, + Tag, + Target, + ArrowRight, + CheckCircle2, + Clock, + SkipForward, + Play, + Zap, + Calendar, + TrendingUp, + AlertTriangle, + Users, +} from "lucide-react"; +import { useRouter } from "next/navigation"; + +const ACTION_TYPE_CONFIG: Record< + GeoPlanAction["action_type"], + { label: string; icon: React.ReactNode; color: string; bgColor: string } +> = { + content_creation: { + label: "内容创建", + icon: , + color: "text-blue-600", + bgColor: "bg-blue-50", + }, + content_optimization: { + label: "内容优化", + icon: , + color: "text-emerald-600", + bgColor: "bg-emerald-50", + }, + query_expansion: { + label: "查询扩展", + icon: , + color: "text-purple-600", + bgColor: "bg-purple-50", + }, + schema_optimization: { + label: "Schema优化", + icon: , + color: "text-amber-600", + bgColor: "bg-amber-50", + }, + platform_targeting: { + label: "平台定向", + icon: , + color: "text-rose-600", + bgColor: "bg-rose-50", + }, +}; + +const PRIORITY_CONFIG: Record< + GeoPlanAction["priority"], + { label: string; className: string } +> = { + high: { label: "高优先级", className: "bg-red-100 text-red-700 hover:bg-red-100" }, + medium: { label: "中优先级", className: "bg-amber-100 text-amber-700 hover:bg-amber-100" }, + low: { label: "低优先级", className: "bg-blue-100 text-blue-700 hover:bg-blue-100" }, +}; + +const DIFFICULTY_CONFIG: Record< + GeoPlanAction["difficulty"], + { label: string; className: string } +> = { + easy: { label: "简单", className: "bg-emerald-50 text-emerald-700 hover:bg-emerald-50" }, + medium: { label: "中等", className: "bg-amber-50 text-amber-700 hover:bg-amber-50" }, + hard: { label: "困难", className: "bg-red-50 text-red-700 hover:bg-red-50" }, +}; + +const STATUS_CONFIG: Record< + GeoPlanAction["status"], + { label: string; className: string; icon: React.ReactNode } +> = { + pending: { + label: "待处理", + className: "bg-gray-100 text-gray-600 hover:bg-gray-100", + icon: , + }, + in_progress: { + label: "进行中", + className: "bg-blue-100 text-blue-700 hover:bg-blue-100", + icon: , + }, + completed: { + label: "已完成", + className: "bg-emerald-100 text-emerald-700 hover:bg-emerald-100", + icon: , + }, + skipped: { + label: "已跳过", + className: "bg-gray-50 text-gray-400 hover:bg-gray-50", + icon: , + }, +}; + +function ActionCard({ + action, + onExecute, + executing, +}: { + action: GeoPlanAction; + onExecute: (actionId: string) => void; + executing: boolean; +}) { + const typeConfig = ACTION_TYPE_CONFIG[action.action_type]; + const priorityConfig = PRIORITY_CONFIG[action.priority]; + const difficultyConfig = DIFFICULTY_CONFIG[action.difficulty]; + const statusConfig = STATUS_CONFIG[action.status]; + const canExecute = + (action.action_type === "content_creation" || action.action_type === "content_optimization") && + action.status !== "completed" && + action.status !== "skipped"; -export default function StrategyPage() { return ( -
-
-

策略制定

-

制定GEO优化策略、关键词规划与目标设定

+
+
+
+
+ {typeConfig.icon} +
+
+
+ {typeConfig.label} + {priorityConfig.label} +
+

{action.title}

+
+
+ + {statusConfig.icon} + {statusConfig.label} + +
+ +

{action.reason}

+ + {action.estimated_impact && ( +
+ + 预期效果: {action.estimated_impact} +
+ )} + +
+
+ + {difficultyConfig.label} + + {action.target_keyword && ( + + 关键词: {action.target_keyword} + + )} + {action.target_platform && ( + + 平台: {action.target_platform} + + )} +
+ {canExecute && ( + + )}
- - - 功能开发中 - - -

此功能正在开发中,敬请期待。

-
-
+
+ ); +} + +function WeeklyTimeline({ weeklyPlan }: { weeklyPlan: Record[] }) { + return ( + + + + + 周计划时间线 + + + +
+ {weeklyPlan.map((week, idx) => { + const weekNum = (week.week as number) ?? idx + 1; + const title = (week.title as string) ?? `第${weekNum}周`; + const tasks = (week.tasks as string[]) ?? []; + return ( +
+
+
+ {weekNum} +
+ {idx < weeklyPlan.length - 1 && ( +
+ )} +
+
+

{title}

+ {tasks.length > 0 && ( +
    + {tasks.map((task, taskIdx) => ( +
  • + + {task} +
  • + ))} +
+ )} +
+
+ ); + })} +
+ + + ); +} + +function CompetitorWarning({ brandId }: { brandId: string }) { + const router = useRouter(); + const { data: brandDetail } = useApi<{ competitors: unknown[] }>( + brandId ? `/api/v1/brands/${brandId}` : null + ); + const hasCompetitors = (brandDetail?.competitors?.length ?? 0) > 0; + + if (hasCompetitors) return null; + + return ( +
+ +
+

尚未添加竞品

+

+ 添加竞品后,AI将生成更精准的对比分析和优化建议,帮助您了解与竞品的差距并制定针对性策略。 +

+ +
+
+ ); +} + +export default function StrategyPage() { + const { data: session } = useSession(); + const token = (session as Record)?.accessToken as string | undefined; + const [brandId, setBrandId] = useState(""); + const [currentPlan, setCurrentPlan] = useState(null); + const [generating, setGenerating] = useState(false); + const [executingActionId, setExecutingActionId] = useState(null); + + const { data: brandsData } = useApi<{ items: { id: string }[] }>("/api/v1/brands/?limit=1"); + useEffect(() => { + if (brandsData?.items && brandsData.items.length > 0 && !brandId) { + setBrandId(brandsData.items[0].id); + } + }, [brandsData, brandId]); + + const { data: plansData, isLoading: plansLoading, refresh: refreshPlans } = useApi( + brandId ? `/api/v1/strategy/brand/${brandId}` : null + ); + + useEffect(() => { + if (plansData?.plans && plansData.plans.length > 0) { + const activePlan = plansData.plans.find((p) => p.status === "active") ?? plansData.plans[0]; + setCurrentPlan(activePlan); + } + }, [plansData]); + + const handleGeneratePlan = useCallback(async () => { + if (!token || !brandId) return; + setGenerating(true); + try { + const plan = await strategyApi.generatePlan(token, brandId) as GeoPlan; + setCurrentPlan(plan); + refreshPlans(); + } catch (err) { + console.error("Failed to generate plan:", err); + } finally { + setGenerating(false); + } + }, [token, brandId, refreshPlans]); + + const handleExecuteAction = useCallback(async (actionId: string) => { + if (!token) return; + setExecutingActionId(actionId); + try { + await strategyApi.executeAction(token, actionId); + await strategyApi.updateActionStatus(token, actionId, "completed"); + if (currentPlan) { + setCurrentPlan({ + ...currentPlan, + actions: currentPlan.actions.map((a) => + a.id === actionId ? { ...a, status: "completed" as const } : a + ), + }); + } + } catch (err) { + console.error("Failed to execute action:", err); + } finally { + setExecutingActionId(null); + } + }, [token, currentPlan]); + + const completedCount = currentPlan?.actions.filter((a) => a.status === "completed").length ?? 0; + const totalActions = currentPlan?.actions.length ?? 0; + const progressPercent = totalActions > 0 ? Math.round((completedCount / totalActions) * 100) : 0; + + const weeklyPlan = currentPlan?.plan_data?.weekly_plan as Record[] | undefined; + + if (plansLoading && !currentPlan) { + return ( +
+
+
+

策略制定

+

制定GEO优化策略、关键词规划与目标设定

+
+ +
+
+ {[1, 2, 3].map((i) => ( + + + + + + + ))} +
+ + + {[1, 2, 3].map((i) => ( + + ))} + + +
+ ); + } + + if (!currentPlan) { + return ( +
+
+
+

策略制定

+

制定GEO优化策略、关键词规划与目标设定

+
+ +
+ + +
+ +
+

还没有优化方案

+

+ 基于诊断数据,AI将为您制定个性化GEO优化方案 +

+ +
+
+ {brandId && } +
+ ); + } + + return ( +
+
+
+

策略制定

+

制定GEO优化策略、关键词规划与目标设定

+
+ +
+ +
+ + +
诊断分数 → 目标分数
+
+ {currentPlan.diagnosis_score} + + {currentPlan.target_score} +
+
+
+ + +
预计周数
+
+ + {currentPlan.estimated_weeks} 周 +
+
+
+ + +
行动项进度
+
+ {completedCount} + / {totalActions} 完成 +
+
+
+
+ + +
+ + {brandId && } + + + + 行动项列表 + 按优先级排列的优化行动项,点击"AI生成内容"一键执行 + + + {currentPlan.actions + .sort((a, b) => a.sort_order - b.sort_order) + .map((action) => ( + + ))} + + + + {weeklyPlan && weeklyPlan.length > 0 && ( + + )}
); } diff --git a/frontend/app/(dashboard)/dashboard/suggestions/page.tsx b/frontend/app/(dashboard)/dashboard/suggestions/page.tsx index f4d0eb4..65e9be5 100644 --- a/frontend/app/(dashboard)/dashboard/suggestions/page.tsx +++ b/frontend/app/(dashboard)/dashboard/suggestions/page.tsx @@ -1,6 +1,6 @@ "use client"; -import { useEffect, useState, useCallback } from "react"; +import { useState, useEffect } from "react"; import { useSession } from "next-auth/react"; import { Card, CardContent } from "@/components/ui/card"; import { Button } from "@/components/ui/button"; @@ -12,7 +12,8 @@ import { SelectTrigger, SelectValue, } from "@/components/ui/select"; -import { api } from "@/lib/api"; +import { useApi } from "@/lib/hooks/use-api"; +import { useSuggestionsData } from "@/lib/hooks/use-suggestions-data"; import { SuggestionList } from "@/components/dashboard/SuggestionCard"; import type { Suggestion, @@ -64,10 +65,6 @@ interface ProgressStats { export default function SuggestionsPage() { const { data: session } = useSession(); - const [suggestions, setSuggestions] = useState([]); - const [loading, setLoading] = useState(true); - const [regenerating, setRegenerating] = useState(false); - const [error, setError] = useState(null); const [brandId, setBrandId] = useState(""); // 筛选条件 @@ -76,101 +73,33 @@ export default function SuggestionsPage() { const [statusFilter, setStatusFilter] = useState("all"); // 获取品牌ID + const { data: brandsData } = useApi<{ items: { id: string }[] }>("/api/v1/brands/?limit=1"); useEffect(() => { - const token = session?.accessToken; - if (!token) return; - - async function loadBrandId() { - try { - const brandsData = await api.brands.list(token!, { limit: 1 }); - const brands = brandsData as { items?: { id: string }[] }; - if (brands.items && brands.items.length > 0) { - setBrandId(brands.items[0].id); - } - } catch (err) { - console.error("获取品牌失败:", err); - } + if (brandsData?.items && brandsData.items.length > 0 && !brandId) { + setBrandId(brandsData.items[0].id); } + }, [brandsData, brandId]); - loadBrandId(); - }, [session?.accessToken]); + // 使用 useSuggestionsData hook 获取建议数据 + const { + suggestions, + isLoading: loading, + error: hookError, + refresh, + regenerate, + isRegenerating: regenerating, + regenerateError, + updateStatus: handleStatusChange, + } = useSuggestionsData({ + brandId, + filters: { type: typeFilter, priority: priorityFilter, status: statusFilter }, + }); - // 加载建议 - const loadSuggestions = useCallback(async () => { - const token = session?.accessToken; - if (!token || !brandId) return; - - try { - setLoading(true); - setError(null); - const params: Record = {}; - if (typeFilter !== "all") params.type = typeFilter; - if (priorityFilter !== "all") params.priority = priorityFilter; - if (statusFilter !== "all") params.status = statusFilter; - - const data = (await api.suggestions.getSuggestions( - token, - brandId, - params, - )) as SuggestionListResponse; - setSuggestions(data.suggestions || []); - } catch (err) { - console.error("加载建议失败:", err); - setError(err instanceof Error ? err.message : "加载建议失败"); - } finally { - setLoading(false); - } - }, [session?.accessToken, brandId, typeFilter, priorityFilter, statusFilter]); - - useEffect(() => { - loadSuggestions(); - }, [loadSuggestions]); - - // 重新生成建议 - const handleRegenerate = async () => { - const token = session?.accessToken; - if (!token || !brandId) return; - - try { - setRegenerating(true); - const data = (await api.suggestions.regenerateSuggestions( - token, - brandId, - )) as SuggestionListResponse; - setSuggestions(data.suggestions || []); - } catch (err) { - console.error("重新生成建议失败:", err); - setError(err instanceof Error ? err.message : "重新生成建议失败"); - } finally { - setRegenerating(false); - } - }; - - // 更新建议状态 - const handleStatusChange = async (suggestionId: string, newStatus: string) => { - const token = session?.accessToken; - if (!token || !brandId) return; - - try { - await api.suggestions.updateSuggestionStatus( - token, - brandId, - suggestionId, - newStatus, - ); - // 更新本地状态 - setSuggestions((prev) => - prev.map((s) => - s.id === suggestionId ? { ...s, status: newStatus as SuggestionStatus } : s, - ), - ); - } catch (err) { - console.error("更新建议状态失败:", err); - } - }; + const error = hookError?.message ?? regenerateError?.message ?? null; // 计算进度统计 - const progressStats: ProgressStats = suggestions.reduce( + const suggestionList = suggestions ?? []; + const progressStats: ProgressStats = suggestionList.reduce( (acc, s) => { acc.total++; acc[s.status as keyof ProgressStats] = @@ -190,7 +119,7 @@ export default function SuggestionsPage() { : 0; // 加载骨架屏 - if (loading && suggestions.length === 0) { + if (loading && suggestionList.length === 0) { return (
@@ -231,7 +160,7 @@ export default function SuggestionsPage() {

diff --git a/frontend/app/(dashboard)/onboarding/OnboardingProgress.tsx b/frontend/app/(dashboard)/onboarding/OnboardingProgress.tsx index ef77e9c..51648bb 100644 --- a/frontend/app/(dashboard)/onboarding/OnboardingProgress.tsx +++ b/frontend/app/(dashboard)/onboarding/OnboardingProgress.tsx @@ -16,14 +16,11 @@ export function OnboardingProgress({
{ONBOARDING_STEPS.map((step, index) => { - const stepNumber = index + 1; - const isCompleted = stepNumber < currentStep; - const isCurrent = stepNumber === currentStep; - const _isUpcoming = stepNumber > currentStep; + const isCompleted = step.id < currentStep; + const isCurrent = step.id === currentStep; return (
- {/* 步骤圆圈 */}
) : ( - {stepNumber} + {step.id + 1} )}
- {/* 步骤标题 - 只在当前和已完成步骤显示 */}
- {/* 连接线 */} {index < ONBOARDING_STEPS.length - 1 && (
void; + initialBrandName?: string; +} + +const DIMENSION_ICONS: Record = { + "内容可提取性": FileText, + "E-E-A-T信号": Shield, + "引用就绪度": Activity, +}; + +export function Step0HealthScore({ + onNext, + initialBrandName = "", +}: Step0HealthScoreProps) { + const [brandName, setBrandName] = useState(initialBrandName); + const [loading, setLoading] = useState(false); + const [error, setError] = useState(null); + const [result, setResult] = useState(null); + + const handleCheck = async () => { + const trimmed = brandName.trim(); + if (!trimmed || trimmed.length < 2) { + setError("请输入至少2个字符的品牌名称"); + return; + } + + setLoading(true); + setError(null); + + try { + const data = await healthScoreApi.getHealthScore(trimmed); + setResult(data); + } catch (err) { + setError(err instanceof Error ? err.message : "检测失败,请重试"); + } finally { + setLoading(false); + } + }; + + const handleNext = () => { + onNext(brandName.trim(), result); + }; + + const getStatusColor = (status: string) => { + switch (status) { + case "good": + return "text-emerald-600"; + case "warning": + return "text-amber-600"; + case "fail": + return "text-red-600"; + default: + return "text-muted-foreground"; + } + }; + + const getProgressColor = (percentage: number) => { + if (percentage >= 60) return "bg-emerald-500"; + if (percentage >= 30) return "bg-amber-500"; + return "bg-red-500"; + }; + + return ( +
+
+
+ +
+

免费检测您的GEO健康分

+

+ 输入品牌名称,即刻查看您的品牌在AI搜索中的表现 +

+
+ + + +
+
+ { + setBrandName(e.target.value); + setError(null); + }} + placeholder="输入品牌名称,例如:华为" + className="h-12 text-base" + maxLength={50} + autoFocus + onKeyDown={(e) => { + if (e.key === "Enter" && !loading) { + handleCheck(); + } + }} + /> + +
+ + {error && ( +
+ + {error} +
+ )} +
+
+
+ + {result && ( +
+ + +
+
+ + {result.overall_score} + + + /100 + +
+ + {result.health_level_label} + +

+ {result.brand_name} 的GEO健康分 +

+
+
+
+ + + +

核心维度评分

+ {result.dimensions.map((dim: HealthScoreDimension) => { + const Icon = DIMENSION_ICONS[dim.name] || Activity; + return ( +
+
+
+ + {dim.name} +
+ + {round(dim.percentage, 1)}% + +
+
+
+
+
+ ); + })} + + {result.recommendations.length > 0 && ( +
+

关键问题

+ {result.recommendations.slice(0, 3).map((rec, i) => ( +
+ + {rec.title} +
+ ))} +
+ )} + +
+

+ 免费版仅展示3个核心维度 · 升级Pro解锁完整6维度诊断 +

+
+ + + + + +

+ 免费注册即可保存品牌并获取持续监控 +

+
+ )} + + {!result && !loading && ( +
+
+ 无需注册,即刻查看您的GEO健康分 +
+ )} +
+ ); +} + +function round(value: number, decimals: number): number { + const factor = Math.pow(10, decimals); + return Math.round(value * factor) / factor; +} diff --git a/frontend/app/(dashboard)/onboarding/Step4HealthReport.tsx b/frontend/app/(dashboard)/onboarding/Step4HealthReport.tsx index 46c81e1..8d4abf6 100644 --- a/frontend/app/(dashboard)/onboarding/Step4HealthReport.tsx +++ b/frontend/app/(dashboard)/onboarding/Step4HealthReport.tsx @@ -17,6 +17,9 @@ import { BarChart3, Trophy, AlertTriangle, + Shield, + FileText, + Lock, } from "lucide-react"; import { api } from "@/lib/api"; import { PLATFORM_MAP } from "@/lib/platforms"; @@ -25,6 +28,22 @@ import { HEALTH_LEVELS, type BrandHealthReport, } from "@/types/onboarding"; +import { UpgradePrompt } from "@/components/subscription/UpgradePrompt"; + +interface HealthDimension { + name: string; + score: number; + max_score: number; + percentage: number; + status: string; +} + +interface HealthRecommendation { + priority: string; + dimension: string; + title: string; + description: string; +} interface Step4HealthReportProps { brandId: string; @@ -35,6 +54,17 @@ interface Step4HealthReportProps { onBack: () => void; } +const DIMENSION_ICONS: Record = { + "内容可提取性": FileText, + "E-E-A-T信号": Shield, + "引用就绪度": Activity, + "结构化数据": BarChart3, + "语义一致性": Shield, + "技术可访问性": Activity, +}; + +const LOCKED_DIMENSIONS = ["结构化数据", "语义一致性", "技术可访问性"]; + export function Step4HealthReport({ brandId, brandName, @@ -45,6 +75,9 @@ export function Step4HealthReport({ }: Step4HealthReportProps) { const { data: session } = useSession(); const [report, setReport] = useState(null); + const [dimensions, setDimensions] = useState([]); + const [recommendations, setRecommendations] = useState([]); + const [isFullReport, setIsFullReport] = useState(false); const [loading, setLoading] = useState(true); const [error, setError] = useState(null); @@ -57,8 +90,17 @@ export function Step4HealthReport({ const data = (await api.onboarding.getHealthReport( session.accessToken, brandId, - )) as BrandHealthReport; + )) as BrandHealthReport & { + dimensions?: HealthDimension[]; + recommendations?: HealthRecommendation[]; + is_full_report?: boolean; + health_level?: string; + health_level_label?: string; + }; setReport(data); + setDimensions(data.dimensions || []); + setRecommendations(data.recommendations || []); + setIsFullReport(data.is_full_report || false); } catch (err) { console.error("获取健康报告失败:", err); setError("获取健康报告失败,请重试"); @@ -69,7 +111,6 @@ export function Step4HealthReport({ useEffect(() => { fetchReport(); - // eslint-disable-next-line react-hooks/exhaustive-deps }, [session?.accessToken, brandId]); if (loading) { @@ -92,7 +133,6 @@ export function Step4HealthReport({
-
@@ -117,28 +157,54 @@ export function Step4HealthReport({ 上一步 - - + +
); } - // TypeScript 类型守卫:经过 loading 和 error 检查后,report 必定存在 if (!report) return null; const healthLevel = getHealthLevel(report.overall_score); const healthConfig = HEALTH_LEVELS[healthLevel]; - // 计算领先/落后竞品数量 + const getStatusColor = (status: string) => { + switch (status) { + case "good": + return "text-emerald-600"; + case "warning": + return "text-amber-600"; + case "fail": + return "text-red-600"; + default: + return "text-muted-foreground"; + } + }; + + const getProgressBg = (percentage: number) => { + if (percentage >= 60) return "bg-emerald-500"; + if (percentage >= 30) return "bg-amber-500"; + return "bg-red-500"; + }; + const leadingCount = report.competitor_scores.filter( (c) => c.is_leading, ).length; @@ -158,11 +224,9 @@ export function Step4HealthReport({

- {/* 综合评分卡片 */} - +
- {/* 评分大数字 */}
{report.overall_score} @@ -170,71 +234,167 @@ export function Step4HealthReport({
- {/* 健康等级标签 */} {healthConfig.label} - {/* 趋势指示 */} -
- - - 领先 {leadingCount} 个竞品 - - | - - 落后 {laggingCount} 个竞品 -
+ {report.competitor_scores.length > 0 && ( +
+ + + 领先 {leadingCount} 个竞品 + + | + + 落后 {laggingCount} 个竞品 +
+ )}
- {/* 平台评分 */} - - - - - 各平台评分 - - - -
- {Object.entries(report.platform_scores) - .filter(([key]) => platforms.includes(key)) - .map(([platform, score]) => { - const level = getHealthLevel(score); - const config = HEALTH_LEVELS[level]; - const platformName = PLATFORM_MAP[platform] || platform; - - return ( -
- - {platformName} - -
-
-
-
+ {dimensions.length > 0 && ( + + +
+ + + 维度评分 + + {!isFullReport && ( + + )} +
+
+ + {dimensions.map((dim) => { + const Icon = DIMENSION_ICONS[dim.name] || Activity; + return ( +
+
+
+ + {dim.name}
- - {score} + + {round(dim.percentage, 1)}%
- ); - })} -
-
-
+
+
+
+
+ ); + })} + + {!isFullReport && ( + <> + {LOCKED_DIMENSIONS.map((dimName) => ( +
+
+
+ + {dimName} +
+ + Pro + +
+
+
+
+
+ ))} + + )} + + + )} + + {recommendations.length > 0 && ( + + + 关键问题 + + +
+ {recommendations.map((rec, i) => ( +
+ +
+

{rec.title}

+

+ {rec.description} +

+
+
+ ))} +
+
+
+ )} + + {Object.keys(report.platform_scores).length > 0 && ( + + + + + 各平台评分 + + + +
+ {Object.entries(report.platform_scores) + .filter(([key]) => platforms.includes(key)) + .map(([platform, score]) => { + const level = getHealthLevel(score); + const config = HEALTH_LEVELS[level]; + const platformName = PLATFORM_MAP[platform] || platform; + + return ( +
+ + {platformName} + +
+
+
+
+
+ + {score} + +
+ ); + })} +
+ + + )} - {/* 竞品对比 */} {report.competitor_scores.length > 0 && ( @@ -291,7 +451,6 @@ export function Step4HealthReport({ )} - {/* 优劣势分析 */} 优劣势分析 @@ -330,6 +489,16 @@ export function Step4HealthReport({ + {!isFullReport && ( +
+ +
+ )} +
-
); } - // 错误状态 if (!loading && error) { return (
@@ -155,21 +187,33 @@ export function Step5ActionSuggestions({ ); } - // 按优先级分组 const highPriority = suggestions.filter((s) => s.priority === "high"); const mediumPriority = suggestions.filter((s) => s.priority === "medium"); const lowPriority = suggestions.filter((s) => s.priority === "low"); const renderSuggestionCard = ( - suggestion: ActionSuggestion, + suggestion: ActionSuggestionItem, index: number, ) => { const Icon = ACTION_ICONS[suggestion.action_type] || Target; const colors = PRIORITY_COLORS[suggestion.priority]; + const isPaid = suggestion.is_paid_action || false; + + const actionButton = suggestion.action_button_text ? ( + + ) : null; return (
+ {isPaid && ( + + + Pro + + )}

{suggestion.description}

+ {actionButton && ( +
{actionButton}
+ )}
@@ -214,11 +270,10 @@ export function Step5ActionSuggestions({

下一步行动建议

- 基于您的品牌 “{brandName}” 的表现, 我们为您准备了以下优化建议 + 基于您的品牌 “{brandName}” 的表现,我们为您准备了以下优化建议

- {/* 高优先级建议 */} {highPriority.length > 0 && ( @@ -235,7 +290,6 @@ export function Step5ActionSuggestions({ )} - {/* 中优先级建议 */} {mediumPriority.length > 0 && ( @@ -252,7 +306,6 @@ export function Step5ActionSuggestions({ )} - {/* 低优先级建议 */} {lowPriority.length > 0 && ( @@ -272,7 +325,13 @@ export function Step5ActionSuggestions({ )} - {error &&

{error}

} +
+ +
+
+ ); +} + +function ResultView({ + data, + onShowRegisterModal, + onShare, + onDownload, +}: { + data: HealthScoreResponse; + onShowRegisterModal: () => void; + onShare: () => void; + onDownload: () => void; +}) { + const level = getHealthLevel(data.overall_score); + const colors = HEALTH_COLORS[level]; + const topRecommendations = data.recommendations.slice(0, 3); + + return ( +
+
+

+ {data.brand_name} 健康分报告 +

+

以下是该品牌在AI搜索中的综合表现

+
+ + + +
+
+ {data.overall_score} + /100 +
+ + {colors.label} + + {data.cached && ( +

数据来自缓存

+ )} +
+
+
+ + + + + + 维度评分 + + + +
+ {data.dimensions.map((dim) => { + const dimLevel = getHealthLevel(dim.percentage); + const dimColors = HEALTH_COLORS[dimLevel]; + return ( +
+
+ {dim.name} + + {dim.score}/{dim.max_score} + +
+
+
+
+
+ ); + })} +
+ + + + {topRecommendations.length > 0 && ( + + + + + 关键问题 + + + +
+ {topRecommendations.map((rec, idx) => { + const priorityConfig = getPriorityConfig(rec.priority); + return ( +
+ + {priorityConfig.label} + +
+

{rec.title}

+

{rec.description}

+
+
+ ); + })} +
+
+
+ )} + +
+ + + +
+
+ ); +} + +function RegisterModal({ open, onClose }: { open: boolean; onClose: () => void }) { + if (!open) return null; + return ( +
+
e.stopPropagation()}> +
+
+ +
+

注册后查看完整报告

+

+ 注册 GEO 平台账户,即可解锁完整健康报告、详细修复建议和持续监控功能 +

+
+
+ + + + +
+
+
+ ); +} + +export default function HealthScorePage() { + const searchParams = useSearchParams(); + const brandFromUrl = searchParams.get("brand") || ""; + + const [brand, setBrand] = useState(brandFromUrl); + const [competitorsInput, setCompetitorsInput] = useState(""); + const [loading, setLoading] = useState(false); + const [error, setError] = useState(null); + const [result, setResult] = useState(null); + const [registerModalOpen, setRegisterModalOpen] = useState(false); + const [toast, setToast] = useState(null); + + useEffect(() => { + if (brandFromUrl) { + setBrand(brandFromUrl); + } + }, [brandFromUrl]); + + useEffect(() => { + if (toast) { + const timer = setTimeout(() => setToast(null), 2500); + return () => clearTimeout(timer); + } + }, [toast]); + + const handleCheck = useCallback(async () => { + const trimmedBrand = brand.trim(); + if (!trimmedBrand) return; + + const competitors = competitorsInput + .split(/[,,]/) + .map((s) => s.trim()) + .filter(Boolean) + .slice(0, 3); + + setLoading(true); + setError(null); + setResult(null); + + try { + const data = await healthScoreApi.getHealthScore(trimmedBrand, competitors); + setResult(data); + } catch (err) { + setError(err instanceof Error ? err.message : "检测失败,请稍后重试"); + } finally { + setLoading(false); + } + }, [brand, competitorsInput]); + + const handleShare = useCallback(async () => { + const url = `${window.location.origin}/health-score?brand=${encodeURIComponent(brand.trim())}`; + try { + await navigator.clipboard.writeText(url); + setToast("链接已复制到剪贴板"); + } catch { + setToast("复制失败,请手动复制链接"); + } + }, [brand]); + + const handleDownload = useCallback(() => { + setToast("PDF 下载功能即将推出"); + }, []); + + return ( +
+ {toast && ( +
+ {toast} +
+ )} + +
+
+ +
+

+ 品牌健康分检测 +

+

+ 免费检测您的品牌在AI搜索引擎中的表现,了解品牌可见度、推荐排名和情感倾向 +

+
+ + {!result && !loading && !error && ( + + +
+
+ +
+ setBrand(e.target.value)} + onKeyDown={(e) => { + if (e.key === "Enter") handleCheck(); + }} + className="h-12 text-base" + /> + +
+
+
+ + setCompetitorsInput(e.target.value)} + onKeyDown={(e) => { + if (e.key === "Enter") handleCheck(); + }} + className="h-10" + /> +
+
+
+
+ )} + + {loading && } + {error && !loading && ( + + )} + {result && !loading && ( + setRegisterModalOpen(true)} + onShare={handleShare} + onDownload={handleDownload} + /> + )} + +
+

+ 已有账户?{" "} + + 登录 + +

+
+ + setRegisterModalOpen(false)} /> +
+ ); +} diff --git a/frontend/app/(public)/layout.tsx b/frontend/app/(public)/layout.tsx new file mode 100644 index 0000000..8d82346 --- /dev/null +++ b/frontend/app/(public)/layout.tsx @@ -0,0 +1,30 @@ +import Link from "next/link"; +import { Button } from "@/components/ui/button"; +import { Activity } from "lucide-react"; + +export default function PublicLayout({ + children, +}: { + children: React.ReactNode; +}) { + return ( +
+
+
+ +
+ +
+ GEO 平台 + + + + +
+
+
{children}
+
+ ); +} diff --git a/frontend/components/dashboard/PlatformScoreList.tsx b/frontend/components/dashboard/PlatformScoreList.tsx index d86a990..abfe7a4 100644 --- a/frontend/components/dashboard/PlatformScoreList.tsx +++ b/frontend/components/dashboard/PlatformScoreList.tsx @@ -3,7 +3,8 @@ import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card"; import { cn } from "@/lib/utils"; import { TrendingUp, TrendingDown, Minus, AlertTriangle } from "lucide-react"; -import { PlatformScoreWithCompetitor, PLATFORM_LABELS, PLATFORM_ICONS } from "@/types/dashboard-health"; +import { PlatformScoreWithCompetitor } from "@/types/dashboard-health"; +import { PLATFORM_MAP as PLATFORM_LABELS, PLATFORM_ICONS } from "@/lib/platforms"; import { getHealthLevel, calculateCompetitorGap, getHealthLevelClassName } from "@/lib/dashboard-health"; interface PlatformScoreListProps { diff --git a/frontend/components/dashboard/ROICard.tsx b/frontend/components/dashboard/ROICard.tsx new file mode 100644 index 0000000..c162249 --- /dev/null +++ b/frontend/components/dashboard/ROICard.tsx @@ -0,0 +1,86 @@ +"use client"; + +import { cn } from "@/lib/utils"; +import { TrendingUp, TrendingDown, Minus } from "lucide-react"; + +interface ROICardProps { + roiPercentage: number; + valueGenerated: number; + subscriptionCost: number; + className?: string; +} + +export function ROICard({ + roiPercentage, + valueGenerated, + subscriptionCost, + className, +}: ROICardProps) { + const isPositive = roiPercentage > 0; + const isNeutral = roiPercentage === 0; + + const TrendIcon = isNeutral ? Minus : isPositive ? TrendingUp : TrendingDown; + const trendColor = isNeutral + ? "text-gray-500" + : isPositive + ? "text-emerald-600" + : "text-red-600"; + + const formatCurrency = (value: number) => + new Intl.NumberFormat("zh-CN", { + style: "currency", + currency: "CNY", + maximumFractionDigits: 0, + }).format(value); + + return ( +
+
+

投资回报率 (ROI)

+
+ + + {isPositive ? "+" : ""} + {roiPercentage.toFixed(1)}% + +
+
+ +

+ {roiPercentage.toFixed(1)}% +

+ +
+
+

创造价值

+

+ {formatCurrency(valueGenerated)} +

+
+
+

订阅成本

+

+ {formatCurrency(subscriptionCost)} +

+
+
+ +
+
+
+
+ ); +} diff --git a/frontend/components/layout/alert-bell.tsx b/frontend/components/layout/alert-bell.tsx index 4f7f58d..806c18d 100644 --- a/frontend/components/layout/alert-bell.tsx +++ b/frontend/components/layout/alert-bell.tsx @@ -119,7 +119,7 @@ export function AlertBell() { const handleMarkRead = async (alertId: string) => { if (!token) return; try { - await api.alerts.markRead(token, alertId); + await api.alerts.markRead(alertId, token); setAlerts((prev) => prev.map((a) => (a.id === alertId ? { ...a, is_read: true } : a)), ); diff --git a/frontend/components/subscription/SubscriptionStatus.tsx b/frontend/components/subscription/SubscriptionStatus.tsx new file mode 100644 index 0000000..8b3697f --- /dev/null +++ b/frontend/components/subscription/SubscriptionStatus.tsx @@ -0,0 +1,59 @@ +"use client"; + +import { cn } from "@/lib/utils"; +import { Badge } from "@/components/ui/badge"; +import { Button } from "@/components/ui/button"; +import { Crown, ArrowUpRight } from "lucide-react"; +import Link from "next/link"; + +const PLAN_CONFIG: Record = { + free: { label: "免费版", variant: "secondary" }, + starter: { label: "入门版", variant: "info" }, + pro: { label: "专业版", variant: "default" }, + enterprise: { label: "企业版", variant: "primary" }, +}; + +interface SubscriptionStatusProps { + plan: string; + expiresAt?: string; + className?: string; +} + +export function SubscriptionStatus({ plan, expiresAt, className }: SubscriptionStatusProps) { + const config = PLAN_CONFIG[plan] || PLAN_CONFIG.free; + const showUpgrade = plan === "free" || plan === "starter"; + + const formatDate = (dateStr: string) => { + try { + return new Date(dateStr).toLocaleDateString("zh-CN", { + year: "numeric", + month: "long", + day: "numeric", + }); + } catch { + return dateStr; + } + }; + + return ( +
+
+ {plan !== "free" && } + {config.label} +
+ {expiresAt && ( + + 到期:{formatDate(expiresAt)} + + )} + {showUpgrade && ( + + + + )} +
+ ); +} diff --git a/frontend/components/subscription/UpgradePrompt.tsx b/frontend/components/subscription/UpgradePrompt.tsx new file mode 100644 index 0000000..734eefb --- /dev/null +++ b/frontend/components/subscription/UpgradePrompt.tsx @@ -0,0 +1,252 @@ +"use client"; + +import { useState } from "react"; +import { Button } from "@/components/ui/button"; +import { + Dialog, + DialogContent, + DialogDescription, + DialogHeader, + DialogTitle, + DialogTrigger, +} from "@/components/ui/dialog"; +import { Badge } from "@/components/ui/badge"; +import { + Crown, + Lock, + Sparkles, + BarChart3, + Brain, + Zap, +} from "lucide-react"; + +interface UpgradePromptProps { + trigger?: React.ReactNode; + feature?: string; + description?: string; + variant?: "inline" | "dialog" | "badge"; + className?: string; +} + +const PRO_FEATURES = [ + { + icon: BarChart3, + title: "完整6维度诊断", + description: "解锁结构化数据、语义一致性、技术可访问性3个高级维度", + }, + { + icon: Brain, + title: "深度竞品分析", + description: "详细对比竞品在各维度的表现,发现差异化机会", + }, + { + icon: Zap, + title: "AI优化方案", + description: "基于诊断结果自动生成可执行的GEO优化建议", + }, + { + icon: Sparkles, + title: "持续监控", + description: "每日自动检测品牌GEO健康分变化趋势", + }, +]; + +export function UpgradePrompt({ + trigger, + feature = "此功能", + description, + variant = "inline", + className = "", +}: UpgradePromptProps) { + const [open, setOpen] = useState(false); + + if (variant === "badge") { + return ( + + + + + Pro + + + + setOpen(false)} + /> + + + ); + } + + if (variant === "dialog") { + return ( + + + {trigger || ( + + )} + + + setOpen(false)} + /> + + + ); + } + + return ( +
+
+
+ +
+
+
+ {feature} + + + Pro + +
+

+ {description || `升级Pro版即可解锁${feature}功能`} +

+ + + + + + setOpen(false)} + /> + + +
+
+
+ ); +} + +function UpgradeDialogContent({ + feature, + description, + onClose, +}: { + feature: string; + description?: string; + onClose: () => void; +}) { + return ( + <> + + + + 升级到Pro版 + + + {description || `解锁${feature}等全部高级功能`} + + + +
+ {PRO_FEATURES.map((feat) => { + const Icon = feat.icon; + return ( +
+
+ +
+
+

{feat.title}

+

+ {feat.description} +

+
+
+ ); + })} +
+ +
+ + +
+ + ); +} + +export function UpgradeActionBadge({ className = "" }: { className?: string }) { + return ( + + + Pro + + ); +} + +export function PaidActionOverlay({ + children, + isPaidAction, + onUpgradeClick, +}: { + children: React.ReactNode; + isPaidAction: boolean; + onUpgradeClick?: () => void; +}) { + if (!isPaidAction) return <>{children}; + + return ( +
+
{children}
+
+ +
+
+ ); +} diff --git a/frontend/components/subscription/UsageProgress.tsx b/frontend/components/subscription/UsageProgress.tsx new file mode 100644 index 0000000..b75f7a0 --- /dev/null +++ b/frontend/components/subscription/UsageProgress.tsx @@ -0,0 +1,61 @@ +"use client"; + +import { cn } from "@/lib/utils"; + +interface UsageProgressProps { + label: string; + current: number; + limit: number; + unit?: string; + className?: string; +} + +export function UsageProgress({ label, current, limit, unit = "", className }: UsageProgressProps) { + const isUnlimited = limit === -1; + const percentage = isUnlimited ? 0 : limit > 0 ? Math.min((current / limit) * 100, 100) : 0; + const isWarning = !isUnlimited && percentage > 80; + const isCritical = !isUnlimited && percentage > 95; + + const barColor = isCritical + ? "bg-red-500" + : isWarning + ? "bg-amber-500" + : "bg-primary"; + + const textColor = isCritical + ? "text-red-600" + : isWarning + ? "text-amber-600" + : "text-gray-900"; + + return ( +
+
+ {label} + + {isUnlimited + ? `${current}${unit}` + : `${current}${unit} / ${limit}${unit}`} + +
+ {!isUnlimited && ( +
+
+
+ )} + {isUnlimited && ( +
+
+
+ )} + {!isUnlimited && isWarning && ( +

+ {isCritical ? "额度即将用尽" : "额度使用已超过80%"} +

+ )} +
+ ); +} diff --git a/frontend/e2e/pages/health-score.page.ts b/frontend/e2e/pages/health-score.page.ts new file mode 100644 index 0000000..c765c94 --- /dev/null +++ b/frontend/e2e/pages/health-score.page.ts @@ -0,0 +1,40 @@ +import { Page, Locator, expect } from "@playwright/test"; + +export class HealthScorePage { + readonly page: Page; + readonly brandInput: Locator; + readonly checkButton: Locator; + readonly scoreDisplay: Locator; + readonly healthLevelBadge: Locator; + readonly dimensionList: Locator; + readonly registerButton: Locator; + readonly errorMessage: Locator; + readonly loadingSpinner: Locator; + + constructor(page: Page) { + this.page = page; + this.brandInput = page.locator('input[placeholder*="品牌"]'); + this.checkButton = page.getByRole("button", { name: /检测/ }); + this.scoreDisplay = page.locator("text=/\\d+/100/").first(); + this.healthLevelBadge = page.locator("[data-slot='badge']").first(); + this.dimensionList = page.locator("text=核心维度评分"); + this.registerButton = page.getByRole("button", { name: /注册/ }); + this.errorMessage = page.locator(".text-destructive"); + this.loadingSpinner = page.locator(".animate-spin"); + } + + async goto() { + await this.page.goto("/health-score"); + await this.page.waitForLoadState("domcontentloaded"); + } + + async checkBrand(brandName: string) { + await this.brandInput.fill(brandName); + await this.checkButton.click(); + } + + async waitForResults(timeout = 30000) { + await this.page.waitForURL(/health-score/, { timeout }); + await expect(this.scoreDisplay).toBeVisible({ timeout }); + } +} diff --git a/frontend/e2e/tests/core-flow-smoke.spec.ts b/frontend/e2e/tests/core-flow-smoke.spec.ts new file mode 100644 index 0000000..8ac795a --- /dev/null +++ b/frontend/e2e/tests/core-flow-smoke.spec.ts @@ -0,0 +1,67 @@ +import { test, expect } from "@playwright/test"; +import { LoginPage } from "../pages/login.page"; +import { DashboardPage } from "../pages/dashboard.page"; + +const TEST_USER = { + email: process.env.E2E_TEST_EMAIL || "admin@example.com", + password: process.env.E2E_TEST_PASSWORD || "admin@123", +}; + +async function loginAndWait(page: import("@playwright/test").Page) { + const loginPage = new LoginPage(page); + await loginPage.goto(); + await loginPage.login(TEST_USER.email, TEST_USER.password); + try { + await page.waitForURL(/\/dashboard/, { timeout: 60000 }); + } catch { + const currentUrl = page.url(); + if (!currentUrl.includes("/dashboard")) { + await loginPage.goto(); + await loginPage.login(TEST_USER.email, TEST_USER.password); + await page.waitForURL(/\/dashboard/, { timeout: 60000 }); + } + } + await page.waitForLoadState("networkidle"); +} + +test.describe("核心流程烟雾测试", () => { + test("登录→创建品牌→触发诊断→查看诊断结果", async ({ page }) => { + const loginPage = new LoginPage(page); + const dashboardPage = new DashboardPage(page); + + await loginPage.goto(); + await loginPage.login(TEST_USER.email, TEST_USER.password); + + await page.waitForURL(/\/(dashboard|onboarding)/, { timeout: 15000 }); + + const currentUrl = page.url(); + if (currentUrl.includes("/onboarding")) { + const brandInput = page.locator('input[placeholder*="品牌"]'); + if (await brandInput.isVisible()) { + await brandInput.fill("测试品牌E2E"); + await page.getByRole("button", { name: /检测/ }).click(); + await expect(page.locator("text=核心维度评分")).toBeVisible({ timeout: 30000 }).catch(() => {}); + const nextButton = page.getByRole("button", { name: /注册|下一步/ }).first(); + if (await nextButton.isVisible()) { + await nextButton.click(); + } + } + } + + await page.goto("/dashboard"); + await page.waitForLoadState("domcontentloaded"); + + await expect(page.locator("body")).toBeVisible(); + }); + + test("Dashboard页面加载并显示关键元素", async ({ page }) => { + await loginAndWait(page); + + const dashboardPage = new DashboardPage(page); + await dashboardPage.goto(); + + await expect(page.locator("body")).toBeVisible(); + + await expect(page.locator("nav")).toBeVisible(); + }); +}); diff --git a/frontend/e2e/tests/health-score-smoke.spec.ts b/frontend/e2e/tests/health-score-smoke.spec.ts new file mode 100644 index 0000000..d975b05 --- /dev/null +++ b/frontend/e2e/tests/health-score-smoke.spec.ts @@ -0,0 +1,29 @@ +import { test, expect } from "@playwright/test"; +import { HealthScorePage } from "../pages/health-score.page"; + +test.describe("获客路径烟雾测试", () => { + test("未登录用户访问健康分页面,输入品牌名后看到报告", async ({ page }) => { + const healthScorePage = new HealthScorePage(page); + + await healthScorePage.goto(); + + await expect(healthScorePage.brandInput).toBeVisible(); + await expect(healthScorePage.checkButton).toBeVisible(); + + await healthScorePage.checkBrand("华为"); + + await expect(page.locator("text=核心维度评分")).toBeVisible({ timeout: 30000 }); + + await expect(page.locator("text=/\\d+/")).toBeVisible(); + + await expect(page.getByRole("button", { name: /注册/ })).toBeVisible(); + }); + + test("健康分页面支持URL参数预填品牌名", async ({ page }) => { + await page.goto("/health-score?brand=小米"); + await page.waitForLoadState("domcontentloaded"); + + const brandInput = page.locator('input[placeholder*="品牌"]'); + await expect(brandInput).toHaveValue("小米"); + }); +}); diff --git a/frontend/lib/api/agents.ts b/frontend/lib/api/agents.ts index ef4c34a..81fe762 100644 --- a/frontend/lib/api/agents.ts +++ b/frontend/lib/api/agents.ts @@ -69,16 +69,10 @@ export const agentsApi = { { method: "PUT", body: JSON.stringify(config) }, token ) as Promise, - enable: (token: string, agentId: string) => + updateStatus: (token: string, agentId: string, status: "idle" | "running" | "disabled") => fetchWithAuth( - `/api/v1/agents/${agentId}/enable`, - { method: "POST" }, - token - ) as Promise, - disable: (token: string, agentId: string) => - fetchWithAuth( - `/api/v1/agents/${agentId}/disable`, - { method: "POST" }, + `/api/v1/agents/${agentId}/config`, + { method: "PUT", body: JSON.stringify({ configs: { status } }) }, token ) as Promise, getLogs: (token: string, agentId: string) => diff --git a/frontend/lib/api/ai-engines.ts b/frontend/lib/api/ai-engines.ts index caf8a7a..2a2a71f 100644 --- a/frontend/lib/api/ai-engines.ts +++ b/frontend/lib/api/ai-engines.ts @@ -1,21 +1,29 @@ import { fetchWithAuth } from "./client"; import type { AIEngineType, AIEnginesResponse } from "@/types/ai-engines"; +function buildQuery(params: Record): string { + const qs = Object.entries(params) + .filter(([, v]) => v !== undefined) + .map(([k, v]) => `${encodeURIComponent(k)}=${encodeURIComponent(String(v))}`) + .join("&"); + return qs ? `?${qs}` : ""; +} + export const aiEnginesApi = { - querySingle: (engineType: string, query: string, brandId: string) => + querySingle: (engineType: string, query: string, brandName: string) => fetchWithAuth("/api/v1/ai-engines/query", { method: "POST", - body: JSON.stringify({ engines: [engineType], query, brand_id: brandId }), + body: JSON.stringify({ engines: [engineType], query, brand_name: brandName }), }), - queryBatch: (engines: AIEngineType[], query: string, brandId: string) => - fetchWithAuth("/api/v1/ai-engines/query", { + queryBatch: (engines: AIEngineType[], query: string, brandName: string) => + fetchWithAuth("/api/v1/ai-engines/query-batch", { method: "POST", - body: JSON.stringify({ engines, query, brand_id: brandId }), + body: JSON.stringify({ engines, query, brand_name: brandName }), }), - getResults: (brandId: string) => - fetchWithAuth(`/api/v1/ai-engines/results/${brandId}`), + getResults: (params: { engines: string; query: string; brand_name: string; competitor_names?: string }, token?: string) => + fetchWithAuth(`/api/v1/ai-engines/results${buildQuery(params as Record)}`, {}, token), }; export const MOCK_AI_ENGINES_RESPONSE: AIEnginesResponse = { diff --git a/frontend/lib/api/alerts.ts b/frontend/lib/api/alerts.ts index b3cfa99..0df2a03 100644 --- a/frontend/lib/api/alerts.ts +++ b/frontend/lib/api/alerts.ts @@ -16,7 +16,7 @@ export const alertsApi = { /** 获取告警列表 */ getAlerts: ( token?: string, - params?: { limit?: number; offset?: number; is_read?: boolean } + params?: { limit?: number; skip?: number; is_read?: boolean } ) => fetchWithAuth(`/api/v1/alerts/${buildQuery(params || {})}`, {}, token), /** 标记单条告警已读 */ @@ -43,7 +43,30 @@ export const alertsApi = { ) => fetchWithAuth( "/api/v1/alerts/settings", - { method: "PUT", body: JSON.stringify(data) }, + { method: "PUT", body: JSON.stringify({ settings: data }) }, + token + ), + + /** 更新单条告警设置 */ + updateSingleSetting: ( + settingId: string, + data: { + enabled?: boolean; + threshold?: number; + }, + token?: string + ) => + fetchWithAuth( + `/api/v1/alerts/settings/${settingId}`, + { method: "PATCH", body: JSON.stringify(data) }, + token + ), + + /** 删除告警设置 */ + deleteSetting: (settingId: string, token?: string) => + fetchWithAuth( + `/api/v1/alerts/settings/${settingId}`, + { method: "DELETE" }, token ), }; diff --git a/frontend/lib/api/attribution.ts b/frontend/lib/api/attribution.ts new file mode 100644 index 0000000..9f6baf1 --- /dev/null +++ b/frontend/lib/api/attribution.ts @@ -0,0 +1,84 @@ +import { fetchWithAuth } from "./client"; + +export interface AttributionResponse { + id: string; + brand_id: string; + content_id: string | null; + baseline_score: number; + current_score: number | null; + score_delta: number | null; + status: string; + roi_percentage: number | null; + created_at: string; +} + +export interface BrandAttributionSummary { + brand_id: string; + records: AttributionResponse[]; + total_score_delta: number; + tracking_count: number; + completed_count: number; + positive_count: number; +} + +export interface ROIReport { + brand_id: string; + brand_name: string; + subscription_cost: number; + current_plan: string; + total_score_delta: number; + value_generated: number; + roi_percentage: number; + break_even_delta: number; + tracking_records: AttributionResponse[]; + ab_comparison: ABComparison | null; +} + +export interface ABComparison { + overall_before: number; + overall_after: number; + overall_delta: number; + dimensions: ABDimension[]; +} + +export interface ABDimension { + name: string; + before: number; + after: number; + delta: number; + improved: boolean; +} + +export interface ABComparisonResponse { + brand_id: string; + brand_name: string; + overall_before: number; + overall_after: number; + overall_delta: number; + dimensions: ABDimension[]; +} + +export const attributionApi = { + startTracking: async (token: string, data: { brand_id: string; content_id?: string }) => + fetchWithAuth("/api/v1/attribution/start", { + method: "POST", + body: JSON.stringify(data), + }, token) as Promise, + + getBrandSummary: async (token: string, brandId: string) => + fetchWithAuth(`/api/v1/attribution/brand/${brandId}`, {}, token) as Promise, + + getRecord: async (token: string, recordId: string) => + fetchWithAuth(`/api/v1/attribution/${recordId}`, {}, token) as Promise, + + checkAttribution: async (token: string, recordId: string) => + fetchWithAuth(`/api/v1/attribution/${recordId}/check`, { + method: "POST", + }, token) as Promise, + + getROIReport: async (token: string, brandId: string) => + fetchWithAuth(`/api/v1/attribution/roi/${brandId}`, {}, token) as Promise, + + getABComparison: async (token: string, brandId: string) => + fetchWithAuth(`/api/v1/attribution/ab-comparison/${brandId}`, {}, token) as Promise, +}; diff --git a/frontend/lib/api/brands.ts b/frontend/lib/api/brands.ts index d1f1c69..89e2f2b 100644 --- a/frontend/lib/api/brands.ts +++ b/frontend/lib/api/brands.ts @@ -86,4 +86,10 @@ export const brandsApi = { /** 获取品牌对比数据 */ getCompare: (token: string, brandId: string) => fetchWithAuth(`/api/v1/brands/${brandId}/compare`, {}, token), + + getScore: (token: string, brandId: string) => + fetchWithAuth(`/api/v1/brands/${brandId}/score/`, {}, token), + + getScoreHistory: (token: string, brandId: string, params?: { skip?: number; limit?: number }) => + fetchWithAuth(`/api/v1/brands/${brandId}/score/history/${buildQuery(params || {})}`, {}, token), }; diff --git a/frontend/lib/api/competitor-analysis.ts b/frontend/lib/api/competitor-analysis.ts new file mode 100644 index 0000000..19ed91c --- /dev/null +++ b/frontend/lib/api/competitor-analysis.ts @@ -0,0 +1,61 @@ +import { fetchWithAuth } from "./client"; + +function buildQuery(params: Record): string { + const qs = Object.entries(params) + .filter(([, v]) => v !== undefined) + .map(([k, v]) => `${encodeURIComponent(k)}=${encodeURIComponent(String(v))}`) + .join("&"); + return qs ? `?${qs}` : ""; +} + +export interface CompetitorAnalysisRequest { + brand_id: string; + analysis_types?: string[]; + period_days?: number; +} + +export interface CompetitorInsight { + id: string; + brand_id: string; + insight_type: string; + competitor_name: string | null; + data: Record; + recommendations: string[] | null; + period_start: string | null; + period_end: string | null; + created_at: string; +} + +export interface CompetitorInsightList { + items: CompetitorInsight[]; + total: number; +} + +export interface CompetitorInsightResponse extends CompetitorInsight {} + +export interface CompetitorGapSummary { + competitor_name: string; + gap_score: number; + dimensions: Record; +} + +export const competitorAnalysisApi = { + analyze: (token: string, data: CompetitorAnalysisRequest) => + fetchWithAuth("/api/v1/competitor/analyze", { + method: "POST", + body: JSON.stringify(data), + }, token) as Promise, + + getBrandInsights: (token: string, brandId: string, params?: { skip?: number; limit?: number }) => + fetchWithAuth( + `/api/v1/competitor/brand/${brandId}${buildQuery(params || {})}`, + {}, + token + ) as Promise, + + getInsight: (token: string, insightId: string) => + fetchWithAuth(`/api/v1/competitor/${insightId}`, {}, token) as Promise, + + getGapSummary: (token: string, brandId: string) => + fetchWithAuth(`/api/v1/competitor/brand/${brandId}/gap-summary`, {}, token) as Promise, +}; diff --git a/frontend/lib/api/detection.ts b/frontend/lib/api/detection.ts new file mode 100644 index 0000000..dc88941 --- /dev/null +++ b/frontend/lib/api/detection.ts @@ -0,0 +1,26 @@ +import { fetchWithAuth } from "./client"; + +function buildQuery(params: Record): string { + const qs = Object.entries(params) + .filter(([, v]) => v !== undefined) + .map(([k, v]) => `${encodeURIComponent(k)}=${encodeURIComponent(String(v))}`) + .join("&"); + return qs ? `?${qs}` : ""; +} + +export const detectionApi = { + listTasks: (token?: string, params?: { skip?: number; limit?: number; status?: string }) => + fetchWithAuth(`/api/v1/detection/tasks${buildQuery(params || {})}`, {}, token), + + createTask: (data: Record, token?: string) => + fetchWithAuth("/api/v1/detection/tasks", { method: "POST", body: JSON.stringify(data) }, token), + + updateTask: (taskId: string, data: Record, token?: string) => + fetchWithAuth(`/api/v1/detection/tasks/${taskId}`, { method: "PUT", body: JSON.stringify(data) }, token), + + deleteTask: (taskId: string, token?: string) => + fetchWithAuth(`/api/v1/detection/tasks/${taskId}`, { method: "DELETE" }, token), + + triggerTask: (taskId: string, token?: string) => + fetchWithAuth(`/api/v1/detection/tasks/${taskId}/trigger`, { method: "POST" }, token), +}; diff --git a/frontend/lib/api/health-score.ts b/frontend/lib/api/health-score.ts new file mode 100644 index 0000000..fecc7e8 --- /dev/null +++ b/frontend/lib/api/health-score.ts @@ -0,0 +1,42 @@ +import { API_BASE } from "./client"; + +export interface HealthScoreDimension { + name: string; + score: number; + max_score: number; + percentage: number; + status: string; +} + +export interface HealthScoreRecommendation { + priority: string; + dimension: string; + title: string; + description: string; +} + +export interface HealthScoreResponse { + brand_name: string; + overall_score: number; + health_level: string; + health_level_label: string; + dimensions: HealthScoreDimension[]; + recommendations: HealthScoreRecommendation[]; + is_full_report: boolean; + cached: boolean; +} + +export const healthScoreApi = { + getHealthScore: async (brand: string, competitors?: string[]): Promise => { + const params = new URLSearchParams({ brand }); + if (competitors && competitors.length > 0) { + params.set("competitors", competitors.join(",")); + } + const res = await fetch(`${API_BASE}/api/v1/public/health-score?${params}`); + if (!res.ok) { + const error = await res.json().catch(() => ({})); + throw new Error(error.detail || `请求失败 (HTTP ${res.status})`); + } + return res.json(); + }, +}; diff --git a/frontend/lib/api/index.ts b/frontend/lib/api/index.ts index d44e205..e2fc51c 100644 --- a/frontend/lib/api/index.ts +++ b/frontend/lib/api/index.ts @@ -9,6 +9,7 @@ export { citationsApi } from "./citations"; export type { CitationRecord, CitationListResponse, CitationStats } from "./citations"; export { reportsApi } from "./reports"; export { subscriptionsApi } from "./subscriptions"; +export type { SubscriptionInfo } from "./subscriptions"; export { adminApi } from "./admin"; export type { AdminStatsData, AdminUser, AdminUserListResponse, AdminActionResponse } from "./admin"; export { agentsApi } from "./agents"; @@ -35,6 +36,42 @@ export type { InviteMemberPayload, UpdateMemberRolePayload, } from "./organization"; +export { detectionApi } from "./detection"; +export { strategyApi } from "./strategy"; +export { monitoringApi } from "./monitoring"; +export type { + MonitoringRecordCreate, + MonitoringRecord, + MonitoringRecordList, + MonitoringChangeReport, + MonitoringStatusUpdate, + MonitoringRecordResponse, +} from "./monitoring"; +export { competitorAnalysisApi } from "./competitor-analysis"; +export type { + CompetitorAnalysisRequest, + CompetitorInsight, + CompetitorInsightList, + CompetitorInsightResponse, + CompetitorGapSummary, +} from "./competitor-analysis"; +export { schemaAdvisorApi } from "./schema-advisor"; +export type { + SchemaAdviseRequest, + SchemaSuggestion, + SchemaSuggestionList, + SchemaSuggestionResponse, +} from "./schema-advisor"; +export { trendsApi } from "./trends"; +export type { + TrendInsightRequest, + TrendInsight, + TrendInsightList, + TrendInsightResponse, + TrendSummary, +} from "./trends"; +export { usageApi } from "./usage"; +export type { UsageQuota, UsageResponse } from "./usage"; // ── 类型导出 ─────────────────────────────────────────────────────────────────── export type { Agent, AgentRunLog } from "./agents"; @@ -118,6 +155,13 @@ import { suggestionsApi } from "./suggestions"; import { onboardingApi } from "./onboarding"; import { platformRulesApi } from "./platform-rules"; import { imageApi } from "./image"; +import { detectionApi } from "./detection"; +import { strategyApi } from "./strategy"; +import { monitoringApi } from "./monitoring"; +import { competitorAnalysisApi } from "./competitor-analysis"; +import { schemaAdvisorApi } from "./schema-advisor"; +import { trendsApi } from "./trends"; +import { usageApi } from "./usage"; /** * 聚合 API 对象,保持与原 `import { api } from "@/lib/api"` 的向后兼容。 @@ -144,4 +188,11 @@ export const api = { onboarding: onboardingApi, platformRules: platformRulesApi, image: imageApi, + detection: detectionApi, + strategy: strategyApi, + monitoring: monitoringApi, + competitorAnalysis: competitorAnalysisApi, + schemaAdvisor: schemaAdvisorApi, + trends: trendsApi, + usage: usageApi, }; diff --git a/frontend/lib/api/knowledge.ts b/frontend/lib/api/knowledge.ts index c004a28..a34c09b 100644 --- a/frontend/lib/api/knowledge.ts +++ b/frontend/lib/api/knowledge.ts @@ -114,4 +114,30 @@ export const knowledgeApi = { }, token ) as Promise, + + buildGraph: (kbId: string, token?: string) => + fetchWithAuth(`/api/v1/knowledge-bases/${kbId}/graph/build`, { method: "POST" }, token), + + getGraphStatistics: (kbId: string, token?: string) => + fetchWithAuth(`/api/v1/knowledge-bases/${kbId}/graph/statistics`, {}, token), + + searchEntities: (kbId: string, query: string, token?: string) => + fetchWithAuth(`/api/v1/knowledge-bases/${kbId}/graph/entities/search?q=${encodeURIComponent(query)}`, {}, token), + + getEntity: (kbId: string, entityId: string, token?: string) => + fetchWithAuth(`/api/v1/knowledge-bases/${kbId}/graph/entities/${entityId}`, {}, token), + + findPath: (kbId: string, params: { from_entity_id: string; to_entity_id: string }, token?: string) => + fetchWithAuth( + `/api/v1/knowledge-bases/${kbId}/graph/path?from_entity_id=${params.from_entity_id}&to_entity_id=${params.to_entity_id}`, + {}, + token + ), + + batchCreateEntities: (kbId: string, entities: Array>, token?: string) => + fetchWithAuth( + `/api/v1/knowledge-bases/${kbId}/entities/batch`, + { method: "POST", body: JSON.stringify({ entities }) }, + token + ), }; diff --git a/frontend/lib/api/monitoring.ts b/frontend/lib/api/monitoring.ts new file mode 100644 index 0000000..9c083b7 --- /dev/null +++ b/frontend/lib/api/monitoring.ts @@ -0,0 +1,89 @@ +import { fetchWithAuth } from "./client"; + +function buildQuery(params: Record): string { + const qs = Object.entries(params) + .filter(([, v]) => v !== undefined) + .map(([k, v]) => `${encodeURIComponent(k)}=${encodeURIComponent(String(v))}`) + .join("&"); + return qs ? `?${qs}` : ""; +} + +export interface MonitoringRecordCreate { + brand_id: string; + content_id?: string; + query_keywords: string; + platform?: string; + check_interval_hours?: number; +} + +export interface MonitoringRecord { + id: string; + brand_id: string; + content_id: string | null; + query_keywords: string; + platform: string | null; + baseline_citation_count: number; + baseline_sentiment: number | null; + baseline_rank: number | null; + current_citation_count: number | null; + current_sentiment: number | null; + current_rank: number | null; + change_type: string | null; + change_details: Record | null; + check_interval_hours: number; + last_checked_at: string | null; + next_check_at: string | null; + status: string; + created_at: string; + updated_at: string; +} + +export interface MonitoringRecordList { + records: MonitoringRecord[]; + total: number; +} + +export interface MonitoringChangeReport { + monitoring_record_id: string; + brand_id: string; + change_type: string; + change_details: Record; + baseline: Record; + current: Record; + recommendations: string[]; +} + +export interface MonitoringStatusUpdate { + status: string; +} + +export interface MonitoringRecordResponse extends MonitoringRecord {} + +export const monitoringApi = { + createTask: (token: string, data: MonitoringRecordCreate) => + fetchWithAuth("/api/v1/monitoring/tasks", { + method: "POST", + body: JSON.stringify(data), + }, token) as Promise, + + getBrandRecords: (token: string, brandId: string, params?: { skip?: number; limit?: number }) => + fetchWithAuth( + `/api/v1/monitoring/brand/${brandId}${buildQuery(params || {})}`, + {}, + token + ) as Promise, + + getReport: (token: string, recordId: string) => + fetchWithAuth(`/api/v1/monitoring/${recordId}/report`, {}, token) as Promise, + + updateStatus: (token: string, recordId: string, data: MonitoringStatusUpdate) => + fetchWithAuth(`/api/v1/monitoring/${recordId}/status`, { + method: "PUT", + body: JSON.stringify(data), + }, token) as Promise, + + triggerCheck: (token: string, recordId: string) => + fetchWithAuth(`/api/v1/monitoring/${recordId}/check`, { + method: "POST", + }, token) as Promise, +}; diff --git a/frontend/lib/api/onboarding.ts b/frontend/lib/api/onboarding.ts index 4d36686..7d098f3 100644 --- a/frontend/lib/api/onboarding.ts +++ b/frontend/lib/api/onboarding.ts @@ -22,9 +22,9 @@ export const onboardingApi = { ), /** 获取竞品推荐 */ - getCompetitorRecommendations: (token: string, brandName: string) => + getCompetitorRecommendations: (token: string, brandId: string) => fetchWithAuth( - `/api/v1/onboarding/competitor-recommendations?brand_name=${encodeURIComponent(brandName)}`, + `/api/v1/onboarding/competitor-recommendations?brand_id=${encodeURIComponent(brandId)}`, {}, token ), diff --git a/frontend/lib/api/schema-advisor.ts b/frontend/lib/api/schema-advisor.ts new file mode 100644 index 0000000..ca9da06 --- /dev/null +++ b/frontend/lib/api/schema-advisor.ts @@ -0,0 +1,64 @@ +import { fetchWithAuth } from "./client"; + +function buildQuery(params: Record): string { + const qs = Object.entries(params) + .filter(([, v]) => v !== undefined) + .map(([k, v]) => `${encodeURIComponent(k)}=${encodeURIComponent(String(v))}`) + .join("&"); + return qs ? `?${qs}` : ""; +} + +export interface SchemaAdviseRequest { + brand_id: string; + target_url?: string; + focus_dimensions?: string[]; +} + +export interface SchemaSuggestion { + id: string; + brand_id: string; + schema_type: string; + target_url: string | null; + json_ld: Record; + validation_status: string | null; + validation_errors: string[] | null; + priority: number | null; + status: string; + created_at: string; + updated_at: string; +} + +export interface SchemaSuggestionList { + suggestions: SchemaSuggestion[]; + total: number; +} + +export interface SchemaSuggestionResponse extends SchemaSuggestion {} + +export const schemaAdvisorApi = { + advise: (token: string, data: SchemaAdviseRequest) => + fetchWithAuth("/api/v1/schema/advise", { + method: "POST", + body: JSON.stringify(data), + }, token) as Promise, + + getBrandSuggestions: ( + token: string, + brandId: string, + params?: { status?: string; schema_type?: string; skip?: number; limit?: number } + ) => + fetchWithAuth( + `/api/v1/schema/brand/${brandId}${buildQuery(params || {})}`, + {}, + token + ) as Promise, + + getSuggestion: (token: string, suggestionId: string) => + fetchWithAuth(`/api/v1/schema/${suggestionId}`, {}, token) as Promise, + + updateStatus: (token: string, suggestionId: string, status: string) => + fetchWithAuth(`/api/v1/schema/${suggestionId}/status`, { + method: "PUT", + body: JSON.stringify({ status }), + }, token) as Promise, +}; diff --git a/frontend/lib/api/strategy.ts b/frontend/lib/api/strategy.ts new file mode 100644 index 0000000..87bc5d4 --- /dev/null +++ b/frontend/lib/api/strategy.ts @@ -0,0 +1,26 @@ +import { fetchWithAuth } from "./client"; + +export const strategyApi = { + generatePlan: (token: string, brandId: string, targetScore?: number) => + fetchWithAuth("/api/v1/strategy/generate", { + method: "POST", + body: JSON.stringify({ brand_id: brandId, target_score: targetScore ?? 75 }), + }, token), + + getBrandPlans: (token: string, brandId: string) => + fetchWithAuth(`/api/v1/strategy/brand/${brandId}`, {}, token), + + getPlanDetail: (token: string, planId: string) => + fetchWithAuth(`/api/v1/strategy/${planId}`, {}, token), + + updateActionStatus: (token: string, actionId: string, status: string) => + fetchWithAuth(`/api/v1/strategy/actions/${actionId}/status`, { + method: "PUT", + body: JSON.stringify({ status }), + }, token), + + executeAction: (token: string, actionId: string) => + fetchWithAuth(`/api/v1/strategy/actions/${actionId}/execute`, { + method: "POST", + }, token), +}; diff --git a/frontend/lib/api/subscriptions.ts b/frontend/lib/api/subscriptions.ts index 30f6a46..2da25f1 100644 --- a/frontend/lib/api/subscriptions.ts +++ b/frontend/lib/api/subscriptions.ts @@ -1,12 +1,23 @@ import { API_BASE, fetchWithAuth } from "./client"; +export interface SubscriptionInfo { + id: string; + plan: string; + status: string; + start_date: string; + end_date: string; + amount: number | null; + payment_method: string | null; + created_at: string; +} + export const subscriptionsApi = { getPlans: async () => { const res = await fetch(`${API_BASE}/api/v1/subscriptions/plans`); if (!res.ok) throw new Error("获取套餐失败"); return res.json(); }, - getCurrent: async (token: string) => + getCurrent: async (token: string): Promise => fetchWithAuth("/api/v1/subscriptions/current", {}, token), subscribe: async (token: string, plan: string) => fetchWithAuth( diff --git a/frontend/lib/api/trends.ts b/frontend/lib/api/trends.ts new file mode 100644 index 0000000..1c0c682 --- /dev/null +++ b/frontend/lib/api/trends.ts @@ -0,0 +1,67 @@ +import { fetchWithAuth } from "./client"; + +function buildQuery(params: Record): string { + const qs = Object.entries(params) + .filter(([, v]) => v !== undefined) + .map(([k, v]) => `${encodeURIComponent(k)}=${encodeURIComponent(String(v))}`) + .join("&"); + return qs ? `?${qs}` : ""; +} + +export interface TrendInsightRequest { + brand_id: string; + period_days?: number; + platforms?: string[]; + keywords?: string[]; +} + +export interface TrendInsight { + id: string; + brand_id: string; + insight_type: string; + data: Record; + recommendations: string[] | null; + period_start: string | null; + period_end: string | null; + created_at: string; +} + +export interface TrendInsightList { + items: TrendInsight[]; + total: number; +} + +export interface TrendInsightResponse extends TrendInsight {} + +export interface TrendSummary { + brand_id: string; + period_days: number; + trend_direction: string; + hotspot_keywords: string[]; + platform_comparison: Record; +} + +export const trendsApi = { + createInsight: (token: string, data: TrendInsightRequest) => + fetchWithAuth("/api/v1/trends/insight", { + method: "POST", + body: JSON.stringify(data), + }, token) as Promise, + + getBrandInsights: (token: string, brandId: string, params?: { skip?: number; limit?: number }) => + fetchWithAuth( + `/api/v1/trends/brand/${brandId}${buildQuery(params || {})}`, + {}, + token + ) as Promise, + + getSummary: (token: string, brandId: string, periodDays?: number) => + fetchWithAuth( + `/api/v1/trends/brand/${brandId}/summary${buildQuery({ period_days: periodDays })}`, + {}, + token + ) as Promise, + + getInsight: (token: string, insightId: string) => + fetchWithAuth(`/api/v1/trends/${insightId}`, {}, token) as Promise, +}; diff --git a/frontend/lib/api/usage.ts b/frontend/lib/api/usage.ts new file mode 100644 index 0000000..04ec895 --- /dev/null +++ b/frontend/lib/api/usage.ts @@ -0,0 +1,25 @@ +import { fetchWithAuth } from "./client"; + +export interface UsageQuota { + label: string; + current: number; + limit: number; + unit?: string; +} + +export interface UsageResponse { + plan: string; + quotas: UsageQuota[]; +} + +export const usageApi = { + getQuotas: async (token: string): Promise => + fetchWithAuth("/api/v1/usage/quotas", {}, token), + + getROI: async (token: string): Promise<{ + roi_percentage: number; + value_generated: number; + subscription_cost: number; + }> => + fetchWithAuth("/api/v1/usage/roi", {}, token), +}; diff --git a/frontend/lib/hooks/use-compare-data.ts b/frontend/lib/hooks/use-compare-data.ts new file mode 100644 index 0000000..0e624e9 --- /dev/null +++ b/frontend/lib/hooks/use-compare-data.ts @@ -0,0 +1,80 @@ +/** + * 竞品对比页面数据获取 Hook + * + * 封装品牌列表 + 对比数据的 SWR 请求,替代手动 useState + useEffect 模式。 + * - 品牌列表:mount 时自动获取 + * - 对比数据:依赖 selectedBrandId,为空时暂停请求 + */ + +import { useState, useCallback } from "react"; +import { useApi } from "./use-api"; +import type { SWRConfiguration } from "swr"; +import type { BrandListResponse, BrandListItem, CompareResponse } from "@/types/brand"; + +export interface UseCompareDataReturn { + /** 品牌列表 */ + brands: BrandListItem[]; + /** 当前选中的品牌 ID */ + selectedBrandId: string; + /** 设置选中的品牌 ID */ + setSelectedBrandId: (id: string) => void; + /** 对比数据 */ + compareData: CompareResponse | undefined; + /** 是否正在加载(品牌列表或对比数据任一加载中) */ + isLoading: boolean; + /** 错误信息 */ + error: Error | undefined; + /** 刷新品牌列表 */ + refreshBrands: () => void; + /** 刷新对比数据 */ + refreshCompare: () => void; +} + +export interface UseCompareDataOptions { + /** 初始选中的品牌 ID */ + initialBrandId?: string; + /** SWR 配置(用于测试时禁用重试等) */ + swrOptions?: SWRConfiguration; +} + +export function useCompareData( + options?: UseCompareDataOptions +): UseCompareDataReturn { + const initialBrandId = options?.initialBrandId; + const swrOptions = options?.swrOptions; + + const [selectedBrandId, setSelectedBrandId] = useState(initialBrandId ?? ""); + + // 品牌列表 + const { + data: brandsResponse, + isLoading: brandsLoading, + error: brandsError, + refresh: refreshBrands, + } = useApi("/api/v1/brands/", swrOptions); + + const brands: BrandListItem[] = brandsResponse?.items ?? []; + + // 对比数据 — 依赖 selectedBrandId,为空时暂停 + const compareUrl = selectedBrandId ? `/api/v1/brands/${selectedBrandId}/compare` : null; + const { + data: compareData, + isLoading: compareLoading, + error: compareError, + refresh: refreshCompare, + } = useApi(compareUrl, swrOptions); + + const isLoading = brandsLoading || compareLoading; + const error = brandsError || compareError; + + return { + brands, + selectedBrandId, + setSelectedBrandId, + compareData, + isLoading, + error, + refreshBrands, + refreshCompare, + }; +} diff --git a/frontend/lib/hooks/use-content-data.ts b/frontend/lib/hooks/use-content-data.ts new file mode 100644 index 0000000..a6eb83e --- /dev/null +++ b/frontend/lib/hooks/use-content-data.ts @@ -0,0 +1,65 @@ +/** + * 内容工坊页面数据获取 Hook + * + * 封装内容列表 + 知识库列表的 SWR 请求,替代手动 useState + useEffect 模式。 + * 两个请求并行发出,无数据依赖。 + */ + +import { useApi } from "./use-api"; +import type { SWRConfiguration } from "swr"; +import type { Content } from "@/lib/api/contents"; +import type { KnowledgeBase } from "@/lib/api/knowledge"; + +export interface UseContentDataReturn { + /** 内容列表 */ + contents: Content[] | undefined; + /** 知识库列表 */ + knowledgeBases: KnowledgeBase[] | undefined; + /** 是否正在加载(任一请求加载中) */ + isLoading: boolean; + /** 错误信息 */ + error: Error | undefined; + /** 刷新内容列表 */ + refreshContents: () => void; + /** 刷新知识库列表 */ + refreshKnowledgeBases: () => void; +} + +export interface UseContentDataOptions { + /** SWR 配置(用于测试时禁用重试等) */ + swrOptions?: SWRConfiguration; +} + +export function useContentData( + options?: UseContentDataOptions +): UseContentDataReturn { + const swrOptions = options?.swrOptions; + + // 内容列表 + const { + data: contents, + isLoading: contentsLoading, + error: contentsError, + refresh: refreshContents, + } = useApi("/api/v1/contents/", swrOptions); + + // 知识库列表(仅获取 enterprise 类型) + const { + data: knowledgeBases, + isLoading: kbLoading, + error: kbError, + refresh: refreshKnowledgeBases, + } = useApi("/api/v1/knowledge/bases/?type=enterprise", swrOptions); + + const isLoading = contentsLoading || kbLoading; + const error = contentsError || kbError; + + return { + contents, + knowledgeBases, + isLoading, + error, + refreshContents, + refreshKnowledgeBases, + }; +} diff --git a/frontend/lib/hooks/use-onboarding-data.ts b/frontend/lib/hooks/use-onboarding-data.ts new file mode 100644 index 0000000..7975c48 --- /dev/null +++ b/frontend/lib/hooks/use-onboarding-data.ts @@ -0,0 +1,95 @@ +/** + * 新用户引导页面数据获取 Hook + * + * 封装引导状态检查的 SWR 请求 + 品牌创建的 mutation,替代手动 useState + useEffect 模式。 + * - 引导状态检查:mount 时自动获取 + * - 品牌创建:mutation(由用户交互触发) + */ + +import { useCallback } from "react"; +import { useApi, useApiMutation } from "./use-api"; +import type { SWRConfiguration } from "swr"; + +interface OnboardingStatusResponse { + completed: boolean; +} + +interface CreateBrandResponse { + brand_id: string; +} + +export interface CreateBrandPayload { + name: string; + competitors: string[]; + platforms: string[]; + frequency: "daily" | "weekly" | "monthly"; +} + +export interface UseOnboardingDataReturn { + /** 引导状态数据 */ + onboardingStatus: OnboardingStatusResponse | undefined; + /** 是否已完成引导 */ + isCompleted: boolean; + /** 是否正在加载 */ + isLoading: boolean; + /** 错误信息 */ + error: Error | undefined; + /** 刷新引导状态 */ + refresh: () => void; + /** 创建品牌 */ + createBrand: (data: CreateBrandPayload) => Promise; + /** 是否正在创建品牌 */ + isCreatingBrand: boolean; + /** 创建品牌错误 */ + mutationError: Error | undefined; +} + +export interface UseOnboardingDataOptions { + /** SWR 配置(用于测试时禁用重试等) */ + swrOptions?: SWRConfiguration; +} + +export function useOnboardingData( + options?: UseOnboardingDataOptions +): UseOnboardingDataReturn { + const swrOptions = options?.swrOptions; + + // 引导状态检查 + const { + data: onboardingStatus, + isLoading, + error, + refresh, + } = useApi("/api/v1/onboarding/status", swrOptions); + + const isCompleted = onboardingStatus?.completed ?? false; + + // 创建品牌(mutation) + const { + trigger: createBrandTrigger, + isMutating: isCreatingBrand, + error: mutationError, + } = useApiMutation( + "/api/v1/onboarding/brand", + "POST" + ); + + const createBrand = useCallback( + async (data: CreateBrandPayload): Promise => { + const result = await createBrandTrigger(data); + return result?.brand_id ?? null; + }, + [createBrandTrigger] + ); + + return { + onboardingStatus, + isCompleted, + isLoading, + error, + refresh, + createBrand, + isCreatingBrand, + mutationError, + }; +} diff --git a/frontend/lib/hooks/use-suggestions-data.ts b/frontend/lib/hooks/use-suggestions-data.ts new file mode 100644 index 0000000..f77f825 --- /dev/null +++ b/frontend/lib/hooks/use-suggestions-data.ts @@ -0,0 +1,147 @@ +/** + * 优化建议页面数据获取 Hook + * + * 封装建议列表的 SWR 请求 + 重新生成/更新状态的 mutation,替代手动 useState + useEffect 模式。 + * - 建议列表:依赖 brandId + 筛选条件 + * - 重新生成:mutation + * - 更新状态:mutation + */ + +import { useCallback } from "react"; +import { useApi, useApiMutation } from "./use-api"; +import { fetchWithAuth } from "@/lib/api/client"; +import type { SWRConfiguration } from "swr"; +import type { SuggestionListResponse, SuggestionStatus } from "@/types/suggestion"; + +export interface SuggestionsFilters { + type?: string; + priority?: string; + status?: string; +} + +export interface UseSuggestionsDataOptions { + brandId?: string; + filters?: SuggestionsFilters; + /** SWR 配置(用于测试时禁用重试等) */ + swrOptions?: SWRConfiguration; +} + +export interface UseSuggestionsDataReturn { + /** 建议列表 */ + suggestions: SuggestionListResponse["suggestions"] | undefined; + /** 是否正在加载 */ + isLoading: boolean; + /** 错误信息 */ + error: Error | undefined; + /** 刷新建议列表 */ + refresh: () => void; + /** 重新生成建议 */ + regenerate: () => Promise; + /** 是否正在重新生成 */ + isRegenerating: boolean; + /** 重新生成错误 */ + regenerateError: Error | undefined; + /** 更新建议状态 */ + updateStatus: (suggestionId: string, newStatus: string) => Promise; + /** 是否正在更新状态 */ + isUpdatingStatus: boolean; +} + +function buildSuggestionsUrl(brandId: string, filters?: SuggestionsFilters): string | null { + if (!brandId) return null; + + const params = new URLSearchParams(); + if (filters?.type && filters.type !== "all") params.set("type", filters.type); + if (filters?.priority && filters.priority !== "all") params.set("priority", filters.priority); + if (filters?.status && filters.status !== "all") params.set("status", filters.status); + + const qs = params.toString(); + return `/api/v1/brands/${brandId}/suggestions${qs ? `?${qs}` : ""}`; +} + +export function useSuggestionsData( + options?: UseSuggestionsDataOptions +): UseSuggestionsDataReturn { + const brandId = options?.brandId ?? ""; + const filters = options?.filters; + const swrOptions = options?.swrOptions; + + // 建议列表 + const suggestionsUrl = buildSuggestionsUrl(brandId, filters); + const { + data: suggestionsResponse, + isLoading, + error, + refresh, + mutate, + } = useApi(suggestionsUrl, swrOptions); + + const suggestions = suggestionsResponse?.suggestions; + + // 重新生成建议(mutation) + const { + trigger: regenerateTrigger, + isMutating: isRegenerating, + error: regenerateError, + } = useApiMutation( + brandId ? `/api/v1/brands/${brandId}/suggestions/regenerate` : "", + "POST" + ); + + const regenerate = useCallback(async (): Promise => { + const result = await regenerateTrigger(); + if (result) { + // 用新数据更新 SWR 缓存 + mutate(result); + } + return result; + }, [regenerateTrigger, mutate]); + + // 更新建议状态(mutation) + const { + isMutating: isUpdatingStatus, + } = useApiMutation( + brandId ? `/api/v1/brands/${brandId}/suggestions` : "", + "PUT" + ); + + const updateStatus = useCallback( + async (suggestionId: string, newStatus: string) => { + await fetchWithAuth( + `/api/v1/brands/${brandId}/suggestions/${suggestionId}/status`, + { + method: "PUT", + body: JSON.stringify({ status: newStatus }), + } + ); + // 乐观更新:直接修改本地缓存 + mutate( + (current) => { + if (!current) return current; + return { + ...current, + suggestions: current.suggestions.map((s) => + s.id === suggestionId + ? { ...s, status: newStatus as SuggestionStatus } + : s + ), + }; + }, + { revalidate: false } + ); + }, + [brandId, mutate] + ); + + return { + suggestions, + isLoading, + error, + refresh, + regenerate, + isRegenerating, + regenerateError, + updateStatus, + isUpdatingStatus, + }; +} diff --git a/frontend/lib/next-action.ts b/frontend/lib/next-action.ts index c05962c..016d890 100644 --- a/frontend/lib/next-action.ts +++ b/frontend/lib/next-action.ts @@ -2,7 +2,14 @@ * 下一步行动建议生成逻辑 */ -import type { ActionContext, ActionRule, NextAction } from "@/types/next-action"; +import type { + ActionContext, + ActionRule, + NextAction, +} from "@/types/next-action"; + +// 行动定义类型(与 ActionRule 中的 action 类型一致) +type ActionDefinition = Omit; // 生成唯一ID function generateActionId(): string { @@ -11,11 +18,12 @@ function generateActionId(): string { // 创建行动项 function createAction( - base: Omit, + base: ActionDefinition, priority: NextAction["priority"], ): NextAction { return { ...base, + href: base.actionUrl, // href 与 actionUrl 保持一致 id: generateActionId(), priority, }; @@ -27,7 +35,8 @@ const newUserRules: ActionRule[] = [ condition: (ctx) => !ctx.hasBrands || ctx.brandCount === 0, primaryAction: { title: "创建第一个品牌", - description: "开始您的品牌分析之旅,创建第一个品牌后即可获得详细的AI认知度评分。", + description: + "开始您的品牌分析之旅,创建第一个品牌后即可获得详细的AI认知度评分。", actionText: "创建品牌", actionUrl: "/brands", icon: "🎯", @@ -79,9 +88,7 @@ const lowScoreRules: ActionRule[] = [ const highScoreNoGrowthRules: ActionRule[] = [ { condition: (ctx) => - ctx.hasData && - ctx.overallScore >= 60 && - ctx.scoreChange <= 0, + ctx.hasData && ctx.overallScore >= 60 && ctx.scoreChange <= 0, primaryAction: { title: "设置涨粉预警", description: "您的品牌评分较高但增长停滞,设置预警监控潜在风险。", @@ -112,7 +119,8 @@ const competitorThreatRules: ActionRule[] = [ condition: (ctx) => ctx.competitorCount > 0 && ctx.hasData, primaryAction: { title: "查看差距分析", - description: "竞品正在逼近或超越您的品牌,点击查看详细分析并制定应对策略。", + description: + "竞品正在逼近或超越您的品牌,点击查看详细分析并制定应对策略。", actionText: "查看差距分析", actionUrl: "/compare", icon: "⚠️", @@ -181,22 +189,16 @@ export function generateNextActions(context: ActionContext): NextAction[] { for (const rule of allRules) { if (rule.condition(context)) { // 添加主要行动 - actions.push( - createAction(rule.primaryAction, "primary"), - ); + actions.push(createAction(rule.primaryAction, "primary")); // 添加次要行动 if (rule.secondaryAction) { - actions.push( - createAction(rule.secondaryAction, "secondary"), - ); + actions.push(createAction(rule.secondaryAction, "secondary")); } // 添加可选行动 if (rule.optionalAction) { - actions.push( - createAction(rule.optionalAction, "optional"), - ); + actions.push(createAction(rule.optionalAction, "optional")); } break; // 找到第一个匹配规则后停止 diff --git a/frontend/lib/platforms.ts b/frontend/lib/platforms.ts index fd36911..81a4d2d 100644 --- a/frontend/lib/platforms.ts +++ b/frontend/lib/platforms.ts @@ -10,6 +10,19 @@ export const PLATFORM_MAP: Record = { xinghuo: "讯飞星火", }; +// 平台图标映射(统一来源,包含所有9个平台) +export const PLATFORM_ICONS: Record = { + wenxin: "🧠", + kimi: "📖", + tongyi: "🏔️", + baidu_ai: "🔍", + yuanbao: "💎", + qingyan: "🔮", + doubao: "🥟", + tiangong: "⚔️", + xinghuo: "🔥", +}; + export const PLATFORMS = [ { key: "wenxin", label: "文心一言" }, { key: "kimi", label: "Kimi" }, diff --git a/frontend/playwright.config.ts b/frontend/playwright.config.ts index da0f81b..8574f2f 100644 --- a/frontend/playwright.config.ts +++ b/frontend/playwright.config.ts @@ -11,6 +11,7 @@ export default defineConfig({ baseURL: "http://localhost:3000", trace: "on-first-retry", screenshot: "only-on-failure", + video: "retain-on-failure", actionTimeout: 30000, }, diff --git a/frontend/types/ai-engines.ts b/frontend/types/ai-engines.ts index c8d59d9..d153e95 100644 --- a/frontend/types/ai-engines.ts +++ b/frontend/types/ai-engines.ts @@ -56,5 +56,5 @@ export interface AIEnginesResponse { export interface AIQueryRequest { engines: AIEngineType[]; query: string; - brand_id: string; + brand_name: string; } diff --git a/frontend/types/brand.ts b/frontend/types/brand.ts index b816577..a7a37b5 100644 --- a/frontend/types/brand.ts +++ b/frontend/types/brand.ts @@ -182,3 +182,104 @@ export const INDUSTRY_OPTIONS = [ { value: "real_estate", label: "房地产" }, { value: "other", label: "其他" }, ] as const; + +export interface DimensionScoreResponse { + name: string; + score: number; + max_score: number; + percentage: number; + detail: Record; +} + +export interface BrandScoreV2Response { + overall_score: number; + health_level: "excellent" | "good" | "pass" | "danger"; + mention_rate: DimensionScoreResponse; + recommendation_rank: DimensionScoreResponse; + sentiment_score: DimensionScoreResponse; + citation_quality: DimensionScoreResponse; + competitive_position: DimensionScoreResponse; + mention_rate_score: number; + sov_score: number; + quality_score: number; +} + +export interface BrandScoreHistoryItem { + date: string; + mention_rate_score: number; + sov_score: number; + quality_score: number; + overall_score: number; + total_queries: number; + cited_count: number; +} + +export interface BrandScoreHistoryResponse { + history: BrandScoreHistoryItem[]; + total: number; +} + +export interface GraphStatistics { + entity_count: number; + relationship_count: number; + entity_types: Record; + relationship_types: Record; +} + +export interface GraphEntity { + id: string; + name: string; + entity_type: string; + properties: Record; + confidence: string; + source_chunk_ids: string[]; +} + +export interface GraphPath { + nodes: GraphEntity[]; + edges: Array<{ + id: string; + source_id: string; + target_id: string; + relationship_type: string; + properties: Record; + }>; + length: number; +} + +// ── 检测任务相关类型 ──────────────────────────────────────────────────────────── + +export interface DetectionTask { + id: string; + name: string; + brand_name: string; + platforms: string[]; + keywords: string[]; + frequency: string; + status: "active" | "paused" | "completed"; + last_run_at: string | null; + next_run_at: string | null; + created_at: string; + updated_at: string; +} + +export interface CreateDetectionTaskRequest { + name: string; + brand_name: string; + platforms: string[]; + keywords: string[]; + frequency: string; +} + +export interface UpdateDetectionTaskRequest { + name?: string; + platforms?: string[]; + keywords?: string[]; + frequency?: string; + status?: "active" | "paused"; +} + +export interface DetectionTaskListResponse { + items: DetectionTask[]; + total: number; +} diff --git a/frontend/types/dashboard-health.ts b/frontend/types/dashboard-health.ts index 5831614..2448355 100644 --- a/frontend/types/dashboard-health.ts +++ b/frontend/types/dashboard-health.ts @@ -2,6 +2,11 @@ * 健康状态Dashboard类型定义 */ +import type { + ActionPriority as ActionPriorityBase, + ActionItem, +} from "./suggestion"; + // 健康等级 export type HealthLevel = "excellent" | "good" | "pass" | "danger"; @@ -115,53 +120,27 @@ export interface RecentQuery { queried_at: string; } -// 行动建议 -export interface ActionSuggestion { - id: string; - type: "primary" | "secondary" | "optional"; - title: string; - description: string; - icon: string; - href: string; - priority: number; +// 行动建议(扩展自统一 ActionItem 类型) +export type ActionPriority = ActionPriorityBase; +export interface ActionSuggestion extends Omit { + type: ActionPriority; + priority: number; // 排序优先级(数字越小越优先) } -// 平台中文名称映射 -export const PLATFORM_LABELS: Record = { - wenxin: "文心一言", - kimi: "Kimi", - tongyi: "通义千问", - doubao: "豆包", - xinghuo: "讯飞星火", - tiangong: "天工AI", - qingyan: "智谱清言", -}; - -// 平台图标映射 -export const PLATFORM_ICONS: Record = { - wenxin: "🧠", - kimi: "📖", - tongyi: "🏔️", - doubao: "🥟", - xinghuo: "🔥", - tiangong: "⚔️", - qingyan: "💎", -}; - // 维度名称映射 export const DIMENSION_LABELS: Record = { - "提及率": "提及率", - "推荐排名": "推荐排名", - "情感倾向": "情感倾向", - "引用质量": "引用质量", - "竞品对比": "竞品对比", + 提及率: "提及率", + 推荐排名: "推荐排名", + 情感倾向: "情感倾向", + 引用质量: "引用质量", + 竞品对比: "竞品对比", }; // 维度颜色映射 export const DIMENSION_COLORS: Record = { - "提及率": "bg-blue-500", - "推荐排名": "bg-purple-500", - "情感倾向": "bg-emerald-500", - "引用质量": "bg-amber-500", - "竞品对比": "bg-rose-500", + 提及率: "bg-blue-500", + 推荐排名: "bg-purple-500", + 情感倾向: "bg-emerald-500", + 引用质量: "bg-amber-500", + 竞品对比: "bg-rose-500", }; diff --git a/frontend/types/next-action.ts b/frontend/types/next-action.ts index 2c0141b..cb20ccb 100644 --- a/frontend/types/next-action.ts +++ b/frontend/types/next-action.ts @@ -2,18 +2,19 @@ * 下一步行动相关类型定义 */ -// 行动优先级 -export type ActionPriority = "primary" | "secondary" | "optional"; +import type { + ActionPriority as ActionPriorityBase, + ActionItem, +} from "./suggestion"; -// 行动项 -export interface NextAction { - id: string; - priority: ActionPriority; - title: string; - description: string; - actionText: string; - actionUrl: string; - icon: string; +// 从统一类型重新导出 +export type ActionPriority = ActionPriorityBase; +export type { ActionItem }; + +// 行动项(扩展自统一 ActionItem,保留 actionText 和 actionUrl 兼容字段) +export interface NextAction extends ActionItem { + actionText: string; // 行动按钮文字 + actionUrl: string; // 兼容旧字段,与 href 值相同 } // 用户状态上下文 @@ -38,12 +39,13 @@ export interface ActionGenerationParams { cachedActions?: NextAction[]; } -// 行动建议规则 +// 行动建议规则(action 定义中不需要 href,由 createAction 自动从 actionUrl 生成) +type ActionDefinition = Omit; export interface ActionRule { condition: (context: ActionContext) => boolean; - primaryAction: Omit; - secondaryAction: Omit | null; - optionalAction: Omit | null; + primaryAction: ActionDefinition; + secondaryAction: ActionDefinition | null; + optionalAction: ActionDefinition | null; } // 行动建议卡片Props diff --git a/frontend/types/onboarding.ts b/frontend/types/onboarding.ts index 073e712..1d4a3d2 100644 --- a/frontend/types/onboarding.ts +++ b/frontend/types/onboarding.ts @@ -2,64 +2,10 @@ * 新用户引导向导相关类型定义 */ -// 健康等级 -export type HealthLevel = "excellent" | "good" | "fair" | "danger"; +export type { HealthLevel } from "@/types/dashboard-health"; +export { HEALTH_LEVEL_CONFIG as HEALTH_LEVELS } from "@/types/dashboard-health"; +export { getHealthLevel } from "@/lib/dashboard-health"; -// 健康等级配置 -export interface HealthLevelConfig { - label: string; - color: string; - bgColor: string; - borderColor: string; - minScore: number; - maxScore: number; -} - -// 健康等级定义 -export const HEALTH_LEVELS: Record = { - excellent: { - label: "优秀", - color: "text-emerald-600", - bgColor: "bg-emerald-50", - borderColor: "border-emerald-200", - minScore: 80, - maxScore: 100, - }, - good: { - label: "良好", - color: "text-yellow-600", - bgColor: "bg-yellow-50", - borderColor: "border-yellow-200", - minScore: 60, - maxScore: 79, - }, - fair: { - label: "及格", - color: "text-orange-600", - bgColor: "bg-orange-50", - borderColor: "border-orange-200", - minScore: 40, - maxScore: 59, - }, - danger: { - label: "危险", - color: "text-red-600", - bgColor: "bg-red-50", - borderColor: "border-red-200", - minScore: 0, - maxScore: 39, - }, -}; - -// 获取健康等级 -export function getHealthLevel(score: number): HealthLevel { - if (score >= 80) return "excellent"; - if (score >= 60) return "good"; - if (score >= 40) return "fair"; - return "danger"; -} - -// 引导流程状态 export interface OnboardingState { currentStep: number; brandName: string; @@ -71,18 +17,18 @@ export interface OnboardingState { isSkipped: boolean; } -// 竞品推荐响应 export interface CompetitorRecommendation { id: string; name: string; reason: string; } -// 品牌健康报告 export interface BrandHealthReport { brand_id: string; brand_name: string; overall_score: number; + health_level?: string; + health_level_label?: string; platform_scores: Record; competitor_scores: Array<{ name: string; @@ -91,18 +37,32 @@ export interface BrandHealthReport { }>; strengths: string[]; weaknesses: string[]; + dimensions?: Array<{ + name: string; + score: number; + max_score: number; + percentage: number; + status: string; + }>; + recommendations?: Array<{ + priority: string; + dimension: string; + title: string; + description: string; + }>; + is_full_report?: boolean; } -// 行动建议 export interface ActionSuggestion { id: string; title: string; description: string; priority: "high" | "medium" | "low"; - action_type: "improve_platform" | "add_competitor" | "optimize_content" | "increase_frequency"; + action_type: string; + is_paid_action?: boolean; + action_button_text?: string; } -// 创建品牌请求(引导流程用) export interface OnboardingCreateBrandRequest { name: string; aliases?: string[]; @@ -111,7 +71,6 @@ export interface OnboardingCreateBrandRequest { frequency: "daily" | "weekly" | "monthly"; } -// 引导流程步骤定义 export interface OnboardingStep { id: number; title: string; @@ -120,9 +79,9 @@ export interface OnboardingStep { } export const ONBOARDING_STEPS: OnboardingStep[] = [ - { id: 1, title: "输入品牌名称", description: "输入您要监控的品牌名称", isSkippable: false }, + { id: 0, title: "健康分检测", description: "免费检测品牌GEO健康分", isSkippable: false }, + { id: 1, title: "创建品牌", description: "输入品牌名称开始监控", isSkippable: false }, { id: 2, title: "确认竞品", description: "选择与您品牌竞争的对手", isSkippable: true }, - { id: 3, title: "选择平台", description: "选择要监控的AI搜索平台", isSkippable: true }, - { id: 4, title: "健康报告", description: "查看您的品牌健康报告", isSkippable: false }, - { id: 5, title: "行动建议", description: "获取提升品牌曝光的建议", isSkippable: true }, + { id: 3, title: "健康报告", description: "查看详细诊断报告", isSkippable: false }, + { id: 4, title: "行动建议", description: "获取提升品牌曝光的建议", isSkippable: true }, ]; diff --git a/frontend/types/strategy.ts b/frontend/types/strategy.ts new file mode 100644 index 0000000..2c9db85 --- /dev/null +++ b/frontend/types/strategy.ts @@ -0,0 +1,45 @@ +export interface GeoPlanAction { + id: string; + plan_id: string; + action_type: "content_creation" | "content_optimization" | "query_expansion" | "schema_optimization" | "platform_targeting"; + title: string; + description: string; + reason: string; + priority: "high" | "medium" | "low"; + status: "pending" | "in_progress" | "completed" | "skipped"; + target_keyword: string | null; + target_platform: string | null; + content_style: string | null; + estimated_impact: string | null; + difficulty: "easy" | "medium" | "hard"; + execution_params: Record | null; + sort_order: number; + completed_at: string | null; + created_at: string; +} + +export interface GeoPlan { + id: string; + brand_id: string; + title: string; + status: "draft" | "active" | "completed" | "archived"; + diagnosis_score: number; + target_score: number; + estimated_weeks: number; + plan_data: Record | null; + source: string; + actions: GeoPlanAction[]; + created_at: string; + updated_at: string; +} + +export interface GeoPlanListResponse { + plans: GeoPlan[]; + total: number; +} + +export interface GeoPlanActionExecuteResponse { + action_id: string; + content_id: string | null; + message: string; +} diff --git a/frontend/types/suggestion.ts b/frontend/types/suggestion.ts index 4936b2d..6b8e097 100644 --- a/frontend/types/suggestion.ts +++ b/frontend/types/suggestion.ts @@ -1,7 +1,25 @@ /** * 优化建议相关类型定义 + * 包含统一的行动建议基础类型(ActionItem, ActionPriority) */ +// ========== 统一行动建议基础类型 ========== + +// 行动优先级(ActionSuggestion 和 NextAction 共享) +export type ActionPriority = "primary" | "secondary" | "optional"; + +// 统一行动项接口 +export interface ActionItem { + id: string; + title: string; + description: string; + priority: ActionPriority; + icon: string; + href: string; +} + +// ========== 优化建议相关类型 ========== + // 建议类型 export type SuggestionType = | "content_optimization" diff --git a/tests/conftest.py b/tests/conftest.py deleted file mode 100644 index 97c87de..0000000 --- a/tests/conftest.py +++ /dev/null @@ -1,122 +0,0 @@ -import sys -from pathlib import Path - -sys.path.insert(0, str(Path(__file__).parent.parent / "backend")) - -import pytest -import pytest_asyncio -import uuid -from datetime import datetime -from unittest.mock import AsyncMock, patch - -from httpx import ASGITransport, AsyncClient -from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession -from sqlalchemy.pool import StaticPool - -from app.main import app -from app.api.deps import get_current_user -from app.database import Base, get_db -from app.services.auth import create_access_token - - -# --------------------------------------------------------------------------- -# 全局 mock:防止启动真实调度器 / Playwright 浏览器 -# --------------------------------------------------------------------------- -@pytest.fixture(scope="session", autouse=True) -def mock_scheduler(): - """Mock the query scheduler to prevent real background jobs in tests.""" - with patch("app.main.query_scheduler") as mock_sched: - mock_sched.start = lambda: None - mock_sched.shutdown = AsyncMock() - yield - - -# --------------------------------------------------------------------------- -# 内存数据库 fixture(供集成测试使用) -# --------------------------------------------------------------------------- -TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:" - - -@pytest_asyncio.fixture -async def test_engine(): - """Create a fresh in-memory SQLite engine for each test function.""" - engine = create_async_engine( - TEST_DATABASE_URL, - echo=False, - future=True, - poolclass=StaticPool, - connect_args={"check_same_thread": False}, - ) - async with engine.begin() as conn: - await conn.run_sync(Base.metadata.create_all) - yield engine - await engine.dispose() - - -@pytest_asyncio.fixture -async def test_session(test_engine) -> AsyncSession: - """Yield an async session bound to the in-memory engine.""" - async_session = async_sessionmaker( - test_engine, class_=AsyncSession, expire_on_commit=False, autoflush=False, autocommit=False - ) - async with async_session() as session: - yield session - - -@pytest_asyncio.fixture -async def override_get_db(test_session): - """Override FastAPI get_db dependency to use the test session.""" - async def _get_db(): - yield test_session - - app.dependency_overrides[get_db] = _get_db - yield test_session - app.dependency_overrides.pop(get_db, None) - - -# --------------------------------------------------------------------------- -# Mock 用户 fixture(供单元测试使用) -# --------------------------------------------------------------------------- -@pytest.fixture -def mock_user(): - """Return a mock authenticated user.""" - user = AsyncMock() - user.id = uuid.UUID("12345678-1234-1234-1234-123456789abc") - user.email = "test@example.com" - user.name = "Test User" - user.plan = "free" - user.max_queries = 5 - user.is_active = True - user.created_at = datetime.now() - return user - - -@pytest.fixture -def override_get_current_user(mock_user): - """Override the get_current_user dependency to return a mock user.""" - async def _override(): - return mock_user - - app.dependency_overrides[get_current_user] = _override - yield - app.dependency_overrides.pop(get_current_user, None) - - -@pytest.fixture -def auth_token(mock_user): - """Generate a valid JWT access token for the mock user.""" - return create_access_token(data={"sub": str(mock_user.id)}) - - -@pytest.fixture -def auth_headers(auth_token): - """Return request headers containing the Bearer token.""" - return {"Authorization": f"Bearer {auth_token}"} - - -@pytest_asyncio.fixture -async def async_client(): - """Create an async HTTP client for testing the FastAPI app.""" - transport = ASGITransport(app=app) - async with AsyncClient(transport=transport, base_url="http://test") as client: - yield client diff --git a/tests/test_citation_engine.py b/tests/test_citation_engine.py deleted file mode 100644 index bb8757e..0000000 --- a/tests/test_citation_engine.py +++ /dev/null @@ -1,126 +0,0 @@ -import pytest - -from app.workers.citation_engine import BrandMatcher, CompetitorDetector - - -def test_brand_matcher_exact(): - matcher = BrandMatcher(target_brand="华为", brand_aliases=["Huawei"]) - result = matcher.match("华为是一家伟大的科技公司") - assert result["cited"] is True - assert result["match_type"] == "exact" - assert result["confidence"] == 1.0 - - -def test_brand_matcher_alias(): - matcher = BrandMatcher(target_brand="华为", brand_aliases=["Huawei"]) - result = matcher.match("Huawei makes great phones") - assert result["cited"] is True - assert result["match_type"] == "alias" - assert result["confidence"] == 0.9 - - -def test_brand_matcher_fuzzy(): - matcher = BrandMatcher(target_brand="华为") - # "华伟" is a fuzzy match to "华为" - result = matcher.match("华伟 是一家科技公司") - assert result["cited"] is True - assert result["match_type"] == "fuzzy" - assert result["confidence"] > 0.4 - - -def test_brand_matcher_no_match(): - matcher = BrandMatcher(target_brand="华为") - result = matcher.match("这是一段完全不相关的文本,没有任何品牌信息") - assert result["cited"] is False - assert result["match_type"] is None - assert result["confidence"] == 0.0 - - -def test_competitor_detector(): - detector = CompetitorDetector() - text = "中国平安和中国人寿都是大型保险公司" - competitors = detector.detect(text, target_brand="中国平安") - assert "中国人寿" in competitors - assert "中国平安" not in competitors - - -def test_citation_position(): - matcher = BrandMatcher(target_brand="华为") - text = "第一段介绍市场情况\n第二段提到华为的产品\n第三段是总结" - result = matcher.match(text) - assert result["cited"] is True - assert result["position"] == 2 - assert result["citation_text"] is not None - - -# --------------------------------------------------------------------------- -# 补充测试 -# --------------------------------------------------------------------------- -def test_brand_matcher_multiple_aliases(): - """多个别名时,应能匹配任意一个别名。""" - matcher = BrandMatcher(target_brand="华为", brand_aliases=["Huawei", "HW", "Honor"]) - # 匹配第二个别名 - result = matcher.match("HW released a new chip") - assert result["cited"] is True - assert result["match_type"] == "alias" - assert result["confidence"] == 0.9 - # 匹配第三个别名 - result2 = matcher.match("Honor phones are popular") - assert result2["cited"] is True - assert result2["match_type"] == "alias" - - -def test_brand_matcher_fuzzy_threshold_boundary(): - """模糊匹配阈值边界:ratio 恰好在 0.4 附近的情况。""" - matcher = BrandMatcher(target_brand="华为") - # "华伟" 与 "华为" 的 ratio 约为 0.5,应大于 0.4 - result = matcher.match("华伟 是一家科技公司") - assert result["cited"] is True - assert result["match_type"] == "fuzzy" - assert result["confidence"] > 0.4 - - # "苹果" vs "华为" ratio 很低,不应超过 0.4 - result2 = matcher.match("苹果科技公司") - assert result2["cited"] is False - assert result2["match_type"] is None - - -def test_brand_matcher_empty_text(): - """空字符串输入不应崩溃,应返回 cited=False。""" - matcher = BrandMatcher(target_brand="华为", brand_aliases=["Huawei"]) - result = matcher.match("") - assert result["cited"] is False - assert result["confidence"] == 0.0 - assert result["match_type"] is None - assert result["position"] is None - - -def test_competitor_detector_multi_industry(): - """竞品检测应能跨行业识别品牌。""" - detector = CompetitorDetector() - text = "华为和小米是科技公司,工商银行和招商银行是银行" - competitors_tech = detector.detect(text, target_brand="华为") - assert "小米" in competitors_tech - assert "腾讯" not in competitors_tech # 未在文本中出现 - assert "华为" not in competitors_tech - - competitors_finance = detector.detect(text, target_brand="工商银行") - assert "招商银行" in competitors_finance - assert "建设银行" not in competitors_finance # 未在文本中出现 - - -def test_citation_position_multiple_paragraphs(): - """品牌在不同段落位置出现时的 position 检测。""" - matcher = BrandMatcher(target_brand="华为") - - text_first = "华为位于第一段\n第二段没有\n第三段也没有" - result = matcher.match(text_first) - assert result["position"] == 1 - - text_third = "第一段没有\n第二段没有\n华为在第三段" - result = matcher.match(text_third) - assert result["position"] == 3 - - text_last = "第一段\n第二段\n第三段\n最后一段提到华为" - result = matcher.match(text_last) - assert result["position"] == 4 diff --git a/tests/test_content_agents.py b/tests/test_content_agents.py deleted file mode 100644 index 05844aa..0000000 --- a/tests/test_content_agents.py +++ /dev/null @@ -1,357 +0,0 @@ -"""Agent执行逻辑单元测试 - ContentGeneratorAgent / DeAIAgent / GEOOptimizerAgent - -测试策略: -- 使用 FakeLLMProvider mock LLM 调用,避免真实网络请求 -- patch BaseAgent.report_progress 避免 Redis / 数据库依赖 -- patch RAGService / AsyncSessionLocal 避免真实数据库访问 -""" - -import json -import uuid -from datetime import datetime, timezone -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest - -from app.agent_framework.agents.content_generator_agent import ContentGeneratorAgent -from app.agent_framework.agents.deai_agent import DeAIAgent -from app.agent_framework.agents.geo_optimizer_agent import GEOOptimizerAgent -from app.agent_framework.protocol import TaskMessage -from app.services.llm import LLMProvider, LLMResponse, LLMError - - -# --------------------------------------------------------------------------- -# FakeLLMProvider - 测试用假LLM -# --------------------------------------------------------------------------- -class FakeLLMProvider(LLMProvider): - """测试用假LLM,返回预设响应""" - - def __init__(self, response_content: str = "fake response"): - self._response = response_content - - @property - def provider_name(self) -> str: - return "fake" - - @property - def model_name(self) -> str: - return "fake-model" - - @property - def max_context_length(self) -> int: - return 4096 - - async def chat(self, messages, **kwargs) -> LLMResponse: - return LLMResponse( - content=self._response, - model="fake-model", - usage={"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}, - ) - - async def chat_stream(self, messages, **kwargs): - for word in self._response.split(): - yield word + " " - - -# --------------------------------------------------------------------------- -# 测试辅助函数 -# --------------------------------------------------------------------------- -def _make_task(task_type: str, input_data: dict) -> TaskMessage: - return TaskMessage( - task_id=str(uuid.uuid4()), - agent_name="test_agent", - task_type=task_type, - priority=1, - input_data=input_data, - callback_url=None, - created_at=datetime.now(timezone.utc), - timeout_seconds=300, - ) - - -# --------------------------------------------------------------------------- -# ContentGeneratorAgent 测试 -# --------------------------------------------------------------------------- -@pytest.mark.asyncio -async def test_generate_topics_returns_parsed_json(): - """FakeLLM返回JSON数组,验证topics字段正确解析""" - agent = ContentGeneratorAgent() - fake_llm = FakeLLMProvider( - response_content='[{"title": "AI营销趋势", "reason": "热门话题"}]' - ) - - with patch.object(agent, "report_progress", new_callable=AsyncMock): - with patch( - "app.agent_framework.agents.content_generator_agent.LLMFactory.get_default", - return_value=fake_llm, - ): - task = _make_task("generate_topics", {"target_keyword": "AI营销"}) - result = await agent.execute(task) - - assert result.status == "completed" - assert "topics" in result.output_data - topics = result.output_data["topics"] - assert isinstance(topics, list) - assert topics[0]["title"] == "AI营销趋势" - assert topics[0]["reason"] == "热门话题" - - -@pytest.mark.asyncio -async def test_generate_article_success(): - """验证返回content字段""" - agent = ContentGeneratorAgent() - fake_llm = FakeLLMProvider(response_content="这是一篇测试文章") - - with patch.object(agent, "report_progress", new_callable=AsyncMock): - with patch( - "app.agent_framework.agents.content_generator_agent.LLMFactory.get_default", - return_value=fake_llm, - ): - task = _make_task("generate_article", {"target_keyword": "AI营销"}) - result = await agent.execute(task) - - assert result.status == "completed" - assert result.output_data["content"] == "这是一篇测试文章" - assert result.output_data["word_count"] == len("这是一篇测试文章") - assert "usage" in result.output_data - - -@pytest.mark.asyncio -async def test_generate_with_rag_context(): - """Mock RAGService,验证知识上下文被注入""" - agent = ContentGeneratorAgent() - fake_llm = FakeLLMProvider( - response_content='[{"title": "RAG测试选题", "reason": "测试"}]' - ) - - # Mock AsyncSessionLocal 上下文管理器 - mock_session = AsyncMock() - mock_session.__aenter__ = AsyncMock(return_value=mock_session) - mock_session.__aexit__ = AsyncMock(return_value=False) - - mock_local = MagicMock() - mock_local.return_value.__aenter__ = AsyncMock(return_value=mock_session) - mock_local.return_value.__aexit__ = AsyncMock(return_value=False) - - with patch.object(agent, "report_progress", new_callable=AsyncMock): - with patch("app.database.AsyncSessionLocal", mock_local): - with patch( - "app.services.knowledge.rag_service.RAGService" - ) as MockRAG: - mock_rag = MockRAG.return_value - mock_rag.search = AsyncMock( - return_value=[ - {"document_title": "知识库文档", "content": "相关知识内容"} - ] - ) - with patch( - "app.agent_framework.agents.content_generator_agent.LLMFactory.get_default", - return_value=fake_llm, - ): - task = _make_task( - "generate_topics", - {"target_keyword": "AI营销", "knowledge_base_ids": ["kb-1"]}, - ) - result = await agent.execute(task) - - assert result.status == "completed" - mock_rag.search.assert_awaited_once() - # 验证 search 调用参数 - call_kwargs = mock_rag.search.call_args.kwargs - assert call_kwargs["query"] == "AI营销" - assert call_kwargs["knowledge_base_ids"] == ["kb-1"] - - -@pytest.mark.asyncio -async def test_llm_error_returns_failed(): - """Mock LLM抛出LLMError,验证返回failed状态""" - agent = ContentGeneratorAgent() - - class ErrorLLM(FakeLLMProvider): - async def chat(self, messages, **kwargs) -> LLMResponse: - raise LLMError("API错误", provider="fake", status_code=500) - - error_llm = ErrorLLM() - - with patch.object(agent, "report_progress", new_callable=AsyncMock): - with patch( - "app.agent_framework.agents.content_generator_agent.LLMFactory.get_default", - return_value=error_llm, - ): - task = _make_task("generate_topics", {"target_keyword": "AI营销"}) - result = await agent.execute(task) - - assert result.status == "failed" - assert "LLM调用失败" in result.error_message - - -@pytest.mark.asyncio -async def test_extract_json_from_code_block(): - """测试```json```包裹的JSON提取""" - agent = ContentGeneratorAgent() - text = '```json\n[{"title": "测试"}]\n```' - result = agent._extract_json(text) - assert result == '[{"title": "测试"}]' - - -# --------------------------------------------------------------------------- -# DeAIAgent 测试 -# --------------------------------------------------------------------------- -@pytest.mark.asyncio -async def test_deai_success(): - """正常处理返回success""" - agent = DeAIAgent() - fake_llm = FakeLLMProvider(response_content="去AI化后的内容") - - with patch.object(agent, "report_progress", new_callable=AsyncMock): - with patch( - "app.agent_framework.agents.deai_agent.LLMFactory.get_default", - return_value=fake_llm, - ): - task = _make_task("deai_process", {"content": "原始的AI生成内容"}) - result = await agent.execute(task) - - assert result.status == "completed" - assert result.output_data["content"] == "去AI化后的内容" - assert result.output_data["original_word_count"] == len("原始的AI生成内容") - assert result.output_data["processed_word_count"] == len("去AI化后的内容") - - -@pytest.mark.asyncio -async def test_deai_empty_content_fails(): - """空content返回failed""" - agent = DeAIAgent() - fake_llm = FakeLLMProvider(response_content="something") - - with patch.object(agent, "report_progress", new_callable=AsyncMock): - with patch( - "app.agent_framework.agents.deai_agent.LLMFactory.get_default", - return_value=fake_llm, - ): - task = _make_task("deai_process", {"content": ""}) - result = await agent.execute(task) - - assert result.status == "failed" - # ValueError 会被外层 except 捕获,error_message 包含原始异常信息 - assert "content" in result.error_message.lower() or "input_data" in result.error_message.lower() - - -@pytest.mark.asyncio -async def test_deai_temperature_is_high(): - """验证调用LLM时temperature=0.9""" - agent = DeAIAgent() - mock_provider = AsyncMock() - mock_provider.chat = AsyncMock( - return_value=LLMResponse( - content="processed", - model="fake", - usage={"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}, - ) - ) - - with patch.object(agent, "report_progress", new_callable=AsyncMock): - with patch( - "app.agent_framework.agents.deai_agent.LLMFactory.get_default", - return_value=mock_provider, - ): - task = _make_task("deai_process", {"content": "some content"}) - await agent.execute(task) - - mock_provider.chat.assert_awaited_once() - _, kwargs = mock_provider.chat.call_args - assert kwargs.get("temperature") == 0.9 - - -# --------------------------------------------------------------------------- -# GEOOptimizerAgent 测试 -# --------------------------------------------------------------------------- -@pytest.mark.asyncio -async def test_geo_optimize_json_response(): - """FakeLLM返回标准JSON,验证解析""" - agent = GEOOptimizerAgent() - fake_llm = FakeLLMProvider( - response_content=json.dumps( - { - "optimized_content": "优化后的文章", - "seo_score": 85, - "changes": ["优化了标题"], - } - ) - ) - - with patch.object(agent, "report_progress", new_callable=AsyncMock): - with patch( - "app.agent_framework.agents.geo_optimizer_agent.LLMFactory.get_default", - return_value=fake_llm, - ): - task = _make_task( - "geo_optimize", - {"content": "原始文章", "target_keywords": ["SEO"]}, - ) - result = await agent.execute(task) - - assert result.status == "completed" - assert result.output_data["optimized_content"] == "优化后的文章" - assert result.output_data["seo_score"] == 85 - assert result.output_data["changes"] == ["优化了标题"] - assert "usage" in result.output_data - - -@pytest.mark.asyncio -async def test_geo_optimize_fallback(): - """FakeLLM返回纯文本,验证降级处理""" - agent = GEOOptimizerAgent() - fake_llm = FakeLLMProvider(response_content="这不是JSON格式") - - with patch.object(agent, "report_progress", new_callable=AsyncMock): - with patch( - "app.agent_framework.agents.geo_optimizer_agent.LLMFactory.get_default", - return_value=fake_llm, - ): - task = _make_task( - "geo_optimize", - {"content": "原始文章", "target_keywords": ["SEO"]}, - ) - result = await agent.execute(task) - - assert result.status == "completed" - assert result.output_data["optimized_content"] == "这不是JSON格式" - assert result.output_data["seo_score"] is None - assert result.output_data["changes"] == ["LLM输出非标准格式,已返回原始优化结果"] - - -@pytest.mark.asyncio -async def test_geo_optimize_keywords_in_prompt(): - """验证关键词出现在渲染后的prompt variables中""" - agent = GEOOptimizerAgent() - - with patch.object(agent, "report_progress", new_callable=AsyncMock): - with patch( - "app.agent_framework.agents.geo_optimizer_agent.GEO_OPTIMIZER_TEMPLATE.render" - ) as mock_render: - mock_render.return_value = [ - {"role": "system", "content": "test prompt"} - ] - with patch( - "app.agent_framework.agents.geo_optimizer_agent.LLMFactory.get_default" - ) as mock_factory: - mock_provider = AsyncMock() - mock_provider.chat = AsyncMock( - return_value=LLMResponse( - content=json.dumps({"optimized_content": "test"}), - model="fake", - usage={}, - ) - ) - mock_factory.return_value = mock_provider - - task = _make_task( - "geo_optimize", - {"content": "原始文章", "target_keywords": ["SEO", "GEO优化"]}, - ) - await agent.execute(task) - - mock_render.assert_called_once() - variables = mock_render.call_args[0][0] - assert "SEO" in variables["target_keywords"] - assert "GEO优化" in variables["target_keywords"] diff --git a/tests/test_pipeline_engine.py b/tests/test_pipeline_engine.py deleted file mode 100644 index 664386a..0000000 --- a/tests/test_pipeline_engine.py +++ /dev/null @@ -1,254 +0,0 @@ -"""Pipeline 引擎单元测试""" -import pytest -import textwrap - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - -SIMPLE_YAML = textwrap.dedent(""" -name: test_pipeline -version: "1.0" -description: "单元测试用Pipeline" -variables: - brand_name: "TestBrand" -stages: - - name: step1 - agent: content_agent - action: generate - inputs: - brand: "${brand_name}" - outputs: - - result - - name: step2 - agent: review_agent - action: review - depends_on: - - step1 - inputs: - content: "${stages.step1.outputs.result}" - outputs: - - reviewed -""") - -CYCLIC_YAML = textwrap.dedent(""" -name: cyclic_pipeline -stages: - - name: a - agent: agent1 - action: act - depends_on: - - b - - name: b - agent: agent2 - action: act - depends_on: - - a -""") - - -# --------------------------------------------------------------------------- -# PipelineLoader 测试 -# --------------------------------------------------------------------------- - -class TestPipelineLoader: - def test_load_valid_yaml(self): - """加载正常 YAML 返回 Pipeline 对象""" - from app.agent_framework.pipeline.loader import PipelineLoader - from app.agent_framework.pipeline.schema import Pipeline - - loader = PipelineLoader() - pipeline = loader.load_from_yaml(SIMPLE_YAML) - - assert isinstance(pipeline, Pipeline) - assert pipeline.name == "test_pipeline" - assert len(pipeline.stages) == 2 - assert pipeline.stages[0].name == "step1" - assert pipeline.stages[1].name == "step2" - - def test_dag_validation_detects_cycle(self): - """有环图抛出 PipelineCyclicError""" - from app.agent_framework.pipeline.loader import PipelineLoader, PipelineCyclicError - - loader = PipelineLoader() - with pytest.raises(PipelineCyclicError): - loader.load_from_yaml(CYCLIC_YAML) - - def test_dag_validation_passes_acyclic(self): - """无环图验证通过,不抛异常""" - from app.agent_framework.pipeline.loader import PipelineLoader - - loader = PipelineLoader() - pipeline = loader.load_from_yaml(SIMPLE_YAML) - assert pipeline is not None - - def test_topological_sort_order(self): - """拓扑排序结果尊重依赖:step1 必须在 step2 之前""" - from app.agent_framework.pipeline.loader import PipelineLoader - from app.agent_framework.pipeline.engine import PipelineEngine - - loader = PipelineLoader() - pipeline = loader.load_from_yaml(SIMPLE_YAML) - - engine = PipelineEngine(dispatcher=None) - sorted_stages = engine._topological_sort(pipeline.stages) - names = [s.name for s in sorted_stages] - - assert names.index("step1") < names.index("step2") - - def test_variable_resolution_simple(self): - """${var} 简单变量替换""" - from app.agent_framework.pipeline.loader import PipelineLoader - - result = PipelineLoader.resolve_variables("${brand_name}", {"brand_name": "MyBrand"}) - assert result == "MyBrand" - - def test_variable_resolution_nested(self): - """${stages.step1.outputs.result} 嵌套路径解析""" - from app.agent_framework.pipeline.loader import PipelineLoader - - context = { - "stages": { - "step1": { - "outputs": { - "result": "generated_content" - } - } - } - } - result = PipelineLoader.resolve_variables( - "${stages.step1.outputs.result}", context - ) - assert result == "generated_content" - - def test_variable_unresolved_kept(self): - """未定义变量保持 ${var} 原样""" - from app.agent_framework.pipeline.loader import PipelineLoader - - result = PipelineLoader.resolve_variables("${undefined_var}", {}) - assert result == "${undefined_var}" - - def test_variable_resolution_in_dict(self): - """dict 中的变量引用被递归解析""" - from app.agent_framework.pipeline.loader import PipelineLoader - - template = {"key": "${greeting}", "nested": {"val": "${name}"}} - context = {"greeting": "Hello", "name": "World"} - result = PipelineLoader.resolve_variables(template, context) - - assert result["key"] == "Hello" - assert result["nested"]["val"] == "World" - - -# --------------------------------------------------------------------------- -# PipelineEngine 测试 -# --------------------------------------------------------------------------- - -class TestPipelineEngine: - @pytest.mark.asyncio - async def test_dry_run_mode(self): - """dispatcher=None 时 dry-run 模式正常执行,返回 PipelineResult""" - from app.agent_framework.pipeline.loader import PipelineLoader - from app.agent_framework.pipeline.engine import PipelineEngine - from app.agent_framework.pipeline.schema import StageStatus - - loader = PipelineLoader() - pipeline = loader.load_from_yaml(SIMPLE_YAML) - - engine = PipelineEngine(dispatcher=None) - result = await engine.execute(pipeline, context={"brand_name": "TestBrand"}) - - assert result.pipeline_name == "test_pipeline" - assert result.status == StageStatus.COMPLETED - assert "step1" in result.stages_results - assert "step2" in result.stages_results - - @pytest.mark.asyncio - async def test_dry_run_stage_outputs(self): - """dry-run 模式下每个 stage 的 outputs 有值""" - from app.agent_framework.pipeline.loader import PipelineLoader - from app.agent_framework.pipeline.engine import PipelineEngine - from app.agent_framework.pipeline.schema import StageStatus - - loader = PipelineLoader() - pipeline = loader.load_from_yaml(SIMPLE_YAML) - - engine = PipelineEngine(dispatcher=None) - result = await engine.execute(pipeline) - - step1_result = result.stages_results["step1"] - assert step1_result.status == StageStatus.COMPLETED - assert "result" in step1_result.outputs - - def test_stage_timeout_config(self): - """超时配置正确传递到 stage""" - from app.agent_framework.pipeline.schema import PipelineStage - - stage = PipelineStage( - name="timed_stage", - agent="some_agent", - action="do_action", - timeout_seconds=60, - ) - assert stage.timeout_seconds == 60 - - def test_stage_retry_count(self): - """重试次数配置正确""" - from app.agent_framework.pipeline.schema import PipelineStage - - stage = PipelineStage( - name="retry_stage", - agent="some_agent", - action="do_action", - retry_count=3, - ) - assert stage.retry_count == 3 - - @pytest.mark.asyncio - async def test_pipeline_variables_override(self): - """外部 context 变量覆盖 pipeline 默认变量""" - from app.agent_framework.pipeline.loader import PipelineLoader - from app.agent_framework.pipeline.engine import PipelineEngine - - loader = PipelineLoader() - pipeline = loader.load_from_yaml(SIMPLE_YAML) - - # 使用外部 context 覆盖 brand_name - engine = PipelineEngine(dispatcher=None) - result = await engine.execute(pipeline, context={"brand_name": "OverrideBrand"}) - - # 执行不应失败 - from app.agent_framework.pipeline.schema import StageStatus - assert result.status == StageStatus.COMPLETED - - def test_load_error_on_invalid_yaml(self): - """加载无效 YAML 抛出 PipelineLoadError""" - from app.agent_framework.pipeline.loader import PipelineLoader, PipelineLoadError - - loader = PipelineLoader() - with pytest.raises(PipelineLoadError): - loader.load_from_yaml("not: valid: yaml: [[[") - - def test_validate_dag_acyclic(self): - """validate_dag 对无环图返回 True""" - from app.agent_framework.pipeline.loader import PipelineLoader - from app.agent_framework.pipeline.schema import PipelineStage - - stages = [ - PipelineStage(name="a", agent="ag", action="act", depends_on=[]), - PipelineStage(name="b", agent="ag", action="act", depends_on=["a"]), - PipelineStage(name="c", agent="ag", action="act", depends_on=["b"]), - ] - assert PipelineLoader.validate_dag(stages) is True - - def test_validate_dag_cyclic(self): - """validate_dag 对有环图返回 False""" - from app.agent_framework.pipeline.loader import PipelineLoader - from app.agent_framework.pipeline.schema import PipelineStage - - stages = [ - PipelineStage(name="a", agent="ag", action="act", depends_on=["b"]), - PipelineStage(name="b", agent="ag", action="act", depends_on=["a"]), - ] - assert PipelineLoader.validate_dag(stages) is False diff --git a/tests/test_prompt_template.py b/tests/test_prompt_template.py deleted file mode 100644 index 385b492..0000000 --- a/tests/test_prompt_template.py +++ /dev/null @@ -1,264 +0,0 @@ -"""Prompt 模板单元测试""" -import pytest - - -# --------------------------------------------------------------------------- -# Fixtures -# --------------------------------------------------------------------------- - -@pytest.fixture -def simple_template(): - """构造一个简单的 PromptTemplate""" - from app.agent_framework.prompts.base_template import PromptSection, PromptTemplate - - sections = PromptSection( - identity="你是一个AI助手", - context="品牌:${brand_name}", - instructions="请为 ${topic} 生成内容", - constraints="字数不超过500字", - output_format="输出 JSON 格式", - examples="示例:{ 'title': '...' }", - ) - return PromptTemplate(sections=sections) - - -# --------------------------------------------------------------------------- -# PromptTemplate 基本测试 -# --------------------------------------------------------------------------- - -class TestPromptTemplate: - def test_render_returns_messages(self, simple_template): - """render 输出 list[dict] 格式""" - messages = simple_template.render() - - assert isinstance(messages, list) - assert len(messages) >= 1 - for msg in messages: - assert "role" in msg - assert "content" in msg - assert msg["role"] in ("system", "user") - - def test_system_user_message_split(self, simple_template): - """system 含 identity+context,user 含 instructions""" - messages = simple_template.render(variables={"brand_name": "TestBrand", "topic": "AI"}) - - roles = [m["role"] for m in messages] - assert "system" in roles - assert "user" in roles - - system_msg = next(m for m in messages if m["role"] == "system") - user_msg = next(m for m in messages if m["role"] == "user") - - # identity 和 context 在 system - assert "你是一个AI助手" in system_msg["content"] - # instructions 在 user - assert "生成内容" in user_msg["content"] - - def test_variable_injection_simple(self, simple_template): - """${var} 被替换""" - messages = simple_template.render(variables={"brand_name": "MyBrand", "topic": "AI"}) - - all_content = " ".join(m["content"] for m in messages) - assert "MyBrand" in all_content - assert "AI" in all_content - # 原始占位符不应存在 - assert "${brand_name}" not in all_content - assert "${topic}" not in all_content - - def test_variable_injection_nested(self): - """${a.b.c} 嵌套路径解析""" - from app.agent_framework.prompts.base_template import PromptSection, PromptTemplate - - sections = PromptSection( - instructions="平台:${platform.name},目标:${goal.type}", - ) - template = PromptTemplate(sections=sections) - - messages = template.render(variables={ - "platform": {"name": "微信公众号"}, - "goal": {"type": "品牌曝光"}, - }) - - user_msg = next(m for m in messages if m["role"] == "user") - assert "微信公众号" in user_msg["content"] - assert "品牌曝光" in user_msg["content"] - - def test_unresolved_variable_kept(self): - """未注入的变量保持 ${var} 原样""" - from app.agent_framework.prompts.base_template import PromptSection, PromptTemplate - - sections = PromptSection( - instructions="主题是 ${undefined_var}", - ) - template = PromptTemplate(sections=sections) - messages = template.render(variables={}) - - user_msg = next(m for m in messages if m["role"] == "user") - assert "${undefined_var}" in user_msg["content"] - - def test_truncation_within_budget(self): - """短文本不被截断""" - from app.agent_framework.prompts.base_template import PromptSection, PromptTemplate - - short_context = "简短的上下文内容" - sections = PromptSection( - identity="身份", - context=short_context, - ) - template = PromptTemplate(sections=sections) - messages = template.render(context_budget=3000) - - system_msg = next(m for m in messages if m["role"] == "system") - assert short_context in system_msg["content"] - assert "中间内容已省略" not in system_msg["content"] - - def test_truncation_exceeds_budget(self): - """超长文本被智能截断""" - from app.agent_framework.prompts.base_template import PromptSection, PromptTemplate - - # 生成超过 budget 的文本(budget=10 token) - long_context = "这是一段非常长的上下文内容,包含大量文字。" * 200 - - sections = PromptSection( - identity="我是AI助手", - context=long_context, - ) - template = PromptTemplate(sections=sections) - # 设置很小的 budget 强制触发截断 - messages = template.render(context_budget=50) - - system_msg = next(m for m in messages if m["role"] == "system") - # 截断标记应该出现 - assert "中间内容已省略" in system_msg["content"] - - def test_render_with_no_variables(self, simple_template): - """不传变量也能正常 render(未解析变量保持原样)""" - messages = simple_template.render() - assert isinstance(messages, list) - assert len(messages) > 0 - - def test_render_empty_sections(self): - """空 sections 不返回空角色 message""" - from app.agent_framework.prompts.base_template import PromptSection, PromptTemplate - - sections = PromptSection(instructions="做些事情") - template = PromptTemplate(sections=sections) - messages = template.render() - - # system 内容为空时不应出现 system message - roles = [m["role"] for m in messages] - assert "user" in roles - # 只有 user message(identity/context/constraints 都为空) - system_msgs = [m for m in messages if m["role"] == "system"] - for sm in system_msgs: - assert sm["content"].strip() != "" - - -# --------------------------------------------------------------------------- -# 5个 Template 全部能正常 render -# --------------------------------------------------------------------------- - -class TestAllTemplatesRender: - def test_topic_selector_template_renders(self): - """TOPIC_SELECTOR_TEMPLATE 能正常 render""" - from app.agent_framework.prompts import TOPIC_SELECTOR_TEMPLATE - - messages = TOPIC_SELECTOR_TEMPLATE.render(variables={ - "target_keyword": "AI营销", - "brand_name": "示例品牌", - "target_platform": "微信公众号", - }) - assert isinstance(messages, list) - assert len(messages) > 0 - for m in messages: - assert "role" in m - assert "content" in m - - def test_content_generator_template_renders(self): - """CONTENT_GENERATOR_TEMPLATE 能正常 render""" - from app.agent_framework.prompts import CONTENT_GENERATOR_TEMPLATE - - messages = CONTENT_GENERATOR_TEMPLATE.render(variables={ - "topic": "AI发展趋势", - "platform": "知乎", - }) - assert isinstance(messages, list) - assert len(messages) > 0 - - def test_deai_template_renders(self): - """DEAI_TEMPLATE 能正常 render""" - from app.agent_framework.prompts import DEAI_TEMPLATE - - messages = DEAI_TEMPLATE.render(variables={ - "content": "测试AI生成内容", - }) - assert isinstance(messages, list) - assert len(messages) > 0 - - def test_geo_optimizer_template_renders(self): - """GEO_OPTIMIZER_TEMPLATE 能正常 render""" - from app.agent_framework.prompts import GEO_OPTIMIZER_TEMPLATE - - messages = GEO_OPTIMIZER_TEMPLATE.render(variables={ - "content": "待优化的内容", - "platform": "小红书", - }) - assert isinstance(messages, list) - assert len(messages) > 0 - - def test_rule_checker_template_renders(self): - """RULE_CHECKER_TEMPLATE 能正常 render""" - from app.agent_framework.prompts import RULE_CHECKER_TEMPLATE - - messages = RULE_CHECKER_TEMPLATE.render(variables={ - "content": "待检查的内容", - "platform": "微信公众号", - }) - assert isinstance(messages, list) - assert len(messages) > 0 - - -# --------------------------------------------------------------------------- -# PromptTemplate._inject 直接测试 -# --------------------------------------------------------------------------- - -class TestPromptTemplateInject: - @pytest.fixture - def template(self): - from app.agent_framework.prompts.base_template import PromptSection, PromptTemplate - return PromptTemplate(sections=PromptSection()) - - def test_inject_simple_var(self, template): - """简单变量注入""" - result = template._inject("Hello ${name}", {"name": "World"}) - assert result == "Hello World" - - def test_inject_multiple_vars(self, template): - """多变量注入""" - result = template._inject("${a} and ${b}", {"a": "foo", "b": "bar"}) - assert result == "foo and bar" - - def test_inject_nested_path(self, template): - """嵌套路径注入""" - result = template._inject("${x.y}", {"x": {"y": "deep_value"}}) - assert result == "deep_value" - - def test_inject_missing_var_kept(self, template): - """缺失变量保持原样""" - result = template._inject("${missing}", {}) - assert result == "${missing}" - - def test_inject_empty_text(self, template): - """空文本原样返回""" - result = template._inject("", {"x": "val"}) - assert result == "" - - def test_estimate_tokens_chinese(self, template): - """中文字符 token 估算:每字约 1 token""" - text = "你好世界" # 4 个中文字符 - tokens = template._estimate_tokens(text) - assert tokens == 4 - - def test_estimate_tokens_empty(self, template): - """空文本 token 为 0""" - assert template._estimate_tokens("") == 0 diff --git a/tests/test_queries.py b/tests/test_queries.py deleted file mode 100644 index 786d3e8..0000000 --- a/tests/test_queries.py +++ /dev/null @@ -1,153 +0,0 @@ -import uuid -from datetime import datetime -from unittest.mock import AsyncMock, patch - -import pytest - -from app.api.deps import get_current_user - - -@pytest.fixture -def mock_query(): - """Return a mock query object.""" - q = AsyncMock() - q.id = uuid.UUID("22345678-1234-1234-1234-123456789abc") - q.user_id = uuid.UUID("12345678-1234-1234-1234-123456789abc") - q.keyword = "test keyword" - q.target_brand = "TestBrand" - q.brand_aliases = [] - q.platforms = ["wenxin", "kimi"] - q.frequency = "weekly" - q.status = "active" - q.last_queried_at = None - q.next_query_at = datetime.now() - q.created_at = datetime.now() - q.updated_at = datetime.now() - return q - - -@pytest.mark.asyncio -async def test_create_query_success( - async_client, override_get_current_user, auth_headers, mock_query -): - with patch("app.api.queries.create_query", return_value=mock_query): - response = await async_client.post( - "/api/v1/queries/", - headers=auth_headers, - json={ - "keyword": "test keyword", - "target_brand": "TestBrand", - "platforms": ["wenxin", "kimi"], - "frequency": "weekly", - }, - ) - assert response.status_code == 201 - data = response.json() - assert data["keyword"] == "test keyword" - assert data["target_brand"] == "TestBrand" - - -@pytest.mark.asyncio -async def test_create_query_exceeds_limit( - async_client, override_get_current_user, auth_headers -): - with patch( - "app.api.queries.create_query", - side_effect=PermissionError("Query limit exceeded"), - ): - response = await async_client.post( - "/api/v1/queries/", - headers=auth_headers, - json={ - "keyword": "test keyword", - "target_brand": "TestBrand", - "platforms": ["wenxin"], - "frequency": "daily", - }, - ) - assert response.status_code == 403 - data = response.json() - assert "Query limit exceeded" in data["detail"] - - -@pytest.mark.asyncio -async def test_list_queries( - async_client, override_get_current_user, auth_headers, mock_query -): - with patch("app.api.queries.get_queries", return_value=([mock_query], 1)): - response = await async_client.get("/api/v1/queries/", headers=auth_headers) - assert response.status_code == 200 - data = response.json() - assert data["total"] == 1 - assert len(data["items"]) == 1 - assert data["items"][0]["keyword"] == "test keyword" - - -@pytest.mark.asyncio -async def test_update_query( - async_client, override_get_current_user, auth_headers, mock_query -): - updated_query = AsyncMock() - updated_query.id = mock_query.id - updated_query.keyword = "updated keyword" - updated_query.target_brand = "TestBrand" - updated_query.brand_aliases = [] - updated_query.platforms = ["wenxin"] - updated_query.frequency = "daily" - updated_query.status = "active" - updated_query.last_queried_at = None - updated_query.next_query_at = datetime.now() - updated_query.created_at = datetime.now() - updated_query.updated_at = datetime.now() - - with patch("app.api.queries.update_query", return_value=updated_query): - response = await async_client.put( - f"/api/v1/queries/{mock_query.id}", - headers=auth_headers, - json={"keyword": "updated keyword", "frequency": "daily"}, - ) - assert response.status_code == 200 - data = response.json() - assert data["keyword"] == "updated keyword" - assert data["frequency"] == "daily" - - -@pytest.mark.asyncio -async def test_delete_query( - async_client, override_get_current_user, auth_headers, mock_query -): - with patch("app.api.queries.delete_query", return_value=True): - response = await async_client.delete( - f"/api/v1/queries/{mock_query.id}", - headers=auth_headers, - ) - assert response.status_code == 204 - - -@pytest.mark.asyncio -async def test_query_not_found( - async_client, override_get_current_user, auth_headers -): - non_existent_id = uuid.UUID("33333333-3333-3333-3333-333333333333") - with patch("app.api.queries.get_query", return_value=None): - response = await async_client.get( - f"/api/v1/queries/{non_existent_id}", - headers=auth_headers, - ) - assert response.status_code == 404 - data = response.json() - assert "Query not found" in data["detail"] - - -@pytest.mark.asyncio -async def test_query_belongs_to_other_user( - async_client, override_get_current_user, auth_headers -): - other_user_query_id = uuid.UUID("44444444-4444-4444-4444-444444444444") - # Simulate that the query does not belong to the current user by returning None - with patch("app.api.queries.get_query", return_value=None): - response = await async_client.get( - f"/api/v1/queries/{other_user_query_id}", - headers=auth_headers, - ) - assert response.status_code == 404 diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py deleted file mode 100644 index 115365a..0000000 --- a/tests/test_scheduler.py +++ /dev/null @@ -1,122 +0,0 @@ -import uuid -from datetime import datetime, timezone, timedelta -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest - -from app.models.query import Query -from app.models.user import User -from app.services.auth import hash_password -from app.workers.citation_engine import CitationEngine -from app.workers.scheduler import QueryScheduler - - -# --------------------------------------------------------------------------- -# 1. 调度器启动 / 关闭 -# --------------------------------------------------------------------------- -@pytest.mark.asyncio -async def test_scheduler_start_stop(): - scheduler = QueryScheduler() - scheduler.engine = AsyncMock() - - scheduler.start() - # Verify the scheduled job was added - job = scheduler.scheduler.get_job("check_queries") - assert job is not None - assert job.name == "检查并执行到期的查询任务" - - await scheduler.shutdown() - # Verify engine.close was awaited - scheduler.engine.close.assert_awaited_once() - - -# --------------------------------------------------------------------------- -# 2. 查询任务筛选:只选择 status=active 且 next_query_at <= now 的查询 -# --------------------------------------------------------------------------- -@pytest.mark.asyncio -async def test_scheduler_query_filtering(test_session): - # Create a user first - user = User( - email="sched@test.com", - password_hash=hash_password("pass"), - name="Scheduler", - ) - test_session.add(user) - await test_session.commit() - await test_session.refresh(user) - - now = datetime.now(timezone.utc) - - # q1: active and overdue -> should be picked - q1 = Query( - user_id=user.id, - keyword="overdue", - target_brand="B1", - status="active", - next_query_at=now - timedelta(hours=1), - ) - # q2: active but in the future -> should NOT be picked - q2 = Query( - user_id=user.id, - keyword="future", - target_brand="B2", - status="active", - next_query_at=now + timedelta(days=1), - ) - # q3: paused and overdue -> should NOT be picked - q3 = Query( - user_id=user.id, - keyword="paused", - target_brand="B3", - status="paused", - next_query_at=now - timedelta(hours=1), - ) - - test_session.add_all([q1, q2, q3]) - await test_session.commit() - - scheduler = QueryScheduler() - scheduler.engine = AsyncMock() - - # Mock AsyncSessionLocal so scheduler uses our test session - mock_local = MagicMock() - mock_local.return_value.__aenter__ = AsyncMock(return_value=test_session) - mock_local.return_value.__aexit__ = AsyncMock(return_value=False) - - with patch("app.workers.scheduler.AsyncSessionLocal", mock_local): - await scheduler.check_and_execute_queries() - - # execute_query should be called exactly once (for q1) - scheduler.engine.execute_query.assert_called_once() - called_query = scheduler.engine.execute_query.call_args[0][0] - assert called_query.keyword == "overdue" - - -# --------------------------------------------------------------------------- -# 3. 频率计算:daily 和 weekly 的 next_query_at 正确计算 -# --------------------------------------------------------------------------- -@pytest.mark.asyncio -async def test_scheduler_frequency_calculation_daily(): - engine = CitationEngine() - now = datetime.utcnow() - result = engine._calculate_next_query_at("daily") - expected = now + timedelta(days=1) - assert abs((result - expected).total_seconds()) < 5 - - -@pytest.mark.asyncio -async def test_scheduler_frequency_calculation_weekly(): - engine = CitationEngine() - now = datetime.utcnow() - result = engine._calculate_next_query_at("weekly") - expected = now + timedelta(days=7) - assert abs((result - expected).total_seconds()) < 5 - - -@pytest.mark.asyncio -async def test_scheduler_frequency_calculation_default(): - engine = CitationEngine() - now = datetime.utcnow() - result = engine._calculate_next_query_at(None) - expected = now + timedelta(days=7) - assert abs((result - expected).total_seconds()) < 5