feat(geo): U1-U9 monetization closed loop implementation

U1: GEO diagnosis auto data collection (DataCollectorService + 11 tests)
U2: Free GEO health score public page (HealthScoreAPI + 7 tests)
U3: Onboarding redesign with conversion layer (Step0 + UpgradePrompt + 14 tests)
U4: Real payment integration (WeChat/Alipay gateway + PaymentOrder + 12 tests)
U5: AI content generation & distribution (Publishers + PublishEngine + 11 tests)
U6: Attribution system & ROI reports (AttributionEngine + ROICalculator + 12 tests)
U7: Email integration & Dashboard monetization UI (EmailScheduler + templates + 22 tests)
U8: Integration tests & shared fixtures (monetization flow + fixture consolidation + 12 tests)
U9: E2E smoke tests (HealthScorePage + acquisition/core flow + 4 E2E cases)

Total: 101 new backend tests passing, 4 E2E test cases
All third-party integrations use mock mode by default with config center management
This commit is contained in:
chiguyong 2026-06-01 07:39:05 +08:00
parent 900a90ba84
commit b41da42d74
293 changed files with 25006 additions and 4779 deletions

View File

@ -0,0 +1,38 @@
---
description: CodeGraph MCP usage guide — when to use which tool
alwaysApply: true
---
<!-- CODEGRAPH_START -->
## CodeGraph
This project has a CodeGraph MCP server (`codegraph_*` tools) configured. CodeGraph is a tree-sitter-parsed knowledge graph of every symbol, edge, and file. Reads are sub-millisecond and return structural information grep cannot.
### When to prefer codegraph over native search
Use codegraph for **structural** questions — what calls what, what would break, where is X defined, what is X's signature. Use native grep/read only for **literal text** queries (string contents, comments, log messages) or after you already have a specific file open.
| Question | Tool |
|---|---|
| "Where is X defined?" / "Find symbol named X" | `codegraph_search` |
| "What calls function Y?" | `codegraph_callers` |
| "What does Y call?" | `codegraph_callees` |
| "What would break if I changed Z?" | `codegraph_impact` |
| "Show me Y's signature / source / docstring" | `codegraph_node` |
| "Give me focused context for a task/area" | `codegraph_context` |
| "See several related symbols' source at once" | `codegraph_explore` |
| "What files exist under path/" | `codegraph_files` |
| "Is the index healthy?" | `codegraph_status` |
### Rules of thumb
- **Answer directly — don't delegate exploration.** For "how does X work" / architecture / trace questions, answer with 2-3 codegraph calls: `codegraph_context` first, then ONE `codegraph_explore` for the source of the symbols it surfaces. Codegraph IS the pre-built index, so spawning a separate file-reading sub-task/agent — or running a grep + read loop — repeats work codegraph already did and costs more for the same answer.
- **Trust codegraph results.** They come from a full AST parse. Do NOT re-verify them with grep — that's slower, less accurate, and wastes context.
- **Don't grep first** when looking up a symbol by name. `codegraph_search` is faster and returns kind + location + signature in one call.
- **Don't chain `codegraph_search` + `codegraph_node`** when you just want context — `codegraph_context` is one call.
- **Don't loop `codegraph_node` over many symbols** — one `codegraph_explore` call returns several symbols' source grouped in a single capped call, while each separate node/Read call re-reads the whole context and costs far more.
- **Index lag**: the file watcher debounces ~500ms behind writes; don't re-query immediately after editing a file in the same turn.
### If `.codegraph/` doesn't exist
The MCP server returns "not initialized." Ask the user: *"I notice this project doesn't have CodeGraph initialized. Want me to run `codegraph init -i` to build the index?"*
<!-- CODEGRAPH_END -->

44
AGENTS.md Normal file
View File

@ -0,0 +1,44 @@
# GEO Platform Agents
## Agent Registry
| Agent | Name | Type | Supported Tasks | File |
|-------|------|------|----------------|------|
| CitationDetectorAgent | citation_detector | CITATION_DETECTOR | citation_detect, citation_batch | backend/app/agent_framework/agents/citation_detector.py |
| ContentGeneratorAgent | content_generator | CONTENT_GENERATOR | content_generate, content_regenerate | backend/app/agent_framework/agents/content_generator_agent.py |
| DeAIAgent | deai_agent | DEAI_AGENT | deai_process | backend/app/agent_framework/agents/deai_agent.py |
| GEOOptimizerAgent | geo_optimizer | GEO_OPTIMIZER | geo_optimize | backend/app/agent_framework/agents/geo_optimizer.py |
| MonitorAgent | monitor | PERFORMANCE_TRACKER | monitor_track, monitor_check_single | backend/app/agent_framework/agents/monitor_agent.py |
| SchemaAdvisorAgent | schema_advisor | SCHEMA_ADVISOR | schema_advise | backend/app/agent_framework/agents/schema_advisor.py |
| CompetitorAnalyzerAgent | competitor_analyzer | COMPETITOR_ANALYZER | competitor_analyze, competitor_gap_analysis | backend/app/agent_framework/agents/competitor_analyzer.py |
| TrendAgent | trend_agent | TREND_AGENT | trend_insight, trend_hotspot | backend/app/agent_framework/agents/trend_agent.py |
## Running Agents
### Standalone Mode (No Redis Required)
```bash
cd geo/backend
python3 -m app.agent_framework.standalone [agent_name|all]
```
### With Redis Queue
Agents auto-register via Registry when Redis is available. Tasks dispatched via TaskDispatcher.
## Agent Framework Components
- BaseAgent: Abstract base class (backend/app/agent_framework/base_agent.py)
- Dispatcher: Task distribution (backend/app/agent_framework/dispatcher.py)
- Registry: Agent registration (backend/app/agent_framework/registry.py)
- Protocol: Message types and AgentType enum (backend/app/agent_framework/protocol.py)
- ConfigManager: Agent configuration (backend/app/agent_framework/config_manager.py)
## New Agent Data Models
- MonitoringRecord + ContentBaseline (backend/app/models/monitoring.py)
- SchemaSuggestion (backend/app/models/schema_suggestion.py)
- CompetitorInsight (backend/app/models/competitor_insight.py)
- TrendInsight (backend/app/models/trend_insight.py)
## New Agent API Endpoints
- /api/v1/monitoring - MonitorAgent API
- /api/v1/competitor - CompetitorAnalyzer API
- /api/v1/schema - SchemaAdvisor API
- /api/v1/trends - TrendAgent API

58
CLAUDE.md Normal file
View File

@ -0,0 +1,58 @@
# GEO Platform - AI Context
## Project Overview
GEO (Generative Engine Optimization) SaaS platform that helps brands get cited by AI search engines (ChatGPT, Perplexity, DeepSeek, etc.).
## Tech Stack
- Backend: FastAPI + SQLAlchemy 2.0 (async) + PostgreSQL 15 + Redis 7
- Frontend: Next.js 14 + React 18 + TypeScript + Tailwind CSS + shadcn/ui
- AI: Multi-LLM provider (DeepSeek, OpenAI, etc.) + RAG knowledge base
- Infrastructure: Docker Compose + Prometheus metrics
## Architecture
- Backend API: 33 route modules under backend/app/api/
- Agent Framework: 8 agents (CitationDetector, ContentGenerator, DeAIAgent, GEOOptimizer, MonitorAgent, SchemaAdvisor, CompetitorAnalyzer, TrendAgent)
- Data Models: 35 models under backend/app/models/
- Services: 80+ service files organized by domain under backend/app/services/
- Repositories: 12 repository files under backend/app/repositories/ for data access
- Frontend: 25+ pages under frontend/app/(dashboard)/
- Frontend API Clients: 27 modules under frontend/lib/api/
## Key Patterns
- Repository pattern for data access (backend/app/repositories/)
- Agent Framework: BaseAgent abstract class → concrete agents, Dispatcher + Registry + Redis Queue
- Shared BrandScoringDataService for V2 5-dimension scoring
- Content Pipeline: Generate → DeAI → GEO Optimize → HTML Output
- GEO Workflow: Diagnosis → Strategy → Plan → Content Generation → Monitoring
## API Conventions
- All routes under /api/v1/ prefix
- Authentication via JWT Bearer token (get_current_user dependency)
- Pydantic schemas for request/response validation
- Async SQLAlchemy ORM with AsyncSession
## Frontend Conventions
- API clients in frontend/lib/api/ using fetchWithAuth from client.ts
- Barrel export in frontend/lib/api/index.ts
- Types co-located with API modules
- Dashboard pages under frontend/app/(dashboard)/dashboard/
## Environment
- Backend runs on port 8000
- Frontend runs on port 3001 (dev)
- PostgreSQL on port 5432
- Redis on port 6379
- .env.example has all required variables
## Common Commands
- Start backend: cd geo/backend && python3 -m uvicorn app.main:app --reload --port 8000
- Start frontend: cd geo/frontend && npm run dev
- Start all agents: cd geo/backend && python3 -m app.agent_framework.standalone all
- Docker: docker-compose up -d
## Important Notes
- SEOOptimizer in code actually performs GEO optimization, not traditional SEO
- content.py (content generation) vs contents.py (content management) - different modules
- Analytics endpoints all require authentication via _get_org_id or get_current_user
- CORS: dev mode uses allow_origins=["*"], production must restrict
- Brand scoring uses V2 5-dimension system: mention_rate, recommendation_rank, sentiment_score, citation_quality, competitive_position

View File

@ -4,10 +4,12 @@ from .citation_detector import CitationDetectorAgent
from .content_generator_agent import ContentGeneratorAgent
from .deai_agent import DeAIAgent
from .geo_optimizer_agent import GEOOptimizerAgent
from .trend_agent import TrendAgent
__all__ = [
"CitationDetectorAgent",
"ContentGeneratorAgent",
"DeAIAgent",
"GEOOptimizerAgent",
"TrendAgent",
]

View File

@ -1,9 +1,9 @@
"""CitationDetector Agent - 将现有 CitationEngine 封装为 Agent"""
import logging
import re
import time
from datetime import datetime, timezone
import uuid
from datetime import datetime, timedelta, timezone
from sqlalchemy import select
from app.agent_framework.base import BaseAgent
from app.agent_framework.protocol import (
@ -16,19 +16,13 @@ from app.agent_framework.protocol import (
from app.database import AsyncSessionLocal
from app.models.citation_record import CitationRecord
from app.models.query import Query
from app.workers.citation_engine import CitationEngine
from app.models.query_task import QueryTask
from app.services.ai_engine.platform_bridge import execute_single_platform
logger = logging.getLogger(__name__)
class CitationDetectorAgent(BaseAgent):
"""
引用检测 Agent将现有 CitationEngine 封装为 BaseAgent 实现
支持的任务类型:
- citation_detect: 执行完整的引用检测遍历 query 的所有平台
- citation_detect_single: 执行单个平台的引用检测
"""
def __init__(self):
super().__init__(
@ -36,7 +30,6 @@ class CitationDetectorAgent(BaseAgent):
agent_type=AgentType.CITATION_DETECTOR,
version="1.0.0",
)
self._engine = CitationEngine()
def get_capabilities(self) -> AgentCapability:
return AgentCapability(
@ -49,7 +42,6 @@ class CitationDetectorAgent(BaseAgent):
)
async def execute(self, task: TaskMessage) -> TaskResult:
"""执行引用检测任务"""
started_at = datetime.now(timezone.utc)
start_time = time.monotonic()
@ -94,17 +86,11 @@ class CitationDetectorAgent(BaseAgent):
)
async def _execute_full_detect(self, task: TaskMessage) -> dict:
"""
执行完整的引用检测遍历 query 的所有平台
input_data 需包含: query_id (str)
"""
query_id = task.input_data.get("query_id")
if not query_id:
raise ValueError("input_data must contain 'query_id'")
async with AsyncSessionLocal() as db:
from sqlalchemy import select
stmt = select(Query).where(Query.id == query_id)
result = await db.execute(stmt)
query = result.scalar_one_or_none()
@ -112,23 +98,75 @@ class CitationDetectorAgent(BaseAgent):
if not query:
raise ValueError(f"Query {query_id} not found")
# 上报进度:开始检测
await self.report_progress(
task_id=task.task_id,
progress=0.1,
message=f"Starting citation detection for query '{query.keyword}'",
)
records = await self._engine.execute_query(query, db)
records: list[CitationRecord] = []
platforms = query.platforms or ["wenxin", "kimi"]
brand_aliases = query.brand_aliases or []
for i, platform_name in enumerate(platforms):
progress = 0.1 + 0.8 * (i / len(platforms))
await self.report_progress(
task_id=task.task_id,
progress=progress,
message=f"Detecting on platform '{platform_name}' ({i+1}/{len(platforms)})",
)
task_obj = await self._get_or_create_task(db, query.id, platform_name)
task_obj.status = "running"
task_obj.started_at = datetime.utcnow()
task_obj.error_message = None
await db.commit()
try:
detect_result = await execute_single_platform(
keyword=query.keyword,
platform=platform_name,
target_brand=query.target_brand,
brand_aliases=brand_aliases,
)
record = CitationRecord.from_citation_result(
query_id=query.id,
platform=platform_name,
result=detect_result,
)
db.add(record)
records.append(record)
task_obj.status = "success"
task_obj.completed_at = datetime.utcnow()
await db.commit()
except Exception as e:
logger.error(f"平台 {platform_name} 查询失败: {e}")
task_obj.status = "failed"
task_obj.error_message = str(e)
task_obj.completed_at = datetime.utcnow()
record = CitationRecord.from_citation_result(
query_id=query.id,
platform=platform_name,
result={"cited": False, "raw_response": str(e)},
)
db.add(record)
records.append(record)
await db.commit()
query.last_queried_at = datetime.utcnow()
query.next_query_at = self._calculate_next_query_at(query.frequency)
await db.commit()
# 上报进度:检测完成
await self.report_progress(
task_id=task.task_id,
progress=1.0,
message=f"Detection completed: {len(records)} records found",
)
# 构建输出
record_summaries = []
for r in records:
record_summaries.append({
@ -148,10 +186,6 @@ class CitationDetectorAgent(BaseAgent):
}
async def _execute_single_detect(self, task: TaskMessage) -> dict:
"""
执行单个平台的引用检测
input_data 需包含: keyword, platform, target_brand, brand_aliases(optional)
"""
keyword = task.input_data.get("keyword")
platform = task.input_data.get("platform")
target_brand = task.input_data.get("target_brand")
@ -162,56 +196,118 @@ class CitationDetectorAgent(BaseAgent):
"input_data must contain 'keyword', 'platform', 'target_brand'"
)
# 上报进度
await self.report_progress(
task_id=task.task_id,
progress=0.2,
message=f"Querying platform '{platform}' with keyword '{keyword}'",
)
result = await self._engine.execute_single_platform(
result = await execute_single_platform(
keyword=keyword,
platform=platform,
target_brand=target_brand,
brand_aliases=brand_aliases,
)
# 上报进度:完成
await self.report_progress(
task_id=task.task_id,
progress=1.0,
message=f"Single platform detection completed on '{platform}'",
)
# 清理 raw_response 以避免返回过大的数据
output = {k: v for k, v in result.items() if k != "raw_response"}
return output
# -----------------------------------------------------------------------
# 向后兼容:保留原有 CitationEngine 的同步调用接口
# -----------------------------------------------------------------------
async def execute_query_compat(self, query: Query, db) -> list[CitationRecord]:
"""
向后兼容方法供现有 scheduler 继续调用
签名与 CitationEngine.execute_query 完全一致
"""
return await self._engine.execute_query(query, db)
records: list[CitationRecord] = []
platforms = query.platforms or ["wenxin", "kimi"]
brand_aliases = query.brand_aliases or []
for platform_name in platforms:
task_obj = await self._get_or_create_task(db, query.id, platform_name)
task_obj.status = "running"
task_obj.started_at = datetime.utcnow()
task_obj.error_message = None
await db.commit()
try:
detect_result = await execute_single_platform(
keyword=query.keyword,
platform=platform_name,
target_brand=query.target_brand,
brand_aliases=brand_aliases,
)
record = CitationRecord.from_citation_result(
query_id=query.id,
platform=platform_name,
result=detect_result,
)
db.add(record)
records.append(record)
task_obj.status = "success"
task_obj.completed_at = datetime.utcnow()
await db.commit()
except Exception as e:
logger.error(f"平台 {platform_name} 查询失败: {e}")
task_obj.status = "failed"
task_obj.error_message = str(e)
task_obj.completed_at = datetime.utcnow()
record = CitationRecord.from_citation_result(
query_id=query.id,
platform=platform_name,
result={"cited": False, "raw_response": str(e)},
)
db.add(record)
records.append(record)
await db.commit()
query.last_queried_at = datetime.utcnow()
query.next_query_at = self._calculate_next_query_at(query.frequency)
await db.commit()
return records
async def execute_single_platform_compat(
self, keyword: str, platform: str, target_brand: str, brand_aliases: list
) -> dict:
"""
向后兼容方法供现有 scheduler 继续调用
签名与 CitationEngine.execute_single_platform 完全一致
"""
return await self._engine.execute_single_platform(
return await execute_single_platform(
keyword=keyword,
platform=platform,
target_brand=target_brand,
brand_aliases=brand_aliases,
)
async def close(self):
"""关闭底层引擎"""
await self._engine.close()
async def _get_or_create_task(self, db, query_id: uuid.UUID, platform: str) -> QueryTask:
stmt = select(QueryTask).where(
QueryTask.query_id == query_id,
QueryTask.platform == platform,
)
result = await db.execute(stmt)
task_obj = result.scalar_one_or_none()
if not task_obj:
task_obj = QueryTask(
query_id=query_id,
platform=platform,
status="pending",
)
db.add(task_obj)
await db.commit()
await db.refresh(task_obj)
return task_obj
@staticmethod
def _calculate_next_query_at(frequency: str | None) -> datetime:
now = datetime.utcnow()
freq_map = {
"daily": timedelta(days=1),
"weekly": timedelta(days=7),
"monthly": timedelta(days=30),
}
delta = freq_map.get(frequency or "weekly", timedelta(days=7))
return now + delta

View File

@ -0,0 +1,163 @@
import logging
import time
from datetime import datetime, timezone
from app.agent_framework.base import BaseAgent
from app.agent_framework.protocol import (
AgentCapability,
AgentType,
TaskMessage,
TaskResult,
TaskStatus,
)
from app.services.llm import LLMFactory, LLMError
logger = logging.getLogger(__name__)
class CompetitorAnalyzerAgent(BaseAgent):
def __init__(self):
super().__init__(
name="competitor_analyzer",
agent_type=AgentType.COMPETITOR_ANALYZER,
version="1.0.0",
)
def get_capabilities(self) -> AgentCapability:
return AgentCapability(
agent_name=self.name,
agent_type=self.agent_type,
version=self.version,
supported_tasks=["competitor_analyze", "competitor_gap_analysis"],
max_concurrency=2,
description="竞品策略分析Agent对比品牌与竞品的引用数据识别差距领域发现机会点生成策略建议",
)
async def execute(self, task: TaskMessage) -> TaskResult:
started_at = datetime.now(timezone.utc)
start_time = time.monotonic()
try:
if task.task_type == "competitor_gap_analysis":
output = await self._gap_analysis(task)
else:
output = await self._analyze(task)
elapsed = time.monotonic() - start_time
return TaskResult(
task_id=task.task_id,
agent_name=self.name,
status=TaskStatus.COMPLETED,
output_data=output,
error_message=None,
started_at=started_at,
completed_at=datetime.now(timezone.utc),
metrics={
"elapsed_seconds": round(elapsed, 2),
"task_type": task.task_type,
},
)
except LLMError as e:
elapsed = time.monotonic() - start_time
logger.error(f"CompetitorAnalyzer LLM error on task {task.task_id}: {e}")
return TaskResult(
task_id=task.task_id,
agent_name=self.name,
status=TaskStatus.FAILED,
output_data=None,
error_message=f"LLM调用失败: {e}",
started_at=started_at,
completed_at=datetime.now(timezone.utc),
metrics={
"elapsed_seconds": round(elapsed, 2),
"task_type": task.task_type,
},
)
except Exception as e:
elapsed = time.monotonic() - start_time
logger.error(f"CompetitorAnalyzer task {task.task_id} failed: {e}")
return TaskResult(
task_id=task.task_id,
agent_name=self.name,
status=TaskStatus.FAILED,
output_data=None,
error_message=str(e),
started_at=started_at,
completed_at=datetime.now(timezone.utc),
metrics={
"elapsed_seconds": round(elapsed, 2),
"task_type": task.task_type,
},
)
async def _analyze(self, task: TaskMessage) -> dict:
from app.services.competitor.competitor_analyzer_service import CompetitorAnalyzerService
input_data = task.input_data
brand_id = input_data.get("brand_id")
analysis_types = input_data.get("analysis_types")
period_days = input_data.get("period_days", 30)
if not brand_id:
raise ValueError("input_data必须包含'brand_id'字段")
await self.report_progress(
task_id=task.task_id,
progress=0.1,
message="开始竞品策略分析...",
)
service = CompetitorAnalyzerService()
result = await service.analyze_competitor(
brand_id=brand_id,
analysis_types=analysis_types,
period_days=period_days,
progress_callback=lambda p, m: self.report_progress(
task_id=task.task_id, progress=p, message=m,
),
)
await self.report_progress(
task_id=task.task_id,
progress=1.0,
message="竞品策略分析完成",
)
return result
async def _gap_analysis(self, task: TaskMessage) -> dict:
from app.services.competitor.competitor_analyzer_service import CompetitorAnalyzerService
input_data = task.input_data
brand_id = input_data.get("brand_id")
period_days = input_data.get("period_days", 30)
if not brand_id:
raise ValueError("input_data必须包含'brand_id'字段")
await self.report_progress(
task_id=task.task_id,
progress=0.1,
message="开始竞品差距分析...",
)
service = CompetitorAnalyzerService()
result = await service.analyze_competitor(
brand_id=brand_id,
analysis_types=["citation_gap", "platform_coverage", "query_overlap"],
period_days=period_days,
progress_callback=lambda p, m: self.report_progress(
task_id=task.task_id, progress=p, message=m,
),
)
await self.report_progress(
task_id=task.task_id,
progress=1.0,
message="竞品差距分析完成",
)
return result

View File

@ -2,7 +2,6 @@
import json
import logging
import re
import time
from datetime import datetime, timezone
@ -16,6 +15,7 @@ from app.agent_framework.protocol import (
TaskStatus,
)
from app.services.llm import LLMFactory, LLMError
from app.utils.json_extractor import extract_json
logger = logging.getLogger(__name__)
@ -165,7 +165,7 @@ class ContentGeneratorAgent(BaseAgent):
# 4. 解析JSON输出
try:
topics = json.loads(self._extract_json(response.content))
topics = json.loads(extract_json(response.content))
except json.JSONDecodeError:
topics = [{"title": response.content, "reason": "LLM输出解析失败返回原始内容"}]
@ -277,22 +277,3 @@ class ContentGeneratorAgent(BaseAgent):
logger.warning(f"RAG检索失败跳过知识库上下文: {e}")
return "暂无相关知识库内容"
def _extract_json(self, text: str) -> str:
"""从LLM响应中提取JSON可能被markdown包裹"""
# 尝试提取```json ... ```中的内容
match = re.search(r'```(?:json)?\s*\n?(.*?)\n?```', text, re.DOTALL)
if match:
return match.group(1).strip()
# 尝试直接找 [ 或 { 开头的JSON
for i, c in enumerate(text):
if c in '[{':
depth = 0
for j in range(i, len(text)):
if text[j] in '[{':
depth += 1
elif text[j] in ']}':
depth -= 1
if depth == 0:
return text[i : j + 1]
return text

View File

@ -2,7 +2,6 @@
import json
import logging
import re
import time
from datetime import datetime, timezone
@ -16,6 +15,7 @@ from app.agent_framework.protocol import (
TaskStatus,
)
from app.services.llm import LLMFactory, LLMError
from app.utils.json_extractor import extract_json
logger = logging.getLogger(__name__)
@ -155,7 +155,7 @@ class GEOOptimizerAgent(BaseAgent):
# 尝试解析JSON输出
try:
result = json.loads(self._extract_json(response.content))
result = json.loads(extract_json(response.content))
result["usage"] = response.usage
# 上报进度:完成
await self.report_progress(
@ -178,20 +178,3 @@ class GEOOptimizerAgent(BaseAgent):
"changes": ["LLM输出非标准格式已返回原始优化结果"],
"usage": response.usage,
}
def _extract_json(self, text: str) -> str:
"""从响应中提取JSON"""
match = re.search(r'```(?:json)?\s*\n?(.*?)\n?```', text, re.DOTALL)
if match:
return match.group(1).strip()
for i, c in enumerate(text):
if c == '{':
depth = 0
for j in range(i, len(text)):
if text[j] == '{':
depth += 1
elif text[j] == '}':
depth -= 1
if depth == 0:
return text[i : j + 1]
return text

View File

@ -0,0 +1,231 @@
import logging
import time
import uuid
from datetime import datetime, timezone
from sqlalchemy import select
from app.agent_framework.base import BaseAgent
from app.agent_framework.protocol import (
AgentCapability,
AgentType,
TaskMessage,
TaskResult,
TaskStatus,
)
from app.database import AsyncSessionLocal
from app.models.monitoring import MonitoringRecord
from app.models.query import Query
from app.models.brand import Brand
from app.services.monitoring.monitor_service import MonitorService
logger = logging.getLogger(__name__)
class MonitorAgent(BaseAgent):
def __init__(self):
super().__init__(
name="monitor",
agent_type=AgentType.PERFORMANCE_TRACKER,
version="1.0.0",
)
def get_capabilities(self) -> AgentCapability:
return AgentCapability(
agent_name=self.name,
agent_type=self.agent_type,
version=self.version,
supported_tasks=["monitor_track", "monitor_check_single"],
max_concurrency=3,
description="效果追踪Agent监测品牌引用量、情感、排名变化生成变化报告",
)
async def execute(self, task: TaskMessage) -> TaskResult:
started_at = datetime.now(timezone.utc)
start_time = time.monotonic()
try:
task_type = task.task_type
if task_type == "monitor_track":
output = await self._monitor_track(task)
elif task_type == "monitor_check_single":
output = await self._monitor_check_single(task)
else:
raise ValueError(f"不支持的任务类型: {task_type}")
elapsed = time.monotonic() - start_time
return TaskResult(
task_id=task.task_id,
agent_name=self.name,
status=TaskStatus.COMPLETED,
output_data=output,
error_message=None,
started_at=started_at,
completed_at=datetime.now(timezone.utc),
metrics={
"elapsed_seconds": round(elapsed, 2),
"task_type": task_type,
},
)
except Exception as e:
elapsed = time.monotonic() - start_time
logger.error(f"MonitorAgent task {task.task_id} failed: {e}")
return TaskResult(
task_id=task.task_id,
agent_name=self.name,
status=TaskStatus.FAILED,
output_data=None,
error_message=str(e),
started_at=started_at,
completed_at=datetime.now(timezone.utc),
metrics={
"elapsed_seconds": round(elapsed, 2),
"task_type": task.task_type,
},
)
async def _monitor_track(self, task: TaskMessage) -> dict:
input_data = task.input_data
brand_id = input_data.get("brand_id")
if not brand_id:
raise ValueError("input_data必须包含'brand_id'字段")
brand_id = uuid.UUID(brand_id)
await self.report_progress(
task_id=task.task_id,
progress=0.1,
message="开始效果追踪...",
)
async with AsyncSessionLocal() as db:
stmt = select(Brand).where(Brand.id == brand_id)
result = await db.execute(stmt)
brand = result.scalar_one_or_none()
if not brand:
raise ValueError(f"品牌不存在: {brand_id}")
queries_stmt = select(Query).where(Query.target_brand == brand.name)
queries_result = await db.execute(queries_stmt)
queries = list(queries_result.scalars().all())
if not queries:
await self.report_progress(
task_id=task.task_id,
progress=1.0,
message="品牌下无关联查询,效果追踪完成",
)
return {
"brand_id": str(brand_id),
"brand_name": brand.name,
"total_queries": 0,
"reports": [],
}
total_queries = len(queries)
reports = []
service = MonitorService()
monitoring_stmt = select(MonitoringRecord).where(
MonitoringRecord.brand_id == brand_id,
MonitoringRecord.status == "active",
)
monitoring_result = await db.execute(monitoring_stmt)
monitoring_records = list(monitoring_result.scalars().all())
for idx, record in enumerate(monitoring_records):
progress = 0.2 + (0.7 * idx / max(len(monitoring_records), 1))
await self.report_progress(
task_id=task.task_id,
progress=progress,
message=f"正在检测第 {idx + 1}/{len(monitoring_records)} 条监测记录...",
)
updated_record = await service.check_and_compare(db, record.id)
if updated_record:
report = await service.generate_change_report(db, updated_record.id)
if report:
reports.append(report)
await self.report_progress(
task_id=task.task_id,
progress=1.0,
message="效果追踪完成",
)
return {
"brand_id": str(brand_id),
"brand_name": brand.name,
"total_queries": total_queries,
"checked_records": len(monitoring_records),
"reports": reports,
}
async def _monitor_check_single(self, task: TaskMessage) -> dict:
input_data = task.input_data
brand_id = input_data.get("brand_id")
keyword = input_data.get("keyword")
platform = input_data.get("platform")
if not brand_id:
raise ValueError("input_data必须包含'brand_id'字段")
brand_id = uuid.UUID(brand_id)
await self.report_progress(
task_id=task.task_id,
progress=0.1,
message="开始单关键词检测...",
)
async with AsyncSessionLocal() as db:
service = MonitorService()
await self.report_progress(
task_id=task.task_id,
progress=0.3,
message="正在创建监测记录...",
)
record = await service.create_monitoring_record(
db=db,
brand_id=brand_id,
query_keywords=keyword,
platform=platform,
check_interval_hours=input_data.get("check_interval_hours", 24),
)
await self.report_progress(
task_id=task.task_id,
progress=0.6,
message="正在执行检测对比...",
)
updated_record = await service.check_and_compare(db, record.id)
await self.report_progress(
task_id=task.task_id,
progress=0.8,
message="正在生成变化报告...",
)
report = None
if updated_record:
report = await service.generate_change_report(db, updated_record.id)
await self.report_progress(
task_id=task.task_id,
progress=1.0,
message="单关键词检测完成",
)
return {
"record_id": str(record.id),
"brand_id": str(brand_id),
"keyword": keyword,
"platform": platform,
"change_type": updated_record.change_type if updated_record else None,
"report": report,
}

View File

@ -0,0 +1,398 @@
import copy
import json
import logging
import time
from datetime import datetime, timezone
from app.agent_framework.base import BaseAgent
from app.agent_framework.prompts.schema_advisor import SCHEMA_ADVISOR_TEMPLATE
from app.agent_framework.protocol import (
AgentCapability,
AgentType,
TaskMessage,
TaskResult,
TaskStatus,
)
from app.services.llm import LLMFactory, LLMError
from app.utils.json_extractor import extract_json
logger = logging.getLogger(__name__)
SCHEMA_TEMPLATES = {
"Organization": {
"@context": "https://schema.org",
"@type": "Organization",
"name": "",
"description": "",
"url": "",
"logo": "",
"sameAs": [],
"contactPoint": {
"@type": "ContactPoint",
"contactType": "customer service",
"telephone": "",
},
},
"Product": {
"@context": "https://schema.org",
"@type": "Product",
"name": "",
"description": "",
"brand": {"@type": "Brand", "name": ""},
"offers": {
"@type": "Offer",
"priceCurrency": "CNY",
"availability": "https://schema.org/InStock",
},
},
"FAQPage": {
"@context": "https://schema.org",
"@type": "FAQPage",
"mainEntity": [
{
"@type": "Question",
"name": "",
"acceptedAnswer": {
"@type": "Answer",
"text": "",
},
}
],
},
"Article": {
"@context": "https://schema.org",
"@type": "Article",
"headline": "",
"description": "",
"author": {"@type": "Organization", "name": ""},
"datePublished": "",
"image": "",
},
"LocalBusiness": {
"@context": "https://schema.org",
"@type": "LocalBusiness",
"name": "",
"address": {
"@type": "PostalAddress",
"streetAddress": "",
"addressLocality": "",
"addressRegion": "",
"postalCode": "",
"addressCountry": "CN",
},
"geo": {
"@type": "GeoCoordinates",
"latitude": "",
"longitude": "",
},
"telephone": "",
"openingHours": "",
},
}
DIMENSION_SCHEMA_MAP = {
"schema_marketing": ["Organization", "LocalBusiness"],
"entity_clarity": ["Organization", "Product"],
"citation_readiness": ["FAQPage", "Article"],
"brand_visibility": ["Organization", "Product"],
"local_seo": ["LocalBusiness"],
}
PRIORITY_THRESHOLD = {
"high": 30.0,
"medium": 60.0,
}
DIFFICULTY_MAP = {
"Organization": "easy",
"Product": "medium",
"FAQPage": "medium",
"Article": "easy",
"LocalBusiness": "hard",
}
class SchemaAdvisorAgent(BaseAgent):
def __init__(self):
super().__init__(
name="schema_advisor",
agent_type=AgentType.SCHEMA_ADVISOR,
version="1.0.0",
)
def get_capabilities(self) -> AgentCapability:
return AgentCapability(
agent_name=self.name,
agent_type=self.agent_type,
version=self.version,
supported_tasks=["schema_advise"],
max_concurrency=2,
description="Schema优化建议Agent识别Schema缺失维度生成JSON-LD结构化数据建议",
)
async def execute(self, task: TaskMessage) -> TaskResult:
started_at = datetime.now(timezone.utc)
start_time = time.monotonic()
try:
output = await self._advise(task)
elapsed = time.monotonic() - start_time
return TaskResult(
task_id=task.task_id,
agent_name=self.name,
status=TaskStatus.COMPLETED,
output_data=output,
error_message=None,
started_at=started_at,
completed_at=datetime.now(timezone.utc),
metrics={
"elapsed_seconds": round(elapsed, 2),
"task_type": task.task_type,
},
)
except LLMError as e:
elapsed = time.monotonic() - start_time
logger.error(f"SchemaAdvisor LLM error on task {task.task_id}: {e}")
return TaskResult(
task_id=task.task_id,
agent_name=self.name,
status=TaskStatus.FAILED,
output_data=None,
error_message=f"LLM调用失败: {e}",
started_at=started_at,
completed_at=datetime.now(timezone.utc),
metrics={
"elapsed_seconds": round(elapsed, 2),
"task_type": task.task_type,
},
)
except Exception as e:
elapsed = time.monotonic() - start_time
logger.error(f"SchemaAdvisor task {task.task_id} failed: {e}")
return TaskResult(
task_id=task.task_id,
agent_name=self.name,
status=TaskStatus.FAILED,
output_data=None,
error_message=str(e),
started_at=started_at,
completed_at=datetime.now(timezone.utc),
metrics={
"elapsed_seconds": round(elapsed, 2),
"task_type": task.task_type,
},
)
async def _advise(self, task: TaskMessage) -> dict:
input_data = task.input_data
brand_id = input_data.get("brand_id")
diagnosis_data = input_data.get("diagnosis_data", {})
brand_info = input_data.get("brand_info", {})
focus_dimensions = input_data.get("focus_dimensions")
if not brand_id:
raise ValueError("input_data必须包含'brand_id'字段")
await self.report_progress(
task_id=task.task_id,
progress=0.1,
message="开始Schema建议分析...",
)
missing_dimensions = self._identify_missing_dimensions(diagnosis_data, focus_dimensions)
await self.report_progress(
task_id=task.task_id,
progress=0.3,
message=f"识别到{len(missing_dimensions)}个Schema缺失维度...",
)
matched = self._match_templates(missing_dimensions)
await self.report_progress(
task_id=task.task_id,
progress=0.5,
message="匹配预定义模板完成开始LLM填充...",
)
filled = await self._fill_with_llm(matched, brand_info)
await self.report_progress(
task_id=task.task_id,
progress=0.8,
message="LLM填充完成验证JSON-LD格式...",
)
validated = self._validate_and_sort(filled)
await self.report_progress(
task_id=task.task_id,
progress=1.0,
message="Schema建议生成完成",
)
return {
"brand_id": brand_id,
"suggestions": validated,
"total": len(validated),
}
def _identify_missing_dimensions(
self,
diagnosis_data: dict,
focus_dimensions: list[str] | None = None,
) -> list[dict]:
dimensions = []
dimension_scores = diagnosis_data.get("dimensions", {})
for dim_name, dim_info in dimension_scores.items():
if dim_name not in DIMENSION_SCHEMA_MAP:
continue
if focus_dimensions and dim_name not in focus_dimensions:
continue
score = dim_info.get("score", 0) if isinstance(dim_info, dict) else dim_info
max_score = dim_info.get("max_score", 100) if isinstance(dim_info, dict) else 100
percentage = (score / max_score * 100) if max_score > 0 else 0
if percentage < 80:
dimensions.append({
"dimension": dim_name,
"current_score": round(score, 2),
"max_score": max_score,
"percentage": round(percentage, 2),
})
if not dimensions and diagnosis_data:
overall = diagnosis_data.get("overall_score", 0)
if overall < 80:
for dim_name in DIMENSION_SCHEMA_MAP:
if focus_dimensions and dim_name not in focus_dimensions:
continue
dimensions.append({
"dimension": dim_name,
"current_score": 0,
"max_score": 100,
"percentage": 0,
})
return dimensions
def _match_templates(self, missing_dimensions: list[dict]) -> list[dict]:
matched = []
seen_types = set()
for dim in missing_dimensions:
schema_types = DIMENSION_SCHEMA_MAP.get(dim["dimension"], [])
for schema_type in schema_types:
if schema_type in seen_types:
continue
seen_types.add(schema_type)
template = SCHEMA_TEMPLATES.get(schema_type)
if template:
percentage = dim["percentage"]
if percentage < PRIORITY_THRESHOLD["high"]:
priority = "high"
elif percentage < PRIORITY_THRESHOLD["medium"]:
priority = "medium"
else:
priority = "low"
matched.append({
"schema_type": schema_type,
"priority": priority,
"diagnosis_dimensions": {
"dimension": dim["dimension"],
"current_score": dim["current_score"],
"max_score": dim["max_score"],
"percentage": dim["percentage"],
},
"json_ld_template": copy.deepcopy(template),
"implementation_difficulty": DIFFICULTY_MAP.get(schema_type, "medium"),
})
return matched
async def _fill_with_llm(self, matched: list[dict], brand_info: dict) -> list[dict]:
provider = LLMFactory.get_default()
results = []
for item in matched:
schema_type = item["schema_type"]
try:
variables = {
"brand_name": brand_info.get("name", ""),
"brand_website": brand_info.get("website", ""),
"brand_industry": brand_info.get("industry", ""),
"schema_type": schema_type,
"diagnosis_data": json.dumps(item.get("diagnosis_dimensions", {}), ensure_ascii=False),
"existing_schemas": "",
}
messages = SCHEMA_ADVISOR_TEMPLATE.render(variables)
response = await provider.chat(
messages,
temperature=0.3,
max_tokens=2048,
)
filled = json.loads(extract_json(response.content))
item["json_ld_filled"] = filled
item["estimated_impact"] = self._generate_impact_description(
schema_type, item.get("diagnosis_dimensions", {}).get("dimension", "")
)
except (json.JSONDecodeError, LLMError, ValueError) as e:
logger.warning(f"LLM填充Schema {schema_type} 失败: {e}")
item["json_ld_filled"] = None
item["estimated_impact"] = self._generate_impact_description(
schema_type, item.get("diagnosis_dimensions", {}).get("dimension", "")
)
results.append(item)
return results
def _validate_json_ld(self, json_ld: dict) -> dict:
errors = []
warnings = []
if not json_ld:
return {"is_valid": False, "errors": ["JSON-LD为空"], "warnings": []}
if "@context" not in json_ld:
errors.append("缺少@context字段")
if "@type" not in json_ld:
errors.append("缺少@type字段")
if "@context" in json_ld and json_ld["@context"] != "https://schema.org":
warnings.append(f"@context值非标准: {json_ld.get('@context')}")
if "@type" in json_ld and json_ld["@type"] not in SCHEMA_TEMPLATES:
warnings.append(f"@type非推荐类型: {json_ld.get('@type')}")
try:
json.dumps(json_ld)
except (json.JSONDecodeError, TypeError) as e:
errors.append(f"JSON序列化失败: {e}")
return {
"is_valid": len(errors) == 0,
"errors": errors,
"warnings": warnings,
}
def _validate_and_sort(self, items: list[dict]) -> list[dict]:
validated = []
for item in items:
json_ld_filled = item.get("json_ld_filled")
if json_ld_filled:
validation = self._validate_json_ld(json_ld_filled)
item["validation_errors"] = None if validation["is_valid"] else {"errors": validation["errors"], "warnings": validation["warnings"]}
else:
item["validation_errors"] = {"errors": ["JSON-LD填充失败"], "warnings": []}
validated.append(item)
priority_order = {"high": 0, "medium": 1, "low": 2}
validated.sort(key=lambda x: priority_order.get(x.get("priority", "medium"), 1))
return validated
def _generate_impact_description(self, schema_type: str, dimension: str) -> str:
impacts = {
"Organization": "增强品牌实体识别提升AI搜索引擎对品牌的理解和引用概率",
"Product": "提升产品在搜索结果中的富摘要展示,增加点击率和引用率",
"FAQPage": "增加FAQ富摘要展示机会提升在AI回答中的直接引用概率",
"Article": "优化文章内容的结构化表达提升AI搜索引擎的内容理解和引用",
"LocalBusiness": "增强本地搜索可见性,提升地理位置相关查询的引用率",
}
return impacts.get(schema_type, f"提升{dimension}维度的得分和AI引用率")

View File

@ -0,0 +1,166 @@
import logging
import time
from datetime import datetime, timezone
from app.agent_framework.base import BaseAgent
from app.agent_framework.protocol import (
AgentCapability,
AgentType,
TaskMessage,
TaskResult,
TaskStatus,
)
logger = logging.getLogger(__name__)
class TrendAgent(BaseAgent):
def __init__(self):
super().__init__(
name="trend_agent",
agent_type=AgentType.TREND_AGENT,
version="1.0.0",
)
def get_capabilities(self) -> AgentCapability:
return AgentCapability(
agent_name=self.name,
agent_type=self.agent_type,
version=self.version,
supported_tasks=["trend_insight", "trend_hotspot"],
max_concurrency=2,
description="趋势洞察Agent分析品牌引用趋势、识别热点话题、推断变化原因并生成建议",
)
async def execute(self, task: TaskMessage) -> TaskResult:
started_at = datetime.now(timezone.utc)
start_time = time.monotonic()
try:
if task.task_type == "trend_insight":
output = await self._trend_insight(task)
elif task.task_type == "trend_hotspot":
output = await self._trend_hotspot(task)
else:
raise ValueError(f"不支持的任务类型: {task.task_type}")
elapsed = time.monotonic() - start_time
return TaskResult(
task_id=task.task_id,
agent_name=self.name,
status=TaskStatus.COMPLETED,
output_data=output,
error_message=None,
started_at=started_at,
completed_at=datetime.now(timezone.utc),
metrics={
"elapsed_seconds": round(elapsed, 2),
"task_type": task.task_type,
},
)
except Exception as e:
elapsed = time.monotonic() - start_time
logger.error(f"TrendAgent task {task.task_id} failed: {e}")
return TaskResult(
task_id=task.task_id,
agent_name=self.name,
status=TaskStatus.FAILED,
output_data=None,
error_message=str(e),
started_at=started_at,
completed_at=datetime.now(timezone.utc),
metrics={
"elapsed_seconds": round(elapsed, 2),
"task_type": task.task_type,
},
)
async def _trend_insight(self, task: TaskMessage) -> dict:
input_data = task.input_data
brand_id = input_data.get("brand_id")
days = input_data.get("days", 30)
platforms = input_data.get("platforms")
keywords = input_data.get("keywords")
if not brand_id:
raise ValueError("input_data必须包含'brand_id'字段")
await self.report_progress(
task_id=task.task_id,
progress=0.1,
message="开始趋势洞察分析...",
)
from app.database import AsyncSessionLocal
from app.services.trend.trend_analyzer_service import TrendAnalyzerService
async with AsyncSessionLocal() as db:
service = TrendAnalyzerService(db)
await self.report_progress(
task_id=task.task_id,
progress=0.3,
message="获取历史引用数据...",
)
await self.report_progress(
task_id=task.task_id,
progress=0.5,
message="执行时间序列分析...",
)
result = await service.analyze_trends(
brand_id=brand_id,
days=days,
platforms=platforms,
keywords=keywords,
)
await self.report_progress(
task_id=task.task_id,
progress=1.0,
message="趋势洞察分析完成",
)
return result
async def _trend_hotspot(self, task: TaskMessage) -> dict:
input_data = task.input_data
brand_id = input_data.get("brand_id")
days = input_data.get("days", 30)
if not brand_id:
raise ValueError("input_data必须包含'brand_id'字段")
await self.report_progress(
task_id=task.task_id,
progress=0.1,
message="开始热点话题分析...",
)
from app.database import AsyncSessionLocal
from app.services.trend.trend_analyzer_service import TrendAnalyzerService
async with AsyncSessionLocal() as db:
service = TrendAnalyzerService(db)
await self.report_progress(
task_id=task.task_id,
progress=0.5,
message="检测引用量突增的关键词/话题...",
)
result = await service.get_hotspots(
brand_id=brand_id,
days=days,
)
await self.report_progress(
task_id=task.task_id,
progress=1.0,
message="热点话题分析完成",
)
return result

View File

@ -21,6 +21,33 @@ from app.config import settings
logger = logging.getLogger(__name__)
# Module-level lazy singleton for TaskDispatcher — avoids creating a new
# dispatcher on every method call while still deferring the import to
# prevent circular-dependency issues at module-load time.
_dispatcher_instance = None
def _get_dispatcher():
"""Return a cached TaskDispatcher instance (lazy singleton)."""
global _dispatcher_instance
if _dispatcher_instance is None:
from app.agent_framework.dispatcher import TaskDispatcher
_dispatcher_instance = TaskDispatcher(settings.REDIS_URL)
return _dispatcher_instance
# Module-level lazy singleton for AgentRegistry — same rationale.
_registry_instance = None
def _get_registry():
"""Return a cached AgentRegistry instance (lazy singleton)."""
global _registry_instance
if _registry_instance is None:
from app.agent_framework.registry import AgentRegistry
_registry_instance = AgentRegistry()
return _registry_instance
class BaseAgent(ABC):
"""所有 Agent 的基类,定义标准生命周期"""
@ -40,6 +67,10 @@ class BaseAgent(ABC):
def status(self) -> AgentStatus:
return self._status
@property
def is_distributed(self) -> bool:
return self._redis is not None
@abstractmethod
async def execute(self, task: TaskMessage) -> TaskResult:
"""执行任务的核心逻辑,子类必须实现"""
@ -51,48 +82,47 @@ class BaseAgent(ABC):
...
async def start(self):
"""启动 Agent注册到 Registry开始监听任务队列"""
logger.info(f"Starting agent '{self.name}' (type={self.agent_type}, version={self.version})")
# 初始化 Redis 连接
self._redis = aioredis.from_url(
settings.REDIS_URL,
decode_responses=True,
)
try:
self._redis = aioredis.from_url(
settings.REDIS_URL,
decode_responses=True,
)
await self._redis.ping()
# 注册到 Registry
from app.agent_framework.registry import AgentRegistry
registry = _get_registry()
capability = self.get_capabilities()
await registry.register(capability, endpoint=f"agent:{self.name}")
registry = AgentRegistry()
capability = self.get_capabilities()
await registry.register(capability, endpoint=f"agent:{self.name}")
self._status = AgentStatus.ONLINE
# 更新状态
self._status = AgentStatus.ONLINE
capability = self.get_capabilities()
max_concurrency = getattr(capability, 'max_concurrency', 1) or 1
self._semaphore = asyncio.Semaphore(max_concurrency)
logger.info(
f"Agent '{self.name}' concurrency limit set to {max_concurrency}"
)
# 根据 capabilities 的 max_concurrency 初始化 Semaphore
capability = self.get_capabilities()
max_concurrency = getattr(capability, 'max_concurrency', 1) or 1
self._semaphore = asyncio.Semaphore(max_concurrency)
logger.info(
f"Agent '{self.name}' concurrency limit set to {max_concurrency}"
)
self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
self._listen_task = asyncio.create_task(self._listen_for_tasks())
# 启动心跳
self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
logger.info(f"Agent '{self.name}' started in distributed mode")
except Exception as e:
self._redis = None
self._status = AgentStatus.ONLINE
# 开始监听任务队列
self._listen_task = asyncio.create_task(self._listen_for_tasks())
capability = self.get_capabilities()
max_concurrency = getattr(capability, 'max_concurrency', 1) or 1
self._semaphore = asyncio.Semaphore(max_concurrency)
logger.info(f"Agent '{self.name}' started successfully")
logger.warning(f"Agent '{self.name}' started in local mode (Redis unavailable: {e})")
async def stop(self):
"""停止 Agent注销停止监听"""
logger.info(f"Stopping agent '{self.name}'")
self._status = AgentStatus.OFFLINE
# 取消监听任务
if self._listen_task and not self._listen_task.done():
self._listen_task.cancel()
try:
@ -100,7 +130,6 @@ class BaseAgent(ABC):
except asyncio.CancelledError:
pass
# 取消心跳任务
if self._heartbeat_task and not self._heartbeat_task.done():
self._heartbeat_task.cancel()
try:
@ -108,14 +137,10 @@ class BaseAgent(ABC):
except asyncio.CancelledError:
pass
# 注销
from app.agent_framework.registry import AgentRegistry
if self._redis is not None:
registry = _get_registry()
await registry.unregister(self.name)
registry = AgentRegistry()
await registry.unregister(self.name)
# 关闭 Redis 连接
if self._redis:
await self._redis.close()
self._redis = None
@ -123,13 +148,10 @@ class BaseAgent(ABC):
async def heartbeat(self):
"""定期心跳上报"""
from app.agent_framework.registry import AgentRegistry
registry = AgentRegistry()
registry = _get_registry()
await registry.update_heartbeat(self.name)
async def report_progress(self, task_id: str, progress: float, message: str):
"""上报任务进度"""
progress_obj = TaskProgress(
task_id=task_id,
agent_name=self.name,
@ -138,7 +160,6 @@ class BaseAgent(ABC):
updated_at=datetime.now(timezone.utc),
)
# 通过 Redis Pub/Sub 发布进度
if self._redis:
try:
await self._redis.publish(
@ -148,11 +169,11 @@ class BaseAgent(ABC):
except Exception as e:
logger.warning(f"Failed to publish progress for task {task_id}: {e}")
# 同时更新数据库
from app.agent_framework.dispatcher import TaskDispatcher
dispatcher = TaskDispatcher(settings.REDIS_URL)
await dispatcher.handle_progress(progress_obj)
try:
dispatcher = _get_dispatcher()
await dispatcher.handle_progress(progress_obj)
except Exception as e:
logger.warning(f"Failed to report progress to dispatcher for task {task_id}: {e}")
async def _heartbeat_loop(self):
"""心跳循环"""
@ -198,7 +219,6 @@ class BaseAgent(ABC):
await self._execute_task(task)
async def _execute_task(self, task: TaskMessage):
"""执行单个任务"""
self._running_tasks.add(task.task_id)
self._status = AgentStatus.BUSY
@ -209,15 +229,12 @@ class BaseAgent(ABC):
result.started_at = started_at
result.completed_at = datetime.now(timezone.utc)
# 处理结果
from app.agent_framework.dispatcher import TaskDispatcher
dispatcher = TaskDispatcher(settings.REDIS_URL)
await dispatcher.handle_result(result)
if self._redis is not None:
dispatcher = _get_dispatcher()
await dispatcher.handle_result(result)
except Exception as e:
logger.error(f"Agent '{self.name}' task {task.task_id} failed: {e}")
# 构建失败结果
error_result = TaskResult(
task_id=task.task_id,
agent_name=self.name,
@ -228,10 +245,9 @@ class BaseAgent(ABC):
completed_at=datetime.now(timezone.utc),
metrics=None,
)
from app.agent_framework.dispatcher import TaskDispatcher
dispatcher = TaskDispatcher(settings.REDIS_URL)
await dispatcher.handle_result(error_result)
if self._redis is not None:
dispatcher = _get_dispatcher()
await dispatcher.handle_result(error_result)
finally:
self._running_tasks.discard(task.task_id)

View File

@ -15,6 +15,7 @@ from .content_generator import CONTENT_GENERATOR_TEMPLATE
from .deai_agent import DEAI_TEMPLATE
from .geo_optimizer import GEO_OPTIMIZER_TEMPLATE
from .rule_checker import RULE_CHECKER_TEMPLATE
from .schema_advisor import SCHEMA_ADVISOR_TEMPLATE
from .topic_selector import TOPIC_SELECTOR_TEMPLATE
__all__ = [
@ -25,4 +26,5 @@ __all__ = [
"DEAI_TEMPLATE",
"GEO_OPTIMIZER_TEMPLATE",
"RULE_CHECKER_TEMPLATE",
"SCHEMA_ADVISOR_TEMPLATE",
]

View File

@ -0,0 +1,70 @@
from .base_template import PromptSection, PromptTemplate
SCHEMA_ADVISOR_TEMPLATE = PromptTemplate(
PromptSection(
identity="""你是一位精通Schema.org结构化数据和JSON-LD的技术专家。
你深刻理解搜索引擎和AI模型如ChatGPTPerplexityKimi如何解析和利用结构化数据
知道如何通过精准的Schema标记提升品牌在AI搜索结果中的可见性和引用率
你生成的JSON-LD严格遵循Schema.org规范确保可被搜索引擎正确解析""",
context="""## 品牌信息
- 品牌名称${brand_name}
- 网站${brand_website}
- 行业${brand_industry}
## 诊断数据
${diagnosis_data}
## 已有Schema标记
${existing_schemas}
## 目标Schema类型
${schema_type}""",
instructions="""请根据以上品牌信息和诊断数据为品牌生成完整的JSON-LD结构化数据。
生成要求
1. 内容填充
- 所有字段必须填充真实具体的内容不得留空
- 品牌名称网站等基本信息必须与提供的数据一致
- 描述性文本应当专业准确体现品牌特色
2. Schema类型特定要求
- Organization: 包含name, description, url, logo, sameAs(社交媒体链接), contactPoint
- Product: 包含name, description, brand, offers, aggregateRating(如有)
- FAQPage: 生成3-5个与品牌行业相关的高质量FAQ问题和答案需自然且信息丰富
- Article: 包含headline, author, datePublished, description, image
- LocalBusiness: 包含name, address(完整地址结构), geo, telephone, openingHours
3. 语言要求
- 所有自然语言内容使用与品牌名称相同的语言
- 技术字段@type, @context保持英文
4. 结构完整性
- 必须包含@context和@type
- 嵌套对象必须完整不得省略必要子属性""",
constraints="""## 约束条件
- 严格遵循Schema.org规范不得使用非标准属性
- @context必须为"https://schema.org"
- @type必须是Schema.org定义的有效类型
- 不得编造不存在的品牌信息如无实际地址LocalBusiness的address可使用占位结构
- FAQ的问题必须是用户真实可能搜索的问题
- 所有URL字段如无实际值留空字符串
- 不得在JSON-LD中包含HTML标签""",
output_format="""## 输出格式
请以JSON格式输出填充后的JSON-LD
```json
{
"@context": "https://schema.org",
"@type": "...",
"...": "..."
}
```
仅输出JSON-LD对象不要包含任何解释文字""",
)
)

View File

@ -6,7 +6,6 @@ from enum import Enum
class AgentType(str, Enum):
"""Agent 类型枚举"""
CITATION_DETECTOR = "citation_detector"
CONTENT_GENERATOR = "content_generator"
DEAI_AGENT = "deai_agent"
@ -14,6 +13,9 @@ class AgentType(str, Enum):
RULE_CHECKER = "rule_checker"
COMPETITOR_ANALYZER = "competitor_analyzer"
PERFORMANCE_TRACKER = "performance_tracker"
SCHEMA_ADVISOR = "schema_advisor"
MONITOR_AGENT = "monitor_agent"
TREND_AGENT = "trend_agent"
class TaskStatus(str, Enum):

View File

@ -0,0 +1,62 @@
import asyncio
import argparse
import logging
import sys
from app.agent_framework.agents.citation_detector import CitationDetectorAgent
from app.agent_framework.agents.content_generator_agent import ContentGeneratorAgent
from app.agent_framework.agents.deai_agent import DeAIAgent
from app.agent_framework.agents.geo_optimizer_agent import GEOOptimizerAgent
from app.agent_framework.agents.monitor_agent import MonitorAgent
from app.agent_framework.agents.schema_advisor import SchemaAdvisorAgent
from app.agent_framework.agents.competitor_analyzer import CompetitorAnalyzerAgent
from app.agent_framework.agents.trend_agent import TrendAgent
AGENTS = {
"citation_detector": CitationDetectorAgent,
"content_generator": ContentGeneratorAgent,
"deai_agent": DeAIAgent,
"geo_optimizer": GEOOptimizerAgent,
"monitor": MonitorAgent,
"schema_advisor": SchemaAdvisorAgent,
"competitor_analyzer": CompetitorAnalyzerAgent,
"trend_agent": TrendAgent,
"all": None,
}
async def run_agent(name: str):
if name == "all":
agents = [cls() for cls in AGENTS.values() if cls is not None]
else:
cls = AGENTS.get(name)
if cls is None:
print(f"Unknown agent: {name}")
sys.exit(1)
agents = [cls()]
for agent in agents:
await agent.start()
print(f"Agent(s) running: {[a.name for a in agents]}")
try:
await asyncio.Future()
except asyncio.CancelledError:
pass
finally:
for agent in agents:
await agent.stop()
def main():
parser = argparse.ArgumentParser(description="Run GEO Agent(s)")
parser.add_argument("agent", choices=list(AGENTS.keys()), help="Agent name or 'all'")
parser.add_argument("--log-level", default="INFO", help="Logging level")
args = parser.parse_args()
logging.basicConfig(level=getattr(logging, args.log_level.upper()))
asyncio.run(run_agent(args.agent))
if __name__ == "__main__":
main()

View File

@ -46,6 +46,7 @@ class QueryResultResponse(BaseModel):
brand_context: str | None
competitor_contexts: list[str]
response_time_ms: int
timestamp: str
model_config = {"from_attributes": True}
@ -61,6 +62,7 @@ class CitationRateResponse(BaseModel):
class BatchQueryResponse(BaseModel):
results: list[QueryResultResponse]
citation_rate: CitationRateResponse
avg_response_time_ms: float = 0.0
def _result_to_response(r: AIQueryResult) -> QueryResultResponse:
@ -73,6 +75,7 @@ def _result_to_response(r: AIQueryResult) -> QueryResultResponse:
brand_context=r.brand_context,
competitor_contexts=r.competitor_contexts,
response_time_ms=r.response_time_ms,
timestamp=r.timestamp.isoformat(),
)
@ -97,9 +100,16 @@ async def _execute_batch(
engine_types, query, brand_name, competitor_names
)
citation_rate = service.calculate_citation_rate(results)
response_results = [_result_to_response(r) for r in results]
avg_response_time = (
sum(r.response_time_ms for r in results) / len(results)
if results
else 0.0
)
return BatchQueryResponse(
results=[_result_to_response(r) for r in results],
results=response_results,
citation_rate=CitationRateResponse(**citation_rate),
avg_response_time_ms=avg_response_time,
)

View File

@ -24,7 +24,7 @@ from app.schemas.alert import (
AlertSettingListResponse,
AlertSettingBulkUpdate,
)
from app.services.alert_engine import AlertEngine
from app.services.alert.alert_engine import AlertEngine
logger = logging.getLogger(__name__)

View File

@ -6,7 +6,7 @@ from pydantic import BaseModel
from app.api.deps import get_current_user
from app.models.user import User
from app.services.api_key_manager import APIKeyManager, KeySource, KeyStatus
from app.services.smart_router import ENGINE_COST_PROFILES
from app.services.llm.smart_router import ENGINE_COST_PROFILES
logger = logging.getLogger(__name__)

View File

@ -0,0 +1,300 @@
import logging
import uuid
from datetime import datetime
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.database import get_db
from app.models.attribution_record import AttributionRecord
from app.models.brand import Brand
from app.models.diagnosis_record import DiagnosisRecord
from app.models.user import User
from app.services.attribution.attribution_engine import AttributionEngine
from app.services.attribution.roi_calculator import ROICalculator
logger = logging.getLogger(__name__)
router = APIRouter()
PLAN_COSTS = {
"free": 0.0,
"starter": 99.0,
"pro": 299.0,
"enterprise": 999.0,
}
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
if isinstance(value, uuid.UUID):
return value
return uuid.UUID(str(value))
class StartTrackingRequest(BaseModel):
brand_id: str
content_id: str | None = None
class AttributionResponse(BaseModel):
id: str
brand_id: str
content_id: str | None
baseline_score: float
current_score: float | None
score_delta: float | None
status: str
roi_percentage: float | None
created_at: datetime
model_config = {"from_attributes": True}
class ROIReport(BaseModel):
brand_id: str
brand_name: str
subscription_cost: float
current_plan: str
total_score_delta: float
value_generated: float
roi_percentage: float
break_even_delta: float
tracking_records: list[AttributionResponse]
ab_comparison: dict | None
class ABComparisonResponse(BaseModel):
brand_id: str
brand_name: str
overall_before: float
overall_after: float
overall_delta: float
dimensions: list[dict]
def _record_to_response(record: AttributionRecord) -> AttributionResponse:
return AttributionResponse(
id=str(record.id),
brand_id=str(record.brand_id),
content_id=str(record.content_id) if record.content_id else None,
baseline_score=record.baseline_score,
current_score=record.current_score,
score_delta=record.score_delta,
status=record.status,
roi_percentage=record.roi_percentage,
created_at=record.created_at,
)
async def _get_brand_or_404(
brand_id: uuid.UUID,
current_user: User,
db: AsyncSession,
) -> Brand:
user_uuid = _to_uuid(current_user.id)
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == user_uuid)
result = await db.execute(stmt)
brand = result.scalar_one_or_none()
if not brand:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="品牌不存在",
)
return brand
@router.post("/start", response_model=AttributionResponse)
async def start_tracking(
body: StartTrackingRequest,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
brand_id = _to_uuid(body.brand_id)
brand = await _get_brand_or_404(brand_id, current_user, db)
content_id = _to_uuid(body.content_id) if body.content_id else None
engine = AttributionEngine()
record = await engine.start_tracking(db, brand.id, content_id, current_user.id)
return _record_to_response(record)
@router.get("/brand/{brand_id}")
async def get_brand_attribution(
brand_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
brand = await _get_brand_or_404(brand_id, current_user, db)
engine = AttributionEngine()
summary = await engine.get_brand_attribution_summary(db, brand.id)
summary["records"] = [_record_to_response(r) for r in summary["records"]]
return summary
@router.get("/{record_id}", response_model=AttributionResponse)
async def get_attribution_record(
record_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
stmt = select(AttributionRecord).where(AttributionRecord.id == record_id)
result = await db.execute(stmt)
record = result.scalar_one_or_none()
if not record:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="归因记录不存在",
)
return _record_to_response(record)
@router.post("/{record_id}/check", response_model=AttributionResponse)
async def check_attribution(
record_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
engine = AttributionEngine()
try:
record = await engine.check_attribution(db, record_id)
except ValueError:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="归因记录不存在",
)
return _record_to_response(record)
@router.get("/roi/{brand_id}", response_model=ROIReport)
async def get_roi_report(
brand_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
brand = await _get_brand_or_404(brand_id, current_user, db)
engine = AttributionEngine()
summary = await engine.get_brand_attribution_summary(db, brand.id)
user_plan = getattr(current_user, "plan", "free") or "free"
subscription_cost = PLAN_COSTS.get(user_plan, 0.0)
calculator = ROICalculator()
roi_data = calculator.calculate_roi(
subscription_cost=subscription_cost,
score_delta=summary["total_score_delta"],
attribution_records=summary["records"],
)
ab_comparison = None
baseline_record = (
await db.execute(
select(DiagnosisRecord)
.where(
DiagnosisRecord.brand_id == brand.id,
DiagnosisRecord.status == "completed",
)
.order_by(DiagnosisRecord.completed_at.asc())
.limit(1)
)
).scalar_one_or_none()
latest_record = (
await db.execute(
select(DiagnosisRecord)
.where(
DiagnosisRecord.brand_id == brand.id,
DiagnosisRecord.status == "completed",
)
.order_by(DiagnosisRecord.completed_at.desc())
.limit(1)
)
).scalar_one_or_none()
if baseline_record and latest_record and baseline_record.id != latest_record.id:
before_dims = baseline_record.result_json.get("dimensions", []) if baseline_record.result_json else []
after_dims = latest_record.result_json.get("dimensions", []) if latest_record.result_json else []
before_map = {d.get("name"): d for d in before_dims}
after_map = {d.get("name"): d for d in after_dims}
ab_comparison = calculator.generate_ab_comparison(
before_score=baseline_record.overall_score or 0,
after_score=latest_record.overall_score or 0,
before_dimensions=before_map,
after_dimensions=after_map,
)
return ROIReport(
brand_id=str(brand.id),
brand_name=brand.name,
subscription_cost=subscription_cost,
current_plan=user_plan,
total_score_delta=summary["total_score_delta"],
value_generated=roi_data["value_generated"],
roi_percentage=roi_data["roi_percentage"],
break_even_delta=roi_data["break_even_delta"],
tracking_records=[_record_to_response(r) for r in summary["records"]],
ab_comparison=ab_comparison,
)
@router.get("/ab-comparison/{brand_id}", response_model=ABComparisonResponse)
async def get_ab_comparison(
brand_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
brand = await _get_brand_or_404(brand_id, current_user, db)
baseline_record = (
await db.execute(
select(DiagnosisRecord)
.where(
DiagnosisRecord.brand_id == brand.id,
DiagnosisRecord.status == "completed",
)
.order_by(DiagnosisRecord.completed_at.asc())
.limit(1)
)
).scalar_one_or_none()
latest_record = (
await db.execute(
select(DiagnosisRecord)
.where(
DiagnosisRecord.brand_id == brand.id,
DiagnosisRecord.status == "completed",
)
.order_by(DiagnosisRecord.completed_at.desc())
.limit(1)
)
).scalar_one_or_none()
if not baseline_record or not latest_record:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="暂无诊断数据无法生成A/B对比",
)
calculator = ROICalculator()
before_dims = baseline_record.result_json.get("dimensions", []) if baseline_record.result_json else []
after_dims = latest_record.result_json.get("dimensions", []) if latest_record.result_json else []
before_map = {d.get("name"): d for d in before_dims}
after_map = {d.get("name"): d for d in after_dims}
comparison = calculator.generate_ab_comparison(
before_score=baseline_record.overall_score or 0,
after_score=latest_record.overall_score or 0,
before_dimensions=before_map,
after_dimensions=after_map,
)
return ABComparisonResponse(
brand_id=str(brand.id),
brand_name=brand.name,
overall_before=comparison["overall_before"],
overall_after=comparison["overall_after"],
overall_delta=comparison["overall_delta"],
dimensions=comparison["dimensions"],
)

View File

@ -1,5 +1,6 @@
"""Brands API endpoints."""
import json
import logging
import uuid
from typing import Annotated
@ -14,9 +15,12 @@ from app.api.scoring import router as scoring_router
from app.database import get_db
from app.models.user import User
from app.models.brand import Brand
from app.models.query import Query as QueryModel
from app.schemas.brand import BrandCreate, BrandUpdate, BrandResponse, BrandListResponse
from app.services.cache import get_cache_service, TTL_BRANDS
logger = logging.getLogger(__name__)
router = APIRouter()
# Include competitors router under brands
@ -127,6 +131,28 @@ async def update_brand(
# Update only provided fields
update_data = brand_data.model_dump(exclude_unset=True)
# 检测品牌名称变更,同步更新关联 Query 的 target_brand 和 brand_aliases
old_name = brand.name
new_name = update_data.get("name")
if new_name and new_name != old_name:
queries_stmt = select(QueryModel).where(
QueryModel.user_id == current_user.id,
QueryModel.target_brand == old_name,
)
queries_result = await db.execute(queries_stmt)
related_queries = queries_result.scalars().all()
for query in related_queries:
query.target_brand = new_name
# 如果 brand_aliases 中包含旧名称,也同步更新
if query.brand_aliases and old_name in query.brand_aliases:
query.brand_aliases = [new_name if a == old_name else a for a in query.brand_aliases]
if related_queries:
logger.info(f"Brand renamed from '{old_name}' to '{new_name}', synced {len(related_queries)} queries")
for field, value in update_data.items():
setattr(brand, field, value)

View File

@ -13,7 +13,7 @@ from app.schemas.citation import (
CitationListResponse,
CitationStatsResponse,
)
from app.services.citation import (
from app.services.citation.citation import (
get_citation_stats,
get_citations,
)

View File

@ -0,0 +1,135 @@
import logging
import uuid
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.database import get_db
from app.models.user import User
from app.models.brand import Brand
from app.models.competitor_insight import CompetitorInsight
from app.schemas.competitor_insight import (
CompetitorAnalysisRequest,
CompetitorInsightResponse,
CompetitorInsightList,
CompetitorGapSummary,
)
from app.services.competitor.competitor_analyzer_service import CompetitorAnalyzerService
logger = logging.getLogger(__name__)
router = APIRouter()
async def _get_brand_if_owned(
brand_id: uuid.UUID,
current_user: User,
db: AsyncSession,
) -> Brand:
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == current_user.id)
result = await db.execute(stmt)
brand = result.scalar_one_or_none()
if not brand:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="品牌不存在",
)
return brand
@router.post("/analyze", response_model=CompetitorInsightList)
async def analyze_competitor(
request: CompetitorAnalysisRequest,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
await _get_brand_if_owned(request.brand_id, current_user, db)
service = CompetitorAnalyzerService()
try:
result = await service.analyze_competitor(
brand_id=request.brand_id,
analysis_types=request.analysis_types,
period_days=request.period_days or 30,
)
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
)
stmt = (
select(CompetitorInsight)
.where(CompetitorInsight.brand_id == request.brand_id)
.order_by(CompetitorInsight.created_at.desc())
)
db_result = await db.execute(stmt)
insights = list(db_result.scalars().all())
return {"items": insights, "total": len(insights)}
@router.get("/brand/{brand_id}", response_model=CompetitorInsightList)
async def get_brand_insights(
brand_id: uuid.UUID,
skip: int = Query(0, ge=0),
limit: int = Query(20, ge=1, le=100),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
await _get_brand_if_owned(brand_id, current_user, db)
count_stmt = select(func.count()).select_from(CompetitorInsight).where(
CompetitorInsight.brand_id == brand_id,
)
count_result = await db.execute(count_stmt)
total = count_result.scalar_one()
stmt = (
select(CompetitorInsight)
.where(CompetitorInsight.brand_id == brand_id)
.order_by(CompetitorInsight.created_at.desc())
.offset(skip)
.limit(limit)
)
result = await db.execute(stmt)
insights = list(result.scalars().all())
return {"items": insights, "total": total}
@router.get("/{insight_id}", response_model=CompetitorInsightResponse)
async def get_insight_detail(
insight_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
stmt = select(CompetitorInsight).where(CompetitorInsight.id == insight_id)
result = await db.execute(stmt)
insight = result.scalar_one_or_none()
if not insight:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="洞察不存在",
)
await _get_brand_if_owned(insight.brand_id, current_user, db)
return insight
@router.get("/brand/{brand_id}/gap-summary", response_model=list[CompetitorGapSummary])
async def get_gap_summary(
brand_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
brand = await _get_brand_if_owned(brand_id, current_user, db)
service = CompetitorAnalyzerService()
gap_summaries = await service.calculate_gap_score(db, brand_id, brand.name)
return gap_summaries

View File

@ -21,6 +21,7 @@ from app.schemas.competitor import (
CompetitorRecommendationItem,
CompetitorRecommendationResponse,
)
from app.utils.json_extractor import extract_json
logger = logging.getLogger(__name__)
@ -363,7 +364,7 @@ async def _get_llm_recommendations(
raise ValueError("API返回空响应")
# 提取JSON
json_str = _extract_json_from_text(content)
json_str = extract_json(content)
data = json.loads(json_str)
recommendations = []
@ -393,32 +394,6 @@ async def _get_llm_recommendations(
raise
def _extract_json_from_text(text: str) -> str:
"""从文本中提取JSON字符串。"""
import re
# 尝试直接解析
try:
json.loads(text)
return text
except json.JSONDecodeError:
pass
# 尝试从代码块中提取
json_pattern = r'```(?:json)?\s*([\s\S]*?)\s*```'
match = re.search(json_pattern, text)
if match:
return match.group(1).strip()
# 尝试找到第一个{到最后一个}之间的内容
first_brace = text.find('{')
last_brace = text.rfind('}')
if first_brace != -1 and last_brace != -1 and last_brace > first_brace:
return text[first_brace:last_brace + 1]
raise ValueError(f"无法从响应中提取JSON: {text[:200]}")
def _get_industry_label(industry: str | None) -> str:
"""获取行业中文标签。"""
labels = {

View File

@ -1,18 +1,28 @@
"""内容生产API - 串联Agent Pipeline"""
"""内容生产API - 串联Agent Pipeline
业务逻辑已委托给 ContentGenerationServiceAPI 层仅负责
1. 请求解析与参数校验
2. 调用服务层
3. 格式化响应
"""
import json
import logging
import re
import uuid
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.database import get_db
from app.models.brand import Brand
from app.models.content import Content, ContentVersion
from app.models.diagnosis_record import DiagnosisRecord
from app.models.user import User
from app.services.content.content_generation_service import ContentGenerationService
from app.services.llm import LLMError
logger = logging.getLogger(__name__)
@ -29,6 +39,7 @@ class ContentGenerateRequest(BaseModel):
brand_description: str = ""
run_deai: bool = True
run_geo: bool = True
use_agent_framework: bool = False
class ContentGenerateResponse(BaseModel):
@ -41,44 +52,6 @@ class ContentGenerateResponse(BaseModel):
pipeline_stages: list[dict] = [] # 每个阶段的执行结果摘要
async def _get_knowledge_context(
db: AsyncSession,
brand_name: str,
knowledge_base_ids: list[str],
target_keyword: str,
) -> str:
"""
从知识库检索与查询相关的上下文
如果有知识库ID则调用 RAGService.search 获取相关内容
否则返回空字符串不影响后续流程
"""
if not knowledge_base_ids:
return ""
try:
from app.services.knowledge.rag_service import RAGService
rag_service = RAGService()
results = await rag_service.search(
session=db,
query=f"{brand_name} {target_keyword}" if brand_name else target_keyword,
knowledge_base_ids=knowledge_base_ids,
top_k=3,
)
if results:
context_parts = []
for r in results:
content = r.get("content", "")
title = r.get("document_title", "")
if content:
context_parts.append(f"[{title}] {content}")
return "\n".join(context_parts)
return ""
except Exception as e:
logger.warning(f"知识库检索失败,将不使用知识库上下文: {e}")
return ""
@router.post("/generate", response_model=ContentGenerateResponse)
async def generate_content(
req: ContentGenerateRequest,
@ -88,115 +61,132 @@ async def generate_content(
"""
一键生成内容同步执行Pipeline结果存入数据库
流程ContentGenerator DeAI GEOOptimizer
流程ContentGenerator -> DeAI -> GEOOptimizer
业务逻辑委托给 ContentGenerationService
"""
from app.services.llm import LLMError, LLMFactory
from app.agent_framework.prompts import (
CONTENT_GENERATOR_TEMPLATE,
DEAI_TEMPLATE,
GEO_OPTIMIZER_TEMPLATE,
)
org_id = getattr(current_user, "organization_id", None)
if not org_id:
raise HTTPException(status_code=403, detail="用户未关联组织")
stages = []
try:
provider = LLMFactory.get_default()
# 获取知识库上下文
knowledge_context = await _get_knowledge_context(
db, req.brand_name, req.knowledge_base_ids, req.target_keyword
service = ContentGenerationService()
result = await service.generate_content(
keyword=req.target_keyword,
brand_name=req.brand_name,
platform=req.target_platform,
content_style=req.content_style,
word_count=req.word_count,
knowledge_base_ids=req.knowledge_base_ids,
db=db,
user_id=current_user.id,
org_id=org_id,
run_deai=req.run_deai,
run_geo=req.run_geo,
use_agent_framework=req.use_agent_framework,
)
# Stage 1: 内容生成
gen_variables = {
"topic_title": req.target_keyword,
"target_keyword": req.target_keyword,
"target_platform": req.target_platform,
"content_angle": "综合分析",
"content_style": req.content_style,
"word_count": str(req.word_count),
"brand_name": req.brand_name,
"knowledge_context": knowledge_context,
}
messages = CONTENT_GENERATOR_TEMPLATE.render(gen_variables)
response = await provider.chat(messages, temperature=0.7, max_tokens=req.word_count * 2)
content = response.content
stages.append({"stage": "content_generation", "status": "success", "word_count": len(content)})
# Stage 2: 去AI化可选
if req.run_deai:
deai_variables = {
"original_content": content,
"target_style": "自然流畅",
"preserve_structure": "",
}
messages = DEAI_TEMPLATE.render(deai_variables)
response = await provider.chat(messages, temperature=0.9, max_tokens=len(content) * 2)
content = response.content
stages.append({"stage": "deai", "status": "success"})
# Stage 3: GEO优化可选
optimized = content
seo_score = None
if req.run_geo:
geo_variables = {
"original_content": content,
"target_keywords": req.target_keyword,
"target_platform": req.target_platform,
"optimization_level": "moderate",
}
messages = GEO_OPTIMIZER_TEMPLATE.render(geo_variables)
response = await provider.chat(messages, temperature=0.5, max_tokens=len(content) * 2)
optimized = response.content
stages.append({"stage": "geo_optimization", "status": "success"})
# ---- 存入数据库 ----
content_obj = Content(
organization_id=org_id,
title=req.target_keyword,
content_type="article",
body=optimized,
status="draft",
target_platforms=[req.target_platform] if req.target_platform else [],
keywords=[req.target_keyword],
extra_metadata={
"original_content": content if content != optimized else None,
"pipeline_stages": stages,
"seo_score": seo_score,
"brand_name": req.brand_name,
"content_style": req.content_style,
"word_count_target": req.word_count,
},
created_by=current_user.id,
current_version=1,
)
db.add(content_obj)
await db.flush() # get content_obj.id
# 创建版本记录(初始版本)
version = ContentVersion(
content_id=content_obj.id,
version_number=1,
title=req.target_keyword,
body=optimized,
change_summary="Pipeline自动生成",
created_by=current_user.id,
)
db.add(version)
await db.commit()
await db.refresh(content_obj)
return ContentGenerateResponse(
status="success",
content=content,
optimized_content=optimized,
seo_score=seo_score,
content_id=str(content_obj.id),
pipeline_stages=stages,
content=result["content"],
optimized_content=result["optimized_content"],
seo_score=result["seo_score"],
content_id=result["content_id"],
pipeline_stages=result["pipeline_stages"],
)
except LLMError as e:
raise HTTPException(status_code=502, detail=f"LLM调用失败: {str(e)}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"内容生成异常: {str(e)}")
class GEOContentGenerateRequest(BaseModel):
brand_id: str
target_keywords: list[str]
platform: str = "通用"
content_style: str = "专业严谨"
word_count: int = 2000
knowledge_base_ids: list[str] = []
run_deai: bool = True
run_geo: bool = True
class GEOContentGenerateResponse(BaseModel):
content_id: Optional[str] = None
content: str = ""
optimized_content: str = ""
seo_score: Optional[int] = None
pipeline_stages: list[dict] = []
@router.post("/generate-geo", response_model=GEOContentGenerateResponse, status_code=201)
async def generate_geo_content(
req: GEOContentGenerateRequest,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
org_id = getattr(current_user, "organization_id", None)
if not org_id:
raise HTTPException(status_code=403, detail="用户未关联组织")
from sqlalchemy import select
try:
brand_uuid = uuid.UUID(req.brand_id)
except ValueError:
raise HTTPException(status_code=400, detail=f"Invalid brand_id format: {req.brand_id}")
brand_stmt = select(Brand).where(Brand.id == brand_uuid)
brand_result = await db.execute(brand_stmt)
brand = brand_result.scalar_one_or_none()
if not brand:
raise HTTPException(status_code=404, detail=f"Brand not found: {req.brand_id}")
diagnosis_context = ""
diag_stmt = (
select(DiagnosisRecord)
.where(DiagnosisRecord.brand_id == brand_uuid, DiagnosisRecord.status == "completed")
.order_by(DiagnosisRecord.created_at.desc())
)
diag_result = await db.execute(diag_stmt)
diagnosis = diag_result.scalar_one_or_none()
if diagnosis and diagnosis.result_json:
result_json = diagnosis.result_json
weak_dimensions = []
if isinstance(result_json, dict):
dimensions = result_json.get("dimensions", {})
for dim_name, dim_data in dimensions.items():
if isinstance(dim_data, dict) and dim_data.get("score", 100) < 60:
weak_dimensions.append(dim_name)
if weak_dimensions:
diagnosis_context = f"基于诊断结果,以下维度需要重点优化:{', '.join(weak_dimensions)}。请围绕这些维度生成针对性内容。"
keyword = "".join(req.target_keywords)
if diagnosis_context:
keyword = f"{keyword}{diagnosis_context}"
try:
service = ContentGenerationService()
result = await service.generate_content(
keyword=keyword,
brand_name=brand.name,
platform=req.platform,
content_style=req.content_style,
word_count=req.word_count,
knowledge_base_ids=req.knowledge_base_ids,
db=db,
user_id=current_user.id,
org_id=org_id,
run_deai=req.run_deai,
run_geo=req.run_geo,
)
return GEOContentGenerateResponse(
content_id=result["content_id"],
content=result["content"],
optimized_content=result["optimized_content"],
seo_score=result["seo_score"],
pipeline_stages=result["pipeline_stages"],
)
except LLMError as e:

View File

@ -1,6 +1,5 @@
"""Dashboard API endpoints."""
import uuid
from datetime import datetime, timedelta
from fastapi import APIRouter, Depends, Query
from sqlalchemy import select, func, Integer
@ -19,325 +18,15 @@ from app.schemas.dashboard import (
PlatformScoreItem,
RecentQueryItem,
)
from app.services.scoring_service import ScoringService, get_health_level
from app.services.sentiment_service import get_sentiment_service
from app.schemas.scoring import CitationResult
from app.services.scoring.scoring_service import get_health_level
from app.services.scoring.brand_scoring_data_service import (
get_brand_scoring_data_service,
REQUIRED_PLATFORMS,
)
from app.services.cache import get_cache_service, TTL_DASHBOARD
router = APIRouter()
# Required platforms for the dashboard
REQUIRED_PLATFORMS = [
"wenxin", # 文心一言
"kimi", # Kimi
"tongyi", # 通义千问
"doubao", # 豆包
"xinghuo", # 讯飞星火
"tiangong", # 天工AI
"qingyan", # 智谱清言
]
async def _get_brand_score_by_platform(
db: AsyncSession,
user_id: uuid.UUID,
brand_id: uuid.UUID,
) -> dict[str, float]:
"""
Calculate brand score by platform.
Returns a dict mapping platform key to score (0-100).
"""
brand_stmt = select(Brand).where(
Brand.id == brand_id,
Brand.user_id == user_id,
)
brand_result = await db.execute(brand_stmt)
brand = brand_result.scalar_one_or_none()
if not brand:
return {platform: 0.0 for platform in REQUIRED_PLATFORMS}
queries_stmt = select(QueryModel).where(
QueryModel.user_id == user_id,
QueryModel.target_brand == brand.name,
)
queries_result = await db.execute(queries_stmt)
queries = list(queries_result.scalars().all())
if not queries:
return {platform: 0.0 for platform in REQUIRED_PLATFORMS}
query_ids = [q.id for q in queries]
platform_scores = {platform: 0.0 for platform in REQUIRED_PLATFORMS}
for platform in REQUIRED_PLATFORMS:
citation_stmt = select(
func.count().label("total"),
func.sum(
func.cast(
func.case((CitationRecord.cited == True, 1), else_=0),
Integer
)
).label("cited")
).where(
CitationRecord.query_id.in_(query_ids),
CitationRecord.platform == platform,
)
result = await db.execute(citation_stmt)
row = result.one()
total = row.total or 0
cited = row.cited or 0
if total > 0:
citation_rate = cited / total
platform_scores[platform] = round(citation_rate * 100, 2)
else:
platform_scores[platform] = 0.0
return platform_scores
async def _get_competitor_scores_by_platform(
db: AsyncSession,
user_id: uuid.UUID,
brand_id: uuid.UUID,
) -> dict[str, float]:
"""
Calculate average competitor score by platform.
Returns a dict mapping platform key to average competitor score (0-100).
"""
competitor_stmt = select(Competitor).where(Competitor.brand_id == brand_id)
competitor_result = await db.execute(competitor_stmt)
competitors = list(competitor_result.scalars().all())
if not competitors:
return {platform: 0.0 for platform in REQUIRED_PLATFORMS}
competitor_names = [c.name for c in competitors]
competitor_queries_stmt = select(QueryModel).where(
QueryModel.user_id == user_id,
QueryModel.target_brand.in_(competitor_names),
)
competitor_queries_result = await db.execute(competitor_queries_stmt)
competitor_queries = list(competitor_queries_result.scalars().all())
if not competitor_queries:
return {platform: 0.0 for platform in REQUIRED_PLATFORMS}
competitor_query_ids = [q.id for q in competitor_queries]
platform_scores = {platform: 0.0 for platform in REQUIRED_PLATFORMS}
for platform in REQUIRED_PLATFORMS:
citation_stmt = select(
func.count().label("total"),
func.sum(
func.cast(
func.case((CitationRecord.cited == True, 1), else_=0),
Integer
)
).label("cited")
).where(
CitationRecord.query_id.in_(competitor_query_ids),
CitationRecord.platform == platform,
)
result = await db.execute(citation_stmt)
row = result.one()
total = row.total or 0
cited = row.cited or 0
if total > 0:
citation_rate = cited / total
platform_scores[platform] = round(citation_rate * 100, 2)
else:
platform_scores[platform] = 0.0
return platform_scores
async def _calculate_overall_score_v2(
db: AsyncSession,
user_id: uuid.UUID,
brand_id: uuid.UUID,
) -> tuple[float, float, list[DimensionScoreItem], int, int]:
"""
Calculate V2 overall score, change from yesterday, dimension scores,
and competitor position.
Returns: (overall_score, change_from_yesterday, dimensions, ahead_count, behind_count)
"""
brand_stmt = select(Brand).where(
Brand.id == brand_id,
Brand.user_id == user_id,
)
brand_result = await db.execute(brand_stmt)
brand = brand_result.scalar_one_or_none()
if not brand:
return 0.0, 0.0, [], 0, 0
# Get queries for this brand
queries_stmt = select(QueryModel).where(
QueryModel.user_id == user_id,
QueryModel.target_brand == brand.name,
)
queries_result = await db.execute(queries_stmt)
queries = list(queries_result.scalars().all())
if not queries:
return 0.0, 0.0, [], 0, 0
query_ids = [q.id for q in queries]
# Get all citations
citations_stmt = select(CitationRecord).where(
CitationRecord.query_id.in_(query_ids),
)
citations_result = await db.execute(citations_stmt)
all_citations = list(citations_result.scalars().all())
total_queries = len(all_citations)
brand_citations = [c for c in all_citations if c.cited]
# Get competitor data
competitor_stmt = select(Competitor).where(Competitor.brand_id == brand_id)
competitor_result = await db.execute(competitor_stmt)
competitors = list(competitor_result.scalars().all())
competitor_names = [c.name for c in competitors]
# Calculate competitor mentions
competitor_mentions: dict[str, int] = {}
for comp_name in competitor_names:
count = sum(
1 for c in all_citations
if c.cited and c.competitor_brands
and comp_name in c.competitor_brands
)
if count > 0:
competitor_mentions[comp_name] = count
# Sentiment analysis - prefer persisted sentiment field
sentiment_service = get_sentiment_service()
sentiment_counts = {"positive": 0, "neutral": 0, "negative": 0}
for citation in brand_citations:
if citation.sentiment and citation.sentiment in ("positive", "neutral", "negative"):
sentiment_counts[citation.sentiment] += 1
else:
content = citation.raw_response or citation.citation_text or ""
if content.strip():
try:
result = await sentiment_service.analyze(
brand_name=brand.name,
content=content,
)
sentiment_counts[result.sentiment] += 1
except Exception:
sentiment_counts["neutral"] += 1
else:
sentiment_counts["neutral"] += 1
# Build citation results
citation_results = [
CitationResult(
cited=c.cited,
position=c.citation_position,
citation_text=c.citation_text,
sentiment="neutral", # simplified for dashboard
confidence=c.confidence or 0.0,
)
for c in brand_citations
]
# Extract positions
positions = [c.citation_position for c in brand_citations if c.cited]
# Calculate V2 score
scoring_service = ScoringService()
v2_result = scoring_service.calculate_v2(
mentioned_count=len(brand_citations),
total_queries=total_queries,
positions=positions,
sentiment_counts=sentiment_counts,
citations=citation_results,
brand_mentions=len(brand_citations),
competitor_mentions=competitor_mentions,
)
# Build dimension items
dimensions = [
DimensionScoreItem(
name=v2_result.mention_rate.name,
score=round(v2_result.mention_rate.score, 2),
max_score=v2_result.mention_rate.max_score,
percentage=round(v2_result.mention_rate.percentage, 2),
),
DimensionScoreItem(
name=v2_result.recommendation_rank.name,
score=round(v2_result.recommendation_rank.score, 2),
max_score=v2_result.recommendation_rank.max_score,
percentage=round(v2_result.recommendation_rank.percentage, 2),
),
DimensionScoreItem(
name=v2_result.sentiment_score.name,
score=round(v2_result.sentiment_score.score, 2),
max_score=v2_result.sentiment_score.max_score,
percentage=round(v2_result.sentiment_score.percentage, 2),
),
DimensionScoreItem(
name=v2_result.citation_quality.name,
score=round(v2_result.citation_quality.score, 2),
max_score=v2_result.citation_quality.max_score,
percentage=round(v2_result.citation_quality.percentage, 2),
),
DimensionScoreItem(
name=v2_result.competitive_position.name,
score=round(v2_result.competitive_position.score, 2),
max_score=v2_result.competitive_position.max_score,
percentage=round(v2_result.competitive_position.percentage, 2),
),
]
# Calculate competitor ahead/behind
ahead_count = sum(
1 for count in competitor_mentions.values()
if len(brand_citations) > count
)
behind_count = sum(
1 for count in competitor_mentions.values()
if len(brand_citations) <= count
)
# Calculate change from yesterday
today = datetime.now().date()
yesterday = today - timedelta(days=1)
today_citations = [
c for c in all_citations
if c.queried_at.date() == today
]
yesterday_citations = [
c for c in all_citations
if c.queried_at.date() == yesterday
]
today_cited = sum(1 for c in today_citations if c.cited)
today_total = len(today_citations)
today_score = (today_cited / today_total * 100) if today_total > 0 else 0.0
yesterday_cited = sum(1 for c in yesterday_citations if c.cited)
yesterday_total = len(yesterday_citations)
yesterday_score = (yesterday_cited / yesterday_total * 100) if yesterday_total > 0 else 0.0
change = round(today_score - yesterday_score, 2)
return v2_result.overall_score, change, dimensions, ahead_count, behind_count
@router.get("/stats", response_model=DashboardStatsResponse)
async def get_dashboard_stats(
@ -357,7 +46,6 @@ async def get_dashboard_stats(
- 最近查询记录
"""
cache = get_cache_service()
# 如果 brand_id 尚未确定,先查库取第一个品牌
if brand_id is None:
brand_stmt = select(Brand).where(Brand.user_id == current_user.id).limit(1)
brand_result = await db.execute(brand_stmt)
@ -379,34 +67,67 @@ async def get_dashboard_stats(
total_platforms=7,
)
# 尝试从缓存读取TTL: 2 分钟)
cache_key = f"dashboard:stats:{current_user.id}:{brand_id}"
cached = await cache.get_json(cache_key)
if cached is not None:
return cached
# Get brand name
brand_stmt = select(Brand).where(Brand.id == brand_id)
brand_result = await db.execute(brand_stmt)
brand = brand_result.scalar_one_or_none()
brand_name = brand.name if brand else None
# Calculate V2 overall score and dimensions
overall_score, score_change, dimensions, ahead_count, behind_count = (
await _calculate_overall_score_v2(db, current_user.id, brand_id)
scoring_data_service = get_brand_scoring_data_service()
scoring_data = await scoring_data_service.get_brand_scoring_data(
db, current_user.id, brand
)
# Get platform scores
platform_scores_dict = await _get_brand_score_by_platform(
overall_score = scoring_data.v2_result.overall_score
score_change = scoring_data.change_from_yesterday
dimensions = [
DimensionScoreItem(
name=scoring_data.v2_result.mention_rate.name,
score=round(scoring_data.v2_result.mention_rate.score, 2),
max_score=scoring_data.v2_result.mention_rate.max_score,
percentage=round(scoring_data.v2_result.mention_rate.percentage, 2),
),
DimensionScoreItem(
name=scoring_data.v2_result.recommendation_rank.name,
score=round(scoring_data.v2_result.recommendation_rank.score, 2),
max_score=scoring_data.v2_result.recommendation_rank.max_score,
percentage=round(scoring_data.v2_result.recommendation_rank.percentage, 2),
),
DimensionScoreItem(
name=scoring_data.v2_result.sentiment_score.name,
score=round(scoring_data.v2_result.sentiment_score.score, 2),
max_score=scoring_data.v2_result.sentiment_score.max_score,
percentage=round(scoring_data.v2_result.sentiment_score.percentage, 2),
),
DimensionScoreItem(
name=scoring_data.v2_result.citation_quality.name,
score=round(scoring_data.v2_result.citation_quality.score, 2),
max_score=scoring_data.v2_result.citation_quality.max_score,
percentage=round(scoring_data.v2_result.citation_quality.percentage, 2),
),
DimensionScoreItem(
name=scoring_data.v2_result.competitive_position.name,
score=round(scoring_data.v2_result.competitive_position.score, 2),
max_score=scoring_data.v2_result.competitive_position.max_score,
percentage=round(scoring_data.v2_result.competitive_position.percentage, 2),
),
]
ahead_count = scoring_data.competitor_data.get("ahead_count", 0)
behind_count = scoring_data.competitor_data.get("behind_count", 0)
platform_scores_dict = scoring_data.platform_scores
competitor_scores_dict = await scoring_data_service.get_competitor_platform_scores(
db, current_user.id, brand_id
)
# Get competitor platform scores
competitor_scores_dict = await _get_competitor_scores_by_platform(
db, current_user.id, brand_id
)
# Get first competitor name for display
competitor_stmt = select(Competitor).where(
Competitor.brand_id == brand_id
).limit(1)
@ -424,10 +145,8 @@ async def get_dashboard_stats(
for platform, score in platform_scores_dict.items()
]
# Count monitored platforms
monitored = sum(1 for s in platform_scores_dict.values() if s > 0)
# Get recent queries (last 10)
recent_queries_stmt = (
select(QueryModel)
.where(QueryModel.user_id == current_user.id)
@ -437,7 +156,6 @@ async def get_dashboard_stats(
recent_queries_result = await db.execute(recent_queries_stmt)
recent_queries_list = list(recent_queries_result.scalars().all())
# Get citation count for each query
recent_queries = []
for query in recent_queries_list:
citation_count_stmt = select(
@ -460,7 +178,6 @@ async def get_dashboard_stats(
queried_at=query.last_queried_at or query.created_at,
))
# Health level
health_level = get_health_level(overall_score)
response = DashboardStatsResponse(
@ -477,7 +194,6 @@ async def get_dashboard_stats(
brand_name=brand_name,
)
# 将结果写入缓存TTL: 2 分钟)
await cache.set_json(
cache_key,
response.model_dump(mode="json"),

View File

@ -16,7 +16,7 @@ from app.schemas.detection_task import (
DetectionTaskUpdate,
DetectionTriggerResponse,
)
from app.services.detection_scheduler import DetectionSchedulerService, TaskNotFoundError
from app.services.detection.detection_scheduler import DetectionSchedulerService, TaskNotFoundError
logger = logging.getLogger(__name__)

View File

@ -1,6 +1,8 @@
"""诊断API端点 - 提供SEO和GEO诊断功能"""
import asyncio
import logging
import uuid
from datetime import UTC, datetime
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, status
@ -11,13 +13,32 @@ from app.api.deps import get_current_user
from app.database import get_db
from app.models.user import User
from app.models.brand import Brand
from app.services.seo_diagnosis import SEODiagnosisService
from app.services.geo_diagnosis import GEODiagnosisService, GEODiagnosisInput
from app.models.diagnosis_record import DiagnosisRecord
from app.schemas.diagnosis import (
GEODiagnosisHistoryItem,
GEODiagnosisHistoryResponse,
GEODiagnosisResponse,
GEODiagnosisResultResponse,
GEODiagnosisTaskResponse,
GEODiagnosisTriggerRequest,
)
from app.services.diagnosis.data_collector import DataCollectorService
from app.services.diagnosis.seo_diagnosis import SEODiagnosisService
from app.services.diagnosis.geo_diagnosis import GEODiagnosisService, GEODiagnosisInput
from app.utils.health import get_health_level_label
logger = logging.getLogger(__name__)
router = APIRouter()
_FREE_TIER_DIMENSIONS = {"内容可提取性", "E-E-A-T信号", "引用就绪度"}
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
if isinstance(value, uuid.UUID):
return value
return uuid.UUID(str(value))
@router.get("/seo/{brand_id}")
async def get_seo_diagnosis(
@ -25,24 +46,14 @@ async def get_seo_diagnosis(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""
获取品牌的SEO诊断结果
返回5维度SEO诊断:
- 技术SEO (25)
- 页面SEO (20)
- 内容质量 (20)
- 外链分析 (15)
- 用户体验 (20)
"""
brand = await _get_brand_or_404(brand_id, current_user, db)
try:
service = SEODiagnosisService()
result = service.diagnose()
logger.info(f"SEO诊断完成: brand_id={brand_id}, brand={brand.name}, score={result.overall_score}")
return result.to_dict()
except Exception as e:
logger.error(f"SEO诊断失败: brand_id={brand_id}, error={str(e)}", exc_info=True)
@ -52,39 +63,158 @@ async def get_seo_diagnosis(
)
@router.get("/geo/{brand_id}")
async def get_geo_diagnosis(
@router.post("/geo/{brand_id}", status_code=status.HTTP_202_ACCEPTED)
async def trigger_geo_diagnosis(
brand_id: uuid.UUID,
body: GEODiagnosisTriggerRequest = GEODiagnosisTriggerRequest(),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""
获取品牌的GEO诊断结果
返回6维度GEO诊断:
- 内容可提取性 (20)
- 实体清晰度 (15)
- E-E-A-T信号 (20)
- Schema标记 (15)
- 主题权威 (15)
- 引用就绪度 (15)
"""
brand = await _get_brand_or_404(brand_id, current_user, db)
try:
input_data = GEODiagnosisInput()
service = GEODiagnosisService()
result = service.diagnose(input_data)
logger.info(f"GEO诊断完成: brand_id={brand_id}, brand={brand.name}, score={result.overall_score}")
return result.to_dict()
except Exception as e:
logger.error(f"GEO诊断失败: brand_id={brand_id}, error={str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="GEO诊断服务异常,请稍后重试",
if not body.force_refresh:
existing = await _find_recent_completed(db, brand_id, hours=24)
if existing:
return GEODiagnosisTaskResponse(
task_id=str(existing.id),
brand_id=str(brand_id),
status="completed",
)
record = DiagnosisRecord(
brand_id=brand_id,
user_id=_to_uuid(current_user.id),
diagnosis_type="geo",
status="pending",
)
db.add(record)
await db.commit()
await db.refresh(record)
asyncio.create_task(
_run_geo_diagnosis(
record_id=record.id,
brand_id=brand_id,
brand_name=brand.name,
brand_aliases=brand.aliases or [],
website=brand.website,
industry=brand.industry,
user_id=_to_uuid(current_user.id),
)
)
return GEODiagnosisTaskResponse(
task_id=str(record.id),
brand_id=str(brand_id),
status="pending",
)
@router.get("/geo/{brand_id}/result")
async def get_geo_diagnosis_result(
brand_id: uuid.UUID,
task_id: str | None = None,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
await _get_brand_or_404(brand_id, current_user, db)
if task_id:
try:
tid = uuid.UUID(task_id)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="无效的task_id",
)
stmt = select(DiagnosisRecord).where(
DiagnosisRecord.id == tid,
DiagnosisRecord.brand_id == brand_id,
)
else:
stmt = (
select(DiagnosisRecord)
.where(
DiagnosisRecord.brand_id == brand_id,
DiagnosisRecord.diagnosis_type == "geo",
)
.order_by(DiagnosisRecord.created_at.desc())
.limit(1)
)
result = await db.execute(stmt)
record = result.scalar_one_or_none()
if not record:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="诊断记录不存在",
)
if record.status == "pending" or record.status == "running":
return GEODiagnosisResultResponse(
task_id=str(record.id),
brand_id=str(brand_id),
status=record.status,
)
if record.status == "failed":
return GEODiagnosisResultResponse(
task_id=str(record.id),
brand_id=str(brand_id),
status="failed",
error=record.error_message,
)
user_plan = getattr(current_user, "plan", None) or "free"
is_paid = user_plan not in ("free", None)
diagnosis_resp = _build_diagnosis_response(record.result_json, is_paid)
return GEODiagnosisResultResponse(
task_id=str(record.id),
brand_id=str(brand_id),
status="completed",
result=diagnosis_resp,
)
@router.get("/geo/{brand_id}/history")
async def get_geo_diagnosis_history(
brand_id: uuid.UUID,
limit: int = 10,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
await _get_brand_or_404(brand_id, current_user, db)
stmt = (
select(DiagnosisRecord)
.where(
DiagnosisRecord.brand_id == brand_id,
DiagnosisRecord.diagnosis_type == "geo",
DiagnosisRecord.status == "completed",
)
.order_by(DiagnosisRecord.created_at.desc())
.limit(limit)
)
result = await db.execute(stmt)
records = result.scalars().all()
items = [
GEODiagnosisHistoryItem(
task_id=str(r.id),
overall_score=r.overall_score or 0,
health_level=r.result_json.get("health_level", "danger") if r.result_json else "danger",
created_at=r.created_at.isoformat(),
completed_at=r.completed_at.isoformat() if r.completed_at else None,
)
for r in records
]
return GEODiagnosisHistoryResponse(
brand_id=str(brand_id),
history=items,
)
@router.get("/combined/{brand_id}")
@ -93,29 +223,32 @@ async def get_combined_diagnosis(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""
获取品牌的综合诊断结果
结合SEO和GEO诊断,返回综合评分和详细诊断结果
"""
brand = await _get_brand_or_404(brand_id, current_user, db)
try:
seo_service = SEODiagnosisService()
seo_result = seo_service.diagnose()
collector = DataCollectorService(db)
collection = await collector.collect(
brand_name=brand.name,
brand_aliases=brand.aliases or [],
website=brand.website,
industry=brand.industry,
)
geo_service = GEODiagnosisService()
geo_result = geo_service.diagnose(GEODiagnosisInput())
geo_result = geo_service.diagnose(collection.diagnosis_input)
combined_score = round((seo_result.overall_score + geo_result.overall_score) / 2, 2)
logger.info(
f"综合诊断完成: brand_id={brand_id}, brand={brand.name}, "
f"seo_score={seo_result.overall_score}, "
f"geo_score={geo_result.overall_score}, "
f"combined_score={combined_score}"
)
return {
"seo_score": seo_result.overall_score,
"geo_score": geo_result.overall_score,
@ -131,20 +264,131 @@ async def get_combined_diagnosis(
)
async def _run_geo_diagnosis(
record_id: uuid.UUID,
brand_id: uuid.UUID,
brand_name: str,
brand_aliases: list[str],
website: str | None,
industry: str | None,
user_id: uuid.UUID,
) -> None:
from app.database import AsyncSessionLocal
async with AsyncSessionLocal() as db:
try:
stmt = select(DiagnosisRecord).where(DiagnosisRecord.id == record_id)
result = await db.execute(stmt)
record = result.scalar_one_or_none()
if not record:
return
record.status = "running"
await db.commit()
collector = DataCollectorService(db)
collection = await collector.collect(
brand_name=brand_name,
brand_aliases=brand_aliases,
website=website,
industry=industry,
)
geo_service = GEODiagnosisService()
diagnosis_result = geo_service.diagnose(collection.diagnosis_input)
record.status = "completed"
record.overall_score = diagnosis_result.overall_score
record.result_json = diagnosis_result.to_dict()
record.completed_at = datetime.now(UTC)
record.collection_metadata = collection.metadata
if collection.errors:
record.collection_metadata["errors"] = collection.errors
await db.commit()
logger.info(
f"GEO诊断完成: brand_id={brand_id}, brand={brand_name}, "
f"score={diagnosis_result.overall_score}"
)
except Exception as e:
logger.error(f"GEO诊断任务失败: record_id={record_id}, error={e}", exc_info=True)
try:
stmt = select(DiagnosisRecord).where(DiagnosisRecord.id == record_id)
result = await db.execute(stmt)
record = result.scalar_one_or_none()
if record:
record.status = "failed"
record.error_message = str(e)
await db.commit()
except Exception:
logger.error(f"更新失败状态也失败: record_id={record_id}")
def _build_diagnosis_response(result_json: dict | None, is_paid: bool) -> GEODiagnosisResponse:
if not result_json:
return GEODiagnosisResponse(
overall_score=0,
health_level="danger",
health_level_label=get_health_level_label("danger"),
dimensions=[],
recommendations=[],
is_full_report=is_paid,
)
dimensions = result_json.get("dimensions", [])
if not is_paid:
dimensions = [d for d in dimensions if d.get("name") in _FREE_TIER_DIMENSIONS]
recommendations = result_json.get("recommendations", [])
if not is_paid:
recommendations = [r for r in recommendations if r.get("priority") == "P0"]
return GEODiagnosisResponse(
overall_score=result_json.get("overall_score", 0),
health_level=result_json.get("health_level", "danger"),
health_level_label=result_json.get("health_level_label", get_health_level_label("danger")),
dimensions=dimensions,
recommendations=recommendations,
is_full_report=is_paid,
)
async def _find_recent_completed(
db: AsyncSession, brand_id: uuid.UUID, hours: int = 24
) -> DiagnosisRecord | None:
from datetime import timedelta
cutoff = datetime.now(UTC) - timedelta(hours=hours)
stmt = (
select(DiagnosisRecord)
.where(
DiagnosisRecord.brand_id == brand_id,
DiagnosisRecord.status == "completed",
DiagnosisRecord.completed_at >= cutoff,
)
.order_by(DiagnosisRecord.completed_at.desc())
.limit(1)
)
result = await db.execute(stmt)
return result.scalar_one_or_none()
async def _get_brand_or_404(
brand_id: uuid.UUID,
current_user: User,
db: AsyncSession,
) -> Brand:
"""获取品牌或抛出404异常"""
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == current_user.id)
user_uuid = _to_uuid(current_user.id)
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == user_uuid)
result = await db.execute(stmt)
brand = result.scalar_one_or_none()
if not brand:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="品牌不存在",
)
return brand

View File

@ -4,6 +4,7 @@ import uuid
from datetime import datetime
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
@ -25,7 +26,9 @@ from app.schemas.distribution import (
)
from app.services.distribution.formatter import ContentFormatter
from app.services.distribution.platform_rules import PLATFORM_RULES, PlatformRuleEngine
from app.services.distribution.publish_engine import PublishEngine
from app.services.distribution.publish_strategy import PublishStrategyService
from app.services.distribution.publishers.base import PublishResult
router = APIRouter()
@ -235,3 +238,72 @@ async def create_schedule(
tips=tips,
created_at=schedule.created_at.strftime("%Y-%m-%d %H:%M:%S"),
)
class PublishRequest(BaseModel):
content_id: str = Field(min_length=1)
platforms: list[str] = Field(min_length=1)
class PublishStatusItem(BaseModel):
platform: str
status: str
article_url: str | None = None
published_at: str | None = None
class PublishStatusResponse(BaseModel):
platforms: list[PublishStatusItem]
class PublishResponse(BaseModel):
results: list[PublishResult]
_publish_engine = PublishEngine()
@router.post("/publish", response_model=PublishResponse)
async def publish_content(
req: PublishRequest,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
org_id = current_user.organization_id
if not org_id:
raise HTTPException(status_code=403, detail="用户未关联组织")
try:
results = await _publish_engine.publish_content(
content_id=req.content_id,
platforms=req.platforms,
db=db,
user_id=str(current_user.id),
org_id=str(org_id),
)
except ValueError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
return PublishResponse(results=results)
@router.get("/publish/{content_id}/status", response_model=PublishStatusResponse)
async def get_publish_status(
content_id: str,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
try:
status_list = await _publish_engine.get_publish_status(
content_id=content_id,
db=db,
)
except ValueError:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Content not found: {content_id}",
)
return PublishStatusResponse(
platforms=[PublishStatusItem(**s) for s in status_list]
)

View File

@ -0,0 +1,144 @@
import hashlib
import logging
from fastapi import APIRouter, Depends, Query
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.models.brand import Brand
from app.schemas.health_score import (
HealthScoreDimension,
HealthScoreRecommendation,
HealthScoreResponse,
)
from app.services.cache import get_cache_service
from app.services.diagnosis.data_collector import DataCollectorService
from app.services.diagnosis.geo_diagnosis import GEODiagnosisService
from app.utils.health import get_health_level, get_health_level_label
logger = logging.getLogger(__name__)
router = APIRouter()
_FREE_TIER_DIMENSIONS = {"内容可提取性", "E-E-A-T信号", "引用就绪度"}
_CACHE_TTL = 86400
def _build_default_response(brand_name: str) -> HealthScoreResponse:
return HealthScoreResponse(
brand_name=brand_name,
overall_score=0.0,
health_level="danger",
health_level_label=get_health_level_label("danger"),
dimensions=[
HealthScoreDimension(
name=d,
score=0.0,
max_score=0.0,
percentage=0.0,
status="fail",
)
for d in sorted(_FREE_TIER_DIMENSIONS)
],
recommendations=[],
is_full_report=False,
cached=False,
)
@router.get("/health-score", response_model=HealthScoreResponse)
async def get_public_health_score(
brand: str = Query(..., min_length=1),
competitors: str = Query(default=""),
db: AsyncSession = Depends(get_db),
):
cache = get_cache_service()
cache_key = f"health_score:{hashlib.md5(brand.lower().encode()).hexdigest()}"
cached_data = await cache.get_json(cache_key)
if cached_data is not None:
cached_data["cached"] = True
return cached_data
brand_name = brand.strip()
competitor_list = [c.strip() for c in competitors.split(",") if c.strip()][:3]
brand_aliases: list[str] = []
website: str | None = None
industry: str | None = None
stmt = select(Brand).where(Brand.name == brand_name).limit(1)
result = await db.execute(stmt)
brand_record = result.scalar_one_or_none()
if brand_record:
brand_aliases = brand_record.aliases or []
website = brand_record.website
industry = brand_record.industry
try:
collector = DataCollectorService(db)
collection = await collector.collect(
brand_name=brand_name,
brand_aliases=brand_aliases,
website=website,
industry=industry,
)
geo_service = GEODiagnosisService()
diagnosis_result = geo_service.diagnose(collection.diagnosis_input)
except Exception as e:
logger.error(f"健康分数据采集失败: brand={brand_name}, error={e}", exc_info=True)
default_resp = _build_default_response(brand_name)
await cache.set_json(cache_key, default_resp.model_dump(mode="json"), expire=_CACHE_TTL)
return default_resp
all_dimensions = diagnosis_result.to_dict().get("dimensions", [])
free_dimensions = [d for d in all_dimensions if d.get("name") in _FREE_TIER_DIMENSIONS]
overall_score = round(sum(d.get("score", 0.0) for d in free_dimensions), 2)
health_level = get_health_level(overall_score)
dimension_responses = [
HealthScoreDimension(
name=d.get("name", ""),
score=round(d.get("score", 0.0), 2),
max_score=d.get("max_score", 0.0),
percentage=round(d.get("percentage", 0.0), 2),
status=d.get("status", "fail"),
)
for d in free_dimensions
]
all_recommendations = diagnosis_result.to_dict().get("recommendations", [])
free_recommendations = [
r for r in all_recommendations
if r.get("priority") == "P0" and r.get("dimension") in _FREE_TIER_DIMENSIONS
]
recommendation_responses = [
HealthScoreRecommendation(
priority=r.get("priority", "P0"),
dimension=r.get("dimension", ""),
title=r.get("title", ""),
description=r.get("description", ""),
)
for r in free_recommendations
]
response = HealthScoreResponse(
brand_name=brand_name,
overall_score=overall_score,
health_level=health_level,
health_level_label=get_health_level_label(health_level),
dimensions=dimension_responses,
recommendations=recommendation_responses,
is_full_report=False,
cached=False,
)
await cache.set_json(cache_key, response.model_dump(mode="json"), expire=_CACHE_TTL)
return response

View File

@ -0,0 +1,207 @@
import uuid
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.database import get_db
from app.models.user import User
from app.models.brand import Brand
from app.models.monitoring import MonitoringRecord
from app.schemas.monitoring import (
MonitoringRecordCreate,
MonitoringRecordResponse,
MonitoringRecordList,
MonitoringChangeReport,
MonitoringStatusUpdate,
)
from app.services.monitoring.monitor_service import MonitorService
router = APIRouter()
async def _get_brand_with_access(
brand_id: uuid.UUID,
db: AsyncSession,
current_user: User,
) -> Brand:
stmt = select(Brand).where(
Brand.id == brand_id,
Brand.user_id == current_user.id,
)
result = await db.execute(stmt)
brand = result.scalar_one_or_none()
if not brand:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="品牌不存在",
)
return brand
def _record_to_response(record: MonitoringRecord) -> MonitoringRecordResponse:
return MonitoringRecordResponse(
id=record.id,
brand_id=record.brand_id,
content_id=record.content_id,
query_keywords=record.query_keywords,
platform=record.platform,
baseline_citation_count=record.baseline_citation_count,
baseline_sentiment=record.baseline_sentiment,
baseline_rank=record.baseline_rank,
current_citation_count=record.current_citation_count,
current_sentiment=record.current_sentiment,
current_rank=record.current_rank,
change_type=record.change_type,
change_details=record.change_details,
check_interval_hours=record.check_interval_hours,
last_checked_at=record.last_checked_at,
next_check_at=record.next_check_at,
status=record.status,
created_at=record.created_at,
updated_at=record.updated_at,
)
@router.post("/tasks", response_model=MonitoringRecordResponse)
async def create_monitoring_task(
request: MonitoringRecordCreate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
await _get_brand_with_access(request.brand_id, db, current_user)
service = MonitorService()
record = await service.create_monitoring_record(
db=db,
brand_id=request.brand_id,
content_id=request.content_id,
query_keywords=request.query_keywords,
platform=request.platform,
check_interval_hours=request.check_interval_hours,
)
return _record_to_response(record)
@router.get("/brand/{brand_id}", response_model=MonitoringRecordList)
async def get_brand_monitoring(
brand_id: uuid.UUID,
skip: int = Query(0, ge=0),
limit: int = Query(20, ge=1, le=100),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
await _get_brand_with_access(brand_id, db, current_user)
service = MonitorService()
records, total = await service.get_brand_monitoring(
db=db,
brand_id=brand_id,
skip=skip,
limit=limit,
)
return MonitoringRecordList(
records=[_record_to_response(r) for r in records],
total=total,
)
@router.get("/{record_id}/report", response_model=MonitoringChangeReport)
async def get_change_report(
record_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
stmt = select(MonitoringRecord).where(MonitoringRecord.id == record_id)
result = await db.execute(stmt)
record = result.scalar_one_or_none()
if not record:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="监测记录不存在",
)
await _get_brand_with_access(record.brand_id, db, current_user)
service = MonitorService()
report = await service.generate_change_report(db, record_id)
if not report:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="变化报告不存在",
)
return MonitoringChangeReport(
monitoring_record_id=uuid.UUID(report["monitoring_record_id"]),
brand_id=uuid.UUID(report["brand_id"]),
change_type=report["change_type"],
change_details=report["change_details"],
baseline=report["baseline"],
current=report["current"],
recommendations=report["recommendations"],
)
@router.put("/{record_id}/status", response_model=MonitoringRecordResponse)
async def update_monitoring_status(
record_id: uuid.UUID,
status_update: MonitoringStatusUpdate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
valid_statuses = {"active", "paused", "completed"}
if status_update.status not in valid_statuses:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"无效的状态值,支持: {', '.join(valid_statuses)}",
)
stmt = select(MonitoringRecord).where(MonitoringRecord.id == record_id)
result = await db.execute(stmt)
record = result.scalar_one_or_none()
if not record:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="监测记录不存在",
)
await _get_brand_with_access(record.brand_id, db, current_user)
record.status = status_update.status
await db.commit()
await db.refresh(record)
return _record_to_response(record)
@router.post("/{record_id}/check", response_model=MonitoringRecordResponse)
async def trigger_manual_check(
record_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
stmt = select(MonitoringRecord).where(MonitoringRecord.id == record_id)
result = await db.execute(stmt)
record = result.scalar_one_or_none()
if not record:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="监测记录不存在",
)
await _get_brand_with_access(record.brand_id, db, current_user)
service = MonitorService()
updated_record = await service.check_and_compare(db, record_id)
if not updated_record:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="检测执行失败",
)
return _record_to_response(updated_record)

View File

@ -20,15 +20,24 @@ from app.database import get_db
from app.models.user import User
from app.models.brand import Brand
from app.models.competitor import Competitor
from app.models.citation_record import CitationRecord
from app.models.query import Query as QueryModel
from app.services.scoring_service import ScoringService
from app.schemas.brand import BrandCreate, BrandResponse
from app.services.diagnosis.data_collector import DataCollectorService
from app.services.diagnosis.geo_diagnosis import GEODiagnosisService
from app.utils.health import get_health_level, get_health_level_label
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/onboarding", tags=["onboarding"])
_FREE_TIER_DIMENSIONS = {"内容可提取性", "E-E-A-T信号", "引用就绪度"}
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
if isinstance(value, uuid.UUID):
return value
return uuid.UUID(str(value))
# ------------------------------------------------------------------
# Request / Response schemas
@ -60,23 +69,43 @@ class CompetitorRecommendationSimpleResponse(BaseModel):
recommendations: list[CompetitorRecommendationSimple]
class HealthDimensionItem(BaseModel):
name: str
score: float
max_score: float
percentage: float
status: str
class HealthRecommendationItem(BaseModel):
priority: str
dimension: str
title: str
description: str
class HealthReportResponse(BaseModel):
"""初始健康评分报告"""
brand_id: str
brand_name: str
overall_score: float
health_level: str
health_level_label: str
platform_scores: dict
strengths: list[str]
weaknesses: list[str]
competitor_scores: list[dict]
dimensions: list[HealthDimensionItem] = []
recommendations: list[HealthRecommendationItem] = []
is_full_report: bool = False
class ActionSuggestion(BaseModel):
"""行动建议项"""
title: str
description: str
priority: str # high / medium / low
action_type: str # e.g. coverage, keyword, sentiment, platform
priority: str
action_type: str
is_paid_action: bool = False
action_button_text: str = ""
class ActionSuggestionsResponse(BaseModel):
@ -105,7 +134,7 @@ async def get_onboarding_status(
- completed=True brand_id 有值 已完成
- completed=False, current_step=1 需要创建品牌
"""
stmt = select(Brand).where(Brand.user_id == current_user.id)
stmt = select(Brand).where(Brand.user_id == _to_uuid(current_user.id))
result = await db.execute(stmt)
brand = result.scalar_one_or_none()
@ -144,7 +173,7 @@ async def create_onboarding_brand(
)
brand = Brand(
user_id=current_user.id,
user_id=_to_uuid(current_user.id),
name=full_brand_data.name,
aliases=full_brand_data.aliases,
website=full_brand_data.website,
@ -156,6 +185,38 @@ async def create_onboarding_brand(
await db.commit()
await db.refresh(brand)
# 自动创建默认查询词(检查 max_queries 限制)
try:
current_query_count_stmt = select(func.count()).select_from(QueryModel).where(
QueryModel.user_id == current_user.id
)
current_query_count_result = await db.execute(current_query_count_stmt)
current_query_count = current_query_count_result.scalar_one()
max_queries = getattr(current_user, "max_queries", 3) # 默认免费版 3 个
if current_query_count < max_queries:
default_query = QueryModel(
user_id=current_user.id,
keyword=f"{brand.name} 推荐",
target_brand=brand.name,
brand_aliases=brand.aliases or [],
platforms=brand.platforms or ["wenxin", "kimi"],
frequency=brand.frequency or "weekly",
status="active",
)
db.add(default_query)
await db.commit()
await db.refresh(default_query)
logger.info(f"Auto-created default query for brand '{brand.name}'")
else:
logger.info(
f"Skipped auto-creating default query for brand '{brand.name}': "
f"query limit reached ({current_query_count}/{max_queries})"
)
except Exception as e:
logger.warning(f"Failed to auto-create default query for brand '{brand.name}': {e}")
return brand
@ -171,8 +232,7 @@ async def get_onboarding_competitor_recommendations(
复用 brands/competitors 中的推荐逻辑
支持 LLM 智能推荐和规则推荐两种模式
"""
# 验证品牌归属
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == current_user.id)
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == _to_uuid(current_user.id))
result = await db.execute(stmt)
brand = result.scalar_one_or_none()
@ -182,7 +242,6 @@ async def get_onboarding_competitor_recommendations(
detail="品牌不存在",
)
# 获取已有竞品名称(排除)
existing_stmt = select(Competitor.name).where(Competitor.brand_id == brand_id)
existing_result = await db.execute(existing_stmt)
existing_names = [row[0] for row in existing_result.all()]
@ -228,14 +287,8 @@ async def get_onboarding_health_report(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""
获取品牌初始健康评分报告
基于 citation_records 表统计品牌的引用数据
如果没有引用数据则返回初始化状态overall_score: 0
"""
# 验证品牌归属
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == current_user.id)
user_uuid = _to_uuid(current_user.id)
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == user_uuid)
result = await db.execute(stmt)
brand = result.scalar_one_or_none()
@ -245,7 +298,55 @@ async def get_onboarding_health_report(
detail="品牌不存在",
)
# 查询与品牌关联的 queries
user_plan = getattr(current_user, "plan", None) or "free"
is_paid = user_plan not in ("free", None)
try:
collector = DataCollectorService(db)
collection = await collector.collect(
brand_name=brand.name,
brand_aliases=brand.aliases or [],
website=brand.website,
industry=brand.industry,
)
geo_service = GEODiagnosisService()
diagnosis_result = geo_service.diagnose(collection.diagnosis_input)
except Exception as e:
logger.error(f"Onboarding健康报告采集失败: brand_id={brand_id}, error={e}", exc_info=True)
return HealthReportResponse(
brand_id=str(brand.id),
brand_name=brand.name,
overall_score=0.0,
health_level="danger",
health_level_label=get_health_level_label("danger"),
platform_scores={},
strengths=["品牌已创建,等待数据采集"],
weaknesses=["数据采集失败,请稍后重试"],
competitor_scores=[],
dimensions=[
HealthDimensionItem(name=d, score=0.0, max_score=0.0, percentage=0.0, status="fail")
for d in sorted(_FREE_TIER_DIMENSIONS)
],
recommendations=[],
is_full_report=is_paid,
)
result_dict = diagnosis_result.to_dict()
all_dimensions = result_dict.get("dimensions", [])
all_recommendations = result_dict.get("recommendations", [])
if is_paid:
filtered_dimensions = all_dimensions
filtered_recommendations = all_recommendations
else:
filtered_dimensions = [d for d in all_dimensions if d.get("name") in _FREE_TIER_DIMENSIONS]
filtered_recommendations = [r for r in all_recommendations if r.get("priority") == "P0"]
overall_score = round(diagnosis_result.overall_score, 2)
health_level = get_health_level(overall_score)
platform_scores: dict[str, float] = {}
queries_stmt = select(QueryModel).where(
QueryModel.user_id == current_user.id,
QueryModel.target_brand == brand.name,
@ -253,138 +354,89 @@ async def get_onboarding_health_report(
queries_result = await db.execute(queries_stmt)
queries = list(queries_result.scalars().all())
# 没有查询数据 → 返回初始化状态
if not queries:
return HealthReportResponse(
brand_id=str(brand.id),
brand_name=brand.name,
overall_score=0.0,
platform_scores={},
strengths=["品牌已创建,等待数据采集"],
weaknesses=["尚无AI平台引用数据需等待查询执行"],
competitor_scores=[],
if queries:
from app.models.citation_record import CitationRecord
query_ids = [q.id for q in queries]
citations_stmt = select(CitationRecord).where(
CitationRecord.query_id.in_(query_ids),
)
citations_result = await db.execute(citations_stmt)
citations = list(citations_result.scalars().all())
query_ids = [q.id for q in queries]
platforms_seen: dict[str, dict] = {}
for c in citations:
p = c.platform or "unknown"
if p not in platforms_seen:
platforms_seen[p] = {"total": 0, "cited": 0}
platforms_seen[p]["total"] += 1
if c.cited:
platforms_seen[p]["cited"] += 1
# 获取引用记录
citations_stmt = select(CitationRecord).where(
CitationRecord.query_id.in_(query_ids),
)
citations_result = await db.execute(citations_stmt)
citations = list(citations_result.scalars().all())
for p, data in platforms_seen.items():
rate = (data["cited"] / data["total"] * 100) if data["total"] > 0 else 0.0
platform_scores[p] = round(rate, 2)
total = len(citations)
cited = [c for c in citations if c.cited]
# 计算各平台评分
platform_scores: dict[str, float] = {}
platforms_seen: dict[str, dict] = {} # {platform: {total, cited}}
for c in citations:
p = c.platform or "unknown"
if p not in platforms_seen:
platforms_seen[p] = {"total": 0, "cited": 0}
platforms_seen[p]["total"] += 1
if c.cited:
platforms_seen[p]["cited"] += 1
for p, data in platforms_seen.items():
rate = (data["cited"] / data["total"] * 100) if data["total"] > 0 else 0.0
platform_scores[p] = round(rate, 2)
# 使用 ScoringService 计算 overall_score
scoring_service = ScoringService()
sentiment_counts = {"positive": 0, "neutral": 0, "negative": 0}
for c in cited:
sentiment = c.sentiment or "neutral"
if sentiment in sentiment_counts:
sentiment_counts[sentiment] += 1
from app.schemas.scoring import CitationResult
citation_results = [
CitationResult(
cited=c.cited,
position=c.citation_position,
citation_text=c.citation_text,
sentiment=c.sentiment or "neutral",
confidence=c.confidence or 0.0,
)
for c in cited
]
positions = [c.citation_position for c in cited if c.cited]
# 获取竞品信息
competitor_stmt = select(Competitor).where(Competitor.brand_id == brand_id)
competitor_result = await db.execute(competitor_stmt)
competitors = list(competitor_result.scalars().all())
competitor_names = [c.name for c in competitors]
competitor_mentions: dict[str, int] = {}
for comp_name in competitor_names:
count = sum(
1 for c in citations
if c.cited and c.competitor_brands and comp_name in c.competitor_brands
)
if count > 0:
competitor_mentions[comp_name] = count
v2_result = scoring_service.calculate_v2(
mentioned_count=len(cited),
total_queries=total,
positions=positions,
sentiment_counts=sentiment_counts,
citations=citation_results,
brand_mentions=len(cited),
competitor_mentions=competitor_mentions,
)
# 生成 strengths/weaknesses
strengths = []
weaknesses = []
if total == 0:
strengths.append("品牌已创建")
weaknesses.append("尚无引用数据")
else:
mention_rate = len(cited) / total * 100 if total > 0 else 0
if mention_rate >= 50:
strengths.append(f"提及率较高 ({round(mention_rate, 1)}%)")
else:
weaknesses.append(f"提及率偏低 ({round(mention_rate, 1)}%)")
for d in filtered_dimensions:
pct = d.get("percentage", 0)
if pct >= 60:
strengths.append(f"{d.get('name', '')}表现良好 ({round(pct, 1)}%)")
elif pct > 0:
weaknesses.append(f"{d.get('name', '')}有待提升 ({round(pct, 1)}%)")
for p, score in platform_scores.items():
if score >= 60:
strengths.append(f"{p} 平台表现良好 ({score}%)")
elif score > 0:
weaknesses.append(f"{p} 平台覆盖率不足 ({score}%)")
if not strengths:
strengths.append("已有初步诊断数据")
if not weaknesses:
weaknesses.append("暂无明显短板")
if sentiment_counts["positive"] > sentiment_counts["negative"]:
strengths.append("情感倾向正面")
elif sentiment_counts["negative"] > sentiment_counts["positive"]:
weaknesses.append("情感倾向偏负面")
competitor_stmt = select(Competitor).where(Competitor.brand_id == brand_id)
competitor_result = await db.execute(competitor_stmt)
competitors = list(competitor_result.scalars().all())
if not strengths:
strengths.append("已有初步引用数据")
if not weaknesses:
weaknesses.append("暂无明显短板")
# 竞品评分
competitor_scores = []
for comp_name, mentions in competitor_mentions.items():
comp_score = round(mentions / total * 100, 2) if total > 0 else 0.0
for comp in competitors:
competitor_scores.append({
"name": comp_name,
"score": comp_score,
"name": comp.name,
"score": 0.0,
"is_leading": False,
})
dimension_items = [
HealthDimensionItem(
name=d.get("name", ""),
score=round(d.get("score", 0.0), 2),
max_score=d.get("max_score", 0.0),
percentage=round(d.get("percentage", 0.0), 2),
status=d.get("status", "fail"),
)
for d in filtered_dimensions
]
recommendation_items = [
HealthRecommendationItem(
priority=r.get("priority", "P0"),
dimension=r.get("dimension", ""),
title=r.get("title", ""),
description=r.get("description", ""),
)
for r in filtered_recommendations
]
return HealthReportResponse(
brand_id=str(brand.id),
brand_name=brand.name,
overall_score=round(v2_result.overall_score, 2),
overall_score=overall_score,
health_level=health_level,
health_level_label=get_health_level_label(health_level),
platform_scores=platform_scores,
strengths=strengths,
weaknesses=weaknesses,
competitor_scores=competitor_scores,
dimensions=dimension_items,
recommendations=recommendation_items,
is_full_report=is_paid,
)
@ -394,21 +446,21 @@ async def get_onboarding_action_suggestions(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""
根据健康报告生成行动建议基于规则引擎不需要 LLM
"""
# 先获取健康报告数据(复用逻辑)
report = await get_onboarding_health_report(brand_id, current_user, db)
user_plan = getattr(current_user, "plan", None) or "free"
is_paid = user_plan not in ("free", None)
suggestions = []
# 规则引擎:基于评分和平台数据生成建议
if report.overall_score < 20:
suggestions.append(ActionSuggestion(
title="提升 AI 平台覆盖率",
description=f"当前综合评分仅 {report.overall_score}品牌在AI搜索中几乎未被提及。建议增加查询词覆盖面让AI平台更频繁地引用品牌。",
priority="high",
action_type="coverage",
is_paid_action=False,
action_button_text="设置查询词",
))
if report.overall_score < 50:
@ -417,9 +469,10 @@ async def get_onboarding_action_suggestions(
description="品牌在关键查询词下的提及率偏低,建议调整查询关键词策略,聚焦行业核心术语。",
priority="high",
action_type="keyword",
is_paid_action=False,
action_button_text="优化关键词",
))
# 平台维度建议
for platform, score in report.platform_scores.items():
if score < 30:
suggestions.append(ActionSuggestion(
@ -427,28 +480,32 @@ async def get_onboarding_action_suggestions(
description=f"品牌在 {platform} 平台的引用率仅为 {score}%,需要针对性优化该平台的内容策略。",
priority="medium",
action_type="platform",
is_paid_action=False,
action_button_text=f"优化{platform}",
))
# 情感维度建议
if "情感倾向偏负面" in report.weaknesses:
suggestions.append(ActionSuggestion(
title="改善品牌情感倾向",
description="AI平台对品牌的情感评价偏负面建议发布正面品牌内容、优化品牌描述以改善情感得分。",
priority="medium",
action_type="sentiment",
))
for dim in report.dimensions:
if dim.percentage < 40 and dim.name not in _FREE_TIER_DIMENSIONS:
suggestions.append(ActionSuggestion(
title=f"提升{dim.name}得分",
description=f"{dim.name}当前得分仅 {round(dim.percentage, 1)}%,解锁详细诊断和优化方案。",
priority="medium",
action_type="dimension",
is_paid_action=True,
action_button_text="升级解锁",
))
# 竞品对比建议
for comp in report.competitor_scores:
if comp["score"] > report.overall_score:
if comp.get("score", 0) > report.overall_score:
suggestions.append(ActionSuggestion(
title=f"应对竞品 {comp['name']} 威胁",
description=f"竞品 {comp['name']} 评分 ({comp['score']}) 高于本品牌 ({report.overall_score}),建议分析竞品优势领域并制定差异化策略。",
priority="high",
action_type="competitive",
is_paid_action=True,
action_button_text="查看竞品分析",
))
# 如果没有引用数据,给出基础建议
if report.overall_score == 0:
suggestions = [
ActionSuggestion(
@ -456,28 +513,45 @@ async def get_onboarding_action_suggestions(
description="品牌尚无查询数据,建议首先设置与品牌最相关的核心查询词,让系统开始数据采集。",
priority="high",
action_type="keyword",
is_paid_action=False,
action_button_text="设置查询词",
),
ActionSuggestion(
title="添加竞品对比",
description="添加主要竞品以便进行对比分析,了解品牌在市场中的定位。",
priority="medium",
action_type="coverage",
is_paid_action=False,
action_button_text="添加竞品",
),
ActionSuggestion(
title="完善品牌信息",
description="补充品牌别名、网站、行业等详细信息有助于提升AI平台识别率",
title="解锁完整6维度诊断",
description="免费版仅展示3个核心维度升级后可查看完整6维度诊断报告和深度优化方案",
priority="medium",
action_type="brand_info",
action_type="upgrade",
is_paid_action=True,
action_button_text="升级Pro",
),
]
# 确保至少有1条建议
if not suggestions:
suggestions.append(ActionSuggestion(
title="持续监测品牌表现",
description="品牌表现良好,建议持续监测并保持当前策略。",
priority="low",
action_type="monitor",
is_paid_action=False,
action_button_text="查看Dashboard",
))
if not is_paid:
suggestions.append(ActionSuggestion(
title="升级Pro解锁完整诊断",
description="免费版仅展示3个核心维度和P0建议。升级Pro可获取完整6维度诊断、深度竞品分析和AI优化方案。",
priority="low",
action_type="upgrade",
is_paid_action=True,
action_button_text="升级Pro",
))
return ActionSuggestionsResponse(suggestions=suggestions)
@ -496,8 +570,7 @@ async def complete_onboarding(
User 模型当前没有 onboarding_completed 专用字段
品牌的创建即代表 onboarding 完成
"""
# 验证品牌归属
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == current_user.id)
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == _to_uuid(current_user.id))
result = await db.execute(stmt)
brand = result.scalar_one_or_none()
@ -507,8 +580,6 @@ async def complete_onboarding(
detail="品牌不存在",
)
# 品牌已创建即代表 onboarding 完成,无需额外字段更新
# 后续如需专用字段,可通过 alembic 迁移添加 user.onboarding_completed
logger.info(f"User {current_user.id} completed onboarding with brand {brand_id}")
return OnboardingCompleteResponse(success=True)

274
backend/app/api/payments.py Normal file
View File

@ -0,0 +1,274 @@
import logging
import uuid
from datetime import datetime, timezone
from fastapi import APIRouter, Depends, HTTPException, Request, status
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.database import get_db
from app.models.payment_order import PaymentOrder as PaymentOrderModel
from app.models.user import User
from app.services.payment import get_payment_gateway
from app.services.subscription import PLANS, subscribe
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/payments", tags=["支付"])
class CreateOrderRequest(BaseModel):
plan: str
payment_provider: str = "wechat"
class CreateOrderResponse(BaseModel):
order_id: str
pay_url: str
amount: float
currency: str = "CNY"
status: str = "pending"
class OrderStatusResponse(BaseModel):
order_id: str
status: str
plan: str
amount: float
payment_provider: str
payment_id: str | None = None
created_at: str | None = None
paid_at: str | None = None
class RefundRequest(BaseModel):
reason: str = ""
@router.post("/orders", response_model=CreateOrderResponse, status_code=status.HTTP_201_CREATED)
async def create_payment_order(
request: CreateOrderRequest,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
plan_data = PLANS.get(request.plan)
if plan_data is None:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"无效的套餐: {request.plan}",
)
if plan_data["price"] == 0:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="免费套餐无需支付",
)
order_id = uuid.uuid4()
amount = plan_data["price"]
gateway = get_payment_gateway(request.payment_provider)
payment_order = await gateway.create_order(
order_id=str(order_id),
amount=amount,
description=f"GEO平台{plan_data['name']}订阅",
user_id=current_user.id,
plan=request.plan,
)
db_order = PaymentOrderModel(
id=order_id,
user_id=current_user.id,
plan=request.plan,
amount=amount,
payment_provider=request.payment_provider,
status="pending",
pay_url=payment_order.pay_url,
)
db.add(db_order)
await db.commit()
return CreateOrderResponse(
order_id=str(order_id),
pay_url=payment_order.pay_url,
amount=amount,
status="pending",
)
@router.post("/callback/wechat")
async def wechat_pay_callback(request: Request):
body = await request.form()
request_data = dict(body)
gateway = get_payment_gateway("wechat")
callback = await gateway.verify_callback(request_data)
return await _handle_payment_callback(request_data, callback, "wechat")
@router.post("/callback/alipay")
async def alipay_callback(request: Request):
body = await request.form()
request_data = dict(body)
gateway = get_payment_gateway("alipay")
callback = await gateway.verify_callback(request_data)
return await _handle_payment_callback(request_data, callback, "alipay")
async def _handle_payment_callback(request_data: dict, callback, provider: str):
from app.database import AsyncSessionLocal
async with AsyncSessionLocal() as db:
try:
result = await _process_callback(db, callback, provider)
await db.commit()
return result
except Exception as e:
logger.error(f"[PaymentCallback] 处理回调异常: {e}", exc_info=True)
await db.rollback()
if provider == "wechat":
return _wechat_fail_response()
return "fail"
async def _process_callback(db: AsyncSession, callback, provider: str):
stmt = select(PaymentOrderModel).where(
PaymentOrderModel.id == uuid.UUID(callback.order_id)
)
result = await db.execute(stmt)
order = result.scalar_one_or_none()
if order is None:
logger.warning(f"[PaymentCallback] 订单不存在: order_id={callback.order_id}")
if provider == "wechat":
return _wechat_fail_response()
return "fail"
if callback.status == "success":
order.status = "paid"
order.payment_id = callback.payment_id
order.callback_data = callback.raw_data
order.paid_at = datetime.now(timezone.utc)
await subscribe(db, order.user_id, order.plan)
logger.info(
f"[PaymentCallback] 支付成功: order_id={callback.order_id}, "
f"plan={order.plan}, provider={provider}"
)
else:
order.status = "failed"
order.callback_data = callback.raw_data
if provider == "wechat":
return _wechat_success_response()
return "success"
def _wechat_success_response():
from fastapi.responses import Response
return Response(content="<xml><return_code><![CDATA[SUCCESS]]></return_code></xml>", media_type="application/xml")
def _wechat_fail_response():
from fastapi.responses import Response
return Response(content="<xml><return_code><![CDATA[FAIL]]></return_code></xml>", media_type="application/xml")
@router.get("/orders/{order_id}", response_model=OrderStatusResponse)
async def query_order_status(
order_id: str,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
try:
oid = uuid.UUID(order_id)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="无效的订单ID",
)
stmt = select(PaymentOrderModel).where(PaymentOrderModel.id == oid)
result = await db.execute(stmt)
order = result.scalar_one_or_none()
if order is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="订单不存在",
)
if order.user_id != current_user.id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="无权查看此订单",
)
return OrderStatusResponse(
order_id=str(order.id),
status=order.status,
plan=order.plan,
amount=order.amount,
payment_provider=order.payment_provider,
payment_id=order.payment_id,
created_at=order.created_at.isoformat() if order.created_at else None,
paid_at=order.paid_at.isoformat() if order.paid_at else None,
)
@router.post("/refund/{order_id}")
async def refund_order(
order_id: str,
body: RefundRequest = RefundRequest(),
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
user_plan = getattr(current_user, "plan", "free") or "free"
if user_plan != "enterprise":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="仅企业管理员可执行退款操作",
)
try:
oid = uuid.UUID(order_id)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="无效的订单ID",
)
stmt = select(PaymentOrderModel).where(PaymentOrderModel.id == oid)
result = await db.execute(stmt)
order = result.scalar_one_or_none()
if order is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="订单不存在",
)
if order.status != "paid":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="仅已支付订单可退款",
)
gateway = get_payment_gateway(order.payment_provider)
success = await gateway.refund(order_id, order.amount, body.reason)
if success:
order.status = "refunded"
await db.commit()
return {"message": "退款成功", "order_id": order_id}
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="退款失败",
)

View File

@ -1,11 +1,5 @@
"""
平台健康检查API - 验证各AI平台适配器状态
端点: GET /api/platforms/health
返回: 各平台适配器配置状态和健康信息
"""
import logging
import os
from typing import Annotated
from fastapi import APIRouter, Depends
@ -18,7 +12,6 @@ router = APIRouter(prefix="/platforms", tags=["platforms"])
class PlatformHealthStatus:
"""平台健康状态"""
def __init__(
self,
@ -37,22 +30,38 @@ class PlatformHealthStatus:
self.message = message
_PLATFORM_URLS = {
"kimi": "https://kimi.moonshot.cn",
"wenxin": "https://yiyan.baidu.com",
"doubao": "https://www.doubao.com/",
}
def _check_api_key_health(
platform_name: str,
env_key_name: str,
url: str,
) -> PlatformHealthStatus:
api_key = os.getenv(env_key_name, "")
api_key_set = bool(api_key and api_key.strip())
configured = api_key_set
return PlatformHealthStatus(
name=platform_name,
url=url,
configured=configured,
api_key_set=api_key_set,
status="configured" if configured else "not_configured",
message="API Key已配置" if configured else "API Key未配置",
)
def check_kimi_health() -> PlatformHealthStatus:
"""检查Kimi平台健康状态"""
try:
from app.workers.platforms.kimi import KimiAdapter
adapter = KimiAdapter()
api_key_set = bool(adapter.api_key and adapter.api_key.strip())
configured = adapter.is_configured
return PlatformHealthStatus(
name="kimi",
url=adapter.platform_url,
configured=configured,
api_key_set=api_key_set,
status="configured" if configured else "not_configured",
message="API Key已配置" if configured else "API Key未配置",
return _check_api_key_health(
platform_name="kimi",
env_key_name="MOONSHOT_API_KEY",
url=_PLATFORM_URLS["kimi"],
)
except Exception as e:
logger.error(f"Kimi健康检查失败: {e}")
@ -65,18 +74,16 @@ def check_kimi_health() -> PlatformHealthStatus:
def check_wenxin_health() -> PlatformHealthStatus:
"""检查文心平台健康状态"""
try:
from app.workers.platforms.wenxin import WenxinAdapter
adapter = WenxinAdapter()
api_key_set = bool(adapter.api_key and adapter.api_key.strip())
secret_key_set = bool(adapter.secret_key and adapter.secret_key.strip())
configured = adapter.is_configured
api_key = os.getenv("BAIDU_QIANFAN_API_KEY", "")
secret_key = os.getenv("BAIDU_QIANFAN_SECRET_KEY", "")
api_key_set = bool(api_key and api_key.strip())
secret_key_set = bool(secret_key and secret_key.strip())
configured = api_key_set and secret_key_set
return PlatformHealthStatus(
name="wenxin",
url=adapter.platform_url,
url=_PLATFORM_URLS["wenxin"],
configured=configured,
api_key_set=api_key_set,
status="configured" if configured else "not_configured",
@ -93,21 +100,11 @@ def check_wenxin_health() -> PlatformHealthStatus:
def check_doubao_health() -> PlatformHealthStatus:
"""检查豆包平台健康状态"""
try:
from app.workers.platforms.doubao import DoubaoAdapter
adapter = DoubaoAdapter()
api_key_set = bool(adapter.api_key and adapter.api_key.strip())
configured = adapter.is_configured
return PlatformHealthStatus(
name="doubao",
url=adapter.platform_url,
configured=configured,
api_key_set=api_key_set,
status="configured" if configured else "not_configured",
message="API Key已配置" if configured else "API Key未配置",
return _check_api_key_health(
platform_name="doubao",
env_key_name="DOUBAO_API_KEY",
url=_PLATFORM_URLS["doubao"],
)
except Exception as e:
logger.error(f"豆包健康检查失败: {e}")
@ -120,7 +117,6 @@ def check_doubao_health() -> PlatformHealthStatus:
def check_all_platforms() -> dict:
"""检查所有平台健康状态"""
platforms = [
check_kimi_health(),
check_wenxin_health(),
@ -136,29 +132,12 @@ def check_all_platforms() -> dict:
@router.get("/health")
async def get_platform_health():
"""
获取所有AI平台适配器的健康状态
返回每个平台的:
- name: 平台名称
- configured: 是否已配置
- url: 平台URL
- api_key_set: API Key是否已设置
- status: 健康状态 (configured / not_configured / error)
- message: 状态消息
"""
health_info = check_all_platforms()
return health_info
@router.get("/health/{platform_name}")
async def get_platform_health_by_name(platform_name: str):
"""
获取指定平台适配器的健康状态
Args:
platform_name: 平台名称 (kimi / wenxin / doubao)
"""
if platform_name == "kimi":
result = vars(check_kimi_health())
elif platform_name == "wenxin":

View File

@ -9,7 +9,7 @@ from app.database import get_db
from app.models.user import User
from app.schemas.citation import RunNowResponse
from app.schemas.query import QueryCreate, QueryListResponse, QueryResponse, QueryUpdate
from app.services.citation import trigger_query_now
from app.services.citation.citation import trigger_query_now
from app.services.query import create_query, delete_query, get_queries, get_query, update_query
router = APIRouter()

View File

@ -1,23 +1,90 @@
import logging
import uuid
from datetime import datetime
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query, status
from fastapi.responses import StreamingResponse
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from starlette.responses import Response
from app.api.deps import get_current_user
from app.database import get_db
from app.models.brand import Brand
from app.models.user import User
from app.services.citation import export_citations_csv, export_citations_pdf
from app.schemas.scoring import CitationResult
from app.services.citation.citation import export_citations_csv, export_citations_pdf
from app.services.scoring.scoring_service import ScoringService, ScoringResultV2
logger = logging.getLogger(__name__)
router = APIRouter()
async def _compute_v2_scores(
db: AsyncSession,
user_id: uuid.UUID,
brand_id: uuid.UUID,
) -> ScoringResultV2 | None:
try:
from app.api.scoring import (
_get_citations_for_brand,
_analyze_sentiments_for_citations,
)
total_queries, brand_citations, _, competitor_mentions = (
await _get_citations_for_brand(db, user_id, brand_id)
)
if total_queries == 0:
return None
brand_stmt = select(Brand).where(
Brand.id == brand_id, Brand.user_id == user_id
)
brand_result = await db.execute(brand_stmt)
brand = brand_result.scalar_one_or_none()
if not brand:
return None
sentiment_counts = await _analyze_sentiments_for_citations(
brand_name=brand.name,
brand_citations=brand_citations,
)
citation_results = [
CitationResult(
cited=c.cited,
position=c.citation_position,
citation_text=c.citation_text,
sentiment=c.sentiment or "neutral",
confidence=c.confidence or 0.0,
)
for c in brand_citations
]
positions = [c.citation_position for c in brand_citations if c.cited]
scoring_service = ScoringService()
return scoring_service.calculate_v2(
mentioned_count=len(brand_citations),
total_queries=total_queries,
positions=positions,
sentiment_counts=sentiment_counts,
citations=citation_results,
brand_mentions=len(brand_citations),
competitor_mentions=competitor_mentions,
)
except Exception:
logger.warning("V2 scoring failed for brand %s", brand_id, exc_info=True)
return None
@router.get("/export/csv")
async def export_report(
query_id: uuid.UUID = Query(...),
brand_id: Optional[uuid.UUID] = Query(None),
format: str = Query("csv"),
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
@ -29,7 +96,13 @@ async def export_report(
)
try:
csv_content = await export_citations_csv(db, current_user.id, query_id)
v2_result = None
if brand_id is not None:
v2_result = await _compute_v2_scores(db, current_user.id, brand_id)
csv_content = await export_citations_csv(
db, current_user.id, query_id, v2_result=v2_result
)
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
@ -51,11 +124,18 @@ async def export_report(
@router.get("/export/pdf")
async def export_pdf(
query_id: Optional[uuid.UUID] = None,
brand_id: Optional[uuid.UUID] = Query(None),
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
try:
pdf_bytes = await export_citations_pdf(db, current_user.id, query_id)
v2_result = None
if brand_id is not None:
v2_result = await _compute_v2_scores(db, current_user.id, brand_id)
pdf_bytes = await export_citations_pdf(
db, current_user.id, query_id, v2_result=v2_result
)
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,

View File

@ -0,0 +1,248 @@
import uuid
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.database import get_db
from app.models.user import User
from app.models.brand import Brand
from app.models.schema_suggestion import SchemaSuggestion
from app.schemas.schema_suggestion import (
SchemaAdviseRequest,
SchemaSuggestionResponse,
SchemaSuggestionList,
SchemaValidationResult,
SchemaStatusUpdateRequest,
)
from app.services.schema.schema_advisor_service import SchemaAdvisorService
from app.services.scoring.scoring_service import ScoringService
router = APIRouter()
async def _get_brand_with_access(
brand_id: uuid.UUID,
db: AsyncSession,
current_user: User,
) -> Brand:
stmt = select(Brand).where(
Brand.id == brand_id,
Brand.user_id == current_user.id,
)
result = await db.execute(stmt)
brand = result.scalar_one_or_none()
if not brand:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="品牌不存在",
)
return brand
async def _get_brand_diagnosis_data(
db: AsyncSession,
user_id: uuid.UUID,
brand: Brand,
) -> dict:
from app.models.query import Query as QueryModel
from app.models.citation_record import CitationRecord
from app.models.competitor import Competitor
from app.schemas.scoring import CitationResult
from app.services.analysis.sentiment_service import get_sentiment_service
queries_stmt = select(QueryModel).where(
QueryModel.user_id == user_id,
QueryModel.target_brand == brand.name,
)
queries_result = await db.execute(queries_stmt)
queries = list(queries_result.scalars().all())
if not queries:
scoring_service = ScoringService()
empty_result = scoring_service.calculate_v2(
mentioned_count=0,
total_queries=0,
positions=[],
sentiment_counts={"positive": 0, "neutral": 0, "negative": 0},
citations=[],
brand_mentions=0,
competitor_mentions={},
)
return empty_result.to_dict()
query_ids = [q.id for q in queries]
citations_stmt = select(CitationRecord).where(
CitationRecord.query_id.in_(query_ids),
)
citations_result = await db.execute(citations_stmt)
all_citations = list(citations_result.scalars().all())
total_queries = len(all_citations)
brand_citations = [c for c in all_citations if c.cited]
competitor_stmt = select(Competitor).where(Competitor.brand_id == brand.id)
competitor_result = await db.execute(competitor_stmt)
competitors = list(competitor_result.scalars().all())
competitor_names = [c.name for c in competitors]
competitor_mentions: dict[str, int] = {}
for comp_name in competitor_names:
count = sum(
1 for c in all_citations
if c.cited and c.competitor_brands
and comp_name in c.competitor_brands
)
if count > 0:
competitor_mentions[comp_name] = count
sentiment_service = get_sentiment_service()
sentiment_counts = {"positive": 0, "neutral": 0, "negative": 0}
for citation in brand_citations:
if citation.sentiment and citation.sentiment in ("positive", "neutral", "negative"):
sentiment_counts[citation.sentiment] += 1
else:
content = citation.raw_response or citation.citation_text or ""
if content.strip():
try:
result = await sentiment_service.analyze(
brand_name=brand.name,
content=content,
)
sentiment_counts[result.sentiment] += 1
except Exception:
sentiment_counts["neutral"] += 1
else:
sentiment_counts["neutral"] += 1
citation_results = [
CitationResult(
cited=c.cited,
position=c.citation_position,
citation_text=c.citation_text,
sentiment="neutral",
confidence=c.confidence or 0.0,
)
for c in brand_citations
]
positions = [c.citation_position for c in brand_citations if c.cited]
scoring_service = ScoringService()
v2_result = scoring_service.calculate_v2(
mentioned_count=len(brand_citations),
total_queries=total_queries,
positions=positions,
sentiment_counts=sentiment_counts,
citations=citation_results,
brand_mentions=len(brand_citations),
competitor_mentions=competitor_mentions,
)
return v2_result.to_dict()
@router.post("/advise", response_model=SchemaSuggestionList)
async def generate_schema_advise(
request: SchemaAdviseRequest,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
brand = await _get_brand_with_access(request.brand_id, db, current_user)
diagnosis_data = await _get_brand_diagnosis_data(db, current_user.id, brand)
brand_info = {
"name": brand.name,
"website": brand.website or "",
"industry": brand.industry or "",
}
service = SchemaAdvisorService()
suggestions = await service.generate_suggestions(
db=db,
brand_id=brand.id,
diagnosis_data=diagnosis_data,
brand_info=brand_info,
target_url=request.target_url,
focus_dimensions=request.focus_dimensions,
)
return SchemaSuggestionList(
suggestions=[SchemaSuggestionResponse.model_validate(s) for s in suggestions],
total=len(suggestions),
)
@router.get("/brand/{brand_id}", response_model=SchemaSuggestionList)
async def get_brand_schema_suggestions(
brand_id: uuid.UUID,
status_filter: str | None = Query(None, alias="status", description="按状态筛选"),
schema_type: str | None = Query(None, description="按Schema类型筛选"),
skip: int = Query(0, ge=0),
limit: int = Query(20, ge=1, le=100),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
await _get_brand_with_access(brand_id, db, current_user)
service = SchemaAdvisorService()
suggestions, total = await service.get_suggestions(
db=db,
brand_id=brand_id,
status_filter=status_filter,
schema_type=schema_type,
skip=skip,
limit=limit,
)
return SchemaSuggestionList(
suggestions=[SchemaSuggestionResponse.model_validate(s) for s in suggestions],
total=total,
)
@router.get("/{suggestion_id}", response_model=SchemaSuggestionResponse)
async def get_schema_suggestion_detail(
suggestion_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
service = SchemaAdvisorService()
suggestion = await service.get_suggestion_by_id(db, suggestion_id)
if not suggestion:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="建议不存在",
)
brand = await _get_brand_with_access(suggestion.brand_id, db, current_user)
return SchemaSuggestionResponse.model_validate(suggestion)
@router.put("/{suggestion_id}/status", response_model=SchemaSuggestionResponse)
async def update_schema_suggestion_status(
suggestion_id: uuid.UUID,
status_update: SchemaStatusUpdateRequest,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
valid_statuses = {"pending", "applied", "dismissed"}
if status_update.status not in valid_statuses:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"无效的状态值,支持: {', '.join(valid_statuses)}",
)
service = SchemaAdvisorService()
suggestion = await service.get_suggestion_by_id(db, suggestion_id)
if not suggestion:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="建议不存在",
)
await _get_brand_with_access(suggestion.brand_id, db, current_user)
updated = await service.update_status(db, suggestion_id, status_update.status)
return SchemaSuggestionResponse.model_validate(updated)

View File

@ -27,8 +27,8 @@ from app.schemas.scoring import (
DimensionCompareItem,
CitationResult,
)
from app.services.scoring_service import ScoringService, get_health_level
from app.services.sentiment_service import get_sentiment_service
from app.services.scoring.scoring_service import ScoringService, get_health_level
from app.services.analysis.sentiment_service import get_sentiment_service
logger = logging.getLogger(__name__)
@ -446,7 +446,7 @@ async def get_brand_score(
# 异步触发告警检测(不影响主流程)
try:
from app.services.alert_engine import AlertEngine
from app.services.alert.alert_engine import AlertEngine
alert_engine = AlertEngine(db)
# 获取当前已有提及的平台集合

388
backend/app/api/strategy.py Normal file
View File

@ -0,0 +1,388 @@
import uuid
from datetime import datetime
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.api.deps import get_current_user
from app.database import get_db
from app.models.user import User
from app.models.brand import Brand
from app.models.geo_plan import GeoPlan, GeoPlanAction
from app.schemas.geo_plan import (
GeoPlanGenerateRequest,
GeoPlanResponse,
GeoPlanListResponse,
GeoPlanActionResponse,
GeoPlanActionUpdateStatus,
GeoPlanActionExecuteResponse,
)
from app.services.scoring.brand_scoring_data_service import get_brand_scoring_data_service
from app.services.strategy.geo_plan_generator import generate_geo_plan
from app.services.content.content_generation_service import ContentGenerationService
router = APIRouter()
async def _get_brand_with_access(
brand_id: uuid.UUID,
db: AsyncSession,
current_user: User,
) -> Brand:
stmt = select(Brand).where(
Brand.id == brand_id,
Brand.user_id == current_user.id,
)
result = await db.execute(stmt)
brand = result.scalar_one_or_none()
if not brand:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="品牌不存在",
)
return brand
async def _get_brand_scoring_data(
db: AsyncSession,
user_id: uuid.UUID,
brand: Brand,
) -> tuple:
scoring_data_service = get_brand_scoring_data_service()
scoring_data = await scoring_data_service.get_brand_scoring_data(db, user_id, brand)
return (
scoring_data.v2_result,
scoring_data.competitor_data,
scoring_data.sentiment_counts,
scoring_data.platform_scores,
scoring_data.total_queries,
scoring_data.mentioned_count,
)
def _plan_to_response(plan: GeoPlan) -> GeoPlanResponse:
actions = [
GeoPlanActionResponse(
id=action.id,
plan_id=action.plan_id,
action_type=action.action_type,
title=action.title,
description=action.description,
reason=action.reason,
priority=action.priority,
status=action.status,
target_keyword=action.target_keyword,
target_platform=action.target_platform,
content_style=action.content_style,
estimated_impact=action.estimated_impact,
difficulty=action.difficulty,
execution_params=action.execution_params,
sort_order=action.sort_order,
completed_at=action.completed_at,
created_at=action.created_at,
)
for action in sorted(plan.actions, key=lambda a: a.sort_order)
]
return GeoPlanResponse(
id=plan.id,
brand_id=plan.brand_id,
title=plan.title,
status=plan.status,
diagnosis_score=plan.diagnosis_score,
target_score=plan.target_score,
estimated_weeks=plan.estimated_weeks,
plan_data=plan.plan_data,
source=plan.source,
actions=actions,
created_at=plan.created_at,
updated_at=plan.updated_at,
)
@router.post("/generate", response_model=GeoPlanResponse)
async def generate_geo_plan_endpoint(
request: GeoPlanGenerateRequest,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
brand = await _get_brand_with_access(request.brand_id, db, current_user)
(
v2_result,
competitor_data,
sentiment_data,
platform_scores,
total_queries,
mentioned_count,
) = await _get_brand_scoring_data(db, current_user.id, brand)
target_score = request.target_score or 75
plan_data = await generate_geo_plan(
brand_name=brand.name,
scoring_result=v2_result,
target_score=target_score,
total_queries=total_queries,
platform_scores=platform_scores,
competitor_data=competitor_data,
)
from app.config import settings
source = "llm" if (settings.ENABLE_LLM and settings.DEEPSEEK_API_KEY) else "rule"
organization_id = current_user.id
org_stmt = select(func.count()).select_from(
select(1).where(True).subquery()
)
db_plan = GeoPlan(
organization_id=organization_id,
brand_id=brand.id,
title=plan_data.title,
status="draft",
diagnosis_score=int(round(v2_result.overall_score)),
target_score=target_score,
estimated_weeks=plan_data.estimated_weeks,
plan_data={
"weekly_plan": plan_data.weekly_plan,
},
source=source,
created_by=current_user.id,
)
db.add(db_plan)
await db.flush()
for idx, action_item in enumerate(plan_data.actions):
db_action = GeoPlanAction(
plan_id=db_plan.id,
action_type=action_item.action_type,
title=action_item.title,
description=action_item.description,
reason=action_item.reason,
priority=action_item.priority,
status="pending",
target_keyword=action_item.target_keyword,
target_platform=action_item.target_platform,
content_style=action_item.content_style,
estimated_impact=action_item.estimated_impact,
difficulty=action_item.difficulty,
execution_params=action_item.execution_params,
sort_order=idx,
)
db.add(db_action)
await db.commit()
await db.refresh(db_plan)
stmt = (
select(GeoPlan)
.options(selectinload(GeoPlanAction.plan))
.where(GeoPlan.id == db_plan.id)
)
result = await db.execute(stmt)
db_plan = result.scalar_one()
action_stmt = select(GeoPlanAction).where(
GeoPlanAction.plan_id == db_plan.id
).order_by(GeoPlanAction.sort_order)
action_result = await db.execute(action_stmt)
db_plan.actions = list(action_result.scalars().all())
return _plan_to_response(db_plan)
@router.get("/brand/{brand_id}", response_model=GeoPlanListResponse)
async def get_brand_plans(
brand_id: uuid.UUID,
skip: int = Query(0, ge=0),
limit: int = Query(20, ge=1, le=100),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
await _get_brand_with_access(brand_id, db, current_user)
count_stmt = select(func.count()).select_from(GeoPlan).where(
GeoPlan.brand_id == brand_id,
)
count_result = await db.execute(count_stmt)
total = count_result.scalar_one()
stmt = (
select(GeoPlan)
.where(GeoPlan.brand_id == brand_id)
.order_by(GeoPlan.created_at.desc())
.offset(skip)
.limit(limit)
)
result = await db.execute(stmt)
plans = list(result.scalars().all())
plan_responses = []
for plan in plans:
action_stmt = select(GeoPlanAction).where(
GeoPlanAction.plan_id == plan.id
).order_by(GeoPlanAction.sort_order)
action_result = await db.execute(action_stmt)
plan.actions = list(action_result.scalars().all())
plan_responses.append(_plan_to_response(plan))
return GeoPlanListResponse(plans=plan_responses, total=total)
@router.get("/{plan_id}", response_model=GeoPlanResponse)
async def get_plan_detail(
plan_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
stmt = select(GeoPlan).where(GeoPlan.id == plan_id)
result = await db.execute(stmt)
plan = result.scalar_one_or_none()
if not plan:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="方案不存在",
)
brand = await _get_brand_with_access(plan.brand_id, db, current_user)
action_stmt = select(GeoPlanAction).where(
GeoPlanAction.plan_id == plan.id
).order_by(GeoPlanAction.sort_order)
action_result = await db.execute(action_stmt)
plan.actions = list(action_result.scalars().all())
return _plan_to_response(plan)
@router.put("/actions/{action_id}/status", response_model=GeoPlanActionResponse)
async def update_action_status(
action_id: uuid.UUID,
status_update: GeoPlanActionUpdateStatus,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
valid_statuses = {"pending", "in_progress", "completed", "skipped"}
if status_update.status not in valid_statuses:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"无效的状态值,支持: {', '.join(valid_statuses)}",
)
stmt = select(GeoPlanAction).where(GeoPlanAction.id == action_id)
result = await db.execute(stmt)
action = result.scalar_one_or_none()
if not action:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="行动项不存在",
)
plan_stmt = select(GeoPlan).where(GeoPlan.id == action.plan_id)
plan_result = await db.execute(plan_stmt)
plan = plan_result.scalar_one()
await _get_brand_with_access(plan.brand_id, db, current_user)
action.status = status_update.status
if status_update.status == "completed":
action.completed_at = datetime.now()
await db.commit()
await db.refresh(action)
return GeoPlanActionResponse(
id=action.id,
plan_id=action.plan_id,
action_type=action.action_type,
title=action.title,
description=action.description,
reason=action.reason,
priority=action.priority,
status=action.status,
target_keyword=action.target_keyword,
target_platform=action.target_platform,
content_style=action.content_style,
estimated_impact=action.estimated_impact,
difficulty=action.difficulty,
execution_params=action.execution_params,
sort_order=action.sort_order,
completed_at=action.completed_at,
created_at=action.created_at,
)
@router.post("/actions/{action_id}/execute", response_model=GeoPlanActionExecuteResponse)
async def execute_action(
action_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
stmt = select(GeoPlanAction).where(GeoPlanAction.id == action_id)
result = await db.execute(stmt)
action = result.scalar_one_or_none()
if not action:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="行动项不存在",
)
plan_stmt = select(GeoPlan).where(GeoPlan.id == action.plan_id)
plan_result = await db.execute(plan_stmt)
plan = plan_result.scalar_one()
brand = await _get_brand_with_access(plan.brand_id, db, current_user)
if action.action_type not in ("content_creation", "content_optimization"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"行动类型 '{action.action_type}' 不支持一键执行,仅支持 content_creation 和 content_optimization",
)
params = action.execution_params or {}
keyword = params.get("keyword", action.target_keyword or brand.name)
platform = params.get("platform", action.target_platform or "通用")
style = params.get("style", action.content_style or "专业严谨")
word_count = params.get("word_count", 2000)
knowledge_base_ids = params.get("knowledge_base_ids")
content_service = ContentGenerationService()
try:
gen_result = await content_service.generate_content(
keyword=keyword,
brand_name=brand.name,
platform=platform,
content_style=style,
word_count=word_count,
knowledge_base_ids=knowledge_base_ids,
db=db,
user_id=current_user.id,
org_id=str(plan.organization_id),
run_deai=True,
run_geo=True,
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"内容生成失败: {str(e)}",
)
action.status = "completed"
action.completed_at = datetime.now()
await db.commit()
await db.refresh(action)
content_id = gen_result.get("content_id")
return GeoPlanActionExecuteResponse(
action_id=action.id,
content_id=content_id,
message="内容生成成功" if content_id else "内容生成完成(未持久化)",
)

View File

@ -22,9 +22,9 @@ from app.schemas.suggestion import (
SuggestionHistoryResponse,
)
from app.schemas.scoring import CitationResult
from app.services.scoring_service import ScoringService
from app.services.sentiment_service import get_sentiment_service
from app.services.optimization_advisor import (
from app.services.scoring.scoring_service import ScoringService
from app.services.analysis.sentiment_service import get_sentiment_service
from app.services.advisor.optimization_advisor import (
generate_suggestions,
build_context_from_scoring_result,
)
@ -163,7 +163,7 @@ async def _get_brand_scoring_data(
)
# 计算平台评分
from app.api.dashboard import REQUIRED_PLATFORMS
from app.services.scoring.brand_scoring_data_service import REQUIRED_PLATFORMS
platform_scores: dict[str, float] = {}
for platform in REQUIRED_PLATFORMS:
platform_citations = [c for c in all_citations if c.platform == platform]

124
backend/app/api/trends.py Normal file
View File

@ -0,0 +1,124 @@
import uuid
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.database import get_db
from app.models.user import User
from app.models.brand import Brand
from app.models.trend_insight import TrendInsight
from app.schemas.trend_insight import (
TrendInsightRequest,
TrendInsightResponse,
TrendInsightList,
TrendSummary,
)
from app.services.trend.trend_analyzer_service import TrendAnalyzerService
router = APIRouter()
async def _get_brand_with_access(
brand_id: uuid.UUID,
db: AsyncSession,
current_user: User,
) -> Brand:
stmt = select(Brand).where(
Brand.id == brand_id,
Brand.user_id == current_user.id,
)
result = await db.execute(stmt)
brand = result.scalar_one_or_none()
if not brand:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="品牌不存在",
)
return brand
@router.post("/insight", response_model=TrendInsightResponse)
async def create_trend_insight(
request: TrendInsightRequest,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
await _get_brand_with_access(request.brand_id, db, current_user)
service = TrendAnalyzerService(db)
result = await service.analyze_trends(
brand_id=request.brand_id,
days=request.period_days,
platforms=request.platforms,
keywords=request.keywords,
)
if result.get("status") == "insufficient_data":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=result.get("message", "数据不足"),
)
insight_id = uuid.UUID(result["insight_id"])
insight = await service.get_insight_by_id(insight_id)
if insight is None:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="洞察创建失败",
)
return insight
@router.get("/brand/{brand_id}", response_model=TrendInsightList)
async def list_trend_insights(
brand_id: uuid.UUID,
skip: int = Query(0, ge=0),
limit: int = Query(20, ge=1, le=100),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
await _get_brand_with_access(brand_id, db, current_user)
service = TrendAnalyzerService(db)
items, total = await service.get_insights(
brand_id=brand_id,
skip=skip,
limit=limit,
)
return TrendInsightList(items=items, total=total)
@router.get("/brand/{brand_id}/summary", response_model=TrendSummary)
async def get_trend_summary(
brand_id: uuid.UUID,
period_days: int = Query(30, ge=7, le=365),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
await _get_brand_with_access(brand_id, db, current_user)
service = TrendAnalyzerService(db)
summary = await service.get_summary(
brand_id=brand_id,
days=period_days,
)
return TrendSummary(**summary)
@router.get("/{insight_id}", response_model=TrendInsightResponse)
async def get_trend_insight_detail(
insight_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
service = TrendAnalyzerService(db)
insight = await service.get_insight_by_id(insight_id)
if insight is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="洞察不存在",
)
return insight

View File

@ -64,6 +64,40 @@ class Settings(BaseSettings):
# AI平台API调用频率限制每分钟请求数
API_RATE_LIMIT_RPM: int = 10
# Payment Gateway Configuration
WECHAT_PAY_MCH_ID: str = ""
WECHAT_PAY_API_KEY: str = ""
WECHAT_PAY_APP_ID: str = ""
WECHAT_PAY_CERT_PATH: str = ""
WECHAT_PAY_NOTIFY_URL: str = ""
ALIPAY_APP_ID: str = ""
ALIPAY_PRIVATE_KEY_PATH: str = ""
ALIPAY_PUBLIC_KEY_PATH: str = ""
ALIPAY_NOTIFY_URL: str = ""
PAYMENT_MODE: str = "mock"
ZHIHU_CLIENT_ID: str = ""
ZHIHU_CLIENT_SECRET: str = ""
ZHIHU_ACCESS_TOKEN: str = ""
TOUTIAO_APP_ID: str = ""
TOUTIAO_APP_SECRET: str = ""
TOUTIAO_ACCESS_TOKEN: str = ""
WECHAT_OFFICIAL_APP_ID: str = ""
WECHAT_OFFICIAL_APP_SECRET: str = ""
SMTP_HOST: str = ""
SMTP_PORT: int = 587
SMTP_USER: str = ""
SMTP_PASSWORD: str = ""
SMTP_FROM_EMAIL: str = "noreply@geo-platform.com"
SMTP_FROM_NAME: str = "GEO平台"
EMAIL_MODE: str = "mock"
SENDGRID_API_KEY: str = ""
ALIYUN_MAIL_ACCESS_KEY: str = ""
ALIYUN_MAIL_ACCESS_SECRET: str = ""
ALIYUN_MAIL_REGION: str = "cn-hangzhou"
DISTRIBUTION_MODE: str = "mock"
@field_validator("JWT_SECRET")
@classmethod
def validate_jwt_secret(cls, v: str) -> str:

View File

@ -43,14 +43,21 @@ from app.api.ai_engines import router as ai_engines_router
from app.api.detection import router as detection_router
from app.api.api_keys import router as api_keys_router
from app.api.usage import router as usage_router
from app.api.strategy import router as strategy_router
from app.api.competitor_analysis import router as competitor_analysis_router
from app.api.trends import router as trends_router
from app.api.schema_advisor import router as schema_advisor_router
from app.api.monitoring import router as monitoring_router
from app.api.health_score import router as health_score_router
from app.api.payments import router as payments_router
from app.api.attribution import router as attribution_router
from app.config import settings
from app.database import engine, Base
from app.schemas.common import ErrorResponse, ErrorCode
from app.middleware.rate_limit import RateLimitMiddleware
from app.middleware.logging_middleware import RequestLoggingMiddleware
from app.middleware.request_id import RequestIdMiddleware
from app.middleware.metrics import MetricsMiddleware
from app.monitoring.middleware import MonitoringMiddleware
from app.middleware.metrics import MetricsMiddleware, MonitoringMiddleware
from app.database import get_db
from app.workers.scheduler import query_scheduler
@ -59,7 +66,13 @@ from app.workers.scheduler import query_scheduler
async def lifespan(app: FastAPI):
import app.models
import app.monitoring
import app.middleware.prometheus_metrics
from app.middleware.prometheus_metrics import SERVICE_INFO
import os
SERVICE_INFO.info({
"version": "1.0.0",
"environment": os.getenv("ENVIRONMENT", "development"),
})
async with engine.begin() as conn:
await conn.execute(text("SELECT 1"))
@ -120,10 +133,14 @@ _allow_origins = [origin.strip() for origin in settings.CORS_ORIGINS.split(",")
if not _allow_origins:
_allow_origins = ["http://localhost:3000"]
import os
_is_dev = os.getenv("ENVIRONMENT", "development") == "development"
app.add_middleware(
CORSMiddleware,
allow_origins=_allow_origins,
allow_credentials=True,
allow_origins=_allow_origins if not _is_dev else ["*"],
allow_credentials=not _is_dev,
allow_methods=["*"],
allow_headers=["*"],
)
@ -174,6 +191,14 @@ app.include_router(ai_engines_router, prefix="/api/v1/ai-engines", tags=["AI引
app.include_router(detection_router, prefix="/api/v1/detection", tags=["定时检测任务"])
app.include_router(api_keys_router, prefix="/api/v1/api-keys", tags=["API Key管理"])
app.include_router(usage_router, prefix="/api/v1/usage", tags=["用量追踪"])
app.include_router(strategy_router, prefix="/api/v1/strategy", tags=["GEO方案"])
app.include_router(competitor_analysis_router, prefix="/api/v1/competitor", tags=["竞品分析"])
app.include_router(schema_advisor_router, prefix="/api/v1/schema", tags=["Schema建议"])
app.include_router(trends_router, prefix="/api/v1/trends", tags=["趋势洞察"])
app.include_router(monitoring_router, prefix="/api/v1/monitoring", tags=["效果追踪"])
app.include_router(health_score_router, prefix="/api/v1/public", tags=["公开API"])
app.include_router(payments_router)
app.include_router(attribution_router, prefix="/api/v1/attribution", tags=["效果归因"])
@app.get("/health", tags=["可观测性"])

View File

@ -3,7 +3,7 @@ import time
from contextlib import asynccontextmanager
from typing import Optional
from app.monitoring.metrics import (
from app.middleware.prometheus_metrics import (
AGENT_EXECUTIONS_TOTAL,
AGENT_EXECUTION_DURATION_SECONDS,
AGENT_RUNNING_TASKS,

View File

@ -2,7 +2,7 @@
import time
from typing import Optional
from app.monitoring.metrics import (
from app.middleware.prometheus_metrics import (
LLM_REQUESTS_TOTAL,
LLM_REQUEST_DURATION_SECONDS,
LLM_TOKENS_TOTAL,

View File

@ -1,9 +1,19 @@
"""请求指标收集中间件:计时、慢请求告警、响应时间响应头。"""
"""请求指标收集中间件计时、慢请求告警、响应时间响应头、Prometheus指标收集。
合并自原 middleware/metrics.pyMetricsMiddleware monitoring/middleware.pyMonitoringMiddleware
"""
import time
import logging
from typing import Callable
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
from app.middleware.prometheus_metrics import (
API_REQUESTS_TOTAL,
API_REQUEST_DURATION_SECONDS,
API_REQUESTS_IN_PROGRESS,
)
logger = logging.getLogger("geo.metrics")
@ -11,14 +21,14 @@ logger = logging.getLogger("geo.metrics")
SLOW_REQUEST_THRESHOLD = 1.0
# 跳过指标收集的路径前缀(健康检查等高频低价值路径)
_SKIP_PATHS = {"/health", "/ready", "/docs", "/openapi.json", "/favicon.ico"}
_SKIP_PATHS = {"/health", "/ready", "/docs", "/openapi.json", "/favicon.ico", "/metrics"}
class MetricsMiddleware(BaseHTTPMiddleware):
"""记录每个 HTTP 请求的耗时,并:
- 在响应头写入 X-Response-Time
- 对超过阈值的慢请求输出 WARNING 日志携带结构化字段
- 预留 Sentry / Prometheus 集成点TODO 注释标注
- 预留 Sentry 集成点TODO 注释标注
"""
async def dispatch(self, request: Request, call_next) -> Response:
@ -51,10 +61,82 @@ class MetricsMiddleware(BaseHTTPMiddleware):
else:
logger.debug("Request completed", extra=log_extra)
# TODO: 集成 Prometheus Counter/Histogram
# metrics_registry.http_request_duration.observe(duration, labels={...})
# TODO: 集成 Sentry 性能监控
# if sentry_sdk: sentry_sdk.set_measurement("response_time_ms", duration_ms)
return response
class MonitoringMiddleware(BaseHTTPMiddleware):
"""API监控中间件 — 收集 Prometheus 指标。
- 记录请求总数耗时分布活跃请求数
- 自动规范化端点标签替换路径中的ID参数
"""
async def dispatch(self, request: Request, call_next: Callable) -> Response:
# 跳过排除路径
if request.url.path in _SKIP_PATHS:
return await call_next(request)
# 提取端点标识(用于指标标签)
endpoint = self._get_endpoint_label(request)
# 增加活跃请求计数
API_REQUESTS_IN_PROGRESS.labels(
method=request.method,
endpoint=endpoint
).inc()
# 记录开始时间
start_time = time.perf_counter()
try:
# 执行请求
response = await call_next(request)
status_code = response.status_code
except Exception as e:
status_code = 500
raise
finally:
# 计算耗时
duration = time.perf_counter() - start_time
# 记录指标
API_REQUESTS_TOTAL.labels(
method=request.method,
endpoint=endpoint,
status=str(status_code)
).inc()
API_REQUEST_DURATION_SECONDS.labels(
method=request.method,
endpoint=endpoint
).observe(duration)
# 减少活跃请求计数
API_REQUESTS_IN_PROGRESS.labels(
method=request.method,
endpoint=endpoint
).dec()
return response
def _get_endpoint_label(self, request: Request) -> str:
"""提取端点标签"""
path = request.url.path
# 规范化路径替换ID等参数
parts = path.strip("/").split("/")
# 处理常见模式:/api/v1/resources/{id}
if len(parts) >= 4 and parts[0] == "api":
resource = parts[2] if len(parts) > 2 else "unknown"
action = parts[3] if len(parts) > 3 else "list"
# 映射到规范标签
if action.isdigit():
return f"{resource}_detail"
return f"{resource}_{action}"
return "other"

View File

@ -0,0 +1,57 @@
from fastapi import Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.database import get_db
from app.models.user import User
from app.services.subscription import PLANS
class SubscriptionEnforcement:
@staticmethod
def require_plan(*allowed_plans: str):
async def _check(current_user: User = Depends(get_current_user)):
user_plan = getattr(current_user, "plan", "free") or "free"
if user_plan not in allowed_plans:
raise HTTPException(
status_code=403,
detail={
"message": f"此功能需要 {allowed_plans[0]} 及以上套餐",
"required_plan": allowed_plans[0],
"current_plan": user_plan,
"upgrade_url": "/api/v1/subscriptions/plans",
},
)
return current_user
return _check
@staticmethod
def check_quota(resource: str):
async def _check(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
user_plan = getattr(current_user, "plan", "free") or "free"
plan_config = PLANS.get(user_plan, PLANS["free"])
if resource == "queries":
limit = plan_config.get("max_queries", 3)
current_usage = getattr(current_user, "max_queries", limit) or limit
remaining = max(0, limit - current_usage)
elif resource == "brands":
limit = plan_config.get("max_brands", 1)
remaining = limit if limit == -1 else max(0, limit)
elif resource == "alerts":
limit = plan_config.get("max_alerts_per_month", 0)
remaining = limit if limit == -1 else max(0, limit)
else:
remaining = 0
return {
"user_id": current_user.id,
"plan": user_plan,
"resource": resource,
"remaining": remaining,
"unlimited": remaining == -1,
}
return _check

View File

@ -33,6 +33,14 @@ from app.models.alert import Alert
from app.models.alert_setting import AlertSetting
from app.models.detection_task import DetectionTask
from app.models.usage_record import UsageRecord
from app.models.geo_plan import GeoPlan, GeoPlanAction
from app.models.trend_insight import TrendInsight
from app.models.competitor_insight import CompetitorInsight
from app.models.schema_suggestion import SchemaSuggestion
from app.models.monitoring import MonitoringRecord, ContentBaseline
from app.models.diagnosis_record import DiagnosisRecord
from app.models.payment_order import PaymentOrder
from app.models.attribution_record import AttributionRecord
__all__ = [
"User",
@ -76,4 +84,14 @@ __all__ = [
"AlertSetting",
"DetectionTask",
"UsageRecord",
"GeoPlan",
"GeoPlanAction",
"CompetitorInsight",
"SchemaSuggestion",
"TrendInsight",
"MonitoringRecord",
"ContentBaseline",
"DiagnosisRecord",
"PaymentOrder",
"AttributionRecord",
]

View File

@ -0,0 +1,65 @@
import uuid
from datetime import datetime
from sqlalchemy import String, Float, Integer, ForeignKey, Index, func, Text
from sqlalchemy import Uuid
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.database import Base, JSONType
class AttributionRecord(Base):
__tablename__ = "attribution_records"
id: Mapped[uuid.UUID] = mapped_column(
Uuid(as_uuid=True),
primary_key=True,
default=uuid.uuid4,
)
user_id: Mapped[str] = mapped_column(
Text,
ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
)
brand_id: Mapped[uuid.UUID] = mapped_column(
Uuid(as_uuid=True),
ForeignKey("brands.id", ondelete="CASCADE"),
nullable=False,
)
content_id: Mapped[uuid.UUID | None] = mapped_column(
Uuid(as_uuid=True),
ForeignKey("contents.id", ondelete="SET NULL"),
nullable=True,
)
baseline_score: Mapped[float] = mapped_column(Float, nullable=False)
current_score: Mapped[float | None] = mapped_column(Float, nullable=True)
score_delta: Mapped[float | None] = mapped_column(Float, nullable=True)
attribution_window_days: Mapped[int] = mapped_column(
Integer, server_default="28", nullable=False,
)
published_at: Mapped[datetime | None] = mapped_column(nullable=True)
window_end_at: Mapped[datetime | None] = mapped_column(nullable=True)
status: Mapped[str] = mapped_column(
String(20), server_default="tracking", nullable=False,
)
attributed_dimensions: Mapped[dict | None] = mapped_column(JSONType, nullable=True)
roi_percentage: Mapped[float | None] = mapped_column(Float, nullable=True)
created_at: Mapped[datetime] = mapped_column(
server_default=func.now(),
nullable=False,
)
updated_at: Mapped[datetime] = mapped_column(
server_default=func.now(),
onupdate=func.now(),
nullable=False,
)
brand: Mapped["Brand"] = relationship("Brand")
content: Mapped["Content | None"] = relationship("Content")
__table_args__ = (
Index("idx_attribution_records_brand_id", "brand_id"),
Index("idx_attribution_records_user_id", "user_id"),
Index("idx_attribution_records_status", "status"),
Index("idx_attribution_records_content_id", "content_id"),
)

View File

@ -62,7 +62,12 @@ class Brand(Base):
"Suggestion", back_populates="brand", cascade="all, delete-orphan"
)
schema_suggestions: Mapped[list["SchemaSuggestion"]] = relationship(
"SchemaSuggestion", back_populates="brand", cascade="all, delete-orphan"
)
# Import at bottom to avoid circular import
from app.models.competitor import Competitor # noqa: E402, F401
from app.models.suggestion import Suggestion # noqa: E402, F401
from app.models.schema_suggestion import SchemaSuggestion # noqa: E402, F401

View File

@ -6,6 +6,7 @@ from sqlalchemy import Uuid, JSON
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.database import Base
from app.utils.text import sanitize_raw_response
class CitationRecord(Base):
@ -75,3 +76,40 @@ class CitationRecord(Base):
Index("idx_citation_records_queried_at", "queried_at"),
Index("idx_citation_records_platform", "platform"),
)
@classmethod
def from_citation_result(
cls,
query_id: uuid.UUID,
platform: str,
result: dict,
) -> "CitationRecord":
"""从引用检测结果字典创建 CitationRecord 实例
统一处理字段映射默认值和 raw_response / ai_response_text 的清理
Args:
query_id: 关联的查询 ID
platform: 平台名称
result: 引用检测结果字典
Returns:
CitationRecord 实例未持久化
"""
return cls(
query_id=query_id,
platform=platform,
cited=result.get("cited", False),
citation_position=result.get("position"),
citation_text=result.get("citation_text"),
competitor_brands=result.get("competitor_brands", []),
raw_response=sanitize_raw_response(result.get("raw_response", "")),
confidence=result.get("confidence"),
match_type=result.get("match_type"),
# 引用源分析字段
data_source=result.get("data_source"),
source_urls=result.get("source_urls"),
source_titles=result.get("source_titles"),
citation_contexts=result.get("citation_contexts"),
ai_response_text=sanitize_raw_response(result.get("ai_response_text", "")),
)

View File

@ -0,0 +1,64 @@
import uuid
from datetime import datetime
from sqlalchemy import String, Float, Integer, ForeignKey, Index, func
from sqlalchemy import Uuid
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.types import TypeDecorator, JSON
from app.database import Base
class JSONType(TypeDecorator):
impl = JSON
cache_ok = True
def load_dialect_impl(self, dialect):
if dialect.name == "postgresql":
return dialect.type_descriptor(JSONB())
return dialect.type_descriptor(JSON())
class CompetitorInsight(Base):
__tablename__ = "competitor_insights"
id: Mapped[uuid.UUID] = mapped_column(
Uuid(as_uuid=True),
primary_key=True,
default=uuid.uuid4,
)
brand_id: Mapped[uuid.UUID] = mapped_column(
Uuid(as_uuid=True),
ForeignKey("brands.id", ondelete="CASCADE"),
nullable=False,
)
competitor_name: Mapped[str] = mapped_column(String(100), nullable=False)
analysis_type: Mapped[str] = mapped_column(
String(50), nullable=False,
)
insight_data: Mapped[dict | None] = mapped_column(JSONType, nullable=True)
citation_count_brand: Mapped[int] = mapped_column(Integer, default=0, nullable=False)
citation_count_competitor: Mapped[int] = mapped_column(Integer, default=0, nullable=False)
sentiment_brand: Mapped[float | None] = mapped_column(Float, nullable=True)
sentiment_competitor: Mapped[float | None] = mapped_column(Float, nullable=True)
platform_breakdown: Mapped[dict | None] = mapped_column(JSONType, nullable=True)
gap_analysis: Mapped[dict | None] = mapped_column(JSONType, nullable=True)
opportunity_areas: Mapped[dict | None] = mapped_column(JSONType, nullable=True)
recommendations: Mapped[dict | None] = mapped_column(JSONType, nullable=True)
confidence: Mapped[str] = mapped_column(String(20), default="medium", nullable=False)
period_days: Mapped[int] = mapped_column(Integer, default=30, nullable=False)
created_at: Mapped[datetime] = mapped_column(
server_default=func.now(),
nullable=False,
)
updated_at: Mapped[datetime] = mapped_column(
server_default=func.now(),
onupdate=func.now(),
nullable=False,
)
__table_args__ = (
Index("idx_competitor_insights_brand_id", "brand_id"),
Index("idx_competitor_insights_analysis_type", "analysis_type"),
)

View File

@ -0,0 +1,40 @@
import uuid
from datetime import datetime
from sqlalchemy import String, Uuid, JSON, Float, Text, ForeignKey, Index, func
from sqlalchemy.orm import Mapped, mapped_column
from app.database import Base
class DiagnosisRecord(Base):
__tablename__ = "diagnosis_records"
id: Mapped[uuid.UUID] = mapped_column(
Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4
)
brand_id: Mapped[uuid.UUID] = mapped_column(
Uuid(as_uuid=True),
ForeignKey("brands.id", ondelete="CASCADE"),
nullable=False,
)
user_id: Mapped[uuid.UUID] = mapped_column(Uuid(as_uuid=True), nullable=False)
diagnosis_type: Mapped[str] = mapped_column(
String(20), default="geo", nullable=False
)
status: Mapped[str] = mapped_column(String(20), default="pending", nullable=False)
overall_score: Mapped[float | None] = mapped_column(Float, nullable=True)
result_json: Mapped[dict | None] = mapped_column(JSON, nullable=True)
error_message: Mapped[str | None] = mapped_column(Text, nullable=True)
collection_metadata: Mapped[dict | None] = mapped_column(JSON, nullable=True)
created_at: Mapped[datetime] = mapped_column(
server_default=func.now(), nullable=False
)
completed_at: Mapped[datetime | None] = mapped_column(nullable=True)
__table_args__ = (
Index("idx_diagnosis_records_brand_id", "brand_id"),
Index("idx_diagnosis_records_user_id", "user_id"),
Index("idx_diagnosis_records_status", "status"),
Index("idx_diagnosis_records_created_at", "created_at"),
)

View File

@ -0,0 +1,112 @@
import uuid
from datetime import datetime
from sqlalchemy import String, Integer, Text, ForeignKey, Index, func
from sqlalchemy import Uuid
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.database import Base, JSONType
class GeoPlan(Base):
__tablename__ = "geo_plans"
id: Mapped[uuid.UUID] = mapped_column(
Uuid(as_uuid=True),
primary_key=True,
default=uuid.uuid4,
)
organization_id: Mapped[uuid.UUID] = mapped_column(
Uuid(as_uuid=True),
ForeignKey("organizations.id", ondelete="CASCADE"),
nullable=False,
)
brand_id: Mapped[uuid.UUID] = mapped_column(
Uuid(as_uuid=True),
ForeignKey("brands.id", ondelete="CASCADE"),
nullable=False,
)
title: Mapped[str] = mapped_column(String(500), nullable=False)
status: Mapped[str] = mapped_column(
String(20), server_default="draft", nullable=False,
)
diagnosis_score: Mapped[int] = mapped_column(Integer, nullable=False)
target_score: Mapped[int] = mapped_column(Integer, nullable=False)
estimated_weeks: Mapped[int] = mapped_column(Integer, nullable=False)
plan_data: Mapped[dict | None] = mapped_column(JSONType, nullable=True)
source: Mapped[str] = mapped_column(String(20), nullable=False, default="rule")
created_by: Mapped[str | None] = mapped_column(
String(36),
ForeignKey("users.id", ondelete="SET NULL"),
nullable=True,
)
created_at: Mapped[datetime] = mapped_column(
server_default=func.now(),
nullable=False,
)
updated_at: Mapped[datetime] = mapped_column(
server_default=func.now(),
onupdate=func.now(),
nullable=False,
)
brand: Mapped["Brand"] = relationship("Brand")
creator: Mapped["User | None"] = relationship(
"User", foreign_keys=[created_by]
)
__table_args__ = (
Index("idx_geo_plans_brand_id", "brand_id"),
Index("idx_geo_plans_status", "status"),
Index("idx_geo_plans_organization_id", "organization_id"),
)
class GeoPlanAction(Base):
__tablename__ = "geo_plan_actions"
id: Mapped[uuid.UUID] = mapped_column(
Uuid(as_uuid=True),
primary_key=True,
default=uuid.uuid4,
)
plan_id: Mapped[uuid.UUID] = mapped_column(
Uuid(as_uuid=True),
ForeignKey("geo_plans.id", ondelete="CASCADE"),
nullable=False,
)
action_type: Mapped[str] = mapped_column(String(50), nullable=False)
title: Mapped[str] = mapped_column(String(500), nullable=False)
description: Mapped[str] = mapped_column(Text, nullable=False)
reason: Mapped[str] = mapped_column(Text, nullable=False)
priority: Mapped[str] = mapped_column(String(10), nullable=False)
status: Mapped[str] = mapped_column(
String(20), server_default="pending", nullable=False,
)
target_keyword: Mapped[str | None] = mapped_column(String(200), nullable=True)
target_platform: Mapped[str | None] = mapped_column(String(50), nullable=True)
content_style: Mapped[str | None] = mapped_column(String(50), nullable=True)
estimated_impact: Mapped[str | None] = mapped_column(String(500), nullable=True)
difficulty: Mapped[str] = mapped_column(String(10), nullable=False)
execution_params: Mapped[dict | None] = mapped_column(JSONType, nullable=True)
sort_order: Mapped[int] = mapped_column(
Integer, server_default="0", nullable=False,
)
completed_at: Mapped[datetime | None] = mapped_column(nullable=True)
created_at: Mapped[datetime] = mapped_column(
server_default=func.now(),
nullable=False,
)
updated_at: Mapped[datetime] = mapped_column(
server_default=func.now(),
onupdate=func.now(),
nullable=False,
)
plan: Mapped["GeoPlan"] = relationship("GeoPlan")
__table_args__ = (
Index("idx_geo_plan_actions_plan_id", "plan_id"),
Index("idx_geo_plan_actions_status", "status"),
Index("idx_geo_plan_actions_priority", "priority"),
)

View File

@ -0,0 +1,100 @@
import uuid
from datetime import datetime
from sqlalchemy import String, Integer, Float, ForeignKey, Index, func
from sqlalchemy import Uuid
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.database import Base, JSONType
class MonitoringRecord(Base):
__tablename__ = "monitoring_records"
id: Mapped[uuid.UUID] = mapped_column(
Uuid(as_uuid=True),
primary_key=True,
default=uuid.uuid4,
)
brand_id: Mapped[uuid.UUID] = mapped_column(
Uuid(as_uuid=True),
ForeignKey("brands.id", ondelete="CASCADE"),
nullable=False,
)
content_id: Mapped[str | None] = mapped_column(String(36), nullable=True)
query_keywords: Mapped[str | None] = mapped_column(String(500), nullable=True)
platform: Mapped[str | None] = mapped_column(String(50), nullable=True)
baseline_citation_count: Mapped[int] = mapped_column(
Integer, server_default="0", nullable=False,
)
baseline_sentiment: Mapped[float | None] = mapped_column(Float, nullable=True)
baseline_rank: Mapped[int | None] = mapped_column(Integer, nullable=True)
current_citation_count: Mapped[int] = mapped_column(
Integer, server_default="0", nullable=False,
)
current_sentiment: Mapped[float | None] = mapped_column(Float, nullable=True)
current_rank: Mapped[int | None] = mapped_column(Integer, nullable=True)
change_type: Mapped[str | None] = mapped_column(String(20), nullable=True)
change_details: Mapped[dict | None] = mapped_column(JSONType, nullable=True)
check_interval_hours: Mapped[int] = mapped_column(
Integer, server_default="24", nullable=False,
)
last_checked_at: Mapped[datetime | None] = mapped_column(nullable=True)
next_check_at: Mapped[datetime | None] = mapped_column(nullable=True)
status: Mapped[str] = mapped_column(
String(20), server_default="active", nullable=False,
)
created_at: Mapped[datetime] = mapped_column(
server_default=func.now(),
nullable=False,
)
updated_at: Mapped[datetime] = mapped_column(
server_default=func.now(),
onupdate=func.now(),
nullable=False,
)
brand: Mapped["Brand"] = relationship("Brand")
baselines: Mapped[list["ContentBaseline"]] = relationship(
"ContentBaseline", back_populates="monitoring_record", cascade="all, delete-orphan",
)
__table_args__ = (
Index("idx_monitoring_records_brand_id", "brand_id"),
Index("idx_monitoring_records_status", "status"),
Index("idx_monitoring_records_next_check_at", "next_check_at"),
)
class ContentBaseline(Base):
__tablename__ = "content_baselines"
id: Mapped[uuid.UUID] = mapped_column(
Uuid(as_uuid=True),
primary_key=True,
default=uuid.uuid4,
)
monitoring_record_id: Mapped[uuid.UUID] = mapped_column(
Uuid(as_uuid=True),
ForeignKey("monitoring_records.id", ondelete="CASCADE"),
nullable=False,
)
brand_name: Mapped[str] = mapped_column(String(100), nullable=False)
keyword: Mapped[str] = mapped_column(String(200), nullable=False)
platform: Mapped[str] = mapped_column(String(50), nullable=False)
citation_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
sentiment_score: Mapped[float | None] = mapped_column(Float, nullable=True)
rank_position: Mapped[int | None] = mapped_column(Integer, nullable=True)
snapshot_data: Mapped[dict | None] = mapped_column(JSONType, nullable=True)
recorded_at: Mapped[datetime] = mapped_column(
server_default=func.now(),
nullable=False,
)
monitoring_record: Mapped["MonitoringRecord"] = relationship(
"MonitoringRecord", back_populates="baselines",
)
__table_args__ = (
Index("idx_content_baselines_monitoring_record_id", "monitoring_record_id"),
)

View File

@ -0,0 +1,108 @@
import uuid
from datetime import datetime
from sqlalchemy import String, Integer, Float, ForeignKey, Index, func
from sqlalchemy import Uuid
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.database import Base, JSONType
class MonitoringRecord(Base):
__tablename__ = "monitoring_records"
id: Mapped[uuid.UUID] = mapped_column(
Uuid(as_uuid=True),
primary_key=True,
default=uuid.uuid4,
)
user_id: Mapped[uuid.UUID] = mapped_column(
Uuid(as_uuid=True),
ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
)
brand_id: Mapped[uuid.UUID] = mapped_column(
Uuid(as_uuid=True),
ForeignKey("brands.id", ondelete="CASCADE"),
nullable=False,
)
content_id: Mapped[uuid.UUID | None] = mapped_column(
Uuid(as_uuid=True),
ForeignKey("contents.id", ondelete="SET NULL"),
nullable=True,
)
query_id: Mapped[uuid.UUID | None] = mapped_column(
Uuid(as_uuid=True),
ForeignKey("queries.id", ondelete="SET NULL"),
nullable=True,
)
task_type: Mapped[str] = mapped_column(String(50), nullable=False)
status: Mapped[str] = mapped_column(
String(20), server_default="pending", nullable=False,
)
baseline_data: Mapped[dict | None] = mapped_column(JSONType, nullable=True)
current_data: Mapped[dict | None] = mapped_column(JSONType, nullable=True)
change_report: Mapped[dict | None] = mapped_column(JSONType, nullable=True)
interval_hours: Mapped[int] = mapped_column(
Integer, server_default="24", nullable=False,
)
last_checked_at: Mapped[datetime | None] = mapped_column(nullable=True)
next_check_at: Mapped[datetime | None] = mapped_column(nullable=True)
created_at: Mapped[datetime] = mapped_column(
server_default=func.now(),
nullable=False,
)
updated_at: Mapped[datetime] = mapped_column(
server_default=func.now(),
onupdate=func.now(),
nullable=False,
)
brand: Mapped["Brand"] = relationship("Brand")
user: Mapped["User"] = relationship("User")
__table_args__ = (
Index("idx_monitoring_records_user_id", "user_id"),
Index("idx_monitoring_records_brand_id", "brand_id"),
Index("idx_monitoring_records_status", "status"),
Index("idx_monitoring_records_next_check_at", "next_check_at"),
)
class ContentBaseline(Base):
__tablename__ = "content_baselines"
id: Mapped[uuid.UUID] = mapped_column(
Uuid(as_uuid=True),
primary_key=True,
default=uuid.uuid4,
)
brand_id: Mapped[uuid.UUID] = mapped_column(
Uuid(as_uuid=True),
ForeignKey("brands.id", ondelete="CASCADE"),
nullable=False,
)
content_id: Mapped[uuid.UUID | None] = mapped_column(
Uuid(as_uuid=True),
ForeignKey("contents.id", ondelete="SET NULL"),
nullable=True,
)
citation_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
positive_ratio: Mapped[float] = mapped_column(Float, nullable=False, default=0.0)
avg_rank: Mapped[float] = mapped_column(Float, nullable=False, default=0.0)
platform_data: Mapped[dict | None] = mapped_column(JSONType, nullable=True)
recorded_at: Mapped[datetime] = mapped_column(
server_default=func.now(),
nullable=False,
)
created_at: Mapped[datetime] = mapped_column(
server_default=func.now(),
nullable=False,
)
brand: Mapped["Brand"] = relationship("Brand")
__table_args__ = (
Index("idx_content_baselines_brand_id", "brand_id"),
Index("idx_content_baselines_content_id", "content_id"),
)

View File

@ -0,0 +1,40 @@
import uuid
from datetime import datetime
from sqlalchemy import String, Float, DateTime, ForeignKey, func, Uuid
from sqlalchemy.orm import Mapped, mapped_column
from app.database import Base, JSONType
class PaymentOrder(Base):
__tablename__ = "payment_orders"
id: Mapped[uuid.UUID] = mapped_column(
Uuid(as_uuid=True),
primary_key=True,
default=uuid.uuid4,
)
user_id: Mapped[str] = mapped_column(
String(36),
ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
)
plan: Mapped[str] = mapped_column(String(20), nullable=False)
amount: Mapped[float] = mapped_column(Float, nullable=False)
currency: Mapped[str] = mapped_column(String(10), default="CNY")
payment_provider: Mapped[str] = mapped_column(String(20), nullable=False)
payment_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
status: Mapped[str] = mapped_column(String(20), default="pending")
pay_url: Mapped[str | None] = mapped_column(String(1024), nullable=True)
callback_data: Mapped[dict | None] = mapped_column(JSONType, nullable=True)
created_at: Mapped[datetime] = mapped_column(
server_default=func.now(),
nullable=False,
)
updated_at: Mapped[datetime] = mapped_column(
server_default=func.now(),
onupdate=func.now(),
nullable=False,
)
paid_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)

View File

@ -0,0 +1,74 @@
import uuid
from datetime import datetime
from sqlalchemy import String, Text, DateTime, ForeignKey, Index, func, Float
from sqlalchemy import Uuid
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.database import Base, JSONType
class SchemaSuggestion(Base):
__tablename__ = "schema_suggestions"
id: Mapped[uuid.UUID] = mapped_column(
Uuid(as_uuid=True),
primary_key=True,
default=uuid.uuid4,
)
brand_id: Mapped[uuid.UUID] = mapped_column(
Uuid(as_uuid=True),
ForeignKey("brands.id", ondelete="CASCADE"),
nullable=False,
)
schema_type: Mapped[str] = mapped_column(
String(50), nullable=False,
)
target_url: Mapped[str | None] = mapped_column(
String(500), nullable=True,
)
json_ld_template: Mapped[dict] = mapped_column(
JSONType, nullable=False, default=dict,
)
json_ld_filled: Mapped[dict | None] = mapped_column(
JSONType, nullable=True,
)
priority: Mapped[str] = mapped_column(
String(20), nullable=False, default="medium",
)
status: Mapped[str] = mapped_column(
String(20), nullable=False, default="pending",
)
diagnosis_dimensions: Mapped[dict | None] = mapped_column(
JSONType, nullable=True,
)
implementation_difficulty: Mapped[str] = mapped_column(
String(20), nullable=False, default="medium",
)
estimated_impact: Mapped[str | None] = mapped_column(
Text, nullable=True,
)
validation_errors: Mapped[dict | None] = mapped_column(
JSONType, nullable=True,
)
created_at: Mapped[datetime] = mapped_column(
server_default=func.now(),
nullable=False,
)
updated_at: Mapped[datetime] = mapped_column(
server_default=func.now(),
onupdate=func.now(),
nullable=False,
)
brand: Mapped["Brand"] = relationship("Brand", back_populates="schema_suggestions")
__table_args__ = (
Index("idx_schema_suggestions_brand_id", "brand_id"),
Index("idx_schema_suggestions_status", "status"),
Index("idx_schema_suggestions_schema_type", "schema_type"),
Index("idx_schema_suggestions_brand_status", "brand_id", "status"),
)
from app.models.brand import Brand # noqa: E402, F401

View File

@ -0,0 +1,54 @@
import uuid
from datetime import datetime
from sqlalchemy import String, Integer, Float, ForeignKey, Index, func, Text
from sqlalchemy import Uuid, JSON
from sqlalchemy.orm import Mapped, mapped_column
from app.database import Base
class TrendInsight(Base):
__tablename__ = "trend_insights"
id: Mapped[uuid.UUID] = mapped_column(
Uuid(as_uuid=True),
primary_key=True,
default=uuid.uuid4,
)
brand_id: Mapped[str] = mapped_column(
String(36),
ForeignKey("brands.id", ondelete="CASCADE"),
nullable=False,
)
trend_type: Mapped[str] = mapped_column(String(20), nullable=False)
keyword: Mapped[str | None] = mapped_column(String(200), nullable=True)
platform: Mapped[str | None] = mapped_column(String(50), nullable=True)
period_start: Mapped[datetime] = mapped_column(nullable=False)
period_end: Mapped[datetime] = mapped_column(nullable=False)
data_points: Mapped[list | None] = mapped_column(JSON, nullable=True)
change_rate: Mapped[float | None] = mapped_column(Float, nullable=True)
absolute_change: Mapped[int | None] = mapped_column(Integer, nullable=True)
sentiment_trend: Mapped[dict | None] = mapped_column(JSON, nullable=True)
cause_analysis: Mapped[str | None] = mapped_column(Text, nullable=True)
recommendations: Mapped[list | None] = mapped_column(JSON, nullable=True)
confidence: Mapped[float] = mapped_column(Float, nullable=False, default=0.5)
severity: Mapped[str] = mapped_column(
String(20), nullable=False, server_default="info",
)
created_at: Mapped[datetime] = mapped_column(
server_default=func.now(),
nullable=False,
)
updated_at: Mapped[datetime] = mapped_column(
server_default=func.now(),
onupdate=func.now(),
nullable=False,
)
__table_args__ = (
Index("idx_trend_insights_brand_id", "brand_id"),
Index("idx_trend_insights_trend_type", "trend_type"),
Index("idx_trend_insights_created_at", "created_at"),
Index("idx_trend_insights_period_start", "period_start"),
)

View File

@ -1,13 +0,0 @@
"""监控模块"""
import os
from app.monitoring.metrics import *
from app.monitoring.middleware import MonitoringMiddleware
from app.monitoring.agent_hooks import agent_execution_context, record_agent_execution
from app.monitoring.llm_metrics import get_llm_metrics, LLMMetricsWrapper
# 设置服务信息
SERVICE_INFO.info({
"version": "1.0.0",
"environment": os.getenv("ENVIRONMENT", "development"),
})

View File

@ -1,86 +0,0 @@
"""监控中间件"""
import time
from typing import Callable
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from app.monitoring.metrics import (
API_REQUESTS_TOTAL,
API_REQUEST_DURATION_SECONDS,
API_REQUESTS_IN_PROGRESS,
)
# 需要排除的路径(不记录指标)
EXCLUDED_PATHS = {"/health", "/ready", "/metrics", "/docs", "/openapi.json"}
class MonitoringMiddleware(BaseHTTPMiddleware):
"""API监控中间件"""
async def dispatch(self, request: Request, call_next: Callable) -> Response:
# 跳过排除路径
if request.url.path in EXCLUDED_PATHS:
return await call_next(request)
# 提取端点标识(用于指标标签)
endpoint = self._get_endpoint_label(request)
# 增加活跃请求计数
API_REQUESTS_IN_PROGRESS.labels(
method=request.method,
endpoint=endpoint
).inc()
# 记录开始时间
start_time = time.perf_counter()
try:
# 执行请求
response = await call_next(request)
status_code = response.status_code
except Exception as e:
status_code = 500
raise
finally:
# 计算耗时
duration = time.perf_counter() - start_time
# 记录指标
API_REQUESTS_TOTAL.labels(
method=request.method,
endpoint=endpoint,
status=str(status_code)
).inc()
API_REQUEST_DURATION_SECONDS.labels(
method=request.method,
endpoint=endpoint
).observe(duration)
# 减少活跃请求计数
API_REQUESTS_IN_PROGRESS.labels(
method=request.method,
endpoint=endpoint
).dec()
return response
def _get_endpoint_label(self, request: Request) -> str:
"""提取端点标签"""
path = request.url.path
# 规范化路径替换ID等参数
parts = path.strip("/").split("/")
# 处理常见模式:/api/v1/resources/{id}
if len(parts) >= 4 and parts[0] == "api":
resource = parts[2] if len(parts) > 2 else "unknown"
action = parts[3] if len(parts) > 3 else "list"
# 映射到规范标签
if action.isdigit():
return f"{resource}_detail"
return f"{resource}_{action}"
return "other"

View File

@ -1,7 +1,31 @@
from app.repositories.api_key_repository import APIKeyRepository
from app.repositories.usage_repository import UsageRepository
from app.repositories.brand_repository import BrandRepository
from app.repositories.query_repository import QueryRepository
from app.repositories.citation_repository import CitationRepository
from app.repositories.content_repository import ContentRepository
from app.repositories.knowledge_repository import KnowledgeRepository
from app.repositories.alert_repository import AlertRepository
from app.repositories.subscription_repository import SubscriptionRepository
from app.repositories.organization_repository import OrganizationRepository
from app.repositories.user_repository import UserRepository
from app.repositories.detection_task_repository import DetectionTaskRepository
from app.repositories.suggestion_repository import SuggestionRepository
from app.repositories.competitor_repository import CompetitorRepository
__all__ = [
"APIKeyRepository",
"UsageRepository",
"BrandRepository",
"QueryRepository",
"CitationRepository",
"ContentRepository",
"KnowledgeRepository",
"AlertRepository",
"SubscriptionRepository",
"OrganizationRepository",
"UserRepository",
"DetectionTaskRepository",
"SuggestionRepository",
"CompetitorRepository",
]

View File

@ -0,0 +1,77 @@
import uuid
from typing import Optional
from sqlalchemy import select, func, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.alert import Alert
class AlertRepository:
def __init__(self, session: AsyncSession):
self.session = session
async def get_by_id(self, id: uuid.UUID) -> Optional[Alert]:
result = await self.session.execute(
select(Alert).where(Alert.id == id)
)
return result.scalar_one_or_none()
async def list_by_user(
self, user_id: uuid.UUID, *, skip: int = 0, limit: int = 100
) -> list[Alert]:
result = await self.session.execute(
select(Alert)
.where(Alert.user_id == user_id)
.order_by(Alert.created_at.desc())
.offset(skip)
.limit(limit)
)
return list(result.scalars().all())
async def count_by_user(self, user_id: uuid.UUID) -> int:
result = await self.session.execute(
select(func.count()).select_from(Alert).where(Alert.user_id == user_id)
)
return result.scalar_one()
async def get_unread_by_user(self, user_id: uuid.UUID) -> list[Alert]:
result = await self.session.execute(
select(Alert).where(
Alert.user_id == user_id,
Alert.is_read == False,
).order_by(Alert.created_at.desc())
)
return list(result.scalars().all())
async def mark_as_read(self, alert_id: uuid.UUID) -> bool:
instance = await self.get_by_id(alert_id)
if instance is None:
return False
instance.is_read = True
await self.session.flush()
return True
async def create(self, **kwargs) -> Alert:
instance = Alert(**kwargs)
self.session.add(instance)
await self.session.flush()
return instance
async def update(self, id: uuid.UUID, **kwargs) -> Optional[Alert]:
instance = await self.get_by_id(id)
if instance is None:
return None
for key, value in kwargs.items():
if hasattr(instance, key):
setattr(instance, key, value)
await self.session.flush()
return instance
async def delete(self, id: uuid.UUID) -> bool:
instance = await self.get_by_id(id)
if instance is None:
return False
await self.session.delete(instance)
await self.session.flush()
return True

View File

@ -0,0 +1,71 @@
import uuid
from typing import Optional
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.brand import Brand
class BrandRepository:
def __init__(self, session: AsyncSession):
self.session = session
async def get_by_id(self, id: uuid.UUID) -> Optional[Brand]:
result = await self.session.execute(
select(Brand).where(Brand.id == id)
)
return result.scalar_one_or_none()
async def list_by_user(
self, user_id: uuid.UUID, *, skip: int = 0, limit: int = 100
) -> list[Brand]:
result = await self.session.execute(
select(Brand)
.where(Brand.user_id == user_id)
.order_by(Brand.created_at.desc())
.offset(skip)
.limit(limit)
)
return list(result.scalars().all())
async def count_by_user(self, user_id: uuid.UUID) -> int:
result = await self.session.execute(
select(func.count()).select_from(Brand).where(Brand.user_id == user_id)
)
return result.scalar_one()
async def get_by_name_and_organization(
self, name: str, organization_id: uuid.UUID
) -> Optional[Brand]:
result = await self.session.execute(
select(Brand).where(
Brand.name == name,
Brand.user_id == organization_id,
)
)
return result.scalar_one_or_none()
async def create(self, **kwargs) -> Brand:
instance = Brand(**kwargs)
self.session.add(instance)
await self.session.flush()
return instance
async def update(self, id: uuid.UUID, **kwargs) -> Optional[Brand]:
instance = await self.get_by_id(id)
if instance is None:
return None
for key, value in kwargs.items():
if hasattr(instance, key):
setattr(instance, key, value)
await self.session.flush()
return instance
async def delete(self, id: uuid.UUID) -> bool:
instance = await self.get_by_id(id)
if instance is None:
return False
await self.session.delete(instance)
await self.session.flush()
return True

View File

@ -0,0 +1,84 @@
import uuid
from typing import Optional
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.citation_record import CitationRecord
class CitationRepository:
def __init__(self, session: AsyncSession):
self.session = session
async def get_by_id(self, id: uuid.UUID) -> Optional[CitationRecord]:
result = await self.session.execute(
select(CitationRecord).where(CitationRecord.id == id)
)
return result.scalar_one_or_none()
async def list_by_query(
self, query_id: uuid.UUID, *, skip: int = 0, limit: int = 100
) -> list[CitationRecord]:
result = await self.session.execute(
select(CitationRecord)
.where(CitationRecord.query_id == query_id)
.order_by(CitationRecord.queried_at.desc())
.offset(skip)
.limit(limit)
)
return list(result.scalars().all())
async def count_by_query(self, query_id: uuid.UUID) -> int:
result = await self.session.execute(
select(func.count()).select_from(CitationRecord).where(
CitationRecord.query_id == query_id
)
)
return result.scalar_one()
async def get_by_query_and_platform(
self, query_id: uuid.UUID, platform: str
) -> Optional[CitationRecord]:
result = await self.session.execute(
select(CitationRecord).where(
CitationRecord.query_id == query_id,
CitationRecord.platform == platform,
)
)
return result.scalar_one_or_none()
async def count_cited_by_brand(self, brand_name: str) -> int:
result = await self.session.execute(
select(func.count())
.select_from(CitationRecord)
.join(CitationRecord.query)
.where(
CitationRecord.cited == True,
)
)
return result.scalar_one()
async def create(self, **kwargs) -> CitationRecord:
instance = CitationRecord(**kwargs)
self.session.add(instance)
await self.session.flush()
return instance
async def update(self, id: uuid.UUID, **kwargs) -> Optional[CitationRecord]:
instance = await self.get_by_id(id)
if instance is None:
return None
for key, value in kwargs.items():
if hasattr(instance, key):
setattr(instance, key, value)
await self.session.flush()
return instance
async def delete(self, id: uuid.UUID) -> bool:
instance = await self.get_by_id(id)
if instance is None:
return False
await self.session.delete(instance)
await self.session.flush()
return True

View File

@ -0,0 +1,71 @@
import uuid
from typing import Optional
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.competitor import Competitor
class CompetitorRepository:
def __init__(self, session: AsyncSession):
self.session = session
async def get_by_id(self, id: uuid.UUID) -> Optional[Competitor]:
result = await self.session.execute(
select(Competitor).where(Competitor.id == id)
)
return result.scalar_one_or_none()
async def list_by_brand(
self, brand_id: uuid.UUID, *, skip: int = 0, limit: int = 100
) -> list[Competitor]:
result = await self.session.execute(
select(Competitor)
.where(Competitor.brand_id == brand_id)
.order_by(Competitor.created_at.desc())
.offset(skip)
.limit(limit)
)
return list(result.scalars().all())
async def count_by_brand(self, brand_id: uuid.UUID) -> int:
result = await self.session.execute(
select(func.count()).select_from(Competitor).where(
Competitor.brand_id == brand_id
)
)
return result.scalar_one()
async def get_by_brand(self, brand_name: str) -> list[Competitor]:
from app.models.brand import Brand
result = await self.session.execute(
select(Competitor)
.join(Brand, Competitor.brand_id == Brand.id)
.where(Brand.name == brand_name)
)
return list(result.scalars().all())
async def create(self, **kwargs) -> Competitor:
instance = Competitor(**kwargs)
self.session.add(instance)
await self.session.flush()
return instance
async def update(self, id: uuid.UUID, **kwargs) -> Optional[Competitor]:
instance = await self.get_by_id(id)
if instance is None:
return None
for key, value in kwargs.items():
if hasattr(instance, key):
setattr(instance, key, value)
await self.session.flush()
return instance
async def delete(self, id: uuid.UUID) -> bool:
instance = await self.get_by_id(id)
if instance is None:
return False
await self.session.delete(instance)
await self.session.flush()
return True

View File

@ -0,0 +1,68 @@
import uuid
from typing import Optional
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.content import Content
class ContentRepository:
def __init__(self, session: AsyncSession):
self.session = session
async def get_by_id(self, id: uuid.UUID) -> Optional[Content]:
result = await self.session.execute(
select(Content).where(Content.id == id)
)
return result.scalar_one_or_none()
async def list_by_organization(
self, organization_id: uuid.UUID, *, skip: int = 0, limit: int = 100
) -> list[Content]:
result = await self.session.execute(
select(Content)
.where(Content.organization_id == organization_id)
.order_by(Content.created_at.desc())
.offset(skip)
.limit(limit)
)
return list(result.scalars().all())
async def count_by_organization(self, organization_id: uuid.UUID) -> int:
result = await self.session.execute(
select(func.count()).select_from(Content).where(
Content.organization_id == organization_id
)
)
return result.scalar_one()
async def get_by_status(self, status: str) -> list[Content]:
result = await self.session.execute(
select(Content).where(Content.status == status)
)
return list(result.scalars().all())
async def create(self, **kwargs) -> Content:
instance = Content(**kwargs)
self.session.add(instance)
await self.session.flush()
return instance
async def update(self, id: uuid.UUID, **kwargs) -> Optional[Content]:
instance = await self.get_by_id(id)
if instance is None:
return None
for key, value in kwargs.items():
if hasattr(instance, key):
setattr(instance, key, value)
await self.session.flush()
return instance
async def delete(self, id: uuid.UUID) -> bool:
instance = await self.get_by_id(id)
if instance is None:
return False
await self.session.delete(instance)
await self.session.flush()
return True

View File

@ -0,0 +1,68 @@
import uuid
from typing import Optional
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.detection_task import DetectionTask
class DetectionTaskRepository:
def __init__(self, session: AsyncSession):
self.session = session
async def get_by_id(self, id: uuid.UUID) -> Optional[DetectionTask]:
result = await self.session.execute(
select(DetectionTask).where(DetectionTask.id == id)
)
return result.scalar_one_or_none()
async def list_by_user(
self, user_id: uuid.UUID, *, skip: int = 0, limit: int = 100
) -> list[DetectionTask]:
result = await self.session.execute(
select(DetectionTask)
.where(DetectionTask.user_id == user_id)
.order_by(DetectionTask.created_at.desc())
.offset(skip)
.limit(limit)
)
return list(result.scalars().all())
async def count_by_user(self, user_id: uuid.UUID) -> int:
result = await self.session.execute(
select(func.count()).select_from(DetectionTask).where(
DetectionTask.user_id == user_id
)
)
return result.scalar_one()
async def get_active_tasks(self) -> list[DetectionTask]:
result = await self.session.execute(
select(DetectionTask).where(DetectionTask.is_active == True)
)
return list(result.scalars().all())
async def create(self, **kwargs) -> DetectionTask:
instance = DetectionTask(**kwargs)
self.session.add(instance)
await self.session.flush()
return instance
async def update(self, id: uuid.UUID, **kwargs) -> Optional[DetectionTask]:
instance = await self.get_by_id(id)
if instance is None:
return None
for key, value in kwargs.items():
if hasattr(instance, key):
setattr(instance, key, value)
await self.session.flush()
return instance
async def delete(self, id: uuid.UUID) -> bool:
instance = await self.get_by_id(id)
if instance is None:
return False
await self.session.delete(instance)
await self.session.flush()
return True

View File

@ -0,0 +1,70 @@
import uuid
from typing import Optional
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.knowledge import KnowledgeBase
class KnowledgeRepository:
def __init__(self, session: AsyncSession):
self.session = session
async def get_by_id(self, id: uuid.UUID) -> Optional[KnowledgeBase]:
result = await self.session.execute(
select(KnowledgeBase).where(KnowledgeBase.id == id)
)
return result.scalar_one_or_none()
async def list_by_organization(
self, organization_id: uuid.UUID, *, skip: int = 0, limit: int = 100
) -> list[KnowledgeBase]:
result = await self.session.execute(
select(KnowledgeBase)
.where(KnowledgeBase.organization_id == organization_id)
.order_by(KnowledgeBase.created_at.desc())
.offset(skip)
.limit(limit)
)
return list(result.scalars().all())
async def count_by_organization(self, organization_id: uuid.UUID) -> int:
result = await self.session.execute(
select(func.count()).select_from(KnowledgeBase).where(
KnowledgeBase.organization_id == organization_id
)
)
return result.scalar_one()
async def get_by_organization(self, organization_id: uuid.UUID) -> list[KnowledgeBase]:
result = await self.session.execute(
select(KnowledgeBase).where(
KnowledgeBase.organization_id == organization_id
)
)
return list(result.scalars().all())
async def create(self, **kwargs) -> KnowledgeBase:
instance = KnowledgeBase(**kwargs)
self.session.add(instance)
await self.session.flush()
return instance
async def update(self, id: uuid.UUID, **kwargs) -> Optional[KnowledgeBase]:
instance = await self.get_by_id(id)
if instance is None:
return None
for key, value in kwargs.items():
if hasattr(instance, key):
setattr(instance, key, value)
await self.session.flush()
return instance
async def delete(self, id: uuid.UUID) -> bool:
instance = await self.get_by_id(id)
if instance is None:
return False
await self.session.delete(instance)
await self.session.flush()
return True

View File

@ -0,0 +1,75 @@
import uuid
from typing import Optional
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.organization import Organization
class OrganizationRepository:
def __init__(self, session: AsyncSession):
self.session = session
async def get_by_id(self, id: uuid.UUID) -> Optional[Organization]:
result = await self.session.execute(
select(Organization).where(Organization.id == id)
)
return result.scalar_one_or_none()
async def get_by_slug(self, slug: str) -> Optional[Organization]:
result = await self.session.execute(
select(Organization).where(Organization.slug == slug)
)
return result.scalar_one_or_none()
async def list_all(self, *, skip: int = 0, limit: int = 100) -> list[Organization]:
result = await self.session.execute(
select(Organization)
.order_by(Organization.created_at.desc())
.offset(skip)
.limit(limit)
)
return list(result.scalars().all())
async def count_all(self) -> int:
result = await self.session.execute(
select(func.count()).select_from(Organization)
)
return result.scalar_one()
async def get_by_owner(self, user_id: str) -> list[Organization]:
from app.models.organization import OrgMember
result = await self.session.execute(
select(Organization)
.join(OrgMember, OrgMember.organization_id == Organization.id)
.where(
OrgMember.user_id == user_id,
OrgMember.role == "owner",
)
)
return list(result.scalars().all())
async def create(self, **kwargs) -> Organization:
instance = Organization(**kwargs)
self.session.add(instance)
await self.session.flush()
return instance
async def update(self, id: uuid.UUID, **kwargs) -> Optional[Organization]:
instance = await self.get_by_id(id)
if instance is None:
return None
for key, value in kwargs.items():
if hasattr(instance, key):
setattr(instance, key, value)
await self.session.flush()
return instance
async def delete(self, id: uuid.UUID) -> bool:
instance = await self.get_by_id(id)
if instance is None:
return False
await self.session.delete(instance)
await self.session.flush()
return True

View File

@ -0,0 +1,72 @@
import uuid
from typing import Optional
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.query import Query
class QueryRepository:
def __init__(self, session: AsyncSession):
self.session = session
async def get_by_id(self, id: uuid.UUID) -> Optional[Query]:
result = await self.session.execute(
select(Query).where(Query.id == id)
)
return result.scalar_one_or_none()
async def list_by_user(
self, user_id: str, *, skip: int = 0, limit: int = 100
) -> list[Query]:
result = await self.session.execute(
select(Query)
.where(Query.user_id == user_id)
.order_by(Query.created_at.desc())
.offset(skip)
.limit(limit)
)
return list(result.scalars().all())
async def count_by_user(self, user_id: str) -> int:
result = await self.session.execute(
select(func.count()).select_from(Query).where(Query.user_id == user_id)
)
return result.scalar_one()
async def get_by_brand(self, brand_name: str) -> list[Query]:
result = await self.session.execute(
select(Query).where(Query.target_brand == brand_name)
)
return list(result.scalars().all())
async def get_active_queries(self) -> list[Query]:
result = await self.session.execute(
select(Query).where(Query.status == "active")
)
return list(result.scalars().all())
async def create(self, **kwargs) -> Query:
instance = Query(**kwargs)
self.session.add(instance)
await self.session.flush()
return instance
async def update(self, id: uuid.UUID, **kwargs) -> Optional[Query]:
instance = await self.get_by_id(id)
if instance is None:
return None
for key, value in kwargs.items():
if hasattr(instance, key):
setattr(instance, key, value)
await self.session.flush()
return instance
async def delete(self, id: uuid.UUID) -> bool:
instance = await self.get_by_id(id)
if instance is None:
return False
await self.session.delete(instance)
await self.session.flush()
return True

View File

@ -0,0 +1,70 @@
import uuid
from typing import Optional
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.subscription import Subscription
class SubscriptionRepository:
def __init__(self, session: AsyncSession):
self.session = session
async def get_by_id(self, id: uuid.UUID) -> Optional[Subscription]:
result = await self.session.execute(
select(Subscription).where(Subscription.id == id)
)
return result.scalar_one_or_none()
async def list_by_user(
self, user_id: str, *, skip: int = 0, limit: int = 100
) -> list[Subscription]:
result = await self.session.execute(
select(Subscription)
.where(Subscription.user_id == user_id)
.order_by(Subscription.created_at.desc())
.offset(skip)
.limit(limit)
)
return list(result.scalars().all())
async def count_by_user(self, user_id: str) -> int:
result = await self.session.execute(
select(func.count()).select_from(Subscription).where(
Subscription.user_id == user_id
)
)
return result.scalar_one()
async def get_by_organization(self, organization_id: uuid.UUID) -> list[Subscription]:
result = await self.session.execute(
select(Subscription).where(
Subscription.user_id == organization_id
)
)
return list(result.scalars().all())
async def create(self, **kwargs) -> Subscription:
instance = Subscription(**kwargs)
self.session.add(instance)
await self.session.flush()
return instance
async def update(self, id: uuid.UUID, **kwargs) -> Optional[Subscription]:
instance = await self.get_by_id(id)
if instance is None:
return None
for key, value in kwargs.items():
if hasattr(instance, key):
setattr(instance, key, value)
await self.session.flush()
return instance
async def delete(self, id: uuid.UUID) -> bool:
instance = await self.get_by_id(id)
if instance is None:
return False
await self.session.delete(instance)
await self.session.flush()
return True

View File

@ -0,0 +1,71 @@
import uuid
from typing import Optional
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.suggestion import Suggestion
class SuggestionRepository:
def __init__(self, session: AsyncSession):
self.session = session
async def get_by_id(self, id: uuid.UUID) -> Optional[Suggestion]:
result = await self.session.execute(
select(Suggestion).where(Suggestion.id == id)
)
return result.scalar_one_or_none()
async def list_by_brand(
self, brand_id: uuid.UUID, *, skip: int = 0, limit: int = 100
) -> list[Suggestion]:
result = await self.session.execute(
select(Suggestion)
.where(Suggestion.brand_id == brand_id)
.order_by(Suggestion.generated_at.desc())
.offset(skip)
.limit(limit)
)
return list(result.scalars().all())
async def count_by_brand(self, brand_id: uuid.UUID) -> int:
result = await self.session.execute(
select(func.count()).select_from(Suggestion).where(
Suggestion.brand_id == brand_id
)
)
return result.scalar_one()
async def get_by_brand(self, brand_name: str) -> list[Suggestion]:
from app.models.brand import Brand
result = await self.session.execute(
select(Suggestion)
.join(Brand, Suggestion.brand_id == Brand.id)
.where(Brand.name == brand_name)
)
return list(result.scalars().all())
async def create(self, **kwargs) -> Suggestion:
instance = Suggestion(**kwargs)
self.session.add(instance)
await self.session.flush()
return instance
async def update(self, id: uuid.UUID, **kwargs) -> Optional[Suggestion]:
instance = await self.get_by_id(id)
if instance is None:
return None
for key, value in kwargs.items():
if hasattr(instance, key):
setattr(instance, key, value)
await self.session.flush()
return instance
async def delete(self, id: uuid.UUID) -> bool:
instance = await self.get_by_id(id)
if instance is None:
return False
await self.session.delete(instance)
await self.session.flush()
return True

View File

@ -0,0 +1,62 @@
from typing import Optional
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.user import User
class UserRepository:
def __init__(self, session: AsyncSession):
self.session = session
async def get_by_id(self, id: str) -> Optional[User]:
result = await self.session.execute(
select(User).where(User.id == id)
)
return result.scalar_one_or_none()
async def get_by_email(self, email: str) -> Optional[User]:
result = await self.session.execute(
select(User).where(User.email == email)
)
return result.scalar_one_or_none()
async def list_all(self, *, skip: int = 0, limit: int = 100) -> list[User]:
result = await self.session.execute(
select(User)
.order_by(User.createdAt.desc())
.offset(skip)
.limit(limit)
)
return list(result.scalars().all())
async def count_all(self) -> int:
result = await self.session.execute(
select(func.count()).select_from(User)
)
return result.scalar_one()
async def create(self, **kwargs) -> User:
instance = User(**kwargs)
self.session.add(instance)
await self.session.flush()
return instance
async def update(self, id: str, **kwargs) -> Optional[User]:
instance = await self.get_by_id(id)
if instance is None:
return None
for key, value in kwargs.items():
if hasattr(instance, key):
setattr(instance, key, value)
await self.session.flush()
return instance
async def delete(self, id: str) -> bool:
instance = await self.get_by_id(id)
if instance is None:
return False
await self.session.delete(instance)
await self.session.flush()
return True

View File

@ -0,0 +1,48 @@
import uuid
from datetime import datetime
from typing import Optional
from pydantic import BaseModel, Field
class CompetitorAnalysisRequest(BaseModel):
brand_id: uuid.UUID = Field(..., description="品牌ID")
analysis_types: list[str] | None = Field(
None,
description="分析类型列表: citation_gap/content_strategy/platform_coverage/query_overlap/differentiation",
)
period_days: int | None = Field(30, description="分析周期天数")
class CompetitorInsightResponse(BaseModel):
id: uuid.UUID = Field(..., description="洞察ID")
brand_id: uuid.UUID = Field(..., description="品牌ID")
competitor_name: str = Field(..., description="竞品名称")
analysis_type: str = Field(..., description="分析类型")
insight_data: dict | None = Field(None, description="洞察数据")
citation_count_brand: int = Field(0, description="品牌引用次数")
citation_count_competitor: int = Field(0, description="竞品引用次数")
sentiment_brand: float | None = Field(None, description="品牌情感分数")
sentiment_competitor: float | None = Field(None, description="竞品情感分数")
platform_breakdown: dict | None = Field(None, description="平台分布")
gap_analysis: dict | None = Field(None, description="差距分析")
opportunity_areas: dict | None = Field(None, description="机会领域")
recommendations: dict | None = Field(None, description="策略建议")
confidence: str = Field("medium", description="置信度: high/medium/low")
period_days: int = Field(30, description="分析周期天数")
created_at: datetime = Field(..., description="创建时间")
updated_at: datetime = Field(..., description="更新时间")
model_config = {"from_attributes": True}
class CompetitorInsightList(BaseModel):
items: list[CompetitorInsightResponse] = Field(default_factory=list, description="洞察列表")
total: int = Field(0, description="总数")
class CompetitorGapSummary(BaseModel):
brand_name: str = Field(..., description="品牌名称")
competitor_name: str = Field(..., description="竞品名称")
gap_dimensions: list[dict] = Field(default_factory=list, description="差距维度列表")
overall_gap_score: float = Field(0.0, description="综合差距分数(0-100)")

View File

@ -0,0 +1,71 @@
from __future__ import annotations
from pydantic import BaseModel, Field
class GEODiagnosisTriggerRequest(BaseModel):
force_refresh: bool = Field(default=False)
class GEODiagnosisTaskResponse(BaseModel):
task_id: str
brand_id: str
status: str
class GEODimensionItemResponse(BaseModel):
name: str
status: str
description: str
suggestion: str
score: float
max_score: float
class GEODimensionResponse(BaseModel):
name: str
score: float
max_score: float
percentage: float
status: str
items: list[GEODimensionItemResponse]
detail: dict
class GEORecommendationResponse(BaseModel):
priority: str
dimension: str
title: str
description: str
impact: str
effort: str
class GEODiagnosisResponse(BaseModel):
overall_score: float
health_level: str
health_level_label: str
dimensions: list[GEODimensionResponse]
recommendations: list[GEORecommendationResponse]
is_full_report: bool = False
class GEODiagnosisResultResponse(BaseModel):
task_id: str
brand_id: str
status: str
result: GEODiagnosisResponse | None = None
error: str | None = None
class GEODiagnosisHistoryItem(BaseModel):
task_id: str
overall_score: float
health_level: str
created_at: str
completed_at: str | None = None
class GEODiagnosisHistoryResponse(BaseModel):
brand_id: str
history: list[GEODiagnosisHistoryItem]

View File

@ -0,0 +1,69 @@
import uuid
from datetime import datetime
from pydantic import BaseModel, Field
class GeoPlanGenerateRequest(BaseModel):
brand_id: uuid.UUID = Field(..., description="品牌ID")
target_score: int | None = Field(75, description="目标评分")
class GeoPlanActionResponse(BaseModel):
id: uuid.UUID = Field(..., description="行动项ID")
plan_id: uuid.UUID = Field(..., description="方案ID")
action_type: str = Field(..., description="行动类型")
title: str = Field(..., description="行动标题")
description: str = Field(..., description="详细描述")
reason: str = Field(..., description="基于诊断数据的原因")
priority: str = Field(..., description="优先级: high/medium/low")
status: str = Field(..., description="状态: pending/in_progress/completed/skipped")
target_keyword: str | None = Field(None, description="预填关键词")
target_platform: str | None = Field(None, description="预填平台")
content_style: str | None = Field(None, description="预填风格")
estimated_impact: str | None = Field(None, description="预期效果")
difficulty: str = Field(..., description="难度: easy/medium/hard")
execution_params: dict | None = Field(None, description="一键执行参数")
sort_order: int = Field(0, description="排序序号")
completed_at: datetime | None = Field(None, description="完成时间")
created_at: datetime = Field(..., description="创建时间")
class Config:
from_attributes = True
class GeoPlanResponse(BaseModel):
id: uuid.UUID = Field(..., description="方案ID")
brand_id: uuid.UUID = Field(..., description="品牌ID")
title: str = Field(..., description="方案标题")
status: str = Field(..., description="状态: draft/active/completed/archived")
diagnosis_score: int = Field(..., description="诊断评分")
target_score: int = Field(..., description="目标评分")
estimated_weeks: int = Field(..., description="预计周数")
plan_data: dict | None = Field(None, description="方案详细数据")
source: str = Field(..., description="生成来源: rule/llm")
actions: list[GeoPlanActionResponse] = Field(
default_factory=list, description="行动项列表"
)
created_at: datetime = Field(..., description="创建时间")
updated_at: datetime = Field(..., description="更新时间")
class Config:
from_attributes = True
class GeoPlanListResponse(BaseModel):
plans: list[GeoPlanResponse] = Field(default_factory=list, description="方案列表")
total: int = Field(..., description="总数")
class GeoPlanActionUpdateStatus(BaseModel):
status: str = Field(
..., description="新状态: pending/in_progress/completed/skipped"
)
class GeoPlanActionExecuteResponse(BaseModel):
action_id: uuid.UUID = Field(..., description="行动项ID")
content_id: str | None = Field(None, description="生成的内容ID")
message: str = Field(..., description="执行结果消息")

View File

@ -0,0 +1,34 @@
from __future__ import annotations
from pydantic import BaseModel, Field
class HealthScoreDimension(BaseModel):
name: str
score: float
max_score: float
percentage: float
status: str
class HealthScoreRecommendation(BaseModel):
priority: str
dimension: str
title: str
description: str
class HealthScoreResponse(BaseModel):
brand_name: str
overall_score: float
health_level: str
health_level_label: str
dimensions: list[HealthScoreDimension]
recommendations: list[HealthScoreRecommendation]
is_full_report: bool = False
cached: bool = False
class HealthScoreRequest(BaseModel):
brand: str
competitors: list[str] = Field(default_factory=list)

View File

@ -0,0 +1,72 @@
import uuid
from datetime import datetime
from pydantic import BaseModel, Field
class MonitoringRecordCreate(BaseModel):
brand_id: uuid.UUID = Field(..., description="品牌ID")
content_id: str | None = Field(None, description="内容ID")
query_keywords: str | None = Field(None, description="查询关键词")
platform: str | None = Field(None, description="平台")
check_interval_hours: int = Field(24, description="检测间隔(小时)")
class MonitoringRecordResponse(BaseModel):
id: uuid.UUID = Field(..., description="记录ID")
brand_id: uuid.UUID = Field(..., description="品牌ID")
content_id: str | None = Field(None, description="内容ID")
query_keywords: str | None = Field(None, description="查询关键词")
platform: str | None = Field(None, description="平台")
baseline_citation_count: int = Field(0, description="基线引用量")
baseline_sentiment: float | None = Field(None, description="基线情感分数")
baseline_rank: int | None = Field(None, description="基线排名")
current_citation_count: int = Field(0, description="当前引用量")
current_sentiment: float | None = Field(None, description="当前情感分数")
current_rank: int | None = Field(None, description="当前排名")
change_type: str | None = Field(None, description="变化类型: positive/negative/neutral")
change_details: dict | None = Field(None, description="变化详情")
check_interval_hours: int = Field(24, description="检测间隔(小时)")
last_checked_at: datetime | None = Field(None, description="上次检测时间")
next_check_at: datetime | None = Field(None, description="下次检测时间")
status: str = Field("active", description="状态: active/paused/completed")
created_at: datetime = Field(..., description="创建时间")
updated_at: datetime = Field(..., description="更新时间")
class Config:
from_attributes = True
class MonitoringRecordList(BaseModel):
records: list[MonitoringRecordResponse] = Field(default_factory=list, description="监测记录列表")
total: int = Field(..., description="总数")
class MonitoringChangeReport(BaseModel):
monitoring_record_id: uuid.UUID = Field(..., description="监测记录ID")
brand_id: uuid.UUID = Field(..., description="品牌ID")
change_type: str = Field(..., description="变化类型: positive/negative/neutral")
change_details: dict | None = Field(None, description="变化详情")
baseline: dict = Field(default_factory=dict, description="基线数据")
current: dict = Field(default_factory=dict, description="当前数据")
recommendations: list[str] = Field(default_factory=list, description="建议")
class ContentBaselineResponse(BaseModel):
id: uuid.UUID = Field(..., description="基线ID")
monitoring_record_id: uuid.UUID = Field(..., description="监测记录ID")
brand_name: str = Field(..., description="品牌名称")
keyword: str = Field(..., description="关键词")
platform: str = Field(..., description="平台")
citation_count: int = Field(0, description="引用量")
sentiment_score: float | None = Field(None, description="情感分数")
rank_position: int | None = Field(None, description="排名位置")
snapshot_data: dict | None = Field(None, description="快照数据")
recorded_at: datetime = Field(..., description="记录时间")
class Config:
from_attributes = True
class MonitoringStatusUpdate(BaseModel):
status: str = Field(..., description="新状态: active/paused/completed")

View File

@ -0,0 +1,45 @@
import uuid
from datetime import datetime
from pydantic import BaseModel, Field
class SchemaAdviseRequest(BaseModel):
brand_id: uuid.UUID = Field(..., description="品牌ID")
target_url: str | None = Field(None, description="目标页面URL")
focus_dimensions: list[str] | None = Field(None, description="聚焦的诊断维度")
class SchemaSuggestionResponse(BaseModel):
id: uuid.UUID = Field(..., description="建议ID")
brand_id: uuid.UUID = Field(..., description="品牌ID")
schema_type: str = Field(..., description="Schema类型: Organization/Product/FAQPage/Article/LocalBusiness")
target_url: str | None = Field(None, description="目标页面URL")
json_ld_template: dict = Field(..., description="JSON-LD模板")
json_ld_filled: dict | None = Field(None, description="填充后的JSON-LD")
priority: str = Field(default="medium", description="优先级: high/medium/low")
status: str = Field(default="pending", description="状态: pending/applied/dismissed")
diagnosis_dimensions: dict | None = Field(None, description="诊断维度数据")
implementation_difficulty: str = Field(default="medium", description="实施难度: easy/medium/hard")
estimated_impact: str | None = Field(None, description="预期影响描述")
validation_errors: dict | None = Field(None, description="验证错误信息")
created_at: datetime = Field(..., description="创建时间")
updated_at: datetime = Field(..., description="更新时间")
class Config:
from_attributes = True
class SchemaSuggestionList(BaseModel):
suggestions: list[SchemaSuggestionResponse] = Field(default_factory=list, description="建议列表")
total: int = Field(..., description="总数")
class SchemaValidationResult(BaseModel):
is_valid: bool = Field(..., description="是否有效")
errors: list[str] = Field(default_factory=list, description="错误列表")
warnings: list[str] = Field(default_factory=list, description="警告列表")
class SchemaStatusUpdateRequest(BaseModel):
status: str = Field(..., description="新状态: pending/applied/dismissed")

View File

@ -0,0 +1,48 @@
import uuid
from datetime import datetime
from pydantic import BaseModel, Field
class TrendInsightRequest(BaseModel):
brand_id: uuid.UUID = Field(..., description="品牌ID")
period_days: int = Field(30, ge=7, le=365, description="分析周期天数")
platforms: list[str] | None = Field(None, description="筛选平台列表")
keywords: list[str] | None = Field(None, description="筛选关键词列表")
class TrendInsightResponse(BaseModel):
id: uuid.UUID
brand_id: uuid.UUID
trend_type: str
keyword: str | None
platform: str | None
period_start: datetime
period_end: datetime
data_points: list | None
change_rate: float | None
absolute_change: int | None
sentiment_trend: dict | None
cause_analysis: str | None
recommendations: list | None
confidence: float
severity: str
created_at: datetime
updated_at: datetime
model_config = {"from_attributes": True}
class TrendInsightList(BaseModel):
items: list[TrendInsightResponse]
total: int
class TrendSummary(BaseModel):
brand_id: uuid.UUID
period_days: int
rising_count: int = 0
declining_count: int = 0
hotspot_count: int = 0
top_keywords: list[str] = Field(default_factory=list)
platform_highlights: dict = Field(default_factory=dict)

View File

@ -0,0 +1,17 @@
from .optimization_advisor import (
generate_suggestions,
generate_rule_based_suggestions,
generate_llm_suggestions,
build_context_from_scoring_result,
SuggestionItem,
BrandAnalysisContext,
)
__all__ = [
"generate_suggestions",
"generate_rule_based_suggestions",
"generate_llm_suggestions",
"build_context_from_scoring_result",
"SuggestionItem",
"BrandAnalysisContext",
]

View File

@ -25,7 +25,8 @@ from dataclasses import dataclass, field
from typing import Any
from app.config import settings
from app.services.scoring_service import ScoringResultV2
from app.services.scoring.scoring_service import ScoringResultV2
from app.utils.json_extractor import extract_json
logger = logging.getLogger(__name__)
@ -525,7 +526,7 @@ async def generate_llm_suggestions(
raise ValueError("LLM返回空响应")
# 提取JSON
json_str = _extract_json(content)
json_str = extract_json(content)
result = json.loads(json_str)
# 解析建议
@ -573,32 +574,6 @@ async def generate_llm_suggestions(
return generate_rule_based_suggestions(ctx)
def _extract_json(text: str) -> str:
"""从文本中提取JSON"""
import re
# 尝试直接解析
try:
json.loads(text)
return text
except json.JSONDecodeError:
pass
# 尝试从代码块中提取
json_pattern = r'```(?:json)?\s*([\s\S]*?)\s*```'
match = re.search(json_pattern, text)
if match:
return match.group(1).strip()
# 尝试找到第一个{到最后一个}之间的内容
first_brace = text.find('{')
last_brace = text.rfind('}')
if first_brace != -1 and last_brace != -1 and last_brace > first_brace:
return text[first_brace:last_brace + 1]
raise ValueError(f"无法从响应中提取JSON: {text[:200]}")
# ============================================================
# 主入口:生成优化建议
# ============================================================

View File

@ -6,6 +6,10 @@ from .doubao import DoubaoAdapter
from .gemini import GeminiAdapter
from .kimi import KimiAdapter
from .perplexity import PerplexityAdapter
from .platform_bridge import (
execute_single_platform,
query_platform_raw,
)
from .qwen import QwenAdapter
from .wenxin import WenxinAdapter
from .yuanbao import YuanbaoAdapter
@ -25,4 +29,6 @@ __all__ = [
"QwenAdapter",
"GeminiAdapter",
"BatchQueryService",
"execute_single_platform",
"query_platform_raw",
]

View File

@ -0,0 +1,295 @@
import logging
import os
import re
import time
from urllib.parse import quote
from datetime import UTC, datetime
import httpx
from .base import AIEngineAdapter, AIQueryResult, EngineType
logger = logging.getLogger(__name__)
_PLATFORM_NAME_MAP: dict[str, EngineType] = {
"wenxin": EngineType.WENXIN,
"kimi": EngineType.KIMI,
"doubao": EngineType.DOUBAO,
"tongyi": EngineType.QWEN,
"deepseek": EngineType.DEEPSEEK,
"chatgpt": EngineType.CHATGPT,
"perplexity": EngineType.PERPLEXITY,
"gemini": EngineType.GEMINI,
"yuanbao": EngineType.YUANBAO,
}
_SEARCH_ONLY_PLATFORMS = {"qingyan", "tiangong", "xinghuo"}
def get_engine_type_for_platform(platform_name: str) -> EngineType | None:
return _PLATFORM_NAME_MAP.get(platform_name)
def is_search_only_platform(platform_name: str) -> bool:
return platform_name in _SEARCH_ONLY_PLATFORMS
async def search_wikipedia(keyword: str, max_chars: int = 2000) -> str:
search_url = "https://zh.wikipedia.org/w/api.php"
headers = {
"User-Agent": "GEO-Citation-Bot/1.0 (contact@example.com)",
}
async with httpx.AsyncClient(timeout=30) as client:
search_resp = await client.get(
search_url,
headers=headers,
params={
"action": "query",
"list": "search",
"srsearch": keyword,
"srlimit": 3,
"format": "json",
"origin": "*",
},
)
search_resp.raise_for_status()
search_data = search_resp.json()
search_results = search_data.get("query", {}).get("search", [])
if not search_results:
return ""
title = search_results[0]["title"]
async with httpx.AsyncClient(timeout=30) as client:
extract_resp = await client.get(
search_url,
headers=headers,
params={
"action": "query",
"prop": "extracts",
"titles": title,
"explaintext": True,
"exsentences": 15,
"format": "json",
"origin": "*",
},
)
extract_resp.raise_for_status()
extract_data = extract_resp.json()
pages = extract_data.get("query", {}).get("pages", {})
for page in pages.values():
extract = page.get("extract", "")
if extract:
extract = re.sub(r'\[\d+\]', '', extract)
extract = re.sub(r'\s+', ' ', extract).strip()
return extract[:max_chars]
return ""
def _strip_html(raw: str) -> str:
raw = raw.replace("&nbsp;", " ")
raw = raw.replace("&quot;", '"')
raw = raw.replace("&amp;", "&")
raw = raw.replace("&lt;", "<")
raw = raw.replace("&gt;", ">")
raw = raw.replace("&#39;", "'")
text = re.sub(r"<[^>]+>", "", raw)
text = re.sub(r"\s+", " ", text).strip()
return text
async def search_duckduckgo(query: str, max_results: int = 5) -> str:
url = f"https://html.duckduckgo.com/html/?q={quote(query)}"
headers = {
"User-Agent": (
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
"AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36"
),
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
"Accept-Language": "zh-CN,zh;q=0.9,en-US;q=0.8,en;q=0.7",
}
try:
async with httpx.AsyncClient(timeout=30, follow_redirects=True) as client:
resp = await client.get(url, headers=headers)
resp.raise_for_status()
html = resp.text
if "web-result" not in html and "result__snippet" not in html and "result__title" not in html:
raise RuntimeError("DuckDuckGo 返回了非结果页面")
results: list[str] = []
result_blocks = re.findall(
r'<div class="result[^"]*"[^>]*>.*?<h[^>]*class="result__title"[^>]*>.*?<a[^>]*>(.*?)</a>.*?</h[^>]*>.*?<a[^>]*class="result__snippet"[^>]*>(.*?)</a>.*?</div>',
html,
re.DOTALL | re.IGNORECASE,
)
if result_blocks:
for title_raw, snippet_raw in result_blocks[:max_results]:
title = _strip_html(title_raw)
snippet = _strip_html(snippet_raw)
if title or snippet:
results.append(f"{title}\n{snippet}")
if not results:
snippets = re.findall(
r'<a[^>]*class="result__snippet"[^>]*>(.*?)</a>', html, re.DOTALL | re.IGNORECASE
)
titles = re.findall(
r'<h[^>]*class="result__title"[^>]*>.*?<a[^>]*>(.*?)</a>.*?</h[^>]*>',
html,
re.DOTALL | re.IGNORECASE,
)
for i in range(min(len(titles), len(snippets), max_results)):
title = _strip_html(titles[i])
snippet = _strip_html(snippets[i])
if title or snippet:
results.append(f"{title}\n{snippet}")
if results:
return "\n\n".join(results)
raise RuntimeError("DuckDuckGo 未解析到结果")
except Exception as e:
logger.warning(f"DuckDuckGo 搜索失败: {e},回退到 Wikipedia")
wiki_text = await search_wikipedia(query, max_chars=2000)
if wiki_text:
return wiki_text
raise RuntimeError(f"所有搜索源均失败: {e}")
async def fetch_search_content(platform_name: str, keyword: str) -> str:
logger.info(f"[{platform_name}] 搜索查询: {keyword}")
return await search_duckduckgo(keyword, max_results=5)
class SearchOnlyAdapter(AIEngineAdapter):
def __init__(self, platform_name: str, **kwargs):
self._platform_name = platform_name
super().__init__(**kwargs)
def get_engine_type(self) -> EngineType:
return EngineType.DEEPSEEK
def _get_env_key(self) -> str | None:
return ""
async def query(
self,
query: str,
brand_name: str,
competitor_names: list[str] | None = None,
) -> AIQueryResult:
start_time = time.perf_counter()
content = await fetch_search_content(self._platform_name, query)
elapsed_ms = int((time.perf_counter() - start_time) * 1000)
has_brand, has_comp, brand_ctx, comp_ctx = self._detect_citations(
content, brand_name, competitor_names
)
return AIQueryResult(
engine_type=self.get_engine_type(),
query=query,
raw_response=content,
citations=[],
has_brand_citation=has_brand,
has_competitor_citation=has_comp,
brand_context=brand_ctx,
competitor_contexts=comp_ctx,
response_time_ms=elapsed_ms,
timestamp=datetime.now(UTC),
metadata={"platform_name": self._platform_name, "mode": "search_only"},
)
async def query_platform_raw(
platform_name: str,
keyword: str,
brand_name: str = "",
competitor_names: list[str] | None = None,
) -> str:
from .batch_query import _build_adapters
if is_search_only_platform(platform_name):
content = await fetch_search_content(platform_name, keyword)
return f"[data_source: search_engine]\n{content}"
engine_type = get_engine_type_for_platform(platform_name)
if engine_type is None:
raise ValueError(f"不支持的平台: {platform_name}")
adapters = _build_adapters()
adapter = adapters.get(engine_type.value)
if adapter is None:
raise ValueError(f"平台 {platform_name} 适配器未注册")
result = await adapter.query(keyword, brand_name, competitor_names)
return f"[data_source: ai_platform]\n{result.raw_response}"
_SUPPORTED_PLATFORMS = {
"wenxin", "kimi", "doubao", "tongyi",
"qingyan", "tiangong", "xinghuo",
}
async def execute_single_platform(
keyword: str,
platform: str,
target_brand: str,
brand_aliases: list,
) -> dict:
if platform not in _SUPPORTED_PLATFORMS:
raise ValueError(f"不支持的平台: {platform}")
from app.workers.citation_extractor import analyze_citations
search_keyword = f"{keyword} {target_brand}"
raw_response = await query_platform_raw(
platform_name=platform,
keyword=search_keyword,
brand_name=target_brand,
)
citation_analysis = analyze_citations(raw_response)
from app.workers.citation_engine import BrandMatcher, CompetitorDetector
matcher = BrandMatcher(target_brand=target_brand, brand_aliases=brand_aliases)
match_result = matcher.match(citation_analysis.clean_response)
competitor_detector = CompetitorDetector()
competitor_brands = competitor_detector.detect(
citation_analysis.clean_response, target_brand
)
source_urls = [
c.source_url for c in citation_analysis.citations if c.source_url
]
source_titles = [
c.source_title for c in citation_analysis.citations if c.source_title
]
citation_contexts = [
c.citation_context for c in citation_analysis.citations if c.citation_context
]
return {
"cited": match_result["cited"],
"confidence": match_result["confidence"],
"match_type": match_result["match_type"],
"position": match_result["position"],
"citation_text": match_result["citation_text"],
"competitor_brands": competitor_brands,
"raw_response": raw_response,
"data_source": citation_analysis.data_source,
"source_urls": source_urls,
"source_titles": source_titles,
"citation_contexts": citation_contexts,
"ai_response_text": citation_analysis.clean_response,
}

View File

@ -0,0 +1,11 @@
from .alert_engine import (
AlertEngine,
AlertContext,
DEFAULT_ALERT_CONFIGS,
)
__all__ = [
"AlertEngine",
"AlertContext",
"DEFAULT_ALERT_CONFIGS",
]

View File

@ -0,0 +1,13 @@
from .sentiment_service import (
SentimentAnalysisService,
SentimentResult,
SentimentCache,
get_sentiment_service,
)
__all__ = [
"SentimentAnalysisService",
"SentimentResult",
"SentimentCache",
"get_sentiment_service",
]

View File

@ -5,11 +5,11 @@ import asyncio
import hashlib
import json
import logging
import re
import time
from typing import Optional
from app.config import settings
from app.utils.json_extractor import extract_json
logger = logging.getLogger(__name__)
@ -276,31 +276,11 @@ class SentimentAnalysisService:
raise RuntimeError("API返回空响应")
# 提取JSON
json_str = self._extract_json(content)
return json.loads(json_str)
def _extract_json(self, text: str) -> str:
"""从文本中提取JSON"""
# 尝试直接解析
try:
json.loads(text)
return text
except json.JSONDecodeError:
pass
# 尝试从代码块中提取
json_pattern = r"```(?:json)?\s*([\s\S]*?)\s*```"
match = re.search(json_pattern, text)
if match:
return match.group(1).strip()
# 尝试找到第一个{到最后一个}之间的内容
first_brace = text.find("{")
last_brace = text.rfind("}")
if first_brace != -1 and last_brace != -1 and last_brace > first_brace:
return text[first_brace : last_brace + 1]
raise RuntimeError(f"无法从响应中提取JSON: {text[:200]}")
json_str = extract_json(content)
except ValueError as e:
raise RuntimeError(str(e)) from e
return json.loads(json_str)
def _parse_response(self, response: dict) -> SentimentResult:
"""解析API响应"""

View File

@ -0,0 +1,4 @@
from app.services.attribution.attribution_engine import AttributionEngine
from app.services.attribution.roi_calculator import ROICalculator
__all__ = ["AttributionEngine", "ROICalculator"]

View File

@ -0,0 +1,150 @@
import logging
import uuid
from datetime import UTC, datetime, timedelta
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.attribution_record import AttributionRecord
from app.models.diagnosis_record import DiagnosisRecord
logger = logging.getLogger(__name__)
class AttributionEngine:
async def start_tracking(
self,
db: AsyncSession,
brand_id: uuid.UUID,
content_id: uuid.UUID | None,
user_id: str,
) -> AttributionRecord:
baseline_score = await self._get_latest_score(db, brand_id)
now = datetime.now(UTC)
record = AttributionRecord(
user_id=user_id,
brand_id=brand_id,
content_id=content_id,
baseline_score=baseline_score,
published_at=now,
window_end_at=now + timedelta(days=28),
status="tracking",
)
db.add(record)
await db.commit()
await db.refresh(record)
return record
async def check_attribution(
self,
db: AsyncSession,
record_id: uuid.UUID,
) -> AttributionRecord:
stmt = select(AttributionRecord).where(AttributionRecord.id == record_id)
result = await db.execute(stmt)
record = result.scalar_one_or_none()
if not record:
raise ValueError(f"AttributionRecord {record_id} not found")
current_score = await self._get_latest_score(db, record.brand_id)
record.current_score = current_score
record.score_delta = round(current_score - record.baseline_score, 2)
baseline_dims = await self._get_latest_dimensions(db, record.brand_id, record.published_at)
current_dims = await self._get_latest_dimensions(db, record.brand_id, None)
if baseline_dims and current_dims:
record.attributed_dimensions = self._compute_dimension_deltas(
baseline_dims, current_dims
)
now = datetime.now(UTC)
if record.window_end_at:
window_end = record.window_end_at
if window_end.tzinfo is None:
window_end = window_end.replace(tzinfo=UTC)
if now >= window_end:
record.status = "completed"
elif record.score_delta and record.score_delta > 0:
record.status = "tracking"
await db.commit()
await db.refresh(record)
return record
async def get_brand_attribution_summary(
self,
db: AsyncSession,
brand_id: uuid.UUID,
) -> dict:
stmt = (
select(AttributionRecord)
.where(AttributionRecord.brand_id == brand_id)
.order_by(AttributionRecord.created_at.desc())
)
result = await db.execute(stmt)
records = result.scalars().all()
total_delta = sum(r.score_delta or 0 for r in records)
tracking_count = sum(1 for r in records if r.status == "tracking")
completed_count = sum(1 for r in records if r.status == "completed")
positive_count = sum(1 for r in records if (r.score_delta or 0) > 0)
return {
"brand_id": str(brand_id),
"records": records,
"total_score_delta": round(total_delta, 2),
"tracking_count": tracking_count,
"completed_count": completed_count,
"positive_count": positive_count,
}
async def _get_latest_score(self, db: AsyncSession, brand_id: uuid.UUID) -> float:
stmt = (
select(DiagnosisRecord)
.where(
DiagnosisRecord.brand_id == brand_id,
DiagnosisRecord.status == "completed",
)
.order_by(DiagnosisRecord.completed_at.desc())
.limit(1)
)
result = await db.execute(stmt)
record = result.scalar_one_or_none()
if record and record.overall_score is not None:
return float(record.overall_score)
logger.warning("No completed DiagnosisRecord for brand %s, using 0 as baseline", brand_id)
return 0.0
async def _get_latest_dimensions(
self,
db: AsyncSession,
brand_id: uuid.UUID,
before: datetime | None,
) -> dict | None:
stmt = (
select(DiagnosisRecord)
.where(
DiagnosisRecord.brand_id == brand_id,
DiagnosisRecord.status == "completed",
)
.order_by(DiagnosisRecord.completed_at.desc())
.limit(1)
)
if before:
stmt = stmt.where(DiagnosisRecord.completed_at < before)
result = await db.execute(stmt)
record = result.scalar_one_or_none()
if record and record.result_json:
return record.result_json.get("dimensions")
return None
def _compute_dimension_deltas(self, before_dims: list, after_dims: list) -> dict:
before_map = {d.get("name"): d.get("score", 0) for d in before_dims}
after_map = {d.get("name"): d.get("score", 0) for d in after_dims}
deltas = {}
for name in after_map:
b = before_map.get(name, 0)
a = after_map[name]
deltas[name] = {"before": b, "after": a, "delta": round(a - b, 2)}
return deltas

View File

@ -0,0 +1,59 @@
from app.models.attribution_record import AttributionRecord
class ROICalculator:
INDUSTRY_AVG_CITATION_VALUE = 50.0
def calculate_roi(
self,
subscription_cost: float,
score_delta: float,
attribution_records: list[AttributionRecord],
) -> dict:
value_generated = score_delta * self.INDUSTRY_AVG_CITATION_VALUE
if subscription_cost > 0:
roi_percentage = round(
(value_generated - subscription_cost) / subscription_cost * 100, 2
)
else:
roi_percentage = 0.0
break_even_delta = self.estimate_break_even(subscription_cost)
return {
"roi_percentage": roi_percentage,
"value_generated": round(value_generated, 2),
"cost": subscription_cost,
"break_even_delta": round(break_even_delta, 2),
}
def generate_ab_comparison(
self,
before_score: float,
after_score: float,
before_dimensions: dict,
after_dimensions: dict,
) -> dict:
overall_delta = round(after_score - before_score, 2)
dimensions = []
all_names = set(list(before_dimensions.keys()) + list(after_dimensions.keys()))
for name in all_names:
b = before_dimensions.get(name, {}).get("score", 0)
a = after_dimensions.get(name, {}).get("score", 0)
delta = round(a - b, 2)
dimensions.append({
"name": name,
"before": b,
"after": a,
"delta": delta,
"improved": delta > 0,
})
return {
"overall_before": before_score,
"overall_after": after_score,
"overall_delta": overall_delta,
"dimensions": dimensions,
}
def estimate_break_even(self, subscription_cost: float) -> float:
if self.INDUSTRY_AVG_CITATION_VALUE == 0:
return 0.0
return subscription_cost / self.INDUSTRY_AVG_CITATION_VALUE

View File

@ -0,0 +1,34 @@
from .citation import (
get_citations,
get_citation_stats,
trigger_query_now,
export_citations_pdf,
export_citations_csv,
PLATFORM_NAMES,
)
from .citation_pattern import (
CitationPatternEngine,
CitationPattern,
PatternAnalysisReport,
ContentStructureAnalyzer,
AuthoritySignalAnalyzer,
CitationFormatAnalyzer,
EnginePreferenceAnalyzer,
)
__all__ = [
"get_citations",
"get_citation_stats",
"trigger_query_now",
"export_citations_pdf",
"export_citations_csv",
"PLATFORM_NAMES",
"CitationPatternEngine",
"CitationPattern",
"PatternAnalysisReport",
"ContentStructureAnalyzer",
"AuthoritySignalAnalyzer",
"CitationFormatAnalyzer",
"EnginePreferenceAnalyzer",
]

View File

@ -1,9 +1,15 @@
from __future__ import annotations
import asyncio
import csv
import io
import logging
import uuid
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from app.services.scoring.scoring_service import ScoringResultV2
from sqlalchemy import func, select, and_, cast, Integer
from sqlalchemy.ext.asyncio import AsyncSession
@ -13,7 +19,7 @@ from app.database import AsyncSessionLocal
from app.models.citation_record import CitationRecord
from app.models.query import Query
from app.models.query_task import QueryTask
from app.workers.citation_engine import CitationEngine
from app.services.ai_engine.platform_bridge import execute_single_platform as _execute_single_platform_bridge
logger = logging.getLogger(__name__)
@ -304,8 +310,10 @@ async def _execute_query_tasks(
brand_aliases: list,
user_id: uuid.UUID | None = None,
):
"""后台执行查询任务"""
engine = CitationEngine()
"""后台执行查询任务 — 通过 Agent 框架执行,失败时回退到直接引擎"""
from app.agent_framework.agents.citation_detector import CitationDetectorAgent
agent = CitationDetectorAgent()
try:
async with AsyncSessionLocal() as db:
# 验证 query 归属该用户
@ -330,7 +338,8 @@ async def _execute_query_tasks(
task.error_message = None
await db.commit()
citation_result = await engine.execute_single_platform(
citation_result = await _execute_single_platform_via_agent(
agent=agent,
keyword=keyword,
platform=task.platform,
target_brand=target_brand,
@ -338,16 +347,10 @@ async def _execute_query_tasks(
)
if citation_result:
record = CitationRecord(
record = CitationRecord.from_citation_result(
query_id=query_id,
platform=task.platform,
cited=citation_result.get("cited", False),
citation_position=citation_result.get("position"),
citation_text=citation_result.get("citation_text"),
competitor_brands=citation_result.get("competitor_brands", []),
raw_response=citation_result.get("raw_response", ""),
confidence=citation_result.get("confidence"),
match_type=citation_result.get("match_type"),
result=citation_result,
)
db.add(record)
@ -366,7 +369,34 @@ async def _execute_query_tasks(
except Exception as e:
logger.error(f"查询引擎执行失败: {e}")
finally:
await engine.close()
await agent.close()
async def _execute_single_platform_via_agent(
agent,
keyword: str,
platform: str,
target_brand: str,
brand_aliases: list,
) -> dict:
try:
return await agent.execute_single_platform_compat(
keyword=keyword,
platform=platform,
target_brand=target_brand,
brand_aliases=brand_aliases,
)
except Exception as agent_err:
logger.warning(
f"Agent 框架执行单平台检测失败 ({platform}): {agent_err}"
"回退到直接引擎"
)
return await _execute_single_platform_bridge(
keyword=keyword,
platform=platform,
target_brand=target_brand,
brand_aliases=brand_aliases,
)
PLATFORM_NAMES = {
@ -386,6 +416,7 @@ async def export_citations_pdf(
db: AsyncSession,
user_id: uuid.UUID,
query_id: uuid.UUID | None = None,
v2_result: ScoringResultV2 | None = None,
) -> bytes:
"""生成PDF格式报告"""
import os
@ -505,6 +536,20 @@ async def export_citations_pdf(
pdf.cell(col_widths[i], 7, d, border=1)
pdf.ln()
if v2_result is not None:
pdf.add_page()
pdf.set_font_size(16)
pdf.cell(0, 12, "四、V2 品牌可见性评分", new_x="LMARGIN", new_y="NEXT")
pdf.set_font_size(11)
pdf.cell(0, 8, f"综合评分: {v2_result.overall_score:.2f}/100", new_x="LMARGIN", new_y="NEXT")
pdf.cell(0, 8, f"健康等级: {v2_result.health_level}", new_x="LMARGIN", new_y="NEXT")
pdf.ln(5)
pdf.cell(0, 8, f"提及率: {v2_result.mention_rate.score:.2f}/{v2_result.mention_rate.max_score:.0f} ({v2_result.mention_rate.percentage:.1f}%)", new_x="LMARGIN", new_y="NEXT")
pdf.cell(0, 8, f"推荐排名: {v2_result.recommendation_rank.score:.2f}/{v2_result.recommendation_rank.max_score:.0f} ({v2_result.recommendation_rank.percentage:.1f}%)", new_x="LMARGIN", new_y="NEXT")
pdf.cell(0, 8, f"情感倾向: {v2_result.sentiment_score.score:.2f}/{v2_result.sentiment_score.max_score:.0f} ({v2_result.sentiment_score.percentage:.1f}%)", new_x="LMARGIN", new_y="NEXT")
pdf.cell(0, 8, f"引用质量: {v2_result.citation_quality.score:.2f}/{v2_result.citation_quality.max_score:.0f} ({v2_result.citation_quality.percentage:.1f}%)", new_x="LMARGIN", new_y="NEXT")
pdf.cell(0, 8, f"竞品对比: {v2_result.competitive_position.score:.2f}/{v2_result.competitive_position.max_score:.0f} ({v2_result.competitive_position.percentage:.1f}%)", new_x="LMARGIN", new_y="NEXT")
return pdf.output()
@ -512,6 +557,7 @@ async def export_citations_csv(
db: AsyncSession,
user_id: uuid.UUID,
query_id: uuid.UUID,
v2_result: ScoringResultV2 | None = None,
) -> str:
query = await _verify_query_ownership(db, query_id, user_id)
if query is None:
@ -527,7 +573,7 @@ async def export_citations_csv(
output = io.StringIO()
writer = csv.writer(output)
writer.writerow([
headers = [
"查询关键词",
"目标品牌",
"查询日期",
@ -538,7 +584,18 @@ async def export_citations_csv(
"匹配置信度",
"匹配类型",
"竞争品牌",
])
]
if v2_result is not None:
headers.extend([
"overall_score",
"health_level",
"mention_rate",
"recommendation_rank",
"sentiment_score",
"citation_quality",
"competitive_position",
])
writer.writerow(headers)
total_queries = len(records)
total_citations = 0
@ -570,7 +627,7 @@ async def export_citations_csv(
if record.confidence is not None:
confidence_str = f"{record.confidence:.2f}"
writer.writerow([
row = [
query.keyword,
query.target_brand,
date_str,
@ -581,7 +638,18 @@ async def export_citations_csv(
confidence_str,
match_type_display,
", ".join(record.competitor_brands) if record.competitor_brands else "",
])
]
if v2_result is not None:
row.extend([
round(v2_result.overall_score, 2),
v2_result.health_level,
round(v2_result.mention_rate.score, 2),
round(v2_result.recommendation_rank.score, 2),
round(v2_result.sentiment_score.score, 2),
round(v2_result.citation_quality.score, 2),
round(v2_result.competitive_position.score, 2),
])
writer.writerow(row)
# 汇总统计
writer.writerow([])

View File

@ -0,0 +1,749 @@
import json
import logging
import uuid
from collections import defaultdict
from datetime import datetime, timedelta
from typing import Callable
from sqlalchemy import select, func, and_
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import AsyncSessionLocal
from app.models.brand import Brand
from app.models.competitor import Competitor
from app.models.competitor_insight import CompetitorInsight
from app.models.citation_record import CitationRecord
from app.models.query import Query
from app.services.llm import LLMFactory, LLMError
from app.utils.json_extractor import extract_json
logger = logging.getLogger(__name__)
VALID_ANALYSIS_TYPES = [
"citation_gap",
"content_strategy",
"platform_coverage",
"query_overlap",
"differentiation",
]
class CompetitorAnalyzerService:
async def analyze_competitor(
self,
brand_id: uuid.UUID,
analysis_types: list[str] | None = None,
period_days: int = 30,
progress_callback: Callable[[float, str], None] | None = None,
) -> dict:
if analysis_types is None:
analysis_types = VALID_ANALYSIS_TYPES
invalid = set(analysis_types) - set(VALID_ANALYSIS_TYPES)
if invalid:
raise ValueError(f"不支持的分析类型: {', '.join(invalid)}")
async with AsyncSessionLocal() as session:
brand = await session.get(Brand, brand_id)
if not brand:
raise ValueError(f"品牌不存在: {brand_id}")
if progress_callback:
await progress_callback(0.1, "获取竞品列表...")
competitors = await self._get_competitors(session, brand_id)
if not competitors:
raise ValueError("未找到竞品数据")
if progress_callback:
await progress_callback(0.2, "聚合品牌引用数据...")
brand_citation_data = await self._aggregate_citation_data(
session, brand_id, brand.name, period_days,
)
results = []
total = len(competitors)
for i, competitor in enumerate(competitors):
if progress_callback:
progress = 0.25 + (0.5 * i / total)
await progress_callback(progress, f"分析竞品 {competitor.name}...")
competitor_citation_data = await self._aggregate_citation_data(
session, brand_id, competitor.name, period_days,
)
for analysis_type in analysis_types:
insight = await self._build_insight(
session=session,
brand_id=brand_id,
brand_name=brand.name,
competitor=competitor,
analysis_type=analysis_type,
brand_data=brand_citation_data,
competitor_data=competitor_citation_data,
period_days=period_days,
)
session.add(insight)
results.append(insight)
await session.commit()
for r in results:
await session.refresh(r)
return {
"brand_id": str(brand_id),
"brand_name": brand.name,
"insights": [
{
"id": str(r.id),
"competitor_name": r.competitor_name,
"analysis_type": r.analysis_type,
"citation_count_brand": r.citation_count_brand,
"citation_count_competitor": r.citation_count_competitor,
"sentiment_brand": r.sentiment_brand,
"sentiment_competitor": r.sentiment_competitor,
"platform_breakdown": r.platform_breakdown,
"gap_analysis": r.gap_analysis,
"opportunity_areas": r.opportunity_areas,
"recommendations": r.recommendations,
"confidence": r.confidence,
"period_days": r.period_days,
"created_at": r.created_at.isoformat() if r.created_at else None,
}
for r in results
],
"total": len(results),
}
async def compare_citation_volume(
self,
brand_data: dict,
competitor_data: dict,
) -> dict:
brand_count = brand_data["citation_count"]
competitor_count = competitor_data["citation_count"]
total = brand_count + competitor_count
return {
"brand": brand_count,
"competitor": competitor_count,
"diff": brand_count - competitor_count,
"brand_share": round(brand_count / total, 4) if total > 0 else 0.0,
"competitor_share": round(competitor_count / total, 4) if total > 0 else 0.0,
"by_platform": self._compare_platform_citations(
brand_data, competitor_data,
),
}
async def compare_citation_quality(
self,
brand_data: dict,
competitor_data: dict,
) -> dict:
brand_positive = brand_data.get("positive_ratio", 0.0)
competitor_positive = competitor_data.get("positive_ratio", 0.0)
brand_rank = brand_data.get("avg_rank", 0.0)
competitor_rank = competitor_data.get("avg_rank", 0.0)
return {
"sentiment": {
"brand_positive_ratio": brand_positive,
"competitor_positive_ratio": competitor_positive,
"diff": round(brand_positive - competitor_positive, 4),
},
"ranking": {
"brand_avg_rank": brand_rank,
"competitor_avg_rank": competitor_rank,
"diff": round(brand_rank - competitor_rank, 2),
},
"brand_sentiment_breakdown": brand_data.get("sentiment_breakdown", {}),
"competitor_sentiment_breakdown": competitor_data.get("sentiment_breakdown", {}),
}
async def analyze_content_strategy(
self,
brand_data: dict,
competitor_data: dict,
) -> dict:
brand_types = brand_data.get("content_types", {})
competitor_types = competitor_data.get("content_types", {})
all_types = set(brand_types.keys()) | set(competitor_types.keys())
type_comparison = {}
for ct in all_types:
type_comparison[ct] = {
"brand": brand_types.get(ct, 0),
"competitor": competitor_types.get(ct, 0),
}
competitor_only_types = set(competitor_types.keys()) - set(brand_types.keys())
brand_only_types = set(brand_types.keys()) - set(competitor_types.keys())
return {
"type_comparison": type_comparison,
"competitor_unique_types": list(competitor_only_types),
"brand_unique_types": list(brand_only_types),
"competitor_top_types": sorted(
competitor_types.items(), key=lambda x: x[1], reverse=True,
)[:5],
}
async def identify_opportunities(
self,
brand_data: dict,
competitor_data: dict,
comparison: dict,
) -> dict:
opportunities = []
brand_platforms = set(brand_data["by_platform"].keys())
competitor_platforms = set(competitor_data["by_platform"].keys())
brand_only = brand_platforms - competitor_platforms
if brand_only:
for platform in brand_only:
opportunities.append({
"area": f"platform_{platform}",
"description": f"品牌在{platform}平台有引用而竞品没有,可加大投入建立差异化优势",
"potential": "high",
"action": f"增加在{platform}平台的内容投放和优化",
})
competitor_only = competitor_platforms - brand_platforms
if competitor_only:
for platform in competitor_only:
opportunities.append({
"area": f"platform_{platform}",
"description": f"竞品在{platform}平台有引用而品牌没有,存在进入机会",
"potential": "medium",
"action": f"研究{platform}平台的内容偏好,制定进入策略",
})
citation_volume = comparison.get("citation_volume", {})
if citation_volume.get("diff", 0) > 0:
opportunities.append({
"area": "citation_volume_advantage",
"description": "品牌引用量高于竞品,可强化品牌权威性传播",
"potential": "high",
"action": "收集高引用内容案例,扩大品牌影响力",
})
quality = comparison.get("quality", {})
sentiment = quality.get("sentiment", {})
if sentiment.get("diff", 0) > 0.1:
opportunities.append({
"area": "sentiment_advantage",
"description": "品牌正面引用比例显著高于竞品,可强化正面形象传播",
"potential": "high",
"action": "收集正面引用案例,制作品牌优势内容",
})
if not opportunities:
opportunities.append({
"area": "general",
"description": "当前数据未发现明显差异化机会,建议持续监测并积累更多数据",
"potential": "low",
"action": "增加查询频率和覆盖平台,积累更多引用数据",
})
return {
"opportunities": opportunities,
"total_opportunities": len(opportunities),
"high_potential_count": sum(1 for o in opportunities if o["potential"] == "high"),
}
async def generate_recommendations(
self,
brand_name: str,
competitor_name: str,
comparison: dict,
gaps: dict,
opportunities: dict,
data_sufficiency: str,
) -> dict:
prompt = f"""你是一个专业的GEOGenerative Engine Optimization策略分析师。
请基于以下品牌与竞品的引用对比数据生成策略建议
品牌: {brand_name}
竞品: {competitor_name}
数据充分性: {data_sufficiency}
对比数据:
{json.dumps(comparison, ensure_ascii=False, indent=2)}
差距分析:
{json.dumps(gaps, ensure_ascii=False, indent=2)}
机会发现:
{json.dumps(opportunities, ensure_ascii=False, indent=2)}
请返回JSON格式不要包含其他文字:
{{
"gap_closing_strategies": [
{{"strategy": "策略描述", "priority": "high/medium/low", "expected_impact": "预期效果"}}
],
"differentiation_strategies": [
{{"strategy": "策略描述", "priority": "high/medium/low", "expected_impact": "预期效果"}}
],
"quick_wins": [
{{"action": "行动描述", "effort": "low/medium/high", "timeline": "预计时间"}}
],
"long_term_recommendations": [
{{"recommendation": "建议描述", "rationale": "理由"}}
]
}}"""
try:
provider = LLMFactory.get_default()
response = await provider.chat(
[{"role": "user", "content": prompt}],
temperature=0.3,
max_tokens=2000,
)
result = json.loads(extract_json(response.content))
result["usage"] = response.usage
return result
except (LLMError, json.JSONDecodeError, ValueError) as e:
logger.warning(f"LLM策略生成失败使用默认策略: {e}")
return self._default_strategy(gaps, opportunities)
async def calculate_gap_score(
self,
db: AsyncSession,
brand_id: uuid.UUID,
brand_name: str,
) -> list[dict]:
stmt = (
select(CompetitorInsight)
.where(CompetitorInsight.brand_id == brand_id)
.order_by(CompetitorInsight.created_at.desc())
)
result = await db.execute(stmt)
insights = list(result.scalars().all())
competitor_map: dict[str, list[CompetitorInsight]] = defaultdict(list)
for insight in insights:
competitor_map[insight.competitor_name].append(insight)
summaries = []
for comp_name, comp_insights in competitor_map.items():
gap_dimensions = []
score_components = []
for insight in comp_insights:
gap = insight.gap_analysis or {}
if not gap:
continue
for g in gap.get("gaps", []):
dimension = g.get("dimension", "unknown")
severity = g.get("severity", "low")
gap_value = g.get("gap", 0)
severity_score = {"high": 30, "medium": 15, "low": 5}.get(severity, 5)
score_components.append(severity_score)
gap_dimensions.append({
"dimension": dimension,
"severity": severity,
"gap": gap_value,
"analysis_type": insight.analysis_type,
})
overall_score = min(sum(score_components), 100) if score_components else 0.0
summaries.append({
"brand_name": brand_name,
"competitor_name": comp_name,
"gap_dimensions": gap_dimensions,
"overall_gap_score": round(overall_score, 2),
})
return summaries
async def _build_insight(
self,
session: AsyncSession,
brand_id: uuid.UUID,
brand_name: str,
competitor: Competitor,
analysis_type: str,
brand_data: dict,
competitor_data: dict,
period_days: int,
) -> CompetitorInsight:
comparison = {}
comparison["citation_volume"] = await self.compare_citation_volume(
brand_data, competitor_data,
)
comparison["quality"] = await self.compare_citation_quality(
brand_data, competitor_data,
)
gap = self._identify_gaps(comparison, brand_name, competitor.name)
opportunities = await self.identify_opportunities(
brand_data, competitor_data, comparison,
)
data_sufficiency = self._assess_data_sufficiency(brand_data, competitor_data)
insight_data = {}
if analysis_type == "content_strategy":
insight_data = await self.analyze_content_strategy(
brand_data, competitor_data,
)
elif analysis_type == "platform_coverage":
insight_data = comparison["citation_volume"]["by_platform"]
elif analysis_type == "query_overlap":
insight_data = await self._analyze_query_overlap(
session, brand_id, brand_name, competitor.name, period_days,
)
elif analysis_type == "differentiation":
insight_data = {
"brand_unique_platforms": list(
set(brand_data["by_platform"].keys()) - set(competitor_data["by_platform"].keys())
),
"competitor_unique_platforms": list(
set(competitor_data["by_platform"].keys()) - set(brand_data["by_platform"].keys())
),
"sentiment_diff": comparison["quality"].get("sentiment", {}),
}
recommendations = await self.generate_recommendations(
brand_name=brand_name,
competitor_name=competitor.name,
comparison=comparison,
gaps=gap,
opportunities=opportunities,
data_sufficiency=data_sufficiency,
)
confidence = self._determine_confidence(brand_data, competitor_data)
return CompetitorInsight(
brand_id=brand_id,
competitor_name=competitor.name,
analysis_type=analysis_type,
insight_data=insight_data if insight_data else None,
citation_count_brand=brand_data["citation_count"],
citation_count_competitor=competitor_data["citation_count"],
sentiment_brand=brand_data.get("positive_ratio"),
sentiment_competitor=competitor_data.get("positive_ratio"),
platform_breakdown=comparison["citation_volume"]["by_platform"],
gap_analysis=gap,
opportunity_areas=opportunities,
recommendations=recommendations,
confidence=confidence,
period_days=period_days,
)
async def _get_competitors(
self,
db: AsyncSession,
brand_id: uuid.UUID,
) -> list[Competitor]:
stmt = select(Competitor).where(Competitor.brand_id == brand_id)
result = await db.execute(stmt)
return list(result.scalars().all())
async def _aggregate_citation_data(
self,
db: AsyncSession,
brand_id: uuid.UUID,
target_name: str,
period_days: int = 30,
) -> dict:
since = datetime.utcnow() - timedelta(days=period_days)
query_stmt = select(Query).where(Query.brand_id == brand_id)
query_result = await db.execute(query_stmt)
queries = list(query_result.scalars().all())
if not queries:
return {
"citation_count": 0,
"positive_ratio": 0.0,
"avg_rank": 0.0,
"by_platform": {},
"content_types": {},
"sentiment_breakdown": {"positive": 0, "neutral": 0, "negative": 0},
"total_records": 0,
}
query_ids = [q.id for q in queries]
query_aliases = set()
for q in queries:
query_aliases.add(q.target_brand.lower())
if q.brand_aliases:
for alias in q.brand_aliases:
query_aliases.add(alias.lower())
conditions = [CitationRecord.query_id.in_(query_ids)]
if since:
conditions.append(CitationRecord.queried_at >= since)
stmt = select(CitationRecord).where(and_(*conditions))
result = await db.execute(stmt)
records = list(result.scalars().all())
target_lower = target_name.lower()
matching_records = []
for record in records:
if record.cited and record.competitor_brands:
is_target = False
for cb in record.competitor_brands:
if isinstance(cb, str) and cb.lower() == target_lower:
is_target = True
break
elif isinstance(cb, str) and cb.lower() in query_aliases:
is_target = True
break
if is_target:
matching_records.append(record)
elif record.cited and not record.competitor_brands:
matching_records.append(record)
total_citations = len(matching_records)
if total_citations == 0:
return {
"citation_count": 0,
"positive_ratio": 0.0,
"avg_rank": 0.0,
"by_platform": {},
"content_types": {},
"sentiment_breakdown": {"positive": 0, "neutral": 0, "negative": 0},
"total_records": len(records),
}
sentiment_breakdown = {"positive": 0, "neutral": 0, "negative": 0}
for r in matching_records:
s = r.sentiment or "neutral"
if s in sentiment_breakdown:
sentiment_breakdown[s] += 1
else:
sentiment_breakdown["neutral"] += 1
positive_count = sentiment_breakdown["positive"]
positive_ratio = positive_count / total_citations if total_citations > 0 else 0.0
ranks = [
r.citation_position for r in matching_records
if r.citation_position is not None and r.citation_position > 0
]
avg_rank = sum(ranks) / len(ranks) if ranks else 0.0
by_platform = defaultdict(lambda: {"citations": 0, "positive": 0, "ranks": []})
for r in matching_records:
platform = r.platform
by_platform[platform]["citations"] += 1
if r.sentiment == "positive":
by_platform[platform]["positive"] += 1
if r.citation_position is not None and r.citation_position > 0:
by_platform[platform]["ranks"].append(r.citation_position)
platform_stats = {}
for platform, data in by_platform.items():
platform_stats[platform] = {
"citations": data["citations"],
"positive_ratio": data["positive"] / data["citations"] if data["citations"] > 0 else 0.0,
"avg_rank": sum(data["ranks"]) / len(data["ranks"]) if data["ranks"] else 0.0,
}
content_types = defaultdict(int)
for r in matching_records:
match_type = r.match_type or "unknown"
content_types[match_type] += 1
return {
"citation_count": total_citations,
"positive_ratio": round(positive_ratio, 4),
"avg_rank": round(avg_rank, 2),
"by_platform": platform_stats,
"content_types": dict(content_types),
"sentiment_breakdown": sentiment_breakdown,
"total_records": len(records),
}
def _compare_platform_citations(
self,
brand_data: dict,
competitor_data: dict,
) -> dict:
all_platforms = set(brand_data["by_platform"].keys()) | set(competitor_data["by_platform"].keys())
result = {}
for platform in all_platforms:
bp = brand_data["by_platform"].get(platform, {"citations": 0, "positive_ratio": 0.0, "avg_rank": 0.0})
cp = competitor_data["by_platform"].get(platform, {"citations": 0, "positive_ratio": 0.0, "avg_rank": 0.0})
result[platform] = {
"brand": bp,
"competitor": cp,
}
return result
def _identify_gaps(
self,
comparison: dict,
brand_name: str,
competitor_name: str,
) -> dict:
gaps = []
volume = comparison.get("citation_volume", {})
citation_diff = volume.get("diff", 0)
if citation_diff < 0:
gaps.append({
"dimension": "citation_count",
"brand_value": volume.get("brand", 0),
"competitor_value": volume.get("competitor", 0),
"gap": abs(citation_diff),
"severity": "high" if abs(citation_diff) >= 5 else "medium" if abs(citation_diff) >= 2 else "low",
})
quality = comparison.get("quality", {})
sentiment = quality.get("sentiment", {})
positive_diff = sentiment.get("diff", 0)
if positive_diff < -0.1:
gaps.append({
"dimension": "positive_ratio",
"brand_value": sentiment.get("brand_positive_ratio", 0),
"competitor_value": sentiment.get("competitor_positive_ratio", 0),
"gap": abs(positive_diff),
"severity": "high" if abs(positive_diff) >= 0.3 else "medium" if abs(positive_diff) >= 0.15 else "low",
})
ranking = quality.get("ranking", {})
rank_diff = ranking.get("diff", 0)
if rank_diff > 1.0:
gaps.append({
"dimension": "avg_rank",
"brand_value": ranking.get("brand_avg_rank", 0),
"competitor_value": ranking.get("competitor_avg_rank", 0),
"gap": abs(rank_diff),
"severity": "high" if abs(rank_diff) >= 3.0 else "medium" if abs(rank_diff) >= 2.0 else "low",
})
for platform, data in volume.get("by_platform", {}).items():
brand_citations = data.get("brand", {}).get("citations", 0)
competitor_citations = data.get("competitor", {}).get("citations", 0)
if competitor_citations > brand_citations + 2:
gaps.append({
"dimension": f"platform_{platform}",
"brand_value": brand_citations,
"competitor_value": competitor_citations,
"gap": competitor_citations - brand_citations,
"severity": "high" if (competitor_citations - brand_citations) >= 5 else "medium",
})
return {
"brand_name": brand_name,
"competitor_name": competitor_name,
"gaps": gaps,
"total_gaps": len(gaps),
"high_severity_count": sum(1 for g in gaps if g["severity"] == "high"),
}
async def _analyze_query_overlap(
self,
db: AsyncSession,
brand_id: uuid.UUID,
brand_name: str,
competitor_name: str,
period_days: int,
) -> dict:
since = datetime.utcnow() - timedelta(days=period_days)
stmt = select(Query).where(
Query.brand_id == brand_id,
Query.created_at >= since,
)
result = await db.execute(stmt)
queries = list(result.scalars().all())
brand_keywords = set()
competitor_keywords = set()
for q in queries:
keyword = q.keyword.lower()
brand_keywords.add(keyword)
if competitor_name.lower() in keyword or any(
a.lower() in keyword for a in (q.brand_aliases or [])
):
competitor_keywords.add(keyword)
overlap = brand_keywords & competitor_keywords
brand_only = brand_keywords - competitor_keywords
competitor_only = competitor_keywords - brand_keywords
return {
"brand_keyword_count": len(brand_keywords),
"competitor_keyword_count": len(competitor_keywords),
"overlap_count": len(overlap),
"overlap_keywords": list(overlap)[:20],
"brand_only_count": len(brand_only),
"competitor_only_count": len(competitor_only),
"overlap_ratio": round(len(overlap) / len(brand_keywords), 4) if brand_keywords else 0.0,
}
def _assess_data_sufficiency(
self,
brand_data: dict,
competitor_data: dict,
) -> str:
brand_count = brand_data["citation_count"]
competitor_count = competitor_data["citation_count"]
min_count = min(brand_count, competitor_count)
if min_count > 10:
return "sufficient"
elif min_count >= 5:
return "limited"
else:
return "insufficient"
def _determine_confidence(
self,
brand_data: dict,
competitor_data: dict,
) -> str:
brand_count = brand_data["citation_count"]
competitor_count = competitor_data["citation_count"]
min_count = min(brand_count, competitor_count)
if min_count > 20:
return "high"
elif min_count >= 5:
return "medium"
else:
return "low"
def _default_strategy(self, gaps: dict, opportunities: dict) -> dict:
gap_strategies = []
for gap in gaps.get("gaps", []):
gap_strategies.append({
"strategy": f"提升{gap['dimension']}维度表现,缩小与竞品差距",
"priority": gap["severity"],
"expected_impact": f"预计可将{gap['dimension']}差距缩小{gap['gap'] * 0.5:.1f}",
})
diff_strategies = []
for opp in opportunities.get("opportunities", []):
if opp["potential"] in ("high", "medium"):
diff_strategies.append({
"strategy": opp["action"],
"priority": opp["potential"],
"expected_impact": "建立差异化竞争优势",
})
return {
"gap_closing_strategies": gap_strategies[:5],
"differentiation_strategies": diff_strategies[:5],
"quick_wins": [],
"long_term_recommendations": [
{
"recommendation": "持续监测竞品引用数据变化,定期更新策略",
"rationale": "GEO优化是长期过程需要持续迭代",
}
],
}

View File

@ -0,0 +1,493 @@
"""ContentGenerationService - 内容生成服务
api/content.py 中提取的业务逻辑层负责
1. 三阶段内容生成流程generate -> de-AI -> GEO optimize
2. 知识库上下文检索
3. 生成结果持久化
4. Agent 框架集成可选
API 层只需负责请求解析和响应格式化所有业务逻辑委托给此服务
"""
import asyncio
import logging
import uuid
from datetime import datetime, timezone
from typing import Optional
from sqlalchemy.ext.asyncio import AsyncSession
from app.agent_framework.prompts import (
CONTENT_GENERATOR_TEMPLATE,
DEAI_TEMPLATE,
GEO_OPTIMIZER_TEMPLATE,
)
from app.models.content import Content, ContentVersion
from app.services.llm import LLMFactory, LLMError
logger = logging.getLogger(__name__)
class ContentGenerationService:
"""内容生成服务 - 封装三阶段生成流程及持久化逻辑。"""
def _get_provider(self):
"""获取默认 LLM Provider。可被子类或测试覆盖。"""
return LLMFactory.get_default()
async def _get_knowledge_context(
self,
db: AsyncSession,
brand_name: str,
knowledge_base_ids: list[str],
target_keyword: str,
) -> str:
"""
从知识库检索与查询相关的上下文
如果有知识库ID则调用 RAGService.search 获取相关内容
否则返回空字符串不影响后续流程
"""
if not knowledge_base_ids:
return ""
try:
from app.services.knowledge.rag_service import RAGService
rag_service = RAGService()
results = await rag_service.search(
session=db,
query=f"{brand_name} {target_keyword}" if brand_name else target_keyword,
knowledge_base_ids=knowledge_base_ids,
top_k=3,
)
if results:
context_parts = []
for r in results:
content = r.get("content", "")
title = r.get("document_title", "")
if content:
context_parts.append(f"[{title}] {content}")
return "\n".join(context_parts)
return ""
except Exception as e:
logger.warning(f"知识库检索失败,将不使用知识库上下文: {e}")
return ""
async def _poll_task_result(
self,
dispatcher,
task_id: str,
timeout: int = 300,
) -> dict:
"""
轮询 Agent 框架任务结果
Args:
dispatcher: TaskDispatcher 实例
task_id: 已分发的任务 ID
timeout: 超时时间
Returns:
dict: 任务的 output_data
Raises:
TimeoutError: 任务超时
Exception: 任务执行失败或被取消
"""
from app.agent_framework.protocol import TaskStatus
elapsed = 0.0
poll_interval = 1.0
while elapsed < timeout:
await asyncio.sleep(poll_interval)
elapsed += poll_interval
task_status = await dispatcher.get_task_status(task_id)
status = task_status.get("status")
if status == TaskStatus.COMPLETED:
return task_status.get("output_data", {})
elif status == TaskStatus.FAILED:
error_msg = task_status.get("error_message", "Unknown error")
raise Exception(f"Agent 任务执行失败: {error_msg}")
elif status == TaskStatus.CANCELLED:
raise Exception(f"Agent 任务被取消: {task_id}")
raise TimeoutError(f"Agent 任务超时 ({timeout}s): {task_id}")
async def _execute_via_agent_framework(
self,
keyword: str,
brand_name: str,
platform: str,
content_style: str,
word_count: int,
knowledge_context: str,
knowledge_base_ids: list[str] | None,
run_deai: bool,
run_geo: bool,
db: AsyncSession | None,
user_id: str | None,
org_id: str | None,
) -> dict:
"""
通过 Agent 框架执行三阶段内容生成流程
依次分发任务到 content_generatordeai_agentgeo_optimizer
并轮询等待每个阶段的结果失败时抛出异常由调用方决定是否回退
Returns:
dict: generate_content 返回格式一致的结果字典
Raises:
Exception: Agent 框架不可用或任务执行失败时
"""
from app.agent_framework.dispatcher import TaskDispatcher
from app.agent_framework.protocol import TaskMessage
from app.config import settings
dispatcher = TaskDispatcher(settings.REDIS_URL)
stages = []
try:
# ---- Stage 1: 内容生成 ----
logger.info(f"通过 Agent 框架执行内容生成: keyword={keyword}")
task_id = str(uuid.uuid4())
task_message = TaskMessage(
task_id=task_id,
agent_name="content_generator",
task_type="generate_article",
priority=0,
input_data={
"target_keyword": keyword,
"brand_name": brand_name,
"target_platform": platform,
"knowledge_base_ids": knowledge_base_ids or [],
"word_count": word_count,
"content_style": content_style,
"knowledge_context": knowledge_context,
},
callback_url=None,
created_at=datetime.now(timezone.utc),
timeout_seconds=300,
)
dispatched_id = await dispatcher.dispatch(
task_message,
organization_id=org_id,
created_by=user_id,
)
gen_result = await self._poll_task_result(
dispatcher, dispatched_id, timeout=300
)
content = gen_result.get("content", "")
stages.append(
{
"stage": "content_generation",
"status": "success",
"word_count": len(content),
}
)
# ---- Stage 2: 去AI化可选 ----
if run_deai:
logger.info("通过 Agent 框架执行去AI化")
task_id = str(uuid.uuid4())
task_message = TaskMessage(
task_id=task_id,
agent_name="deai_agent",
task_type="deai_process",
priority=0,
input_data={
"content": content,
"platform": platform,
"style": "自然流畅",
"preserve_structure": True,
},
callback_url=None,
created_at=datetime.now(timezone.utc),
timeout_seconds=180,
)
dispatched_id = await dispatcher.dispatch(
task_message,
organization_id=org_id,
created_by=user_id,
)
deai_result = await self._poll_task_result(
dispatcher, dispatched_id, timeout=180
)
content = deai_result.get("content", content)
stages.append({"stage": "deai", "status": "success"})
# ---- Stage 3: GEO优化可选 ----
optimized = content
seo_score = None
if run_geo:
logger.info("通过 Agent 框架执行 GEO 优化")
task_id = str(uuid.uuid4())
task_message = TaskMessage(
task_id=task_id,
agent_name="geo_optimizer",
task_type="geo_optimize",
priority=0,
input_data={
"content": content,
"target_keywords": [keyword],
"target_platform": platform,
"optimization_level": "moderate",
},
callback_url=None,
created_at=datetime.now(timezone.utc),
timeout_seconds=180,
)
dispatched_id = await dispatcher.dispatch(
task_message,
organization_id=org_id,
created_by=user_id,
)
geo_result = await self._poll_task_result(
dispatcher, dispatched_id, timeout=180
)
optimized = geo_result.get("optimized_content", content)
seo_score = geo_result.get("seo_score")
stages.append({"stage": "geo_optimization", "status": "success"})
# ---- 持久化(可选) ----
content_id = None
if db and user_id and org_id:
content_obj = Content(
organization_id=org_id,
title=keyword,
content_type="article",
body=optimized,
status="draft",
target_platforms=[platform] if platform else [],
keywords=[keyword],
extra_metadata={
"original_content": content if content != optimized else None,
"pipeline_stages": stages,
"seo_score": seo_score,
"brand_name": brand_name,
"content_style": content_style,
"word_count_target": word_count,
"execution_mode": "agent_framework",
},
created_by=user_id,
current_version=1,
)
db.add(content_obj)
await db.flush()
version = ContentVersion(
content_id=content_obj.id,
version_number=1,
title=keyword,
body=optimized,
change_summary="Agent框架Pipeline自动生成",
created_by=user_id,
)
db.add(version)
await db.commit()
await db.refresh(content_obj)
content_id = str(content_obj.id)
logger.info("通过 Agent 框架执行内容生成完成")
return {
"content": content,
"optimized_content": optimized,
"seo_score": seo_score,
"content_id": content_id,
"pipeline_stages": stages,
}
finally:
await dispatcher.close()
async def generate_content(
self,
keyword: str,
brand_name: str = "",
platform: str = "通用",
content_style: str = "专业严谨",
word_count: int = 2000,
knowledge_context: str = "",
knowledge_base_ids: list[str] | None = None,
db: AsyncSession | None = None,
user_id: str | None = None,
org_id: str | None = None,
run_deai: bool = True,
run_geo: bool = True,
use_agent_framework: bool = False,
) -> dict:
"""
执行三阶段内容生成流程
阶段
1. 内容生成CONTENT_GENERATOR_TEMPLATE
2. 去AI化DEAI_TEMPLATE可选
3. GEO优化GEO_OPTIMIZER_TEMPLATE可选
如果提供了 dbuser_id org_id生成结果将持久化到数据库
Args:
keyword: 目标关键词
brand_name: 品牌名称
platform: 目标平台默认"通用"
content_style: 内容风格默认"专业严谨"
word_count: 目标字数默认2000
knowledge_context: 直接传入的知识库上下文优先使用
knowledge_base_ids: 知识库ID列表用于RAG检索
db: 数据库会话可选提供时将持久化结果
user_id: 用户ID可选持久化时需要
org_id: 组织ID可选持久化时需要
run_deai: 是否执行去AI化默认True
run_geo: 是否执行GEO优化默认True
use_agent_framework: 是否通过Agent框架执行默认False
当为True时尝试通过TaskDispatcher分发任务到Agent
如果Agent框架不可用自动回退到直接调用模式
Returns:
dict: {
"content": str, # 去AI化后的内容或原始生成内容
"optimized_content": str, # GEO优化后的内容或与content相同
"seo_score": int | None,
"content_id": str | None, # 数据库记录ID
"pipeline_stages": list[dict],
}
Raises:
LLMError: LLM调用失败时
"""
# ---- Agent 框架路径 ----
if use_agent_framework:
try:
logger.info("尝试通过 Agent 框架执行内容生成")
return await self._execute_via_agent_framework(
keyword=keyword,
brand_name=brand_name,
platform=platform,
content_style=content_style,
word_count=word_count,
knowledge_context=knowledge_context,
knowledge_base_ids=knowledge_base_ids,
run_deai=run_deai,
run_geo=run_geo,
db=db,
user_id=user_id,
org_id=org_id,
)
except Exception as e:
logger.warning(
f"Agent 框架执行失败,回退到直接调用模式: {e}"
)
# 继续执行下方的直接调用逻辑
# ---- 直接调用路径(原有逻辑) ----
provider = self._get_provider()
stages = []
# 如果没有直接传入知识库上下文但提供了知识库ID和db则检索
if not knowledge_context and knowledge_base_ids and db:
knowledge_context = await self._get_knowledge_context(
db, brand_name, knowledge_base_ids, keyword
)
# ---- Stage 1: 内容生成 ----
gen_variables = {
"topic_title": keyword,
"target_keyword": keyword,
"target_platform": platform,
"content_angle": "综合分析",
"content_style": content_style,
"word_count": str(word_count),
"brand_name": brand_name,
"knowledge_context": knowledge_context,
}
messages = CONTENT_GENERATOR_TEMPLATE.render(gen_variables)
response = await provider.chat(
messages, temperature=0.7, max_tokens=word_count * 2
)
content = response.content
stages.append(
{"stage": "content_generation", "status": "success", "word_count": len(content)}
)
# ---- Stage 2: 去AI化可选 ----
if run_deai:
deai_variables = {
"original_content": content,
"target_style": "自然流畅",
"preserve_structure": "",
}
messages = DEAI_TEMPLATE.render(deai_variables)
response = await provider.chat(
messages, temperature=0.9, max_tokens=len(content) * 2
)
content = response.content
stages.append({"stage": "deai", "status": "success"})
# ---- Stage 3: GEO优化可选 ----
optimized = content
seo_score = None
if run_geo:
geo_variables = {
"original_content": content,
"target_keywords": keyword,
"target_platform": platform,
"optimization_level": "moderate",
}
messages = GEO_OPTIMIZER_TEMPLATE.render(geo_variables)
response = await provider.chat(
messages, temperature=0.5, max_tokens=len(content) * 2
)
optimized = response.content
stages.append({"stage": "geo_optimization", "status": "success"})
# ---- 持久化(可选) ----
content_id = None
if db and user_id and org_id:
content_obj = Content(
organization_id=org_id,
title=keyword,
content_type="article",
body=optimized,
status="draft",
target_platforms=[platform] if platform else [],
keywords=[keyword],
extra_metadata={
"original_content": content if content != optimized else None,
"pipeline_stages": stages,
"seo_score": seo_score,
"brand_name": brand_name,
"content_style": content_style,
"word_count_target": word_count,
},
created_by=user_id,
current_version=1,
)
db.add(content_obj)
await db.flush()
version = ContentVersion(
content_id=content_obj.id,
version_number=1,
title=keyword,
body=optimized,
change_summary="Pipeline自动生成",
created_by=user_id,
)
db.add(version)
await db.commit()
await db.refresh(content_obj)
content_id = str(content_obj.id)
return {
"content": content,
"optimized_content": optimized,
"seo_score": seo_score,
"content_id": content_id,
"pipeline_stages": stages,
}

View File

@ -0,0 +1,9 @@
from .detection_scheduler import (
DetectionSchedulerService,
TaskNotFoundError,
)
__all__ = [
"DetectionSchedulerService",
"TaskNotFoundError",
]

Some files were not shown because too many files have changed in this diff Show More