Merge branch 'feat/geo-monetization-closed-loop' into main
This commit is contained in:
commit
394ddfbc61
|
|
@ -0,0 +1,38 @@
|
|||
---
|
||||
description: CodeGraph MCP usage guide — when to use which tool
|
||||
alwaysApply: true
|
||||
---
|
||||
<!-- CODEGRAPH_START -->
|
||||
## CodeGraph
|
||||
|
||||
This project has a CodeGraph MCP server (`codegraph_*` tools) configured. CodeGraph is a tree-sitter-parsed knowledge graph of every symbol, edge, and file. Reads are sub-millisecond and return structural information grep cannot.
|
||||
|
||||
### When to prefer codegraph over native search
|
||||
|
||||
Use codegraph for **structural** questions — what calls what, what would break, where is X defined, what is X's signature. Use native grep/read only for **literal text** queries (string contents, comments, log messages) or after you already have a specific file open.
|
||||
|
||||
| Question | Tool |
|
||||
|---|---|
|
||||
| "Where is X defined?" / "Find symbol named X" | `codegraph_search` |
|
||||
| "What calls function Y?" | `codegraph_callers` |
|
||||
| "What does Y call?" | `codegraph_callees` |
|
||||
| "What would break if I changed Z?" | `codegraph_impact` |
|
||||
| "Show me Y's signature / source / docstring" | `codegraph_node` |
|
||||
| "Give me focused context for a task/area" | `codegraph_context` |
|
||||
| "See several related symbols' source at once" | `codegraph_explore` |
|
||||
| "What files exist under path/" | `codegraph_files` |
|
||||
| "Is the index healthy?" | `codegraph_status` |
|
||||
|
||||
### Rules of thumb
|
||||
|
||||
- **Answer directly — don't delegate exploration.** For "how does X work" / architecture / trace questions, answer with 2-3 codegraph calls: `codegraph_context` first, then ONE `codegraph_explore` for the source of the symbols it surfaces. Codegraph IS the pre-built index, so spawning a separate file-reading sub-task/agent — or running a grep + read loop — repeats work codegraph already did and costs more for the same answer.
|
||||
- **Trust codegraph results.** They come from a full AST parse. Do NOT re-verify them with grep — that's slower, less accurate, and wastes context.
|
||||
- **Don't grep first** when looking up a symbol by name. `codegraph_search` is faster and returns kind + location + signature in one call.
|
||||
- **Don't chain `codegraph_search` + `codegraph_node`** when you just want context — `codegraph_context` is one call.
|
||||
- **Don't loop `codegraph_node` over many symbols** — one `codegraph_explore` call returns several symbols' source grouped in a single capped call, while each separate node/Read call re-reads the whole context and costs far more.
|
||||
- **Index lag**: the file watcher debounces ~500ms behind writes; don't re-query immediately after editing a file in the same turn.
|
||||
|
||||
### If `.codegraph/` doesn't exist
|
||||
|
||||
The MCP server returns "not initialized." Ask the user: *"I notice this project doesn't have CodeGraph initialized. Want me to run `codegraph init -i` to build the index?"*
|
||||
<!-- CODEGRAPH_END -->
|
||||
|
|
@ -0,0 +1,44 @@
|
|||
# GEO Platform Agents
|
||||
|
||||
## Agent Registry
|
||||
|
||||
| Agent | Name | Type | Supported Tasks | File |
|
||||
|-------|------|------|----------------|------|
|
||||
| CitationDetectorAgent | citation_detector | CITATION_DETECTOR | citation_detect, citation_batch | backend/app/agent_framework/agents/citation_detector.py |
|
||||
| ContentGeneratorAgent | content_generator | CONTENT_GENERATOR | content_generate, content_regenerate | backend/app/agent_framework/agents/content_generator_agent.py |
|
||||
| DeAIAgent | deai_agent | DEAI_AGENT | deai_process | backend/app/agent_framework/agents/deai_agent.py |
|
||||
| GEOOptimizerAgent | geo_optimizer | GEO_OPTIMIZER | geo_optimize | backend/app/agent_framework/agents/geo_optimizer.py |
|
||||
| MonitorAgent | monitor | PERFORMANCE_TRACKER | monitor_track, monitor_check_single | backend/app/agent_framework/agents/monitor_agent.py |
|
||||
| SchemaAdvisorAgent | schema_advisor | SCHEMA_ADVISOR | schema_advise | backend/app/agent_framework/agents/schema_advisor.py |
|
||||
| CompetitorAnalyzerAgent | competitor_analyzer | COMPETITOR_ANALYZER | competitor_analyze, competitor_gap_analysis | backend/app/agent_framework/agents/competitor_analyzer.py |
|
||||
| TrendAgent | trend_agent | TREND_AGENT | trend_insight, trend_hotspot | backend/app/agent_framework/agents/trend_agent.py |
|
||||
|
||||
## Running Agents
|
||||
|
||||
### Standalone Mode (No Redis Required)
|
||||
```bash
|
||||
cd geo/backend
|
||||
python3 -m app.agent_framework.standalone [agent_name|all]
|
||||
```
|
||||
|
||||
### With Redis Queue
|
||||
Agents auto-register via Registry when Redis is available. Tasks dispatched via TaskDispatcher.
|
||||
|
||||
## Agent Framework Components
|
||||
- BaseAgent: Abstract base class (backend/app/agent_framework/base_agent.py)
|
||||
- Dispatcher: Task distribution (backend/app/agent_framework/dispatcher.py)
|
||||
- Registry: Agent registration (backend/app/agent_framework/registry.py)
|
||||
- Protocol: Message types and AgentType enum (backend/app/agent_framework/protocol.py)
|
||||
- ConfigManager: Agent configuration (backend/app/agent_framework/config_manager.py)
|
||||
|
||||
## New Agent Data Models
|
||||
- MonitoringRecord + ContentBaseline (backend/app/models/monitoring.py)
|
||||
- SchemaSuggestion (backend/app/models/schema_suggestion.py)
|
||||
- CompetitorInsight (backend/app/models/competitor_insight.py)
|
||||
- TrendInsight (backend/app/models/trend_insight.py)
|
||||
|
||||
## New Agent API Endpoints
|
||||
- /api/v1/monitoring - MonitorAgent API
|
||||
- /api/v1/competitor - CompetitorAnalyzer API
|
||||
- /api/v1/schema - SchemaAdvisor API
|
||||
- /api/v1/trends - TrendAgent API
|
||||
|
|
@ -0,0 +1,58 @@
|
|||
# GEO Platform - AI Context
|
||||
|
||||
## Project Overview
|
||||
GEO (Generative Engine Optimization) SaaS platform that helps brands get cited by AI search engines (ChatGPT, Perplexity, DeepSeek, etc.).
|
||||
|
||||
## Tech Stack
|
||||
- Backend: FastAPI + SQLAlchemy 2.0 (async) + PostgreSQL 15 + Redis 7
|
||||
- Frontend: Next.js 14 + React 18 + TypeScript + Tailwind CSS + shadcn/ui
|
||||
- AI: Multi-LLM provider (DeepSeek, OpenAI, etc.) + RAG knowledge base
|
||||
- Infrastructure: Docker Compose + Prometheus metrics
|
||||
|
||||
## Architecture
|
||||
- Backend API: 33 route modules under backend/app/api/
|
||||
- Agent Framework: 8 agents (CitationDetector, ContentGenerator, DeAIAgent, GEOOptimizer, MonitorAgent, SchemaAdvisor, CompetitorAnalyzer, TrendAgent)
|
||||
- Data Models: 35 models under backend/app/models/
|
||||
- Services: 80+ service files organized by domain under backend/app/services/
|
||||
- Repositories: 12 repository files under backend/app/repositories/ for data access
|
||||
- Frontend: 25+ pages under frontend/app/(dashboard)/
|
||||
- Frontend API Clients: 27 modules under frontend/lib/api/
|
||||
|
||||
## Key Patterns
|
||||
- Repository pattern for data access (backend/app/repositories/)
|
||||
- Agent Framework: BaseAgent abstract class → concrete agents, Dispatcher + Registry + Redis Queue
|
||||
- Shared BrandScoringDataService for V2 5-dimension scoring
|
||||
- Content Pipeline: Generate → DeAI → GEO Optimize → HTML Output
|
||||
- GEO Workflow: Diagnosis → Strategy → Plan → Content Generation → Monitoring
|
||||
|
||||
## API Conventions
|
||||
- All routes under /api/v1/ prefix
|
||||
- Authentication via JWT Bearer token (get_current_user dependency)
|
||||
- Pydantic schemas for request/response validation
|
||||
- Async SQLAlchemy ORM with AsyncSession
|
||||
|
||||
## Frontend Conventions
|
||||
- API clients in frontend/lib/api/ using fetchWithAuth from client.ts
|
||||
- Barrel export in frontend/lib/api/index.ts
|
||||
- Types co-located with API modules
|
||||
- Dashboard pages under frontend/app/(dashboard)/dashboard/
|
||||
|
||||
## Environment
|
||||
- Backend runs on port 8000
|
||||
- Frontend runs on port 3001 (dev)
|
||||
- PostgreSQL on port 5432
|
||||
- Redis on port 6379
|
||||
- .env.example has all required variables
|
||||
|
||||
## Common Commands
|
||||
- Start backend: cd geo/backend && python3 -m uvicorn app.main:app --reload --port 8000
|
||||
- Start frontend: cd geo/frontend && npm run dev
|
||||
- Start all agents: cd geo/backend && python3 -m app.agent_framework.standalone all
|
||||
- Docker: docker-compose up -d
|
||||
|
||||
## Important Notes
|
||||
- SEOOptimizer in code actually performs GEO optimization, not traditional SEO
|
||||
- content.py (content generation) vs contents.py (content management) - different modules
|
||||
- Analytics endpoints all require authentication via _get_org_id or get_current_user
|
||||
- CORS: dev mode uses allow_origins=["*"], production must restrict
|
||||
- Brand scoring uses V2 5-dimension system: mention_rate, recommendation_rank, sentiment_score, citation_quality, competitive_position
|
||||
|
|
@ -0,0 +1,131 @@
|
|||
"""add monetization tables: diagnosis_records, attribution_records, payment_orders
|
||||
|
||||
Revision ID: f063b3da67b6
|
||||
Revises: g1h2i3j4kl56
|
||||
Create Date: 2026-06-01 07:40:07.419407
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
revision: str = "f063b3da67b6"
|
||||
down_revision: Union[str, Sequence[str], None] = "g1h2i3j4kl56"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"diagnosis_records",
|
||||
sa.Column("id", sa.Uuid(), nullable=False),
|
||||
sa.Column("brand_id", sa.Uuid(), nullable=False),
|
||||
sa.Column("user_id", sa.Uuid(), nullable=False),
|
||||
sa.Column("diagnosis_type", sa.String(length=20), nullable=False),
|
||||
sa.Column("status", sa.String(length=20), nullable=False),
|
||||
sa.Column("overall_score", sa.Float(), nullable=True),
|
||||
sa.Column("result_json", sa.JSON(), nullable=True),
|
||||
sa.Column("error_message", sa.Text(), nullable=True),
|
||||
sa.Column("collection_metadata", sa.JSON(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(), server_default=sa.text("now()"), nullable=False),
|
||||
sa.Column("completed_at", sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(["brand_id"], ["brands.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index("idx_diagnosis_records_brand_id", "diagnosis_records", ["brand_id"])
|
||||
op.create_index("idx_diagnosis_records_user_id", "diagnosis_records", ["user_id"])
|
||||
op.create_index("idx_diagnosis_records_status", "diagnosis_records", ["status"])
|
||||
op.create_index("idx_diagnosis_records_created_at", "diagnosis_records", ["created_at"])
|
||||
|
||||
op.create_table(
|
||||
"attribution_records",
|
||||
sa.Column("id", sa.Uuid(), nullable=False),
|
||||
sa.Column("user_id", sa.Text(), nullable=False),
|
||||
sa.Column("brand_id", sa.Uuid(), nullable=False),
|
||||
sa.Column("content_id", sa.Uuid(), nullable=True),
|
||||
sa.Column("baseline_score", sa.Float(), nullable=False),
|
||||
sa.Column("current_score", sa.Float(), nullable=True),
|
||||
sa.Column("score_delta", sa.Float(), nullable=True),
|
||||
sa.Column("attribution_window_days", sa.Integer(), server_default="28", nullable=False),
|
||||
sa.Column("published_at", sa.DateTime(), nullable=True),
|
||||
sa.Column("window_end_at", sa.DateTime(), nullable=True),
|
||||
sa.Column("status", sa.String(length=20), server_default="tracking", nullable=False),
|
||||
sa.Column("attributed_dimensions", sa.JSON(), nullable=True),
|
||||
sa.Column("roi_percentage", sa.Float(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(), server_default=sa.text("now()"), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("now()"), nullable=False),
|
||||
sa.ForeignKeyConstraint(["brand_id"], ["brands.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(["content_id"], ["contents.id"], ondelete="SET NULL"),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index("idx_attribution_records_brand_id", "attribution_records", ["brand_id"])
|
||||
op.create_index("idx_attribution_records_user_id", "attribution_records", ["user_id"])
|
||||
op.create_index("idx_attribution_records_status", "attribution_records", ["status"])
|
||||
op.create_index("idx_attribution_records_content_id", "attribution_records", ["content_id"])
|
||||
|
||||
op.drop_table("payment_orders")
|
||||
op.create_table(
|
||||
"payment_orders",
|
||||
sa.Column("id", sa.Uuid(), nullable=False),
|
||||
sa.Column("user_id", sa.String(length=36), nullable=False),
|
||||
sa.Column("plan", sa.String(length=20), nullable=False),
|
||||
sa.Column("amount", sa.Float(), nullable=False),
|
||||
sa.Column("currency", sa.String(length=10), nullable=False),
|
||||
sa.Column("payment_provider", sa.String(length=20), nullable=False),
|
||||
sa.Column("payment_id", sa.String(length=255), nullable=True),
|
||||
sa.Column("status", sa.String(length=20), nullable=False),
|
||||
sa.Column("pay_url", sa.String(length=1024), nullable=True),
|
||||
sa.Column("callback_data", sa.JSON(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(), server_default=sa.text("now()"), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("now()"), nullable=False),
|
||||
sa.Column("paid_at", sa.DateTime(), nullable=True),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("payment_orders")
|
||||
op.create_table(
|
||||
"payment_orders",
|
||||
sa.Column("id", sa.TEXT(), nullable=False),
|
||||
sa.Column("orderNo", sa.TEXT(), nullable=False),
|
||||
sa.Column("userId", sa.TEXT(), nullable=False),
|
||||
sa.Column("channelId", sa.TEXT(), nullable=True),
|
||||
sa.Column("subject", sa.TEXT(), nullable=False),
|
||||
sa.Column("body", sa.TEXT(), nullable=True),
|
||||
sa.Column("amount", sa.NUMERIC(precision=12, scale=2), nullable=False),
|
||||
sa.Column("currency", sa.TEXT(), server_default="'CNY'", nullable=False),
|
||||
sa.Column("status", sa.TEXT(), server_default="'pending'", nullable=False),
|
||||
sa.Column("clientIp", sa.TEXT(), nullable=True),
|
||||
sa.Column("userAgent", sa.TEXT(), nullable=True),
|
||||
sa.Column("metadata", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column("channelOrderNo", sa.TEXT(), nullable=True),
|
||||
sa.Column("createdAt", postgresql.TIMESTAMP(precision=3), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.Column("updatedAt", postgresql.TIMESTAMP(precision=3), nullable=False),
|
||||
sa.Column("paidAt", postgresql.TIMESTAMP(precision=3), nullable=True),
|
||||
sa.Column("cancelledAt", postgresql.TIMESTAMP(precision=3), nullable=True),
|
||||
sa.ForeignKeyConstraint(["channelId"], ["payment_channels.id"], onupdate="CASCADE", ondelete="SET NULL"),
|
||||
sa.ForeignKeyConstraint(["userId"], ["users.id"], onupdate="CASCADE", ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(op.f("payment_orders_userId_idx"), "payment_orders", ["userId"])
|
||||
op.create_index(op.f("payment_orders_status_idx"), "payment_orders", ["status"])
|
||||
op.create_index(op.f("payment_orders_orderNo_key"), "payment_orders", ["orderNo"], unique=True)
|
||||
op.create_index(op.f("payment_orders_orderNo_idx"), "payment_orders", ["orderNo"])
|
||||
op.create_index(op.f("payment_orders_createdAt_idx"), "payment_orders", ["createdAt"])
|
||||
|
||||
op.drop_index("idx_attribution_records_content_id", table_name="attribution_records")
|
||||
op.drop_index("idx_attribution_records_status", table_name="attribution_records")
|
||||
op.drop_index("idx_attribution_records_user_id", table_name="attribution_records")
|
||||
op.drop_index("idx_attribution_records_brand_id", table_name="attribution_records")
|
||||
op.drop_table("attribution_records")
|
||||
|
||||
op.drop_index("idx_diagnosis_records_created_at", table_name="diagnosis_records")
|
||||
op.drop_index("idx_diagnosis_records_status", table_name="diagnosis_records")
|
||||
op.drop_index("idx_diagnosis_records_user_id", table_name="diagnosis_records")
|
||||
op.drop_index("idx_diagnosis_records_brand_id", table_name="diagnosis_records")
|
||||
op.drop_table("diagnosis_records")
|
||||
|
|
@ -4,10 +4,12 @@ from .citation_detector import CitationDetectorAgent
|
|||
from .content_generator_agent import ContentGeneratorAgent
|
||||
from .deai_agent import DeAIAgent
|
||||
from .geo_optimizer_agent import GEOOptimizerAgent
|
||||
from .trend_agent import TrendAgent
|
||||
|
||||
__all__ = [
|
||||
"CitationDetectorAgent",
|
||||
"ContentGeneratorAgent",
|
||||
"DeAIAgent",
|
||||
"GEOOptimizerAgent",
|
||||
"TrendAgent",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
"""CitationDetector Agent - 将现有 CitationEngine 封装为 Agent"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.agent_framework.base import BaseAgent
|
||||
from app.agent_framework.protocol import (
|
||||
|
|
@ -16,19 +16,13 @@ from app.agent_framework.protocol import (
|
|||
from app.database import AsyncSessionLocal
|
||||
from app.models.citation_record import CitationRecord
|
||||
from app.models.query import Query
|
||||
from app.workers.citation_engine import CitationEngine
|
||||
from app.models.query_task import QueryTask
|
||||
from app.services.ai_engine.platform_bridge import execute_single_platform
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CitationDetectorAgent(BaseAgent):
|
||||
"""
|
||||
引用检测 Agent:将现有 CitationEngine 封装为 BaseAgent 实现。
|
||||
|
||||
支持的任务类型:
|
||||
- citation_detect: 执行完整的引用检测(遍历 query 的所有平台)
|
||||
- citation_detect_single: 执行单个平台的引用检测
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
|
|
@ -36,7 +30,6 @@ class CitationDetectorAgent(BaseAgent):
|
|||
agent_type=AgentType.CITATION_DETECTOR,
|
||||
version="1.0.0",
|
||||
)
|
||||
self._engine = CitationEngine()
|
||||
|
||||
def get_capabilities(self) -> AgentCapability:
|
||||
return AgentCapability(
|
||||
|
|
@ -49,7 +42,6 @@ class CitationDetectorAgent(BaseAgent):
|
|||
)
|
||||
|
||||
async def execute(self, task: TaskMessage) -> TaskResult:
|
||||
"""执行引用检测任务"""
|
||||
started_at = datetime.now(timezone.utc)
|
||||
start_time = time.monotonic()
|
||||
|
||||
|
|
@ -94,17 +86,11 @@ class CitationDetectorAgent(BaseAgent):
|
|||
)
|
||||
|
||||
async def _execute_full_detect(self, task: TaskMessage) -> dict:
|
||||
"""
|
||||
执行完整的引用检测(遍历 query 的所有平台)。
|
||||
input_data 需包含: query_id (str)
|
||||
"""
|
||||
query_id = task.input_data.get("query_id")
|
||||
if not query_id:
|
||||
raise ValueError("input_data must contain 'query_id'")
|
||||
|
||||
async with AsyncSessionLocal() as db:
|
||||
from sqlalchemy import select
|
||||
|
||||
stmt = select(Query).where(Query.id == query_id)
|
||||
result = await db.execute(stmt)
|
||||
query = result.scalar_one_or_none()
|
||||
|
|
@ -112,23 +98,75 @@ class CitationDetectorAgent(BaseAgent):
|
|||
if not query:
|
||||
raise ValueError(f"Query {query_id} not found")
|
||||
|
||||
# 上报进度:开始检测
|
||||
await self.report_progress(
|
||||
task_id=task.task_id,
|
||||
progress=0.1,
|
||||
message=f"Starting citation detection for query '{query.keyword}'",
|
||||
)
|
||||
|
||||
records = await self._engine.execute_query(query, db)
|
||||
records: list[CitationRecord] = []
|
||||
platforms = query.platforms or ["wenxin", "kimi"]
|
||||
brand_aliases = query.brand_aliases or []
|
||||
|
||||
for i, platform_name in enumerate(platforms):
|
||||
progress = 0.1 + 0.8 * (i / len(platforms))
|
||||
await self.report_progress(
|
||||
task_id=task.task_id,
|
||||
progress=progress,
|
||||
message=f"Detecting on platform '{platform_name}' ({i+1}/{len(platforms)})",
|
||||
)
|
||||
|
||||
task_obj = await self._get_or_create_task(db, query.id, platform_name)
|
||||
task_obj.status = "running"
|
||||
task_obj.started_at = datetime.utcnow()
|
||||
task_obj.error_message = None
|
||||
await db.commit()
|
||||
|
||||
try:
|
||||
detect_result = await execute_single_platform(
|
||||
keyword=query.keyword,
|
||||
platform=platform_name,
|
||||
target_brand=query.target_brand,
|
||||
brand_aliases=brand_aliases,
|
||||
)
|
||||
|
||||
record = CitationRecord.from_citation_result(
|
||||
query_id=query.id,
|
||||
platform=platform_name,
|
||||
result=detect_result,
|
||||
)
|
||||
db.add(record)
|
||||
records.append(record)
|
||||
|
||||
task_obj.status = "success"
|
||||
task_obj.completed_at = datetime.utcnow()
|
||||
await db.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"平台 {platform_name} 查询失败: {e}")
|
||||
task_obj.status = "failed"
|
||||
task_obj.error_message = str(e)
|
||||
task_obj.completed_at = datetime.utcnow()
|
||||
|
||||
record = CitationRecord.from_citation_result(
|
||||
query_id=query.id,
|
||||
platform=platform_name,
|
||||
result={"cited": False, "raw_response": str(e)},
|
||||
)
|
||||
db.add(record)
|
||||
records.append(record)
|
||||
await db.commit()
|
||||
|
||||
query.last_queried_at = datetime.utcnow()
|
||||
query.next_query_at = self._calculate_next_query_at(query.frequency)
|
||||
await db.commit()
|
||||
|
||||
# 上报进度:检测完成
|
||||
await self.report_progress(
|
||||
task_id=task.task_id,
|
||||
progress=1.0,
|
||||
message=f"Detection completed: {len(records)} records found",
|
||||
)
|
||||
|
||||
# 构建输出
|
||||
record_summaries = []
|
||||
for r in records:
|
||||
record_summaries.append({
|
||||
|
|
@ -148,10 +186,6 @@ class CitationDetectorAgent(BaseAgent):
|
|||
}
|
||||
|
||||
async def _execute_single_detect(self, task: TaskMessage) -> dict:
|
||||
"""
|
||||
执行单个平台的引用检测。
|
||||
input_data 需包含: keyword, platform, target_brand, brand_aliases(optional)
|
||||
"""
|
||||
keyword = task.input_data.get("keyword")
|
||||
platform = task.input_data.get("platform")
|
||||
target_brand = task.input_data.get("target_brand")
|
||||
|
|
@ -162,56 +196,118 @@ class CitationDetectorAgent(BaseAgent):
|
|||
"input_data must contain 'keyword', 'platform', 'target_brand'"
|
||||
)
|
||||
|
||||
# 上报进度
|
||||
await self.report_progress(
|
||||
task_id=task.task_id,
|
||||
progress=0.2,
|
||||
message=f"Querying platform '{platform}' with keyword '{keyword}'",
|
||||
)
|
||||
|
||||
result = await self._engine.execute_single_platform(
|
||||
result = await execute_single_platform(
|
||||
keyword=keyword,
|
||||
platform=platform,
|
||||
target_brand=target_brand,
|
||||
brand_aliases=brand_aliases,
|
||||
)
|
||||
|
||||
# 上报进度:完成
|
||||
await self.report_progress(
|
||||
task_id=task.task_id,
|
||||
progress=1.0,
|
||||
message=f"Single platform detection completed on '{platform}'",
|
||||
)
|
||||
|
||||
# 清理 raw_response 以避免返回过大的数据
|
||||
output = {k: v for k, v in result.items() if k != "raw_response"}
|
||||
return output
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# 向后兼容:保留原有 CitationEngine 的同步调用接口
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
async def execute_query_compat(self, query: Query, db) -> list[CitationRecord]:
|
||||
"""
|
||||
向后兼容方法:供现有 scheduler 继续调用。
|
||||
签名与 CitationEngine.execute_query 完全一致。
|
||||
"""
|
||||
return await self._engine.execute_query(query, db)
|
||||
records: list[CitationRecord] = []
|
||||
platforms = query.platforms or ["wenxin", "kimi"]
|
||||
brand_aliases = query.brand_aliases or []
|
||||
|
||||
for platform_name in platforms:
|
||||
task_obj = await self._get_or_create_task(db, query.id, platform_name)
|
||||
task_obj.status = "running"
|
||||
task_obj.started_at = datetime.utcnow()
|
||||
task_obj.error_message = None
|
||||
await db.commit()
|
||||
|
||||
try:
|
||||
detect_result = await execute_single_platform(
|
||||
keyword=query.keyword,
|
||||
platform=platform_name,
|
||||
target_brand=query.target_brand,
|
||||
brand_aliases=brand_aliases,
|
||||
)
|
||||
|
||||
record = CitationRecord.from_citation_result(
|
||||
query_id=query.id,
|
||||
platform=platform_name,
|
||||
result=detect_result,
|
||||
)
|
||||
db.add(record)
|
||||
records.append(record)
|
||||
|
||||
task_obj.status = "success"
|
||||
task_obj.completed_at = datetime.utcnow()
|
||||
await db.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"平台 {platform_name} 查询失败: {e}")
|
||||
task_obj.status = "failed"
|
||||
task_obj.error_message = str(e)
|
||||
task_obj.completed_at = datetime.utcnow()
|
||||
|
||||
record = CitationRecord.from_citation_result(
|
||||
query_id=query.id,
|
||||
platform=platform_name,
|
||||
result={"cited": False, "raw_response": str(e)},
|
||||
)
|
||||
db.add(record)
|
||||
records.append(record)
|
||||
await db.commit()
|
||||
|
||||
query.last_queried_at = datetime.utcnow()
|
||||
query.next_query_at = self._calculate_next_query_at(query.frequency)
|
||||
await db.commit()
|
||||
|
||||
return records
|
||||
|
||||
async def execute_single_platform_compat(
|
||||
self, keyword: str, platform: str, target_brand: str, brand_aliases: list
|
||||
) -> dict:
|
||||
"""
|
||||
向后兼容方法:供现有 scheduler 继续调用。
|
||||
签名与 CitationEngine.execute_single_platform 完全一致。
|
||||
"""
|
||||
return await self._engine.execute_single_platform(
|
||||
return await execute_single_platform(
|
||||
keyword=keyword,
|
||||
platform=platform,
|
||||
target_brand=target_brand,
|
||||
brand_aliases=brand_aliases,
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
"""关闭底层引擎"""
|
||||
await self._engine.close()
|
||||
async def _get_or_create_task(self, db, query_id: uuid.UUID, platform: str) -> QueryTask:
|
||||
stmt = select(QueryTask).where(
|
||||
QueryTask.query_id == query_id,
|
||||
QueryTask.platform == platform,
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
task_obj = result.scalar_one_or_none()
|
||||
|
||||
if not task_obj:
|
||||
task_obj = QueryTask(
|
||||
query_id=query_id,
|
||||
platform=platform,
|
||||
status="pending",
|
||||
)
|
||||
db.add(task_obj)
|
||||
await db.commit()
|
||||
await db.refresh(task_obj)
|
||||
|
||||
return task_obj
|
||||
|
||||
@staticmethod
|
||||
def _calculate_next_query_at(frequency: str | None) -> datetime:
|
||||
now = datetime.utcnow()
|
||||
freq_map = {
|
||||
"daily": timedelta(days=1),
|
||||
"weekly": timedelta(days=7),
|
||||
"monthly": timedelta(days=30),
|
||||
}
|
||||
delta = freq_map.get(frequency or "weekly", timedelta(days=7))
|
||||
return now + delta
|
||||
|
|
|
|||
|
|
@ -0,0 +1,163 @@
|
|||
import logging
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from app.agent_framework.base import BaseAgent
|
||||
from app.agent_framework.protocol import (
|
||||
AgentCapability,
|
||||
AgentType,
|
||||
TaskMessage,
|
||||
TaskResult,
|
||||
TaskStatus,
|
||||
)
|
||||
from app.services.llm import LLMFactory, LLMError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CompetitorAnalyzerAgent(BaseAgent):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name="competitor_analyzer",
|
||||
agent_type=AgentType.COMPETITOR_ANALYZER,
|
||||
version="1.0.0",
|
||||
)
|
||||
|
||||
def get_capabilities(self) -> AgentCapability:
|
||||
return AgentCapability(
|
||||
agent_name=self.name,
|
||||
agent_type=self.agent_type,
|
||||
version=self.version,
|
||||
supported_tasks=["competitor_analyze", "competitor_gap_analysis"],
|
||||
max_concurrency=2,
|
||||
description="竞品策略分析Agent:对比品牌与竞品的引用数据,识别差距领域,发现机会点,生成策略建议",
|
||||
)
|
||||
|
||||
async def execute(self, task: TaskMessage) -> TaskResult:
|
||||
started_at = datetime.now(timezone.utc)
|
||||
start_time = time.monotonic()
|
||||
|
||||
try:
|
||||
if task.task_type == "competitor_gap_analysis":
|
||||
output = await self._gap_analysis(task)
|
||||
else:
|
||||
output = await self._analyze(task)
|
||||
|
||||
elapsed = time.monotonic() - start_time
|
||||
return TaskResult(
|
||||
task_id=task.task_id,
|
||||
agent_name=self.name,
|
||||
status=TaskStatus.COMPLETED,
|
||||
output_data=output,
|
||||
error_message=None,
|
||||
started_at=started_at,
|
||||
completed_at=datetime.now(timezone.utc),
|
||||
metrics={
|
||||
"elapsed_seconds": round(elapsed, 2),
|
||||
"task_type": task.task_type,
|
||||
},
|
||||
)
|
||||
|
||||
except LLMError as e:
|
||||
elapsed = time.monotonic() - start_time
|
||||
logger.error(f"CompetitorAnalyzer LLM error on task {task.task_id}: {e}")
|
||||
return TaskResult(
|
||||
task_id=task.task_id,
|
||||
agent_name=self.name,
|
||||
status=TaskStatus.FAILED,
|
||||
output_data=None,
|
||||
error_message=f"LLM调用失败: {e}",
|
||||
started_at=started_at,
|
||||
completed_at=datetime.now(timezone.utc),
|
||||
metrics={
|
||||
"elapsed_seconds": round(elapsed, 2),
|
||||
"task_type": task.task_type,
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
elapsed = time.monotonic() - start_time
|
||||
logger.error(f"CompetitorAnalyzer task {task.task_id} failed: {e}")
|
||||
return TaskResult(
|
||||
task_id=task.task_id,
|
||||
agent_name=self.name,
|
||||
status=TaskStatus.FAILED,
|
||||
output_data=None,
|
||||
error_message=str(e),
|
||||
started_at=started_at,
|
||||
completed_at=datetime.now(timezone.utc),
|
||||
metrics={
|
||||
"elapsed_seconds": round(elapsed, 2),
|
||||
"task_type": task.task_type,
|
||||
},
|
||||
)
|
||||
|
||||
async def _analyze(self, task: TaskMessage) -> dict:
|
||||
from app.services.competitor.competitor_analyzer_service import CompetitorAnalyzerService
|
||||
|
||||
input_data = task.input_data
|
||||
brand_id = input_data.get("brand_id")
|
||||
analysis_types = input_data.get("analysis_types")
|
||||
period_days = input_data.get("period_days", 30)
|
||||
|
||||
if not brand_id:
|
||||
raise ValueError("input_data必须包含'brand_id'字段")
|
||||
|
||||
await self.report_progress(
|
||||
task_id=task.task_id,
|
||||
progress=0.1,
|
||||
message="开始竞品策略分析...",
|
||||
)
|
||||
|
||||
service = CompetitorAnalyzerService()
|
||||
result = await service.analyze_competitor(
|
||||
brand_id=brand_id,
|
||||
analysis_types=analysis_types,
|
||||
period_days=period_days,
|
||||
progress_callback=lambda p, m: self.report_progress(
|
||||
task_id=task.task_id, progress=p, message=m,
|
||||
),
|
||||
)
|
||||
|
||||
await self.report_progress(
|
||||
task_id=task.task_id,
|
||||
progress=1.0,
|
||||
message="竞品策略分析完成",
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def _gap_analysis(self, task: TaskMessage) -> dict:
|
||||
from app.services.competitor.competitor_analyzer_service import CompetitorAnalyzerService
|
||||
|
||||
input_data = task.input_data
|
||||
brand_id = input_data.get("brand_id")
|
||||
period_days = input_data.get("period_days", 30)
|
||||
|
||||
if not brand_id:
|
||||
raise ValueError("input_data必须包含'brand_id'字段")
|
||||
|
||||
await self.report_progress(
|
||||
task_id=task.task_id,
|
||||
progress=0.1,
|
||||
message="开始竞品差距分析...",
|
||||
)
|
||||
|
||||
service = CompetitorAnalyzerService()
|
||||
result = await service.analyze_competitor(
|
||||
brand_id=brand_id,
|
||||
analysis_types=["citation_gap", "platform_coverage", "query_overlap"],
|
||||
period_days=period_days,
|
||||
progress_callback=lambda p, m: self.report_progress(
|
||||
task_id=task.task_id, progress=p, message=m,
|
||||
),
|
||||
)
|
||||
|
||||
await self.report_progress(
|
||||
task_id=task.task_id,
|
||||
progress=1.0,
|
||||
message="竞品差距分析完成",
|
||||
)
|
||||
|
||||
return result
|
||||
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
|
@ -16,6 +15,7 @@ from app.agent_framework.protocol import (
|
|||
TaskStatus,
|
||||
)
|
||||
from app.services.llm import LLMFactory, LLMError
|
||||
from app.utils.json_extractor import extract_json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -165,7 +165,7 @@ class ContentGeneratorAgent(BaseAgent):
|
|||
|
||||
# 4. 解析JSON输出
|
||||
try:
|
||||
topics = json.loads(self._extract_json(response.content))
|
||||
topics = json.loads(extract_json(response.content))
|
||||
except json.JSONDecodeError:
|
||||
topics = [{"title": response.content, "reason": "LLM输出解析失败,返回原始内容"}]
|
||||
|
||||
|
|
@ -277,22 +277,3 @@ class ContentGeneratorAgent(BaseAgent):
|
|||
logger.warning(f"RAG检索失败,跳过知识库上下文: {e}")
|
||||
|
||||
return "暂无相关知识库内容"
|
||||
|
||||
def _extract_json(self, text: str) -> str:
|
||||
"""从LLM响应中提取JSON(可能被markdown包裹)"""
|
||||
# 尝试提取```json ... ```中的内容
|
||||
match = re.search(r'```(?:json)?\s*\n?(.*?)\n?```', text, re.DOTALL)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
# 尝试直接找 [ 或 { 开头的JSON
|
||||
for i, c in enumerate(text):
|
||||
if c in '[{':
|
||||
depth = 0
|
||||
for j in range(i, len(text)):
|
||||
if text[j] in '[{':
|
||||
depth += 1
|
||||
elif text[j] in ']}':
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
return text[i : j + 1]
|
||||
return text
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
|
@ -16,6 +15,7 @@ from app.agent_framework.protocol import (
|
|||
TaskStatus,
|
||||
)
|
||||
from app.services.llm import LLMFactory, LLMError
|
||||
from app.utils.json_extractor import extract_json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -155,7 +155,7 @@ class GEOOptimizerAgent(BaseAgent):
|
|||
|
||||
# 尝试解析JSON输出
|
||||
try:
|
||||
result = json.loads(self._extract_json(response.content))
|
||||
result = json.loads(extract_json(response.content))
|
||||
result["usage"] = response.usage
|
||||
# 上报进度:完成
|
||||
await self.report_progress(
|
||||
|
|
@ -178,20 +178,3 @@ class GEOOptimizerAgent(BaseAgent):
|
|||
"changes": ["LLM输出非标准格式,已返回原始优化结果"],
|
||||
"usage": response.usage,
|
||||
}
|
||||
|
||||
def _extract_json(self, text: str) -> str:
|
||||
"""从响应中提取JSON"""
|
||||
match = re.search(r'```(?:json)?\s*\n?(.*?)\n?```', text, re.DOTALL)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
for i, c in enumerate(text):
|
||||
if c == '{':
|
||||
depth = 0
|
||||
for j in range(i, len(text)):
|
||||
if text[j] == '{':
|
||||
depth += 1
|
||||
elif text[j] == '}':
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
return text[i : j + 1]
|
||||
return text
|
||||
|
|
|
|||
|
|
@ -0,0 +1,231 @@
|
|||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.agent_framework.base import BaseAgent
|
||||
from app.agent_framework.protocol import (
|
||||
AgentCapability,
|
||||
AgentType,
|
||||
TaskMessage,
|
||||
TaskResult,
|
||||
TaskStatus,
|
||||
)
|
||||
from app.database import AsyncSessionLocal
|
||||
from app.models.monitoring import MonitoringRecord
|
||||
from app.models.query import Query
|
||||
from app.models.brand import Brand
|
||||
from app.services.monitoring.monitor_service import MonitorService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MonitorAgent(BaseAgent):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name="monitor",
|
||||
agent_type=AgentType.PERFORMANCE_TRACKER,
|
||||
version="1.0.0",
|
||||
)
|
||||
|
||||
def get_capabilities(self) -> AgentCapability:
|
||||
return AgentCapability(
|
||||
agent_name=self.name,
|
||||
agent_type=self.agent_type,
|
||||
version=self.version,
|
||||
supported_tasks=["monitor_track", "monitor_check_single"],
|
||||
max_concurrency=3,
|
||||
description="效果追踪Agent:监测品牌引用量、情感、排名变化,生成变化报告",
|
||||
)
|
||||
|
||||
async def execute(self, task: TaskMessage) -> TaskResult:
|
||||
started_at = datetime.now(timezone.utc)
|
||||
start_time = time.monotonic()
|
||||
|
||||
try:
|
||||
task_type = task.task_type
|
||||
if task_type == "monitor_track":
|
||||
output = await self._monitor_track(task)
|
||||
elif task_type == "monitor_check_single":
|
||||
output = await self._monitor_check_single(task)
|
||||
else:
|
||||
raise ValueError(f"不支持的任务类型: {task_type}")
|
||||
|
||||
elapsed = time.monotonic() - start_time
|
||||
return TaskResult(
|
||||
task_id=task.task_id,
|
||||
agent_name=self.name,
|
||||
status=TaskStatus.COMPLETED,
|
||||
output_data=output,
|
||||
error_message=None,
|
||||
started_at=started_at,
|
||||
completed_at=datetime.now(timezone.utc),
|
||||
metrics={
|
||||
"elapsed_seconds": round(elapsed, 2),
|
||||
"task_type": task_type,
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
elapsed = time.monotonic() - start_time
|
||||
logger.error(f"MonitorAgent task {task.task_id} failed: {e}")
|
||||
return TaskResult(
|
||||
task_id=task.task_id,
|
||||
agent_name=self.name,
|
||||
status=TaskStatus.FAILED,
|
||||
output_data=None,
|
||||
error_message=str(e),
|
||||
started_at=started_at,
|
||||
completed_at=datetime.now(timezone.utc),
|
||||
metrics={
|
||||
"elapsed_seconds": round(elapsed, 2),
|
||||
"task_type": task.task_type,
|
||||
},
|
||||
)
|
||||
|
||||
async def _monitor_track(self, task: TaskMessage) -> dict:
|
||||
input_data = task.input_data
|
||||
brand_id = input_data.get("brand_id")
|
||||
if not brand_id:
|
||||
raise ValueError("input_data必须包含'brand_id'字段")
|
||||
|
||||
brand_id = uuid.UUID(brand_id)
|
||||
|
||||
await self.report_progress(
|
||||
task_id=task.task_id,
|
||||
progress=0.1,
|
||||
message="开始效果追踪...",
|
||||
)
|
||||
|
||||
async with AsyncSessionLocal() as db:
|
||||
stmt = select(Brand).where(Brand.id == brand_id)
|
||||
result = await db.execute(stmt)
|
||||
brand = result.scalar_one_or_none()
|
||||
if not brand:
|
||||
raise ValueError(f"品牌不存在: {brand_id}")
|
||||
|
||||
queries_stmt = select(Query).where(Query.target_brand == brand.name)
|
||||
queries_result = await db.execute(queries_stmt)
|
||||
queries = list(queries_result.scalars().all())
|
||||
|
||||
if not queries:
|
||||
await self.report_progress(
|
||||
task_id=task.task_id,
|
||||
progress=1.0,
|
||||
message="品牌下无关联查询,效果追踪完成",
|
||||
)
|
||||
return {
|
||||
"brand_id": str(brand_id),
|
||||
"brand_name": brand.name,
|
||||
"total_queries": 0,
|
||||
"reports": [],
|
||||
}
|
||||
|
||||
total_queries = len(queries)
|
||||
reports = []
|
||||
service = MonitorService()
|
||||
|
||||
monitoring_stmt = select(MonitoringRecord).where(
|
||||
MonitoringRecord.brand_id == brand_id,
|
||||
MonitoringRecord.status == "active",
|
||||
)
|
||||
monitoring_result = await db.execute(monitoring_stmt)
|
||||
monitoring_records = list(monitoring_result.scalars().all())
|
||||
|
||||
for idx, record in enumerate(monitoring_records):
|
||||
progress = 0.2 + (0.7 * idx / max(len(monitoring_records), 1))
|
||||
await self.report_progress(
|
||||
task_id=task.task_id,
|
||||
progress=progress,
|
||||
message=f"正在检测第 {idx + 1}/{len(monitoring_records)} 条监测记录...",
|
||||
)
|
||||
|
||||
updated_record = await service.check_and_compare(db, record.id)
|
||||
if updated_record:
|
||||
report = await service.generate_change_report(db, updated_record.id)
|
||||
if report:
|
||||
reports.append(report)
|
||||
|
||||
await self.report_progress(
|
||||
task_id=task.task_id,
|
||||
progress=1.0,
|
||||
message="效果追踪完成",
|
||||
)
|
||||
|
||||
return {
|
||||
"brand_id": str(brand_id),
|
||||
"brand_name": brand.name,
|
||||
"total_queries": total_queries,
|
||||
"checked_records": len(monitoring_records),
|
||||
"reports": reports,
|
||||
}
|
||||
|
||||
async def _monitor_check_single(self, task: TaskMessage) -> dict:
|
||||
input_data = task.input_data
|
||||
brand_id = input_data.get("brand_id")
|
||||
keyword = input_data.get("keyword")
|
||||
platform = input_data.get("platform")
|
||||
|
||||
if not brand_id:
|
||||
raise ValueError("input_data必须包含'brand_id'字段")
|
||||
|
||||
brand_id = uuid.UUID(brand_id)
|
||||
|
||||
await self.report_progress(
|
||||
task_id=task.task_id,
|
||||
progress=0.1,
|
||||
message="开始单关键词检测...",
|
||||
)
|
||||
|
||||
async with AsyncSessionLocal() as db:
|
||||
service = MonitorService()
|
||||
|
||||
await self.report_progress(
|
||||
task_id=task.task_id,
|
||||
progress=0.3,
|
||||
message="正在创建监测记录...",
|
||||
)
|
||||
|
||||
record = await service.create_monitoring_record(
|
||||
db=db,
|
||||
brand_id=brand_id,
|
||||
query_keywords=keyword,
|
||||
platform=platform,
|
||||
check_interval_hours=input_data.get("check_interval_hours", 24),
|
||||
)
|
||||
|
||||
await self.report_progress(
|
||||
task_id=task.task_id,
|
||||
progress=0.6,
|
||||
message="正在执行检测对比...",
|
||||
)
|
||||
|
||||
updated_record = await service.check_and_compare(db, record.id)
|
||||
|
||||
await self.report_progress(
|
||||
task_id=task.task_id,
|
||||
progress=0.8,
|
||||
message="正在生成变化报告...",
|
||||
)
|
||||
|
||||
report = None
|
||||
if updated_record:
|
||||
report = await service.generate_change_report(db, updated_record.id)
|
||||
|
||||
await self.report_progress(
|
||||
task_id=task.task_id,
|
||||
progress=1.0,
|
||||
message="单关键词检测完成",
|
||||
)
|
||||
|
||||
return {
|
||||
"record_id": str(record.id),
|
||||
"brand_id": str(brand_id),
|
||||
"keyword": keyword,
|
||||
"platform": platform,
|
||||
"change_type": updated_record.change_type if updated_record else None,
|
||||
"report": report,
|
||||
}
|
||||
|
|
@ -0,0 +1,398 @@
|
|||
import copy
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from app.agent_framework.base import BaseAgent
|
||||
from app.agent_framework.prompts.schema_advisor import SCHEMA_ADVISOR_TEMPLATE
|
||||
from app.agent_framework.protocol import (
|
||||
AgentCapability,
|
||||
AgentType,
|
||||
TaskMessage,
|
||||
TaskResult,
|
||||
TaskStatus,
|
||||
)
|
||||
from app.services.llm import LLMFactory, LLMError
|
||||
from app.utils.json_extractor import extract_json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SCHEMA_TEMPLATES = {
|
||||
"Organization": {
|
||||
"@context": "https://schema.org",
|
||||
"@type": "Organization",
|
||||
"name": "",
|
||||
"description": "",
|
||||
"url": "",
|
||||
"logo": "",
|
||||
"sameAs": [],
|
||||
"contactPoint": {
|
||||
"@type": "ContactPoint",
|
||||
"contactType": "customer service",
|
||||
"telephone": "",
|
||||
},
|
||||
},
|
||||
"Product": {
|
||||
"@context": "https://schema.org",
|
||||
"@type": "Product",
|
||||
"name": "",
|
||||
"description": "",
|
||||
"brand": {"@type": "Brand", "name": ""},
|
||||
"offers": {
|
||||
"@type": "Offer",
|
||||
"priceCurrency": "CNY",
|
||||
"availability": "https://schema.org/InStock",
|
||||
},
|
||||
},
|
||||
"FAQPage": {
|
||||
"@context": "https://schema.org",
|
||||
"@type": "FAQPage",
|
||||
"mainEntity": [
|
||||
{
|
||||
"@type": "Question",
|
||||
"name": "",
|
||||
"acceptedAnswer": {
|
||||
"@type": "Answer",
|
||||
"text": "",
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
"Article": {
|
||||
"@context": "https://schema.org",
|
||||
"@type": "Article",
|
||||
"headline": "",
|
||||
"description": "",
|
||||
"author": {"@type": "Organization", "name": ""},
|
||||
"datePublished": "",
|
||||
"image": "",
|
||||
},
|
||||
"LocalBusiness": {
|
||||
"@context": "https://schema.org",
|
||||
"@type": "LocalBusiness",
|
||||
"name": "",
|
||||
"address": {
|
||||
"@type": "PostalAddress",
|
||||
"streetAddress": "",
|
||||
"addressLocality": "",
|
||||
"addressRegion": "",
|
||||
"postalCode": "",
|
||||
"addressCountry": "CN",
|
||||
},
|
||||
"geo": {
|
||||
"@type": "GeoCoordinates",
|
||||
"latitude": "",
|
||||
"longitude": "",
|
||||
},
|
||||
"telephone": "",
|
||||
"openingHours": "",
|
||||
},
|
||||
}
|
||||
|
||||
DIMENSION_SCHEMA_MAP = {
|
||||
"schema_marketing": ["Organization", "LocalBusiness"],
|
||||
"entity_clarity": ["Organization", "Product"],
|
||||
"citation_readiness": ["FAQPage", "Article"],
|
||||
"brand_visibility": ["Organization", "Product"],
|
||||
"local_seo": ["LocalBusiness"],
|
||||
}
|
||||
|
||||
PRIORITY_THRESHOLD = {
|
||||
"high": 30.0,
|
||||
"medium": 60.0,
|
||||
}
|
||||
|
||||
DIFFICULTY_MAP = {
|
||||
"Organization": "easy",
|
||||
"Product": "medium",
|
||||
"FAQPage": "medium",
|
||||
"Article": "easy",
|
||||
"LocalBusiness": "hard",
|
||||
}
|
||||
|
||||
|
||||
class SchemaAdvisorAgent(BaseAgent):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name="schema_advisor",
|
||||
agent_type=AgentType.SCHEMA_ADVISOR,
|
||||
version="1.0.0",
|
||||
)
|
||||
|
||||
def get_capabilities(self) -> AgentCapability:
|
||||
return AgentCapability(
|
||||
agent_name=self.name,
|
||||
agent_type=self.agent_type,
|
||||
version=self.version,
|
||||
supported_tasks=["schema_advise"],
|
||||
max_concurrency=2,
|
||||
description="Schema优化建议Agent:识别Schema缺失维度,生成JSON-LD结构化数据建议",
|
||||
)
|
||||
|
||||
async def execute(self, task: TaskMessage) -> TaskResult:
|
||||
started_at = datetime.now(timezone.utc)
|
||||
start_time = time.monotonic()
|
||||
|
||||
try:
|
||||
output = await self._advise(task)
|
||||
|
||||
elapsed = time.monotonic() - start_time
|
||||
return TaskResult(
|
||||
task_id=task.task_id,
|
||||
agent_name=self.name,
|
||||
status=TaskStatus.COMPLETED,
|
||||
output_data=output,
|
||||
error_message=None,
|
||||
started_at=started_at,
|
||||
completed_at=datetime.now(timezone.utc),
|
||||
metrics={
|
||||
"elapsed_seconds": round(elapsed, 2),
|
||||
"task_type": task.task_type,
|
||||
},
|
||||
)
|
||||
|
||||
except LLMError as e:
|
||||
elapsed = time.monotonic() - start_time
|
||||
logger.error(f"SchemaAdvisor LLM error on task {task.task_id}: {e}")
|
||||
return TaskResult(
|
||||
task_id=task.task_id,
|
||||
agent_name=self.name,
|
||||
status=TaskStatus.FAILED,
|
||||
output_data=None,
|
||||
error_message=f"LLM调用失败: {e}",
|
||||
started_at=started_at,
|
||||
completed_at=datetime.now(timezone.utc),
|
||||
metrics={
|
||||
"elapsed_seconds": round(elapsed, 2),
|
||||
"task_type": task.task_type,
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
elapsed = time.monotonic() - start_time
|
||||
logger.error(f"SchemaAdvisor task {task.task_id} failed: {e}")
|
||||
return TaskResult(
|
||||
task_id=task.task_id,
|
||||
agent_name=self.name,
|
||||
status=TaskStatus.FAILED,
|
||||
output_data=None,
|
||||
error_message=str(e),
|
||||
started_at=started_at,
|
||||
completed_at=datetime.now(timezone.utc),
|
||||
metrics={
|
||||
"elapsed_seconds": round(elapsed, 2),
|
||||
"task_type": task.task_type,
|
||||
},
|
||||
)
|
||||
|
||||
async def _advise(self, task: TaskMessage) -> dict:
|
||||
input_data = task.input_data
|
||||
brand_id = input_data.get("brand_id")
|
||||
diagnosis_data = input_data.get("diagnosis_data", {})
|
||||
brand_info = input_data.get("brand_info", {})
|
||||
focus_dimensions = input_data.get("focus_dimensions")
|
||||
|
||||
if not brand_id:
|
||||
raise ValueError("input_data必须包含'brand_id'字段")
|
||||
|
||||
await self.report_progress(
|
||||
task_id=task.task_id,
|
||||
progress=0.1,
|
||||
message="开始Schema建议分析...",
|
||||
)
|
||||
|
||||
missing_dimensions = self._identify_missing_dimensions(diagnosis_data, focus_dimensions)
|
||||
|
||||
await self.report_progress(
|
||||
task_id=task.task_id,
|
||||
progress=0.3,
|
||||
message=f"识别到{len(missing_dimensions)}个Schema缺失维度...",
|
||||
)
|
||||
|
||||
matched = self._match_templates(missing_dimensions)
|
||||
|
||||
await self.report_progress(
|
||||
task_id=task.task_id,
|
||||
progress=0.5,
|
||||
message="匹配预定义模板完成,开始LLM填充...",
|
||||
)
|
||||
|
||||
filled = await self._fill_with_llm(matched, brand_info)
|
||||
|
||||
await self.report_progress(
|
||||
task_id=task.task_id,
|
||||
progress=0.8,
|
||||
message="LLM填充完成,验证JSON-LD格式...",
|
||||
)
|
||||
|
||||
validated = self._validate_and_sort(filled)
|
||||
|
||||
await self.report_progress(
|
||||
task_id=task.task_id,
|
||||
progress=1.0,
|
||||
message="Schema建议生成完成",
|
||||
)
|
||||
|
||||
return {
|
||||
"brand_id": brand_id,
|
||||
"suggestions": validated,
|
||||
"total": len(validated),
|
||||
}
|
||||
|
||||
def _identify_missing_dimensions(
|
||||
self,
|
||||
diagnosis_data: dict,
|
||||
focus_dimensions: list[str] | None = None,
|
||||
) -> list[dict]:
|
||||
dimensions = []
|
||||
dimension_scores = diagnosis_data.get("dimensions", {})
|
||||
for dim_name, dim_info in dimension_scores.items():
|
||||
if dim_name not in DIMENSION_SCHEMA_MAP:
|
||||
continue
|
||||
if focus_dimensions and dim_name not in focus_dimensions:
|
||||
continue
|
||||
score = dim_info.get("score", 0) if isinstance(dim_info, dict) else dim_info
|
||||
max_score = dim_info.get("max_score", 100) if isinstance(dim_info, dict) else 100
|
||||
percentage = (score / max_score * 100) if max_score > 0 else 0
|
||||
if percentage < 80:
|
||||
dimensions.append({
|
||||
"dimension": dim_name,
|
||||
"current_score": round(score, 2),
|
||||
"max_score": max_score,
|
||||
"percentage": round(percentage, 2),
|
||||
})
|
||||
if not dimensions and diagnosis_data:
|
||||
overall = diagnosis_data.get("overall_score", 0)
|
||||
if overall < 80:
|
||||
for dim_name in DIMENSION_SCHEMA_MAP:
|
||||
if focus_dimensions and dim_name not in focus_dimensions:
|
||||
continue
|
||||
dimensions.append({
|
||||
"dimension": dim_name,
|
||||
"current_score": 0,
|
||||
"max_score": 100,
|
||||
"percentage": 0,
|
||||
})
|
||||
return dimensions
|
||||
|
||||
def _match_templates(self, missing_dimensions: list[dict]) -> list[dict]:
|
||||
matched = []
|
||||
seen_types = set()
|
||||
for dim in missing_dimensions:
|
||||
schema_types = DIMENSION_SCHEMA_MAP.get(dim["dimension"], [])
|
||||
for schema_type in schema_types:
|
||||
if schema_type in seen_types:
|
||||
continue
|
||||
seen_types.add(schema_type)
|
||||
template = SCHEMA_TEMPLATES.get(schema_type)
|
||||
if template:
|
||||
percentage = dim["percentage"]
|
||||
if percentage < PRIORITY_THRESHOLD["high"]:
|
||||
priority = "high"
|
||||
elif percentage < PRIORITY_THRESHOLD["medium"]:
|
||||
priority = "medium"
|
||||
else:
|
||||
priority = "low"
|
||||
matched.append({
|
||||
"schema_type": schema_type,
|
||||
"priority": priority,
|
||||
"diagnosis_dimensions": {
|
||||
"dimension": dim["dimension"],
|
||||
"current_score": dim["current_score"],
|
||||
"max_score": dim["max_score"],
|
||||
"percentage": dim["percentage"],
|
||||
},
|
||||
"json_ld_template": copy.deepcopy(template),
|
||||
"implementation_difficulty": DIFFICULTY_MAP.get(schema_type, "medium"),
|
||||
})
|
||||
return matched
|
||||
|
||||
async def _fill_with_llm(self, matched: list[dict], brand_info: dict) -> list[dict]:
|
||||
provider = LLMFactory.get_default()
|
||||
results = []
|
||||
for item in matched:
|
||||
schema_type = item["schema_type"]
|
||||
try:
|
||||
variables = {
|
||||
"brand_name": brand_info.get("name", ""),
|
||||
"brand_website": brand_info.get("website", ""),
|
||||
"brand_industry": brand_info.get("industry", ""),
|
||||
"schema_type": schema_type,
|
||||
"diagnosis_data": json.dumps(item.get("diagnosis_dimensions", {}), ensure_ascii=False),
|
||||
"existing_schemas": "无",
|
||||
}
|
||||
messages = SCHEMA_ADVISOR_TEMPLATE.render(variables)
|
||||
response = await provider.chat(
|
||||
messages,
|
||||
temperature=0.3,
|
||||
max_tokens=2048,
|
||||
)
|
||||
filled = json.loads(extract_json(response.content))
|
||||
item["json_ld_filled"] = filled
|
||||
item["estimated_impact"] = self._generate_impact_description(
|
||||
schema_type, item.get("diagnosis_dimensions", {}).get("dimension", "")
|
||||
)
|
||||
except (json.JSONDecodeError, LLMError, ValueError) as e:
|
||||
logger.warning(f"LLM填充Schema {schema_type} 失败: {e}")
|
||||
item["json_ld_filled"] = None
|
||||
item["estimated_impact"] = self._generate_impact_description(
|
||||
schema_type, item.get("diagnosis_dimensions", {}).get("dimension", "")
|
||||
)
|
||||
results.append(item)
|
||||
return results
|
||||
|
||||
def _validate_json_ld(self, json_ld: dict) -> dict:
|
||||
errors = []
|
||||
warnings = []
|
||||
|
||||
if not json_ld:
|
||||
return {"is_valid": False, "errors": ["JSON-LD为空"], "warnings": []}
|
||||
|
||||
if "@context" not in json_ld:
|
||||
errors.append("缺少@context字段")
|
||||
|
||||
if "@type" not in json_ld:
|
||||
errors.append("缺少@type字段")
|
||||
|
||||
if "@context" in json_ld and json_ld["@context"] != "https://schema.org":
|
||||
warnings.append(f"@context值非标准: {json_ld.get('@context')}")
|
||||
|
||||
if "@type" in json_ld and json_ld["@type"] not in SCHEMA_TEMPLATES:
|
||||
warnings.append(f"@type非推荐类型: {json_ld.get('@type')}")
|
||||
|
||||
try:
|
||||
json.dumps(json_ld)
|
||||
except (json.JSONDecodeError, TypeError) as e:
|
||||
errors.append(f"JSON序列化失败: {e}")
|
||||
|
||||
return {
|
||||
"is_valid": len(errors) == 0,
|
||||
"errors": errors,
|
||||
"warnings": warnings,
|
||||
}
|
||||
|
||||
def _validate_and_sort(self, items: list[dict]) -> list[dict]:
|
||||
validated = []
|
||||
for item in items:
|
||||
json_ld_filled = item.get("json_ld_filled")
|
||||
if json_ld_filled:
|
||||
validation = self._validate_json_ld(json_ld_filled)
|
||||
item["validation_errors"] = None if validation["is_valid"] else {"errors": validation["errors"], "warnings": validation["warnings"]}
|
||||
else:
|
||||
item["validation_errors"] = {"errors": ["JSON-LD填充失败"], "warnings": []}
|
||||
validated.append(item)
|
||||
priority_order = {"high": 0, "medium": 1, "low": 2}
|
||||
validated.sort(key=lambda x: priority_order.get(x.get("priority", "medium"), 1))
|
||||
return validated
|
||||
|
||||
def _generate_impact_description(self, schema_type: str, dimension: str) -> str:
|
||||
impacts = {
|
||||
"Organization": "增强品牌实体识别,提升AI搜索引擎对品牌的理解和引用概率",
|
||||
"Product": "提升产品在搜索结果中的富摘要展示,增加点击率和引用率",
|
||||
"FAQPage": "增加FAQ富摘要展示机会,提升在AI回答中的直接引用概率",
|
||||
"Article": "优化文章内容的结构化表达,提升AI搜索引擎的内容理解和引用",
|
||||
"LocalBusiness": "增强本地搜索可见性,提升地理位置相关查询的引用率",
|
||||
}
|
||||
return impacts.get(schema_type, f"提升{dimension}维度的得分和AI引用率")
|
||||
|
|
@ -0,0 +1,166 @@
|
|||
import logging
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from app.agent_framework.base import BaseAgent
|
||||
from app.agent_framework.protocol import (
|
||||
AgentCapability,
|
||||
AgentType,
|
||||
TaskMessage,
|
||||
TaskResult,
|
||||
TaskStatus,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TrendAgent(BaseAgent):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name="trend_agent",
|
||||
agent_type=AgentType.TREND_AGENT,
|
||||
version="1.0.0",
|
||||
)
|
||||
|
||||
def get_capabilities(self) -> AgentCapability:
|
||||
return AgentCapability(
|
||||
agent_name=self.name,
|
||||
agent_type=self.agent_type,
|
||||
version=self.version,
|
||||
supported_tasks=["trend_insight", "trend_hotspot"],
|
||||
max_concurrency=2,
|
||||
description="趋势洞察Agent:分析品牌引用趋势、识别热点话题、推断变化原因并生成建议",
|
||||
)
|
||||
|
||||
async def execute(self, task: TaskMessage) -> TaskResult:
|
||||
started_at = datetime.now(timezone.utc)
|
||||
start_time = time.monotonic()
|
||||
|
||||
try:
|
||||
if task.task_type == "trend_insight":
|
||||
output = await self._trend_insight(task)
|
||||
elif task.task_type == "trend_hotspot":
|
||||
output = await self._trend_hotspot(task)
|
||||
else:
|
||||
raise ValueError(f"不支持的任务类型: {task.task_type}")
|
||||
|
||||
elapsed = time.monotonic() - start_time
|
||||
return TaskResult(
|
||||
task_id=task.task_id,
|
||||
agent_name=self.name,
|
||||
status=TaskStatus.COMPLETED,
|
||||
output_data=output,
|
||||
error_message=None,
|
||||
started_at=started_at,
|
||||
completed_at=datetime.now(timezone.utc),
|
||||
metrics={
|
||||
"elapsed_seconds": round(elapsed, 2),
|
||||
"task_type": task.task_type,
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
elapsed = time.monotonic() - start_time
|
||||
logger.error(f"TrendAgent task {task.task_id} failed: {e}")
|
||||
return TaskResult(
|
||||
task_id=task.task_id,
|
||||
agent_name=self.name,
|
||||
status=TaskStatus.FAILED,
|
||||
output_data=None,
|
||||
error_message=str(e),
|
||||
started_at=started_at,
|
||||
completed_at=datetime.now(timezone.utc),
|
||||
metrics={
|
||||
"elapsed_seconds": round(elapsed, 2),
|
||||
"task_type": task.task_type,
|
||||
},
|
||||
)
|
||||
|
||||
async def _trend_insight(self, task: TaskMessage) -> dict:
|
||||
input_data = task.input_data
|
||||
brand_id = input_data.get("brand_id")
|
||||
days = input_data.get("days", 30)
|
||||
platforms = input_data.get("platforms")
|
||||
keywords = input_data.get("keywords")
|
||||
|
||||
if not brand_id:
|
||||
raise ValueError("input_data必须包含'brand_id'字段")
|
||||
|
||||
await self.report_progress(
|
||||
task_id=task.task_id,
|
||||
progress=0.1,
|
||||
message="开始趋势洞察分析...",
|
||||
)
|
||||
|
||||
from app.database import AsyncSessionLocal
|
||||
from app.services.trend.trend_analyzer_service import TrendAnalyzerService
|
||||
|
||||
async with AsyncSessionLocal() as db:
|
||||
service = TrendAnalyzerService(db)
|
||||
|
||||
await self.report_progress(
|
||||
task_id=task.task_id,
|
||||
progress=0.3,
|
||||
message="获取历史引用数据...",
|
||||
)
|
||||
|
||||
await self.report_progress(
|
||||
task_id=task.task_id,
|
||||
progress=0.5,
|
||||
message="执行时间序列分析...",
|
||||
)
|
||||
|
||||
result = await service.analyze_trends(
|
||||
brand_id=brand_id,
|
||||
days=days,
|
||||
platforms=platforms,
|
||||
keywords=keywords,
|
||||
)
|
||||
|
||||
await self.report_progress(
|
||||
task_id=task.task_id,
|
||||
progress=1.0,
|
||||
message="趋势洞察分析完成",
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def _trend_hotspot(self, task: TaskMessage) -> dict:
|
||||
input_data = task.input_data
|
||||
brand_id = input_data.get("brand_id")
|
||||
days = input_data.get("days", 30)
|
||||
|
||||
if not brand_id:
|
||||
raise ValueError("input_data必须包含'brand_id'字段")
|
||||
|
||||
await self.report_progress(
|
||||
task_id=task.task_id,
|
||||
progress=0.1,
|
||||
message="开始热点话题分析...",
|
||||
)
|
||||
|
||||
from app.database import AsyncSessionLocal
|
||||
from app.services.trend.trend_analyzer_service import TrendAnalyzerService
|
||||
|
||||
async with AsyncSessionLocal() as db:
|
||||
service = TrendAnalyzerService(db)
|
||||
|
||||
await self.report_progress(
|
||||
task_id=task.task_id,
|
||||
progress=0.5,
|
||||
message="检测引用量突增的关键词/话题...",
|
||||
)
|
||||
|
||||
result = await service.get_hotspots(
|
||||
brand_id=brand_id,
|
||||
days=days,
|
||||
)
|
||||
|
||||
await self.report_progress(
|
||||
task_id=task.task_id,
|
||||
progress=1.0,
|
||||
message="热点话题分析完成",
|
||||
)
|
||||
|
||||
return result
|
||||
|
|
@ -21,6 +21,33 @@ from app.config import settings
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Module-level lazy singleton for TaskDispatcher — avoids creating a new
|
||||
# dispatcher on every method call while still deferring the import to
|
||||
# prevent circular-dependency issues at module-load time.
|
||||
_dispatcher_instance = None
|
||||
|
||||
|
||||
def _get_dispatcher():
|
||||
"""Return a cached TaskDispatcher instance (lazy singleton)."""
|
||||
global _dispatcher_instance
|
||||
if _dispatcher_instance is None:
|
||||
from app.agent_framework.dispatcher import TaskDispatcher
|
||||
_dispatcher_instance = TaskDispatcher(settings.REDIS_URL)
|
||||
return _dispatcher_instance
|
||||
|
||||
|
||||
# Module-level lazy singleton for AgentRegistry — same rationale.
|
||||
_registry_instance = None
|
||||
|
||||
|
||||
def _get_registry():
|
||||
"""Return a cached AgentRegistry instance (lazy singleton)."""
|
||||
global _registry_instance
|
||||
if _registry_instance is None:
|
||||
from app.agent_framework.registry import AgentRegistry
|
||||
_registry_instance = AgentRegistry()
|
||||
return _registry_instance
|
||||
|
||||
|
||||
class BaseAgent(ABC):
|
||||
"""所有 Agent 的基类,定义标准生命周期"""
|
||||
|
|
@ -40,6 +67,10 @@ class BaseAgent(ABC):
|
|||
def status(self) -> AgentStatus:
|
||||
return self._status
|
||||
|
||||
@property
|
||||
def is_distributed(self) -> bool:
|
||||
return self._redis is not None
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, task: TaskMessage) -> TaskResult:
|
||||
"""执行任务的核心逻辑,子类必须实现"""
|
||||
|
|
@ -51,48 +82,47 @@ class BaseAgent(ABC):
|
|||
...
|
||||
|
||||
async def start(self):
|
||||
"""启动 Agent:注册到 Registry,开始监听任务队列"""
|
||||
logger.info(f"Starting agent '{self.name}' (type={self.agent_type}, version={self.version})")
|
||||
|
||||
# 初始化 Redis 连接
|
||||
self._redis = aioredis.from_url(
|
||||
settings.REDIS_URL,
|
||||
decode_responses=True,
|
||||
)
|
||||
try:
|
||||
self._redis = aioredis.from_url(
|
||||
settings.REDIS_URL,
|
||||
decode_responses=True,
|
||||
)
|
||||
await self._redis.ping()
|
||||
|
||||
# 注册到 Registry
|
||||
from app.agent_framework.registry import AgentRegistry
|
||||
registry = _get_registry()
|
||||
capability = self.get_capabilities()
|
||||
await registry.register(capability, endpoint=f"agent:{self.name}")
|
||||
|
||||
registry = AgentRegistry()
|
||||
capability = self.get_capabilities()
|
||||
await registry.register(capability, endpoint=f"agent:{self.name}")
|
||||
self._status = AgentStatus.ONLINE
|
||||
|
||||
# 更新状态
|
||||
self._status = AgentStatus.ONLINE
|
||||
capability = self.get_capabilities()
|
||||
max_concurrency = getattr(capability, 'max_concurrency', 1) or 1
|
||||
self._semaphore = asyncio.Semaphore(max_concurrency)
|
||||
logger.info(
|
||||
f"Agent '{self.name}' concurrency limit set to {max_concurrency}"
|
||||
)
|
||||
|
||||
# 根据 capabilities 的 max_concurrency 初始化 Semaphore
|
||||
capability = self.get_capabilities()
|
||||
max_concurrency = getattr(capability, 'max_concurrency', 1) or 1
|
||||
self._semaphore = asyncio.Semaphore(max_concurrency)
|
||||
logger.info(
|
||||
f"Agent '{self.name}' concurrency limit set to {max_concurrency}"
|
||||
)
|
||||
self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
|
||||
self._listen_task = asyncio.create_task(self._listen_for_tasks())
|
||||
|
||||
# 启动心跳
|
||||
self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
|
||||
logger.info(f"Agent '{self.name}' started in distributed mode")
|
||||
except Exception as e:
|
||||
self._redis = None
|
||||
self._status = AgentStatus.ONLINE
|
||||
|
||||
# 开始监听任务队列
|
||||
self._listen_task = asyncio.create_task(self._listen_for_tasks())
|
||||
capability = self.get_capabilities()
|
||||
max_concurrency = getattr(capability, 'max_concurrency', 1) or 1
|
||||
self._semaphore = asyncio.Semaphore(max_concurrency)
|
||||
|
||||
logger.info(f"Agent '{self.name}' started successfully")
|
||||
logger.warning(f"Agent '{self.name}' started in local mode (Redis unavailable: {e})")
|
||||
|
||||
async def stop(self):
|
||||
"""停止 Agent:注销,停止监听"""
|
||||
logger.info(f"Stopping agent '{self.name}'")
|
||||
|
||||
self._status = AgentStatus.OFFLINE
|
||||
|
||||
# 取消监听任务
|
||||
if self._listen_task and not self._listen_task.done():
|
||||
self._listen_task.cancel()
|
||||
try:
|
||||
|
|
@ -100,7 +130,6 @@ class BaseAgent(ABC):
|
|||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# 取消心跳任务
|
||||
if self._heartbeat_task and not self._heartbeat_task.done():
|
||||
self._heartbeat_task.cancel()
|
||||
try:
|
||||
|
|
@ -108,14 +137,10 @@ class BaseAgent(ABC):
|
|||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# 注销
|
||||
from app.agent_framework.registry import AgentRegistry
|
||||
if self._redis is not None:
|
||||
registry = _get_registry()
|
||||
await registry.unregister(self.name)
|
||||
|
||||
registry = AgentRegistry()
|
||||
await registry.unregister(self.name)
|
||||
|
||||
# 关闭 Redis 连接
|
||||
if self._redis:
|
||||
await self._redis.close()
|
||||
self._redis = None
|
||||
|
||||
|
|
@ -123,13 +148,10 @@ class BaseAgent(ABC):
|
|||
|
||||
async def heartbeat(self):
|
||||
"""定期心跳上报"""
|
||||
from app.agent_framework.registry import AgentRegistry
|
||||
|
||||
registry = AgentRegistry()
|
||||
registry = _get_registry()
|
||||
await registry.update_heartbeat(self.name)
|
||||
|
||||
async def report_progress(self, task_id: str, progress: float, message: str):
|
||||
"""上报任务进度"""
|
||||
progress_obj = TaskProgress(
|
||||
task_id=task_id,
|
||||
agent_name=self.name,
|
||||
|
|
@ -138,7 +160,6 @@ class BaseAgent(ABC):
|
|||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
# 通过 Redis Pub/Sub 发布进度
|
||||
if self._redis:
|
||||
try:
|
||||
await self._redis.publish(
|
||||
|
|
@ -148,11 +169,11 @@ class BaseAgent(ABC):
|
|||
except Exception as e:
|
||||
logger.warning(f"Failed to publish progress for task {task_id}: {e}")
|
||||
|
||||
# 同时更新数据库
|
||||
from app.agent_framework.dispatcher import TaskDispatcher
|
||||
|
||||
dispatcher = TaskDispatcher(settings.REDIS_URL)
|
||||
await dispatcher.handle_progress(progress_obj)
|
||||
try:
|
||||
dispatcher = _get_dispatcher()
|
||||
await dispatcher.handle_progress(progress_obj)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to report progress to dispatcher for task {task_id}: {e}")
|
||||
|
||||
async def _heartbeat_loop(self):
|
||||
"""心跳循环"""
|
||||
|
|
@ -198,7 +219,6 @@ class BaseAgent(ABC):
|
|||
await self._execute_task(task)
|
||||
|
||||
async def _execute_task(self, task: TaskMessage):
|
||||
"""执行单个任务"""
|
||||
self._running_tasks.add(task.task_id)
|
||||
self._status = AgentStatus.BUSY
|
||||
|
||||
|
|
@ -209,15 +229,12 @@ class BaseAgent(ABC):
|
|||
result.started_at = started_at
|
||||
result.completed_at = datetime.now(timezone.utc)
|
||||
|
||||
# 处理结果
|
||||
from app.agent_framework.dispatcher import TaskDispatcher
|
||||
|
||||
dispatcher = TaskDispatcher(settings.REDIS_URL)
|
||||
await dispatcher.handle_result(result)
|
||||
if self._redis is not None:
|
||||
dispatcher = _get_dispatcher()
|
||||
await dispatcher.handle_result(result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Agent '{self.name}' task {task.task_id} failed: {e}")
|
||||
# 构建失败结果
|
||||
error_result = TaskResult(
|
||||
task_id=task.task_id,
|
||||
agent_name=self.name,
|
||||
|
|
@ -228,10 +245,9 @@ class BaseAgent(ABC):
|
|||
completed_at=datetime.now(timezone.utc),
|
||||
metrics=None,
|
||||
)
|
||||
from app.agent_framework.dispatcher import TaskDispatcher
|
||||
|
||||
dispatcher = TaskDispatcher(settings.REDIS_URL)
|
||||
await dispatcher.handle_result(error_result)
|
||||
if self._redis is not None:
|
||||
dispatcher = _get_dispatcher()
|
||||
await dispatcher.handle_result(error_result)
|
||||
|
||||
finally:
|
||||
self._running_tasks.discard(task.task_id)
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from .content_generator import CONTENT_GENERATOR_TEMPLATE
|
|||
from .deai_agent import DEAI_TEMPLATE
|
||||
from .geo_optimizer import GEO_OPTIMIZER_TEMPLATE
|
||||
from .rule_checker import RULE_CHECKER_TEMPLATE
|
||||
from .schema_advisor import SCHEMA_ADVISOR_TEMPLATE
|
||||
from .topic_selector import TOPIC_SELECTOR_TEMPLATE
|
||||
|
||||
__all__ = [
|
||||
|
|
@ -25,4 +26,5 @@ __all__ = [
|
|||
"DEAI_TEMPLATE",
|
||||
"GEO_OPTIMIZER_TEMPLATE",
|
||||
"RULE_CHECKER_TEMPLATE",
|
||||
"SCHEMA_ADVISOR_TEMPLATE",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,70 @@
|
|||
from .base_template import PromptSection, PromptTemplate
|
||||
|
||||
SCHEMA_ADVISOR_TEMPLATE = PromptTemplate(
|
||||
PromptSection(
|
||||
identity="""你是一位精通Schema.org结构化数据和JSON-LD的技术专家。
|
||||
你深刻理解搜索引擎和AI模型(如ChatGPT、Perplexity、Kimi)如何解析和利用结构化数据,
|
||||
知道如何通过精准的Schema标记提升品牌在AI搜索结果中的可见性和引用率。
|
||||
你生成的JSON-LD严格遵循Schema.org规范,确保可被搜索引擎正确解析。""",
|
||||
|
||||
context="""## 品牌信息
|
||||
- 品牌名称:${brand_name}
|
||||
- 网站:${brand_website}
|
||||
- 行业:${brand_industry}
|
||||
|
||||
## 诊断数据
|
||||
${diagnosis_data}
|
||||
|
||||
## 已有Schema标记
|
||||
${existing_schemas}
|
||||
|
||||
## 目标Schema类型
|
||||
${schema_type}""",
|
||||
|
||||
instructions="""请根据以上品牌信息和诊断数据,为品牌生成完整的JSON-LD结构化数据。
|
||||
|
||||
生成要求:
|
||||
|
||||
1. 内容填充:
|
||||
- 所有字段必须填充真实、具体的内容,不得留空
|
||||
- 品牌名称、网站等基本信息必须与提供的数据一致
|
||||
- 描述性文本应当专业、准确,体现品牌特色
|
||||
|
||||
2. Schema类型特定要求:
|
||||
- Organization: 包含name, description, url, logo, sameAs(社交媒体链接), contactPoint
|
||||
- Product: 包含name, description, brand, offers, aggregateRating(如有)
|
||||
- FAQPage: 生成3-5个与品牌行业相关的高质量FAQ,问题和答案需自然且信息丰富
|
||||
- Article: 包含headline, author, datePublished, description, image
|
||||
- LocalBusiness: 包含name, address(完整地址结构), geo, telephone, openingHours
|
||||
|
||||
3. 语言要求:
|
||||
- 所有自然语言内容使用与品牌名称相同的语言
|
||||
- 技术字段(如@type, @context)保持英文
|
||||
|
||||
4. 结构完整性:
|
||||
- 必须包含@context和@type
|
||||
- 嵌套对象必须完整,不得省略必要子属性""",
|
||||
|
||||
constraints="""## 约束条件
|
||||
- 严格遵循Schema.org规范,不得使用非标准属性
|
||||
- @context必须为"https://schema.org"
|
||||
- @type必须是Schema.org定义的有效类型
|
||||
- 不得编造不存在的品牌信息(如无实际地址,LocalBusiness的address可使用占位结构)
|
||||
- FAQ的问题必须是用户真实可能搜索的问题
|
||||
- 所有URL字段如无实际值,留空字符串
|
||||
- 不得在JSON-LD中包含HTML标签""",
|
||||
|
||||
output_format="""## 输出格式
|
||||
请以JSON格式输出填充后的JSON-LD:
|
||||
|
||||
```json
|
||||
{
|
||||
"@context": "https://schema.org",
|
||||
"@type": "...",
|
||||
"...": "..."
|
||||
}
|
||||
```
|
||||
|
||||
仅输出JSON-LD对象,不要包含任何解释文字。""",
|
||||
)
|
||||
)
|
||||
|
|
@ -6,7 +6,6 @@ from enum import Enum
|
|||
|
||||
|
||||
class AgentType(str, Enum):
|
||||
"""Agent 类型枚举"""
|
||||
CITATION_DETECTOR = "citation_detector"
|
||||
CONTENT_GENERATOR = "content_generator"
|
||||
DEAI_AGENT = "deai_agent"
|
||||
|
|
@ -14,6 +13,9 @@ class AgentType(str, Enum):
|
|||
RULE_CHECKER = "rule_checker"
|
||||
COMPETITOR_ANALYZER = "competitor_analyzer"
|
||||
PERFORMANCE_TRACKER = "performance_tracker"
|
||||
SCHEMA_ADVISOR = "schema_advisor"
|
||||
MONITOR_AGENT = "monitor_agent"
|
||||
TREND_AGENT = "trend_agent"
|
||||
|
||||
|
||||
class TaskStatus(str, Enum):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,62 @@
|
|||
import asyncio
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from app.agent_framework.agents.citation_detector import CitationDetectorAgent
|
||||
from app.agent_framework.agents.content_generator_agent import ContentGeneratorAgent
|
||||
from app.agent_framework.agents.deai_agent import DeAIAgent
|
||||
from app.agent_framework.agents.geo_optimizer_agent import GEOOptimizerAgent
|
||||
from app.agent_framework.agents.monitor_agent import MonitorAgent
|
||||
from app.agent_framework.agents.schema_advisor import SchemaAdvisorAgent
|
||||
from app.agent_framework.agents.competitor_analyzer import CompetitorAnalyzerAgent
|
||||
from app.agent_framework.agents.trend_agent import TrendAgent
|
||||
|
||||
AGENTS = {
|
||||
"citation_detector": CitationDetectorAgent,
|
||||
"content_generator": ContentGeneratorAgent,
|
||||
"deai_agent": DeAIAgent,
|
||||
"geo_optimizer": GEOOptimizerAgent,
|
||||
"monitor": MonitorAgent,
|
||||
"schema_advisor": SchemaAdvisorAgent,
|
||||
"competitor_analyzer": CompetitorAnalyzerAgent,
|
||||
"trend_agent": TrendAgent,
|
||||
"all": None,
|
||||
}
|
||||
|
||||
|
||||
async def run_agent(name: str):
|
||||
if name == "all":
|
||||
agents = [cls() for cls in AGENTS.values() if cls is not None]
|
||||
else:
|
||||
cls = AGENTS.get(name)
|
||||
if cls is None:
|
||||
print(f"Unknown agent: {name}")
|
||||
sys.exit(1)
|
||||
agents = [cls()]
|
||||
|
||||
for agent in agents:
|
||||
await agent.start()
|
||||
|
||||
print(f"Agent(s) running: {[a.name for a in agents]}")
|
||||
try:
|
||||
await asyncio.Future()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
for agent in agents:
|
||||
await agent.stop()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Run GEO Agent(s)")
|
||||
parser.add_argument("agent", choices=list(AGENTS.keys()), help="Agent name or 'all'")
|
||||
parser.add_argument("--log-level", default="INFO", help="Logging level")
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(level=getattr(logging, args.log_level.upper()))
|
||||
asyncio.run(run_agent(args.agent))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -46,6 +46,7 @@ class QueryResultResponse(BaseModel):
|
|||
brand_context: str | None
|
||||
competitor_contexts: list[str]
|
||||
response_time_ms: int
|
||||
timestamp: str
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
|
@ -61,6 +62,7 @@ class CitationRateResponse(BaseModel):
|
|||
class BatchQueryResponse(BaseModel):
|
||||
results: list[QueryResultResponse]
|
||||
citation_rate: CitationRateResponse
|
||||
avg_response_time_ms: float = 0.0
|
||||
|
||||
|
||||
def _result_to_response(r: AIQueryResult) -> QueryResultResponse:
|
||||
|
|
@ -73,6 +75,7 @@ def _result_to_response(r: AIQueryResult) -> QueryResultResponse:
|
|||
brand_context=r.brand_context,
|
||||
competitor_contexts=r.competitor_contexts,
|
||||
response_time_ms=r.response_time_ms,
|
||||
timestamp=r.timestamp.isoformat(),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -97,9 +100,16 @@ async def _execute_batch(
|
|||
engine_types, query, brand_name, competitor_names
|
||||
)
|
||||
citation_rate = service.calculate_citation_rate(results)
|
||||
response_results = [_result_to_response(r) for r in results]
|
||||
avg_response_time = (
|
||||
sum(r.response_time_ms for r in results) / len(results)
|
||||
if results
|
||||
else 0.0
|
||||
)
|
||||
return BatchQueryResponse(
|
||||
results=[_result_to_response(r) for r in results],
|
||||
results=response_results,
|
||||
citation_rate=CitationRateResponse(**citation_rate),
|
||||
avg_response_time_ms=avg_response_time,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ from app.schemas.alert import (
|
|||
AlertSettingListResponse,
|
||||
AlertSettingBulkUpdate,
|
||||
)
|
||||
from app.services.alert_engine import AlertEngine
|
||||
from app.services.alert.alert_engine import AlertEngine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from pydantic import BaseModel
|
|||
from app.api.deps import get_current_user
|
||||
from app.models.user import User
|
||||
from app.services.api_key_manager import APIKeyManager, KeySource, KeyStatus
|
||||
from app.services.smart_router import ENGINE_COST_PROFILES
|
||||
from app.services.llm.smart_router import ENGINE_COST_PROFILES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,300 @@
|
|||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.database import get_db
|
||||
from app.models.attribution_record import AttributionRecord
|
||||
from app.models.brand import Brand
|
||||
from app.models.diagnosis_record import DiagnosisRecord
|
||||
from app.models.user import User
|
||||
from app.services.attribution.attribution_engine import AttributionEngine
|
||||
from app.services.attribution.roi_calculator import ROICalculator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
PLAN_COSTS = {
|
||||
"free": 0.0,
|
||||
"starter": 99.0,
|
||||
"pro": 299.0,
|
||||
"enterprise": 999.0,
|
||||
}
|
||||
|
||||
|
||||
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
|
||||
if isinstance(value, uuid.UUID):
|
||||
return value
|
||||
return uuid.UUID(str(value))
|
||||
|
||||
|
||||
class StartTrackingRequest(BaseModel):
|
||||
brand_id: str
|
||||
content_id: str | None = None
|
||||
|
||||
|
||||
class AttributionResponse(BaseModel):
|
||||
id: str
|
||||
brand_id: str
|
||||
content_id: str | None
|
||||
baseline_score: float
|
||||
current_score: float | None
|
||||
score_delta: float | None
|
||||
status: str
|
||||
roi_percentage: float | None
|
||||
created_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class ROIReport(BaseModel):
|
||||
brand_id: str
|
||||
brand_name: str
|
||||
subscription_cost: float
|
||||
current_plan: str
|
||||
total_score_delta: float
|
||||
value_generated: float
|
||||
roi_percentage: float
|
||||
break_even_delta: float
|
||||
tracking_records: list[AttributionResponse]
|
||||
ab_comparison: dict | None
|
||||
|
||||
|
||||
class ABComparisonResponse(BaseModel):
|
||||
brand_id: str
|
||||
brand_name: str
|
||||
overall_before: float
|
||||
overall_after: float
|
||||
overall_delta: float
|
||||
dimensions: list[dict]
|
||||
|
||||
|
||||
def _record_to_response(record: AttributionRecord) -> AttributionResponse:
|
||||
return AttributionResponse(
|
||||
id=str(record.id),
|
||||
brand_id=str(record.brand_id),
|
||||
content_id=str(record.content_id) if record.content_id else None,
|
||||
baseline_score=record.baseline_score,
|
||||
current_score=record.current_score,
|
||||
score_delta=record.score_delta,
|
||||
status=record.status,
|
||||
roi_percentage=record.roi_percentage,
|
||||
created_at=record.created_at,
|
||||
)
|
||||
|
||||
|
||||
async def _get_brand_or_404(
|
||||
brand_id: uuid.UUID,
|
||||
current_user: User,
|
||||
db: AsyncSession,
|
||||
) -> Brand:
|
||||
user_uuid = _to_uuid(current_user.id)
|
||||
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == user_uuid)
|
||||
result = await db.execute(stmt)
|
||||
brand = result.scalar_one_or_none()
|
||||
if not brand:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="品牌不存在",
|
||||
)
|
||||
return brand
|
||||
|
||||
|
||||
@router.post("/start", response_model=AttributionResponse)
|
||||
async def start_tracking(
|
||||
body: StartTrackingRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
brand_id = _to_uuid(body.brand_id)
|
||||
brand = await _get_brand_or_404(brand_id, current_user, db)
|
||||
|
||||
content_id = _to_uuid(body.content_id) if body.content_id else None
|
||||
|
||||
engine = AttributionEngine()
|
||||
record = await engine.start_tracking(db, brand.id, content_id, current_user.id)
|
||||
return _record_to_response(record)
|
||||
|
||||
|
||||
@router.get("/brand/{brand_id}")
|
||||
async def get_brand_attribution(
|
||||
brand_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
brand = await _get_brand_or_404(brand_id, current_user, db)
|
||||
|
||||
engine = AttributionEngine()
|
||||
summary = await engine.get_brand_attribution_summary(db, brand.id)
|
||||
summary["records"] = [_record_to_response(r) for r in summary["records"]]
|
||||
return summary
|
||||
|
||||
|
||||
@router.get("/{record_id}", response_model=AttributionResponse)
|
||||
async def get_attribution_record(
|
||||
record_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
stmt = select(AttributionRecord).where(AttributionRecord.id == record_id)
|
||||
result = await db.execute(stmt)
|
||||
record = result.scalar_one_or_none()
|
||||
if not record:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="归因记录不存在",
|
||||
)
|
||||
return _record_to_response(record)
|
||||
|
||||
|
||||
@router.post("/{record_id}/check", response_model=AttributionResponse)
|
||||
async def check_attribution(
|
||||
record_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
engine = AttributionEngine()
|
||||
try:
|
||||
record = await engine.check_attribution(db, record_id)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="归因记录不存在",
|
||||
)
|
||||
return _record_to_response(record)
|
||||
|
||||
|
||||
@router.get("/roi/{brand_id}", response_model=ROIReport)
|
||||
async def get_roi_report(
|
||||
brand_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
brand = await _get_brand_or_404(brand_id, current_user, db)
|
||||
|
||||
engine = AttributionEngine()
|
||||
summary = await engine.get_brand_attribution_summary(db, brand.id)
|
||||
|
||||
user_plan = getattr(current_user, "plan", "free") or "free"
|
||||
subscription_cost = PLAN_COSTS.get(user_plan, 0.0)
|
||||
|
||||
calculator = ROICalculator()
|
||||
roi_data = calculator.calculate_roi(
|
||||
subscription_cost=subscription_cost,
|
||||
score_delta=summary["total_score_delta"],
|
||||
attribution_records=summary["records"],
|
||||
)
|
||||
|
||||
ab_comparison = None
|
||||
baseline_record = (
|
||||
await db.execute(
|
||||
select(DiagnosisRecord)
|
||||
.where(
|
||||
DiagnosisRecord.brand_id == brand.id,
|
||||
DiagnosisRecord.status == "completed",
|
||||
)
|
||||
.order_by(DiagnosisRecord.completed_at.asc())
|
||||
.limit(1)
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
latest_record = (
|
||||
await db.execute(
|
||||
select(DiagnosisRecord)
|
||||
.where(
|
||||
DiagnosisRecord.brand_id == brand.id,
|
||||
DiagnosisRecord.status == "completed",
|
||||
)
|
||||
.order_by(DiagnosisRecord.completed_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if baseline_record and latest_record and baseline_record.id != latest_record.id:
|
||||
before_dims = baseline_record.result_json.get("dimensions", []) if baseline_record.result_json else []
|
||||
after_dims = latest_record.result_json.get("dimensions", []) if latest_record.result_json else []
|
||||
before_map = {d.get("name"): d for d in before_dims}
|
||||
after_map = {d.get("name"): d for d in after_dims}
|
||||
ab_comparison = calculator.generate_ab_comparison(
|
||||
before_score=baseline_record.overall_score or 0,
|
||||
after_score=latest_record.overall_score or 0,
|
||||
before_dimensions=before_map,
|
||||
after_dimensions=after_map,
|
||||
)
|
||||
|
||||
return ROIReport(
|
||||
brand_id=str(brand.id),
|
||||
brand_name=brand.name,
|
||||
subscription_cost=subscription_cost,
|
||||
current_plan=user_plan,
|
||||
total_score_delta=summary["total_score_delta"],
|
||||
value_generated=roi_data["value_generated"],
|
||||
roi_percentage=roi_data["roi_percentage"],
|
||||
break_even_delta=roi_data["break_even_delta"],
|
||||
tracking_records=[_record_to_response(r) for r in summary["records"]],
|
||||
ab_comparison=ab_comparison,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/ab-comparison/{brand_id}", response_model=ABComparisonResponse)
|
||||
async def get_ab_comparison(
|
||||
brand_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
brand = await _get_brand_or_404(brand_id, current_user, db)
|
||||
|
||||
baseline_record = (
|
||||
await db.execute(
|
||||
select(DiagnosisRecord)
|
||||
.where(
|
||||
DiagnosisRecord.brand_id == brand.id,
|
||||
DiagnosisRecord.status == "completed",
|
||||
)
|
||||
.order_by(DiagnosisRecord.completed_at.asc())
|
||||
.limit(1)
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
latest_record = (
|
||||
await db.execute(
|
||||
select(DiagnosisRecord)
|
||||
.where(
|
||||
DiagnosisRecord.brand_id == brand.id,
|
||||
DiagnosisRecord.status == "completed",
|
||||
)
|
||||
.order_by(DiagnosisRecord.completed_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if not baseline_record or not latest_record:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="暂无诊断数据,无法生成A/B对比",
|
||||
)
|
||||
|
||||
calculator = ROICalculator()
|
||||
before_dims = baseline_record.result_json.get("dimensions", []) if baseline_record.result_json else []
|
||||
after_dims = latest_record.result_json.get("dimensions", []) if latest_record.result_json else []
|
||||
before_map = {d.get("name"): d for d in before_dims}
|
||||
after_map = {d.get("name"): d for d in after_dims}
|
||||
comparison = calculator.generate_ab_comparison(
|
||||
before_score=baseline_record.overall_score or 0,
|
||||
after_score=latest_record.overall_score or 0,
|
||||
before_dimensions=before_map,
|
||||
after_dimensions=after_map,
|
||||
)
|
||||
|
||||
return ABComparisonResponse(
|
||||
brand_id=str(brand.id),
|
||||
brand_name=brand.name,
|
||||
overall_before=comparison["overall_before"],
|
||||
overall_after=comparison["overall_after"],
|
||||
overall_delta=comparison["overall_delta"],
|
||||
dimensions=comparison["dimensions"],
|
||||
)
|
||||
|
|
@ -1,5 +1,6 @@
|
|||
"""Brands API endpoints."""
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Annotated
|
||||
|
||||
|
|
@ -14,9 +15,12 @@ from app.api.scoring import router as scoring_router
|
|||
from app.database import get_db
|
||||
from app.models.user import User
|
||||
from app.models.brand import Brand
|
||||
from app.models.query import Query as QueryModel
|
||||
from app.schemas.brand import BrandCreate, BrandUpdate, BrandResponse, BrandListResponse
|
||||
from app.services.cache import get_cache_service, TTL_BRANDS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# Include competitors router under brands
|
||||
|
|
@ -127,6 +131,28 @@ async def update_brand(
|
|||
|
||||
# Update only provided fields
|
||||
update_data = brand_data.model_dump(exclude_unset=True)
|
||||
|
||||
# 检测品牌名称变更,同步更新关联 Query 的 target_brand 和 brand_aliases
|
||||
old_name = brand.name
|
||||
new_name = update_data.get("name")
|
||||
|
||||
if new_name and new_name != old_name:
|
||||
queries_stmt = select(QueryModel).where(
|
||||
QueryModel.user_id == current_user.id,
|
||||
QueryModel.target_brand == old_name,
|
||||
)
|
||||
queries_result = await db.execute(queries_stmt)
|
||||
related_queries = queries_result.scalars().all()
|
||||
|
||||
for query in related_queries:
|
||||
query.target_brand = new_name
|
||||
# 如果 brand_aliases 中包含旧名称,也同步更新
|
||||
if query.brand_aliases and old_name in query.brand_aliases:
|
||||
query.brand_aliases = [new_name if a == old_name else a for a in query.brand_aliases]
|
||||
|
||||
if related_queries:
|
||||
logger.info(f"Brand renamed from '{old_name}' to '{new_name}', synced {len(related_queries)} queries")
|
||||
|
||||
for field, value in update_data.items():
|
||||
setattr(brand, field, value)
|
||||
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from app.schemas.citation import (
|
|||
CitationListResponse,
|
||||
CitationStatsResponse,
|
||||
)
|
||||
from app.services.citation import (
|
||||
from app.services.citation.citation import (
|
||||
get_citation_stats,
|
||||
get_citations,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,135 @@
|
|||
import logging
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.database import get_db
|
||||
from app.models.user import User
|
||||
from app.models.brand import Brand
|
||||
from app.models.competitor_insight import CompetitorInsight
|
||||
from app.schemas.competitor_insight import (
|
||||
CompetitorAnalysisRequest,
|
||||
CompetitorInsightResponse,
|
||||
CompetitorInsightList,
|
||||
CompetitorGapSummary,
|
||||
)
|
||||
from app.services.competitor.competitor_analyzer_service import CompetitorAnalyzerService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
async def _get_brand_if_owned(
|
||||
brand_id: uuid.UUID,
|
||||
current_user: User,
|
||||
db: AsyncSession,
|
||||
) -> Brand:
|
||||
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == current_user.id)
|
||||
result = await db.execute(stmt)
|
||||
brand = result.scalar_one_or_none()
|
||||
if not brand:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="品牌不存在",
|
||||
)
|
||||
return brand
|
||||
|
||||
|
||||
@router.post("/analyze", response_model=CompetitorInsightList)
|
||||
async def analyze_competitor(
|
||||
request: CompetitorAnalysisRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
await _get_brand_if_owned(request.brand_id, current_user, db)
|
||||
|
||||
service = CompetitorAnalyzerService()
|
||||
try:
|
||||
result = await service.analyze_competitor(
|
||||
brand_id=request.brand_id,
|
||||
analysis_types=request.analysis_types,
|
||||
period_days=request.period_days or 30,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
stmt = (
|
||||
select(CompetitorInsight)
|
||||
.where(CompetitorInsight.brand_id == request.brand_id)
|
||||
.order_by(CompetitorInsight.created_at.desc())
|
||||
)
|
||||
db_result = await db.execute(stmt)
|
||||
insights = list(db_result.scalars().all())
|
||||
|
||||
return {"items": insights, "total": len(insights)}
|
||||
|
||||
|
||||
@router.get("/brand/{brand_id}", response_model=CompetitorInsightList)
|
||||
async def get_brand_insights(
|
||||
brand_id: uuid.UUID,
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(20, ge=1, le=100),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
await _get_brand_if_owned(brand_id, current_user, db)
|
||||
|
||||
count_stmt = select(func.count()).select_from(CompetitorInsight).where(
|
||||
CompetitorInsight.brand_id == brand_id,
|
||||
)
|
||||
count_result = await db.execute(count_stmt)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
stmt = (
|
||||
select(CompetitorInsight)
|
||||
.where(CompetitorInsight.brand_id == brand_id)
|
||||
.order_by(CompetitorInsight.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
insights = list(result.scalars().all())
|
||||
|
||||
return {"items": insights, "total": total}
|
||||
|
||||
|
||||
@router.get("/{insight_id}", response_model=CompetitorInsightResponse)
|
||||
async def get_insight_detail(
|
||||
insight_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
stmt = select(CompetitorInsight).where(CompetitorInsight.id == insight_id)
|
||||
result = await db.execute(stmt)
|
||||
insight = result.scalar_one_or_none()
|
||||
|
||||
if not insight:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="洞察不存在",
|
||||
)
|
||||
|
||||
await _get_brand_if_owned(insight.brand_id, current_user, db)
|
||||
|
||||
return insight
|
||||
|
||||
|
||||
@router.get("/brand/{brand_id}/gap-summary", response_model=list[CompetitorGapSummary])
|
||||
async def get_gap_summary(
|
||||
brand_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
brand = await _get_brand_if_owned(brand_id, current_user, db)
|
||||
|
||||
service = CompetitorAnalyzerService()
|
||||
gap_summaries = await service.calculate_gap_score(db, brand_id, brand.name)
|
||||
|
||||
return gap_summaries
|
||||
|
|
@ -21,6 +21,7 @@ from app.schemas.competitor import (
|
|||
CompetitorRecommendationItem,
|
||||
CompetitorRecommendationResponse,
|
||||
)
|
||||
from app.utils.json_extractor import extract_json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -363,7 +364,7 @@ async def _get_llm_recommendations(
|
|||
raise ValueError("API返回空响应")
|
||||
|
||||
# 提取JSON
|
||||
json_str = _extract_json_from_text(content)
|
||||
json_str = extract_json(content)
|
||||
data = json.loads(json_str)
|
||||
|
||||
recommendations = []
|
||||
|
|
@ -393,32 +394,6 @@ async def _get_llm_recommendations(
|
|||
raise
|
||||
|
||||
|
||||
def _extract_json_from_text(text: str) -> str:
|
||||
"""从文本中提取JSON字符串。"""
|
||||
import re
|
||||
|
||||
# 尝试直接解析
|
||||
try:
|
||||
json.loads(text)
|
||||
return text
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 尝试从代码块中提取
|
||||
json_pattern = r'```(?:json)?\s*([\s\S]*?)\s*```'
|
||||
match = re.search(json_pattern, text)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
|
||||
# 尝试找到第一个{到最后一个}之间的内容
|
||||
first_brace = text.find('{')
|
||||
last_brace = text.rfind('}')
|
||||
if first_brace != -1 and last_brace != -1 and last_brace > first_brace:
|
||||
return text[first_brace:last_brace + 1]
|
||||
|
||||
raise ValueError(f"无法从响应中提取JSON: {text[:200]}")
|
||||
|
||||
|
||||
def _get_industry_label(industry: str | None) -> str:
|
||||
"""获取行业中文标签。"""
|
||||
labels = {
|
||||
|
|
|
|||
|
|
@ -1,18 +1,28 @@
|
|||
"""内容生产API - 串联Agent Pipeline"""
|
||||
"""内容生产API - 串联Agent Pipeline
|
||||
|
||||
业务逻辑已委托给 ContentGenerationService,API 层仅负责:
|
||||
1. 请求解析与参数校验
|
||||
2. 调用服务层
|
||||
3. 格式化响应
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.database import get_db
|
||||
from app.models.brand import Brand
|
||||
from app.models.content import Content, ContentVersion
|
||||
from app.models.diagnosis_record import DiagnosisRecord
|
||||
from app.models.user import User
|
||||
from app.services.content.content_generation_service import ContentGenerationService
|
||||
from app.services.llm import LLMError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -29,6 +39,7 @@ class ContentGenerateRequest(BaseModel):
|
|||
brand_description: str = ""
|
||||
run_deai: bool = True
|
||||
run_geo: bool = True
|
||||
use_agent_framework: bool = False
|
||||
|
||||
|
||||
class ContentGenerateResponse(BaseModel):
|
||||
|
|
@ -41,44 +52,6 @@ class ContentGenerateResponse(BaseModel):
|
|||
pipeline_stages: list[dict] = [] # 每个阶段的执行结果摘要
|
||||
|
||||
|
||||
async def _get_knowledge_context(
|
||||
db: AsyncSession,
|
||||
brand_name: str,
|
||||
knowledge_base_ids: list[str],
|
||||
target_keyword: str,
|
||||
) -> str:
|
||||
"""
|
||||
从知识库检索与查询相关的上下文。
|
||||
|
||||
如果有知识库ID,则调用 RAGService.search 获取相关内容;
|
||||
否则返回空字符串,不影响后续流程。
|
||||
"""
|
||||
if not knowledge_base_ids:
|
||||
return ""
|
||||
|
||||
try:
|
||||
from app.services.knowledge.rag_service import RAGService
|
||||
rag_service = RAGService()
|
||||
results = await rag_service.search(
|
||||
session=db,
|
||||
query=f"{brand_name} {target_keyword}" if brand_name else target_keyword,
|
||||
knowledge_base_ids=knowledge_base_ids,
|
||||
top_k=3,
|
||||
)
|
||||
if results:
|
||||
context_parts = []
|
||||
for r in results:
|
||||
content = r.get("content", "")
|
||||
title = r.get("document_title", "")
|
||||
if content:
|
||||
context_parts.append(f"[{title}] {content}")
|
||||
return "\n".join(context_parts)
|
||||
return ""
|
||||
except Exception as e:
|
||||
logger.warning(f"知识库检索失败,将不使用知识库上下文: {e}")
|
||||
return ""
|
||||
|
||||
|
||||
@router.post("/generate", response_model=ContentGenerateResponse)
|
||||
async def generate_content(
|
||||
req: ContentGenerateRequest,
|
||||
|
|
@ -88,115 +61,132 @@ async def generate_content(
|
|||
"""
|
||||
一键生成内容(同步执行Pipeline),结果存入数据库
|
||||
|
||||
流程:ContentGenerator → DeAI → GEOOptimizer
|
||||
流程:ContentGenerator -> DeAI -> GEOOptimizer
|
||||
业务逻辑委托给 ContentGenerationService
|
||||
"""
|
||||
from app.services.llm import LLMError, LLMFactory
|
||||
from app.agent_framework.prompts import (
|
||||
CONTENT_GENERATOR_TEMPLATE,
|
||||
DEAI_TEMPLATE,
|
||||
GEO_OPTIMIZER_TEMPLATE,
|
||||
)
|
||||
|
||||
org_id = getattr(current_user, "organization_id", None)
|
||||
if not org_id:
|
||||
raise HTTPException(status_code=403, detail="用户未关联组织")
|
||||
|
||||
stages = []
|
||||
|
||||
try:
|
||||
provider = LLMFactory.get_default()
|
||||
|
||||
# 获取知识库上下文
|
||||
knowledge_context = await _get_knowledge_context(
|
||||
db, req.brand_name, req.knowledge_base_ids, req.target_keyword
|
||||
service = ContentGenerationService()
|
||||
result = await service.generate_content(
|
||||
keyword=req.target_keyword,
|
||||
brand_name=req.brand_name,
|
||||
platform=req.target_platform,
|
||||
content_style=req.content_style,
|
||||
word_count=req.word_count,
|
||||
knowledge_base_ids=req.knowledge_base_ids,
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
org_id=org_id,
|
||||
run_deai=req.run_deai,
|
||||
run_geo=req.run_geo,
|
||||
use_agent_framework=req.use_agent_framework,
|
||||
)
|
||||
|
||||
# Stage 1: 内容生成
|
||||
gen_variables = {
|
||||
"topic_title": req.target_keyword,
|
||||
"target_keyword": req.target_keyword,
|
||||
"target_platform": req.target_platform,
|
||||
"content_angle": "综合分析",
|
||||
"content_style": req.content_style,
|
||||
"word_count": str(req.word_count),
|
||||
"brand_name": req.brand_name,
|
||||
"knowledge_context": knowledge_context,
|
||||
}
|
||||
messages = CONTENT_GENERATOR_TEMPLATE.render(gen_variables)
|
||||
response = await provider.chat(messages, temperature=0.7, max_tokens=req.word_count * 2)
|
||||
content = response.content
|
||||
stages.append({"stage": "content_generation", "status": "success", "word_count": len(content)})
|
||||
|
||||
# Stage 2: 去AI化(可选)
|
||||
if req.run_deai:
|
||||
deai_variables = {
|
||||
"original_content": content,
|
||||
"target_style": "自然流畅",
|
||||
"preserve_structure": "是",
|
||||
}
|
||||
messages = DEAI_TEMPLATE.render(deai_variables)
|
||||
response = await provider.chat(messages, temperature=0.9, max_tokens=len(content) * 2)
|
||||
content = response.content
|
||||
stages.append({"stage": "deai", "status": "success"})
|
||||
|
||||
# Stage 3: GEO优化(可选)
|
||||
optimized = content
|
||||
seo_score = None
|
||||
if req.run_geo:
|
||||
geo_variables = {
|
||||
"original_content": content,
|
||||
"target_keywords": req.target_keyword,
|
||||
"target_platform": req.target_platform,
|
||||
"optimization_level": "moderate",
|
||||
}
|
||||
messages = GEO_OPTIMIZER_TEMPLATE.render(geo_variables)
|
||||
response = await provider.chat(messages, temperature=0.5, max_tokens=len(content) * 2)
|
||||
optimized = response.content
|
||||
stages.append({"stage": "geo_optimization", "status": "success"})
|
||||
|
||||
# ---- 存入数据库 ----
|
||||
content_obj = Content(
|
||||
organization_id=org_id,
|
||||
title=req.target_keyword,
|
||||
content_type="article",
|
||||
body=optimized,
|
||||
status="draft",
|
||||
target_platforms=[req.target_platform] if req.target_platform else [],
|
||||
keywords=[req.target_keyword],
|
||||
extra_metadata={
|
||||
"original_content": content if content != optimized else None,
|
||||
"pipeline_stages": stages,
|
||||
"seo_score": seo_score,
|
||||
"brand_name": req.brand_name,
|
||||
"content_style": req.content_style,
|
||||
"word_count_target": req.word_count,
|
||||
},
|
||||
created_by=current_user.id,
|
||||
current_version=1,
|
||||
)
|
||||
db.add(content_obj)
|
||||
await db.flush() # get content_obj.id
|
||||
|
||||
# 创建版本记录(初始版本)
|
||||
version = ContentVersion(
|
||||
content_id=content_obj.id,
|
||||
version_number=1,
|
||||
title=req.target_keyword,
|
||||
body=optimized,
|
||||
change_summary="Pipeline自动生成",
|
||||
created_by=current_user.id,
|
||||
)
|
||||
db.add(version)
|
||||
await db.commit()
|
||||
await db.refresh(content_obj)
|
||||
|
||||
return ContentGenerateResponse(
|
||||
status="success",
|
||||
content=content,
|
||||
optimized_content=optimized,
|
||||
seo_score=seo_score,
|
||||
content_id=str(content_obj.id),
|
||||
pipeline_stages=stages,
|
||||
content=result["content"],
|
||||
optimized_content=result["optimized_content"],
|
||||
seo_score=result["seo_score"],
|
||||
content_id=result["content_id"],
|
||||
pipeline_stages=result["pipeline_stages"],
|
||||
)
|
||||
|
||||
except LLMError as e:
|
||||
raise HTTPException(status_code=502, detail=f"LLM调用失败: {str(e)}")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"内容生成异常: {str(e)}")
|
||||
|
||||
|
||||
class GEOContentGenerateRequest(BaseModel):
|
||||
brand_id: str
|
||||
target_keywords: list[str]
|
||||
platform: str = "通用"
|
||||
content_style: str = "专业严谨"
|
||||
word_count: int = 2000
|
||||
knowledge_base_ids: list[str] = []
|
||||
run_deai: bool = True
|
||||
run_geo: bool = True
|
||||
|
||||
|
||||
class GEOContentGenerateResponse(BaseModel):
|
||||
content_id: Optional[str] = None
|
||||
content: str = ""
|
||||
optimized_content: str = ""
|
||||
seo_score: Optional[int] = None
|
||||
pipeline_stages: list[dict] = []
|
||||
|
||||
|
||||
@router.post("/generate-geo", response_model=GEOContentGenerateResponse, status_code=201)
|
||||
async def generate_geo_content(
|
||||
req: GEOContentGenerateRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
org_id = getattr(current_user, "organization_id", None)
|
||||
if not org_id:
|
||||
raise HTTPException(status_code=403, detail="用户未关联组织")
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
try:
|
||||
brand_uuid = uuid.UUID(req.brand_id)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid brand_id format: {req.brand_id}")
|
||||
|
||||
brand_stmt = select(Brand).where(Brand.id == brand_uuid)
|
||||
brand_result = await db.execute(brand_stmt)
|
||||
brand = brand_result.scalar_one_or_none()
|
||||
if not brand:
|
||||
raise HTTPException(status_code=404, detail=f"Brand not found: {req.brand_id}")
|
||||
|
||||
diagnosis_context = ""
|
||||
diag_stmt = (
|
||||
select(DiagnosisRecord)
|
||||
.where(DiagnosisRecord.brand_id == brand_uuid, DiagnosisRecord.status == "completed")
|
||||
.order_by(DiagnosisRecord.created_at.desc())
|
||||
)
|
||||
diag_result = await db.execute(diag_stmt)
|
||||
diagnosis = diag_result.scalar_one_or_none()
|
||||
if diagnosis and diagnosis.result_json:
|
||||
result_json = diagnosis.result_json
|
||||
weak_dimensions = []
|
||||
if isinstance(result_json, dict):
|
||||
dimensions = result_json.get("dimensions", {})
|
||||
for dim_name, dim_data in dimensions.items():
|
||||
if isinstance(dim_data, dict) and dim_data.get("score", 100) < 60:
|
||||
weak_dimensions.append(dim_name)
|
||||
if weak_dimensions:
|
||||
diagnosis_context = f"基于诊断结果,以下维度需要重点优化:{', '.join(weak_dimensions)}。请围绕这些维度生成针对性内容。"
|
||||
|
||||
keyword = "、".join(req.target_keywords)
|
||||
if diagnosis_context:
|
||||
keyword = f"{keyword}({diagnosis_context})"
|
||||
|
||||
try:
|
||||
service = ContentGenerationService()
|
||||
result = await service.generate_content(
|
||||
keyword=keyword,
|
||||
brand_name=brand.name,
|
||||
platform=req.platform,
|
||||
content_style=req.content_style,
|
||||
word_count=req.word_count,
|
||||
knowledge_base_ids=req.knowledge_base_ids,
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
org_id=org_id,
|
||||
run_deai=req.run_deai,
|
||||
run_geo=req.run_geo,
|
||||
)
|
||||
|
||||
return GEOContentGenerateResponse(
|
||||
content_id=result["content_id"],
|
||||
content=result["content"],
|
||||
optimized_content=result["optimized_content"],
|
||||
seo_score=result["seo_score"],
|
||||
pipeline_stages=result["pipeline_stages"],
|
||||
)
|
||||
|
||||
except LLMError as e:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
"""Dashboard API endpoints."""
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy import select, func, Integer
|
||||
|
|
@ -19,325 +18,15 @@ from app.schemas.dashboard import (
|
|||
PlatformScoreItem,
|
||||
RecentQueryItem,
|
||||
)
|
||||
from app.services.scoring_service import ScoringService, get_health_level
|
||||
from app.services.sentiment_service import get_sentiment_service
|
||||
from app.schemas.scoring import CitationResult
|
||||
from app.services.scoring.scoring_service import get_health_level
|
||||
from app.services.scoring.brand_scoring_data_service import (
|
||||
get_brand_scoring_data_service,
|
||||
REQUIRED_PLATFORMS,
|
||||
)
|
||||
from app.services.cache import get_cache_service, TTL_DASHBOARD
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# Required platforms for the dashboard
|
||||
REQUIRED_PLATFORMS = [
|
||||
"wenxin", # 文心一言
|
||||
"kimi", # Kimi
|
||||
"tongyi", # 通义千问
|
||||
"doubao", # 豆包
|
||||
"xinghuo", # 讯飞星火
|
||||
"tiangong", # 天工AI
|
||||
"qingyan", # 智谱清言
|
||||
]
|
||||
|
||||
|
||||
async def _get_brand_score_by_platform(
|
||||
db: AsyncSession,
|
||||
user_id: uuid.UUID,
|
||||
brand_id: uuid.UUID,
|
||||
) -> dict[str, float]:
|
||||
"""
|
||||
Calculate brand score by platform.
|
||||
|
||||
Returns a dict mapping platform key to score (0-100).
|
||||
"""
|
||||
brand_stmt = select(Brand).where(
|
||||
Brand.id == brand_id,
|
||||
Brand.user_id == user_id,
|
||||
)
|
||||
brand_result = await db.execute(brand_stmt)
|
||||
brand = brand_result.scalar_one_or_none()
|
||||
|
||||
if not brand:
|
||||
return {platform: 0.0 for platform in REQUIRED_PLATFORMS}
|
||||
|
||||
queries_stmt = select(QueryModel).where(
|
||||
QueryModel.user_id == user_id,
|
||||
QueryModel.target_brand == brand.name,
|
||||
)
|
||||
queries_result = await db.execute(queries_stmt)
|
||||
queries = list(queries_result.scalars().all())
|
||||
|
||||
if not queries:
|
||||
return {platform: 0.0 for platform in REQUIRED_PLATFORMS}
|
||||
|
||||
query_ids = [q.id for q in queries]
|
||||
|
||||
platform_scores = {platform: 0.0 for platform in REQUIRED_PLATFORMS}
|
||||
|
||||
for platform in REQUIRED_PLATFORMS:
|
||||
citation_stmt = select(
|
||||
func.count().label("total"),
|
||||
func.sum(
|
||||
func.cast(
|
||||
func.case((CitationRecord.cited == True, 1), else_=0),
|
||||
Integer
|
||||
)
|
||||
).label("cited")
|
||||
).where(
|
||||
CitationRecord.query_id.in_(query_ids),
|
||||
CitationRecord.platform == platform,
|
||||
)
|
||||
result = await db.execute(citation_stmt)
|
||||
row = result.one()
|
||||
|
||||
total = row.total or 0
|
||||
cited = row.cited or 0
|
||||
|
||||
if total > 0:
|
||||
citation_rate = cited / total
|
||||
platform_scores[platform] = round(citation_rate * 100, 2)
|
||||
else:
|
||||
platform_scores[platform] = 0.0
|
||||
|
||||
return platform_scores
|
||||
|
||||
|
||||
async def _get_competitor_scores_by_platform(
|
||||
db: AsyncSession,
|
||||
user_id: uuid.UUID,
|
||||
brand_id: uuid.UUID,
|
||||
) -> dict[str, float]:
|
||||
"""
|
||||
Calculate average competitor score by platform.
|
||||
|
||||
Returns a dict mapping platform key to average competitor score (0-100).
|
||||
"""
|
||||
competitor_stmt = select(Competitor).where(Competitor.brand_id == brand_id)
|
||||
competitor_result = await db.execute(competitor_stmt)
|
||||
competitors = list(competitor_result.scalars().all())
|
||||
|
||||
if not competitors:
|
||||
return {platform: 0.0 for platform in REQUIRED_PLATFORMS}
|
||||
|
||||
competitor_names = [c.name for c in competitors]
|
||||
|
||||
competitor_queries_stmt = select(QueryModel).where(
|
||||
QueryModel.user_id == user_id,
|
||||
QueryModel.target_brand.in_(competitor_names),
|
||||
)
|
||||
competitor_queries_result = await db.execute(competitor_queries_stmt)
|
||||
competitor_queries = list(competitor_queries_result.scalars().all())
|
||||
|
||||
if not competitor_queries:
|
||||
return {platform: 0.0 for platform in REQUIRED_PLATFORMS}
|
||||
|
||||
competitor_query_ids = [q.id for q in competitor_queries]
|
||||
|
||||
platform_scores = {platform: 0.0 for platform in REQUIRED_PLATFORMS}
|
||||
|
||||
for platform in REQUIRED_PLATFORMS:
|
||||
citation_stmt = select(
|
||||
func.count().label("total"),
|
||||
func.sum(
|
||||
func.cast(
|
||||
func.case((CitationRecord.cited == True, 1), else_=0),
|
||||
Integer
|
||||
)
|
||||
).label("cited")
|
||||
).where(
|
||||
CitationRecord.query_id.in_(competitor_query_ids),
|
||||
CitationRecord.platform == platform,
|
||||
)
|
||||
result = await db.execute(citation_stmt)
|
||||
row = result.one()
|
||||
|
||||
total = row.total or 0
|
||||
cited = row.cited or 0
|
||||
|
||||
if total > 0:
|
||||
citation_rate = cited / total
|
||||
platform_scores[platform] = round(citation_rate * 100, 2)
|
||||
else:
|
||||
platform_scores[platform] = 0.0
|
||||
|
||||
return platform_scores
|
||||
|
||||
|
||||
async def _calculate_overall_score_v2(
|
||||
db: AsyncSession,
|
||||
user_id: uuid.UUID,
|
||||
brand_id: uuid.UUID,
|
||||
) -> tuple[float, float, list[DimensionScoreItem], int, int]:
|
||||
"""
|
||||
Calculate V2 overall score, change from yesterday, dimension scores,
|
||||
and competitor position.
|
||||
|
||||
Returns: (overall_score, change_from_yesterday, dimensions, ahead_count, behind_count)
|
||||
"""
|
||||
brand_stmt = select(Brand).where(
|
||||
Brand.id == brand_id,
|
||||
Brand.user_id == user_id,
|
||||
)
|
||||
brand_result = await db.execute(brand_stmt)
|
||||
brand = brand_result.scalar_one_or_none()
|
||||
|
||||
if not brand:
|
||||
return 0.0, 0.0, [], 0, 0
|
||||
|
||||
# Get queries for this brand
|
||||
queries_stmt = select(QueryModel).where(
|
||||
QueryModel.user_id == user_id,
|
||||
QueryModel.target_brand == brand.name,
|
||||
)
|
||||
queries_result = await db.execute(queries_stmt)
|
||||
queries = list(queries_result.scalars().all())
|
||||
|
||||
if not queries:
|
||||
return 0.0, 0.0, [], 0, 0
|
||||
|
||||
query_ids = [q.id for q in queries]
|
||||
|
||||
# Get all citations
|
||||
citations_stmt = select(CitationRecord).where(
|
||||
CitationRecord.query_id.in_(query_ids),
|
||||
)
|
||||
citations_result = await db.execute(citations_stmt)
|
||||
all_citations = list(citations_result.scalars().all())
|
||||
|
||||
total_queries = len(all_citations)
|
||||
brand_citations = [c for c in all_citations if c.cited]
|
||||
|
||||
# Get competitor data
|
||||
competitor_stmt = select(Competitor).where(Competitor.brand_id == brand_id)
|
||||
competitor_result = await db.execute(competitor_stmt)
|
||||
competitors = list(competitor_result.scalars().all())
|
||||
competitor_names = [c.name for c in competitors]
|
||||
|
||||
# Calculate competitor mentions
|
||||
competitor_mentions: dict[str, int] = {}
|
||||
for comp_name in competitor_names:
|
||||
count = sum(
|
||||
1 for c in all_citations
|
||||
if c.cited and c.competitor_brands
|
||||
and comp_name in c.competitor_brands
|
||||
)
|
||||
if count > 0:
|
||||
competitor_mentions[comp_name] = count
|
||||
|
||||
# Sentiment analysis - prefer persisted sentiment field
|
||||
sentiment_service = get_sentiment_service()
|
||||
sentiment_counts = {"positive": 0, "neutral": 0, "negative": 0}
|
||||
for citation in brand_citations:
|
||||
if citation.sentiment and citation.sentiment in ("positive", "neutral", "negative"):
|
||||
sentiment_counts[citation.sentiment] += 1
|
||||
else:
|
||||
content = citation.raw_response or citation.citation_text or ""
|
||||
if content.strip():
|
||||
try:
|
||||
result = await sentiment_service.analyze(
|
||||
brand_name=brand.name,
|
||||
content=content,
|
||||
)
|
||||
sentiment_counts[result.sentiment] += 1
|
||||
except Exception:
|
||||
sentiment_counts["neutral"] += 1
|
||||
else:
|
||||
sentiment_counts["neutral"] += 1
|
||||
|
||||
# Build citation results
|
||||
citation_results = [
|
||||
CitationResult(
|
||||
cited=c.cited,
|
||||
position=c.citation_position,
|
||||
citation_text=c.citation_text,
|
||||
sentiment="neutral", # simplified for dashboard
|
||||
confidence=c.confidence or 0.0,
|
||||
)
|
||||
for c in brand_citations
|
||||
]
|
||||
|
||||
# Extract positions
|
||||
positions = [c.citation_position for c in brand_citations if c.cited]
|
||||
|
||||
# Calculate V2 score
|
||||
scoring_service = ScoringService()
|
||||
v2_result = scoring_service.calculate_v2(
|
||||
mentioned_count=len(brand_citations),
|
||||
total_queries=total_queries,
|
||||
positions=positions,
|
||||
sentiment_counts=sentiment_counts,
|
||||
citations=citation_results,
|
||||
brand_mentions=len(brand_citations),
|
||||
competitor_mentions=competitor_mentions,
|
||||
)
|
||||
|
||||
# Build dimension items
|
||||
dimensions = [
|
||||
DimensionScoreItem(
|
||||
name=v2_result.mention_rate.name,
|
||||
score=round(v2_result.mention_rate.score, 2),
|
||||
max_score=v2_result.mention_rate.max_score,
|
||||
percentage=round(v2_result.mention_rate.percentage, 2),
|
||||
),
|
||||
DimensionScoreItem(
|
||||
name=v2_result.recommendation_rank.name,
|
||||
score=round(v2_result.recommendation_rank.score, 2),
|
||||
max_score=v2_result.recommendation_rank.max_score,
|
||||
percentage=round(v2_result.recommendation_rank.percentage, 2),
|
||||
),
|
||||
DimensionScoreItem(
|
||||
name=v2_result.sentiment_score.name,
|
||||
score=round(v2_result.sentiment_score.score, 2),
|
||||
max_score=v2_result.sentiment_score.max_score,
|
||||
percentage=round(v2_result.sentiment_score.percentage, 2),
|
||||
),
|
||||
DimensionScoreItem(
|
||||
name=v2_result.citation_quality.name,
|
||||
score=round(v2_result.citation_quality.score, 2),
|
||||
max_score=v2_result.citation_quality.max_score,
|
||||
percentage=round(v2_result.citation_quality.percentage, 2),
|
||||
),
|
||||
DimensionScoreItem(
|
||||
name=v2_result.competitive_position.name,
|
||||
score=round(v2_result.competitive_position.score, 2),
|
||||
max_score=v2_result.competitive_position.max_score,
|
||||
percentage=round(v2_result.competitive_position.percentage, 2),
|
||||
),
|
||||
]
|
||||
|
||||
# Calculate competitor ahead/behind
|
||||
ahead_count = sum(
|
||||
1 for count in competitor_mentions.values()
|
||||
if len(brand_citations) > count
|
||||
)
|
||||
behind_count = sum(
|
||||
1 for count in competitor_mentions.values()
|
||||
if len(brand_citations) <= count
|
||||
)
|
||||
|
||||
# Calculate change from yesterday
|
||||
today = datetime.now().date()
|
||||
yesterday = today - timedelta(days=1)
|
||||
|
||||
today_citations = [
|
||||
c for c in all_citations
|
||||
if c.queried_at.date() == today
|
||||
]
|
||||
yesterday_citations = [
|
||||
c for c in all_citations
|
||||
if c.queried_at.date() == yesterday
|
||||
]
|
||||
|
||||
today_cited = sum(1 for c in today_citations if c.cited)
|
||||
today_total = len(today_citations)
|
||||
today_score = (today_cited / today_total * 100) if today_total > 0 else 0.0
|
||||
|
||||
yesterday_cited = sum(1 for c in yesterday_citations if c.cited)
|
||||
yesterday_total = len(yesterday_citations)
|
||||
yesterday_score = (yesterday_cited / yesterday_total * 100) if yesterday_total > 0 else 0.0
|
||||
|
||||
change = round(today_score - yesterday_score, 2)
|
||||
|
||||
return v2_result.overall_score, change, dimensions, ahead_count, behind_count
|
||||
|
||||
|
||||
@router.get("/stats", response_model=DashboardStatsResponse)
|
||||
async def get_dashboard_stats(
|
||||
|
|
@ -357,7 +46,6 @@ async def get_dashboard_stats(
|
|||
- 最近查询记录
|
||||
"""
|
||||
cache = get_cache_service()
|
||||
# 如果 brand_id 尚未确定,先查库取第一个品牌
|
||||
if brand_id is None:
|
||||
brand_stmt = select(Brand).where(Brand.user_id == current_user.id).limit(1)
|
||||
brand_result = await db.execute(brand_stmt)
|
||||
|
|
@ -379,34 +67,67 @@ async def get_dashboard_stats(
|
|||
total_platforms=7,
|
||||
)
|
||||
|
||||
# 尝试从缓存读取(TTL: 2 分钟)
|
||||
cache_key = f"dashboard:stats:{current_user.id}:{brand_id}"
|
||||
cached = await cache.get_json(cache_key)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
# Get brand name
|
||||
brand_stmt = select(Brand).where(Brand.id == brand_id)
|
||||
brand_result = await db.execute(brand_stmt)
|
||||
brand = brand_result.scalar_one_or_none()
|
||||
brand_name = brand.name if brand else None
|
||||
|
||||
# Calculate V2 overall score and dimensions
|
||||
overall_score, score_change, dimensions, ahead_count, behind_count = (
|
||||
await _calculate_overall_score_v2(db, current_user.id, brand_id)
|
||||
scoring_data_service = get_brand_scoring_data_service()
|
||||
|
||||
scoring_data = await scoring_data_service.get_brand_scoring_data(
|
||||
db, current_user.id, brand
|
||||
)
|
||||
|
||||
# Get platform scores
|
||||
platform_scores_dict = await _get_brand_score_by_platform(
|
||||
overall_score = scoring_data.v2_result.overall_score
|
||||
score_change = scoring_data.change_from_yesterday
|
||||
|
||||
dimensions = [
|
||||
DimensionScoreItem(
|
||||
name=scoring_data.v2_result.mention_rate.name,
|
||||
score=round(scoring_data.v2_result.mention_rate.score, 2),
|
||||
max_score=scoring_data.v2_result.mention_rate.max_score,
|
||||
percentage=round(scoring_data.v2_result.mention_rate.percentage, 2),
|
||||
),
|
||||
DimensionScoreItem(
|
||||
name=scoring_data.v2_result.recommendation_rank.name,
|
||||
score=round(scoring_data.v2_result.recommendation_rank.score, 2),
|
||||
max_score=scoring_data.v2_result.recommendation_rank.max_score,
|
||||
percentage=round(scoring_data.v2_result.recommendation_rank.percentage, 2),
|
||||
),
|
||||
DimensionScoreItem(
|
||||
name=scoring_data.v2_result.sentiment_score.name,
|
||||
score=round(scoring_data.v2_result.sentiment_score.score, 2),
|
||||
max_score=scoring_data.v2_result.sentiment_score.max_score,
|
||||
percentage=round(scoring_data.v2_result.sentiment_score.percentage, 2),
|
||||
),
|
||||
DimensionScoreItem(
|
||||
name=scoring_data.v2_result.citation_quality.name,
|
||||
score=round(scoring_data.v2_result.citation_quality.score, 2),
|
||||
max_score=scoring_data.v2_result.citation_quality.max_score,
|
||||
percentage=round(scoring_data.v2_result.citation_quality.percentage, 2),
|
||||
),
|
||||
DimensionScoreItem(
|
||||
name=scoring_data.v2_result.competitive_position.name,
|
||||
score=round(scoring_data.v2_result.competitive_position.score, 2),
|
||||
max_score=scoring_data.v2_result.competitive_position.max_score,
|
||||
percentage=round(scoring_data.v2_result.competitive_position.percentage, 2),
|
||||
),
|
||||
]
|
||||
|
||||
ahead_count = scoring_data.competitor_data.get("ahead_count", 0)
|
||||
behind_count = scoring_data.competitor_data.get("behind_count", 0)
|
||||
|
||||
platform_scores_dict = scoring_data.platform_scores
|
||||
|
||||
competitor_scores_dict = await scoring_data_service.get_competitor_platform_scores(
|
||||
db, current_user.id, brand_id
|
||||
)
|
||||
|
||||
# Get competitor platform scores
|
||||
competitor_scores_dict = await _get_competitor_scores_by_platform(
|
||||
db, current_user.id, brand_id
|
||||
)
|
||||
|
||||
# Get first competitor name for display
|
||||
competitor_stmt = select(Competitor).where(
|
||||
Competitor.brand_id == brand_id
|
||||
).limit(1)
|
||||
|
|
@ -424,10 +145,8 @@ async def get_dashboard_stats(
|
|||
for platform, score in platform_scores_dict.items()
|
||||
]
|
||||
|
||||
# Count monitored platforms
|
||||
monitored = sum(1 for s in platform_scores_dict.values() if s > 0)
|
||||
|
||||
# Get recent queries (last 10)
|
||||
recent_queries_stmt = (
|
||||
select(QueryModel)
|
||||
.where(QueryModel.user_id == current_user.id)
|
||||
|
|
@ -437,7 +156,6 @@ async def get_dashboard_stats(
|
|||
recent_queries_result = await db.execute(recent_queries_stmt)
|
||||
recent_queries_list = list(recent_queries_result.scalars().all())
|
||||
|
||||
# Get citation count for each query
|
||||
recent_queries = []
|
||||
for query in recent_queries_list:
|
||||
citation_count_stmt = select(
|
||||
|
|
@ -460,7 +178,6 @@ async def get_dashboard_stats(
|
|||
queried_at=query.last_queried_at or query.created_at,
|
||||
))
|
||||
|
||||
# Health level
|
||||
health_level = get_health_level(overall_score)
|
||||
|
||||
response = DashboardStatsResponse(
|
||||
|
|
@ -477,7 +194,6 @@ async def get_dashboard_stats(
|
|||
brand_name=brand_name,
|
||||
)
|
||||
|
||||
# 将结果写入缓存(TTL: 2 分钟)
|
||||
await cache.set_json(
|
||||
cache_key,
|
||||
response.model_dump(mode="json"),
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from app.schemas.detection_task import (
|
|||
DetectionTaskUpdate,
|
||||
DetectionTriggerResponse,
|
||||
)
|
||||
from app.services.detection_scheduler import DetectionSchedulerService, TaskNotFoundError
|
||||
from app.services.detection.detection_scheduler import DetectionSchedulerService, TaskNotFoundError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
"""诊断API端点 - 提供SEO和GEO诊断功能"""
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
|
|
@ -11,13 +13,32 @@ from app.api.deps import get_current_user
|
|||
from app.database import get_db
|
||||
from app.models.user import User
|
||||
from app.models.brand import Brand
|
||||
from app.services.seo_diagnosis import SEODiagnosisService
|
||||
from app.services.geo_diagnosis import GEODiagnosisService, GEODiagnosisInput
|
||||
from app.models.diagnosis_record import DiagnosisRecord
|
||||
from app.schemas.diagnosis import (
|
||||
GEODiagnosisHistoryItem,
|
||||
GEODiagnosisHistoryResponse,
|
||||
GEODiagnosisResponse,
|
||||
GEODiagnosisResultResponse,
|
||||
GEODiagnosisTaskResponse,
|
||||
GEODiagnosisTriggerRequest,
|
||||
)
|
||||
from app.services.diagnosis.data_collector import DataCollectorService
|
||||
from app.services.diagnosis.seo_diagnosis import SEODiagnosisService
|
||||
from app.services.diagnosis.geo_diagnosis import GEODiagnosisService, GEODiagnosisInput
|
||||
from app.utils.health import get_health_level_label
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
_FREE_TIER_DIMENSIONS = {"内容可提取性", "E-E-A-T信号", "引用就绪度"}
|
||||
|
||||
|
||||
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
|
||||
if isinstance(value, uuid.UUID):
|
||||
return value
|
||||
return uuid.UUID(str(value))
|
||||
|
||||
|
||||
@router.get("/seo/{brand_id}")
|
||||
async def get_seo_diagnosis(
|
||||
|
|
@ -25,24 +46,14 @@ async def get_seo_diagnosis(
|
|||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取品牌的SEO诊断结果
|
||||
|
||||
返回5维度SEO诊断:
|
||||
- 技术SEO (25分)
|
||||
- 页面SEO (20分)
|
||||
- 内容质量 (20分)
|
||||
- 外链分析 (15分)
|
||||
- 用户体验 (20分)
|
||||
"""
|
||||
brand = await _get_brand_or_404(brand_id, current_user, db)
|
||||
|
||||
|
||||
try:
|
||||
service = SEODiagnosisService()
|
||||
result = service.diagnose()
|
||||
|
||||
|
||||
logger.info(f"SEO诊断完成: brand_id={brand_id}, brand={brand.name}, score={result.overall_score}")
|
||||
|
||||
|
||||
return result.to_dict()
|
||||
except Exception as e:
|
||||
logger.error(f"SEO诊断失败: brand_id={brand_id}, error={str(e)}", exc_info=True)
|
||||
|
|
@ -52,39 +63,158 @@ async def get_seo_diagnosis(
|
|||
)
|
||||
|
||||
|
||||
@router.get("/geo/{brand_id}")
|
||||
async def get_geo_diagnosis(
|
||||
@router.post("/geo/{brand_id}", status_code=status.HTTP_202_ACCEPTED)
|
||||
async def trigger_geo_diagnosis(
|
||||
brand_id: uuid.UUID,
|
||||
body: GEODiagnosisTriggerRequest = GEODiagnosisTriggerRequest(),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取品牌的GEO诊断结果
|
||||
|
||||
返回6维度GEO诊断:
|
||||
- 内容可提取性 (20分)
|
||||
- 实体清晰度 (15分)
|
||||
- E-E-A-T信号 (20分)
|
||||
- Schema标记 (15分)
|
||||
- 主题权威 (15分)
|
||||
- 引用就绪度 (15分)
|
||||
"""
|
||||
brand = await _get_brand_or_404(brand_id, current_user, db)
|
||||
|
||||
try:
|
||||
input_data = GEODiagnosisInput()
|
||||
service = GEODiagnosisService()
|
||||
result = service.diagnose(input_data)
|
||||
|
||||
logger.info(f"GEO诊断完成: brand_id={brand_id}, brand={brand.name}, score={result.overall_score}")
|
||||
|
||||
return result.to_dict()
|
||||
except Exception as e:
|
||||
logger.error(f"GEO诊断失败: brand_id={brand_id}, error={str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="GEO诊断服务异常,请稍后重试",
|
||||
|
||||
if not body.force_refresh:
|
||||
existing = await _find_recent_completed(db, brand_id, hours=24)
|
||||
if existing:
|
||||
return GEODiagnosisTaskResponse(
|
||||
task_id=str(existing.id),
|
||||
brand_id=str(brand_id),
|
||||
status="completed",
|
||||
)
|
||||
|
||||
record = DiagnosisRecord(
|
||||
brand_id=brand_id,
|
||||
user_id=_to_uuid(current_user.id),
|
||||
diagnosis_type="geo",
|
||||
status="pending",
|
||||
)
|
||||
db.add(record)
|
||||
await db.commit()
|
||||
await db.refresh(record)
|
||||
|
||||
asyncio.create_task(
|
||||
_run_geo_diagnosis(
|
||||
record_id=record.id,
|
||||
brand_id=brand_id,
|
||||
brand_name=brand.name,
|
||||
brand_aliases=brand.aliases or [],
|
||||
website=brand.website,
|
||||
industry=brand.industry,
|
||||
user_id=_to_uuid(current_user.id),
|
||||
)
|
||||
)
|
||||
|
||||
return GEODiagnosisTaskResponse(
|
||||
task_id=str(record.id),
|
||||
brand_id=str(brand_id),
|
||||
status="pending",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/geo/{brand_id}/result")
|
||||
async def get_geo_diagnosis_result(
|
||||
brand_id: uuid.UUID,
|
||||
task_id: str | None = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
await _get_brand_or_404(brand_id, current_user, db)
|
||||
|
||||
if task_id:
|
||||
try:
|
||||
tid = uuid.UUID(task_id)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="无效的task_id",
|
||||
)
|
||||
stmt = select(DiagnosisRecord).where(
|
||||
DiagnosisRecord.id == tid,
|
||||
DiagnosisRecord.brand_id == brand_id,
|
||||
)
|
||||
else:
|
||||
stmt = (
|
||||
select(DiagnosisRecord)
|
||||
.where(
|
||||
DiagnosisRecord.brand_id == brand_id,
|
||||
DiagnosisRecord.diagnosis_type == "geo",
|
||||
)
|
||||
.order_by(DiagnosisRecord.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
record = result.scalar_one_or_none()
|
||||
|
||||
if not record:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="诊断记录不存在",
|
||||
)
|
||||
|
||||
if record.status == "pending" or record.status == "running":
|
||||
return GEODiagnosisResultResponse(
|
||||
task_id=str(record.id),
|
||||
brand_id=str(brand_id),
|
||||
status=record.status,
|
||||
)
|
||||
|
||||
if record.status == "failed":
|
||||
return GEODiagnosisResultResponse(
|
||||
task_id=str(record.id),
|
||||
brand_id=str(brand_id),
|
||||
status="failed",
|
||||
error=record.error_message,
|
||||
)
|
||||
|
||||
user_plan = getattr(current_user, "plan", None) or "free"
|
||||
is_paid = user_plan not in ("free", None)
|
||||
diagnosis_resp = _build_diagnosis_response(record.result_json, is_paid)
|
||||
|
||||
return GEODiagnosisResultResponse(
|
||||
task_id=str(record.id),
|
||||
brand_id=str(brand_id),
|
||||
status="completed",
|
||||
result=diagnosis_resp,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/geo/{brand_id}/history")
|
||||
async def get_geo_diagnosis_history(
|
||||
brand_id: uuid.UUID,
|
||||
limit: int = 10,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
await _get_brand_or_404(brand_id, current_user, db)
|
||||
|
||||
stmt = (
|
||||
select(DiagnosisRecord)
|
||||
.where(
|
||||
DiagnosisRecord.brand_id == brand_id,
|
||||
DiagnosisRecord.diagnosis_type == "geo",
|
||||
DiagnosisRecord.status == "completed",
|
||||
)
|
||||
.order_by(DiagnosisRecord.created_at.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
records = result.scalars().all()
|
||||
|
||||
items = [
|
||||
GEODiagnosisHistoryItem(
|
||||
task_id=str(r.id),
|
||||
overall_score=r.overall_score or 0,
|
||||
health_level=r.result_json.get("health_level", "danger") if r.result_json else "danger",
|
||||
created_at=r.created_at.isoformat(),
|
||||
completed_at=r.completed_at.isoformat() if r.completed_at else None,
|
||||
)
|
||||
for r in records
|
||||
]
|
||||
|
||||
return GEODiagnosisHistoryResponse(
|
||||
brand_id=str(brand_id),
|
||||
history=items,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/combined/{brand_id}")
|
||||
|
|
@ -93,29 +223,32 @@ async def get_combined_diagnosis(
|
|||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取品牌的综合诊断结果
|
||||
|
||||
结合SEO和GEO诊断,返回综合评分和详细诊断结果
|
||||
"""
|
||||
brand = await _get_brand_or_404(brand_id, current_user, db)
|
||||
|
||||
|
||||
try:
|
||||
seo_service = SEODiagnosisService()
|
||||
seo_result = seo_service.diagnose()
|
||||
|
||||
|
||||
collector = DataCollectorService(db)
|
||||
collection = await collector.collect(
|
||||
brand_name=brand.name,
|
||||
brand_aliases=brand.aliases or [],
|
||||
website=brand.website,
|
||||
industry=brand.industry,
|
||||
)
|
||||
|
||||
geo_service = GEODiagnosisService()
|
||||
geo_result = geo_service.diagnose(GEODiagnosisInput())
|
||||
|
||||
geo_result = geo_service.diagnose(collection.diagnosis_input)
|
||||
|
||||
combined_score = round((seo_result.overall_score + geo_result.overall_score) / 2, 2)
|
||||
|
||||
|
||||
logger.info(
|
||||
f"综合诊断完成: brand_id={brand_id}, brand={brand.name}, "
|
||||
f"seo_score={seo_result.overall_score}, "
|
||||
f"geo_score={geo_result.overall_score}, "
|
||||
f"combined_score={combined_score}"
|
||||
)
|
||||
|
||||
|
||||
return {
|
||||
"seo_score": seo_result.overall_score,
|
||||
"geo_score": geo_result.overall_score,
|
||||
|
|
@ -131,20 +264,131 @@ async def get_combined_diagnosis(
|
|||
)
|
||||
|
||||
|
||||
async def _run_geo_diagnosis(
|
||||
record_id: uuid.UUID,
|
||||
brand_id: uuid.UUID,
|
||||
brand_name: str,
|
||||
brand_aliases: list[str],
|
||||
website: str | None,
|
||||
industry: str | None,
|
||||
user_id: uuid.UUID,
|
||||
) -> None:
|
||||
from app.database import AsyncSessionLocal
|
||||
|
||||
async with AsyncSessionLocal() as db:
|
||||
try:
|
||||
stmt = select(DiagnosisRecord).where(DiagnosisRecord.id == record_id)
|
||||
result = await db.execute(stmt)
|
||||
record = result.scalar_one_or_none()
|
||||
if not record:
|
||||
return
|
||||
|
||||
record.status = "running"
|
||||
await db.commit()
|
||||
|
||||
collector = DataCollectorService(db)
|
||||
collection = await collector.collect(
|
||||
brand_name=brand_name,
|
||||
brand_aliases=brand_aliases,
|
||||
website=website,
|
||||
industry=industry,
|
||||
)
|
||||
|
||||
geo_service = GEODiagnosisService()
|
||||
diagnosis_result = geo_service.diagnose(collection.diagnosis_input)
|
||||
|
||||
record.status = "completed"
|
||||
record.overall_score = diagnosis_result.overall_score
|
||||
record.result_json = diagnosis_result.to_dict()
|
||||
record.completed_at = datetime.now(UTC)
|
||||
record.collection_metadata = collection.metadata
|
||||
if collection.errors:
|
||||
record.collection_metadata["errors"] = collection.errors
|
||||
|
||||
await db.commit()
|
||||
|
||||
logger.info(
|
||||
f"GEO诊断完成: brand_id={brand_id}, brand={brand_name}, "
|
||||
f"score={diagnosis_result.overall_score}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"GEO诊断任务失败: record_id={record_id}, error={e}", exc_info=True)
|
||||
try:
|
||||
stmt = select(DiagnosisRecord).where(DiagnosisRecord.id == record_id)
|
||||
result = await db.execute(stmt)
|
||||
record = result.scalar_one_or_none()
|
||||
if record:
|
||||
record.status = "failed"
|
||||
record.error_message = str(e)
|
||||
await db.commit()
|
||||
except Exception:
|
||||
logger.error(f"更新失败状态也失败: record_id={record_id}")
|
||||
|
||||
|
||||
def _build_diagnosis_response(result_json: dict | None, is_paid: bool) -> GEODiagnosisResponse:
|
||||
if not result_json:
|
||||
return GEODiagnosisResponse(
|
||||
overall_score=0,
|
||||
health_level="danger",
|
||||
health_level_label=get_health_level_label("danger"),
|
||||
dimensions=[],
|
||||
recommendations=[],
|
||||
is_full_report=is_paid,
|
||||
)
|
||||
|
||||
dimensions = result_json.get("dimensions", [])
|
||||
if not is_paid:
|
||||
dimensions = [d for d in dimensions if d.get("name") in _FREE_TIER_DIMENSIONS]
|
||||
|
||||
recommendations = result_json.get("recommendations", [])
|
||||
if not is_paid:
|
||||
recommendations = [r for r in recommendations if r.get("priority") == "P0"]
|
||||
|
||||
return GEODiagnosisResponse(
|
||||
overall_score=result_json.get("overall_score", 0),
|
||||
health_level=result_json.get("health_level", "danger"),
|
||||
health_level_label=result_json.get("health_level_label", get_health_level_label("danger")),
|
||||
dimensions=dimensions,
|
||||
recommendations=recommendations,
|
||||
is_full_report=is_paid,
|
||||
)
|
||||
|
||||
|
||||
async def _find_recent_completed(
|
||||
db: AsyncSession, brand_id: uuid.UUID, hours: int = 24
|
||||
) -> DiagnosisRecord | None:
|
||||
from datetime import timedelta
|
||||
|
||||
cutoff = datetime.now(UTC) - timedelta(hours=hours)
|
||||
stmt = (
|
||||
select(DiagnosisRecord)
|
||||
.where(
|
||||
DiagnosisRecord.brand_id == brand_id,
|
||||
DiagnosisRecord.status == "completed",
|
||||
DiagnosisRecord.completed_at >= cutoff,
|
||||
)
|
||||
.order_by(DiagnosisRecord.completed_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def _get_brand_or_404(
|
||||
brand_id: uuid.UUID,
|
||||
current_user: User,
|
||||
db: AsyncSession,
|
||||
) -> Brand:
|
||||
"""获取品牌或抛出404异常"""
|
||||
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == current_user.id)
|
||||
user_uuid = _to_uuid(current_user.id)
|
||||
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == user_uuid)
|
||||
result = await db.execute(stmt)
|
||||
brand = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if not brand:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="品牌不存在",
|
||||
)
|
||||
|
||||
|
||||
return brand
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import uuid
|
|||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
|
|
@ -25,7 +26,9 @@ from app.schemas.distribution import (
|
|||
)
|
||||
from app.services.distribution.formatter import ContentFormatter
|
||||
from app.services.distribution.platform_rules import PLATFORM_RULES, PlatformRuleEngine
|
||||
from app.services.distribution.publish_engine import PublishEngine
|
||||
from app.services.distribution.publish_strategy import PublishStrategyService
|
||||
from app.services.distribution.publishers.base import PublishResult
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
|
@ -235,3 +238,72 @@ async def create_schedule(
|
|||
tips=tips,
|
||||
created_at=schedule.created_at.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
)
|
||||
|
||||
|
||||
class PublishRequest(BaseModel):
|
||||
content_id: str = Field(min_length=1)
|
||||
platforms: list[str] = Field(min_length=1)
|
||||
|
||||
|
||||
class PublishStatusItem(BaseModel):
|
||||
platform: str
|
||||
status: str
|
||||
article_url: str | None = None
|
||||
published_at: str | None = None
|
||||
|
||||
|
||||
class PublishStatusResponse(BaseModel):
|
||||
platforms: list[PublishStatusItem]
|
||||
|
||||
|
||||
class PublishResponse(BaseModel):
|
||||
results: list[PublishResult]
|
||||
|
||||
|
||||
_publish_engine = PublishEngine()
|
||||
|
||||
|
||||
@router.post("/publish", response_model=PublishResponse)
|
||||
async def publish_content(
|
||||
req: PublishRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
org_id = current_user.organization_id
|
||||
if not org_id:
|
||||
raise HTTPException(status_code=403, detail="用户未关联组织")
|
||||
|
||||
try:
|
||||
results = await _publish_engine.publish_content(
|
||||
content_id=req.content_id,
|
||||
platforms=req.platforms,
|
||||
db=db,
|
||||
user_id=str(current_user.id),
|
||||
org_id=str(org_id),
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
||||
|
||||
return PublishResponse(results=results)
|
||||
|
||||
|
||||
@router.get("/publish/{content_id}/status", response_model=PublishStatusResponse)
|
||||
async def get_publish_status(
|
||||
content_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
try:
|
||||
status_list = await _publish_engine.get_publish_status(
|
||||
content_id=content_id,
|
||||
db=db,
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Content not found: {content_id}",
|
||||
)
|
||||
|
||||
return PublishStatusResponse(
|
||||
platforms=[PublishStatusItem(**s) for s in status_list]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,144 @@
|
|||
import hashlib
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.brand import Brand
|
||||
from app.schemas.health_score import (
|
||||
HealthScoreDimension,
|
||||
HealthScoreRecommendation,
|
||||
HealthScoreResponse,
|
||||
)
|
||||
from app.services.cache import get_cache_service
|
||||
from app.services.diagnosis.data_collector import DataCollectorService
|
||||
from app.services.diagnosis.geo_diagnosis import GEODiagnosisService
|
||||
from app.utils.health import get_health_level, get_health_level_label
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
_FREE_TIER_DIMENSIONS = {"内容可提取性", "E-E-A-T信号", "引用就绪度"}
|
||||
|
||||
_CACHE_TTL = 86400
|
||||
|
||||
|
||||
def _build_default_response(brand_name: str) -> HealthScoreResponse:
|
||||
return HealthScoreResponse(
|
||||
brand_name=brand_name,
|
||||
overall_score=0.0,
|
||||
health_level="danger",
|
||||
health_level_label=get_health_level_label("danger"),
|
||||
dimensions=[
|
||||
HealthScoreDimension(
|
||||
name=d,
|
||||
score=0.0,
|
||||
max_score=0.0,
|
||||
percentage=0.0,
|
||||
status="fail",
|
||||
)
|
||||
for d in sorted(_FREE_TIER_DIMENSIONS)
|
||||
],
|
||||
recommendations=[],
|
||||
is_full_report=False,
|
||||
cached=False,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/health-score", response_model=HealthScoreResponse)
|
||||
async def get_public_health_score(
|
||||
brand: str = Query(..., min_length=1),
|
||||
competitors: str = Query(default=""),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
cache = get_cache_service()
|
||||
cache_key = f"health_score:{hashlib.md5(brand.lower().encode()).hexdigest()}"
|
||||
|
||||
cached_data = await cache.get_json(cache_key)
|
||||
if cached_data is not None:
|
||||
cached_data["cached"] = True
|
||||
return cached_data
|
||||
|
||||
brand_name = brand.strip()
|
||||
competitor_list = [c.strip() for c in competitors.split(",") if c.strip()][:3]
|
||||
|
||||
brand_aliases: list[str] = []
|
||||
website: str | None = None
|
||||
industry: str | None = None
|
||||
|
||||
stmt = select(Brand).where(Brand.name == brand_name).limit(1)
|
||||
result = await db.execute(stmt)
|
||||
brand_record = result.scalar_one_or_none()
|
||||
|
||||
if brand_record:
|
||||
brand_aliases = brand_record.aliases or []
|
||||
website = brand_record.website
|
||||
industry = brand_record.industry
|
||||
|
||||
try:
|
||||
collector = DataCollectorService(db)
|
||||
collection = await collector.collect(
|
||||
brand_name=brand_name,
|
||||
brand_aliases=brand_aliases,
|
||||
website=website,
|
||||
industry=industry,
|
||||
)
|
||||
|
||||
geo_service = GEODiagnosisService()
|
||||
diagnosis_result = geo_service.diagnose(collection.diagnosis_input)
|
||||
except Exception as e:
|
||||
logger.error(f"健康分数据采集失败: brand={brand_name}, error={e}", exc_info=True)
|
||||
default_resp = _build_default_response(brand_name)
|
||||
await cache.set_json(cache_key, default_resp.model_dump(mode="json"), expire=_CACHE_TTL)
|
||||
return default_resp
|
||||
|
||||
all_dimensions = diagnosis_result.to_dict().get("dimensions", [])
|
||||
free_dimensions = [d for d in all_dimensions if d.get("name") in _FREE_TIER_DIMENSIONS]
|
||||
|
||||
overall_score = round(sum(d.get("score", 0.0) for d in free_dimensions), 2)
|
||||
health_level = get_health_level(overall_score)
|
||||
|
||||
dimension_responses = [
|
||||
HealthScoreDimension(
|
||||
name=d.get("name", ""),
|
||||
score=round(d.get("score", 0.0), 2),
|
||||
max_score=d.get("max_score", 0.0),
|
||||
percentage=round(d.get("percentage", 0.0), 2),
|
||||
status=d.get("status", "fail"),
|
||||
)
|
||||
for d in free_dimensions
|
||||
]
|
||||
|
||||
all_recommendations = diagnosis_result.to_dict().get("recommendations", [])
|
||||
free_recommendations = [
|
||||
r for r in all_recommendations
|
||||
if r.get("priority") == "P0" and r.get("dimension") in _FREE_TIER_DIMENSIONS
|
||||
]
|
||||
|
||||
recommendation_responses = [
|
||||
HealthScoreRecommendation(
|
||||
priority=r.get("priority", "P0"),
|
||||
dimension=r.get("dimension", ""),
|
||||
title=r.get("title", ""),
|
||||
description=r.get("description", ""),
|
||||
)
|
||||
for r in free_recommendations
|
||||
]
|
||||
|
||||
response = HealthScoreResponse(
|
||||
brand_name=brand_name,
|
||||
overall_score=overall_score,
|
||||
health_level=health_level,
|
||||
health_level_label=get_health_level_label(health_level),
|
||||
dimensions=dimension_responses,
|
||||
recommendations=recommendation_responses,
|
||||
is_full_report=False,
|
||||
cached=False,
|
||||
)
|
||||
|
||||
await cache.set_json(cache_key, response.model_dump(mode="json"), expire=_CACHE_TTL)
|
||||
|
||||
return response
|
||||
|
|
@ -0,0 +1,207 @@
|
|||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.database import get_db
|
||||
from app.models.user import User
|
||||
from app.models.brand import Brand
|
||||
from app.models.monitoring import MonitoringRecord
|
||||
from app.schemas.monitoring import (
|
||||
MonitoringRecordCreate,
|
||||
MonitoringRecordResponse,
|
||||
MonitoringRecordList,
|
||||
MonitoringChangeReport,
|
||||
MonitoringStatusUpdate,
|
||||
)
|
||||
from app.services.monitoring.monitor_service import MonitorService
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
async def _get_brand_with_access(
|
||||
brand_id: uuid.UUID,
|
||||
db: AsyncSession,
|
||||
current_user: User,
|
||||
) -> Brand:
|
||||
stmt = select(Brand).where(
|
||||
Brand.id == brand_id,
|
||||
Brand.user_id == current_user.id,
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
brand = result.scalar_one_or_none()
|
||||
|
||||
if not brand:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="品牌不存在",
|
||||
)
|
||||
return brand
|
||||
|
||||
|
||||
def _record_to_response(record: MonitoringRecord) -> MonitoringRecordResponse:
|
||||
return MonitoringRecordResponse(
|
||||
id=record.id,
|
||||
brand_id=record.brand_id,
|
||||
content_id=record.content_id,
|
||||
query_keywords=record.query_keywords,
|
||||
platform=record.platform,
|
||||
baseline_citation_count=record.baseline_citation_count,
|
||||
baseline_sentiment=record.baseline_sentiment,
|
||||
baseline_rank=record.baseline_rank,
|
||||
current_citation_count=record.current_citation_count,
|
||||
current_sentiment=record.current_sentiment,
|
||||
current_rank=record.current_rank,
|
||||
change_type=record.change_type,
|
||||
change_details=record.change_details,
|
||||
check_interval_hours=record.check_interval_hours,
|
||||
last_checked_at=record.last_checked_at,
|
||||
next_check_at=record.next_check_at,
|
||||
status=record.status,
|
||||
created_at=record.created_at,
|
||||
updated_at=record.updated_at,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/tasks", response_model=MonitoringRecordResponse)
|
||||
async def create_monitoring_task(
|
||||
request: MonitoringRecordCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
await _get_brand_with_access(request.brand_id, db, current_user)
|
||||
|
||||
service = MonitorService()
|
||||
record = await service.create_monitoring_record(
|
||||
db=db,
|
||||
brand_id=request.brand_id,
|
||||
content_id=request.content_id,
|
||||
query_keywords=request.query_keywords,
|
||||
platform=request.platform,
|
||||
check_interval_hours=request.check_interval_hours,
|
||||
)
|
||||
return _record_to_response(record)
|
||||
|
||||
|
||||
@router.get("/brand/{brand_id}", response_model=MonitoringRecordList)
|
||||
async def get_brand_monitoring(
|
||||
brand_id: uuid.UUID,
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(20, ge=1, le=100),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
await _get_brand_with_access(brand_id, db, current_user)
|
||||
|
||||
service = MonitorService()
|
||||
records, total = await service.get_brand_monitoring(
|
||||
db=db,
|
||||
brand_id=brand_id,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
)
|
||||
return MonitoringRecordList(
|
||||
records=[_record_to_response(r) for r in records],
|
||||
total=total,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{record_id}/report", response_model=MonitoringChangeReport)
|
||||
async def get_change_report(
|
||||
record_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
stmt = select(MonitoringRecord).where(MonitoringRecord.id == record_id)
|
||||
result = await db.execute(stmt)
|
||||
record = result.scalar_one_or_none()
|
||||
|
||||
if not record:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="监测记录不存在",
|
||||
)
|
||||
|
||||
await _get_brand_with_access(record.brand_id, db, current_user)
|
||||
|
||||
service = MonitorService()
|
||||
report = await service.generate_change_report(db, record_id)
|
||||
if not report:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="变化报告不存在",
|
||||
)
|
||||
|
||||
return MonitoringChangeReport(
|
||||
monitoring_record_id=uuid.UUID(report["monitoring_record_id"]),
|
||||
brand_id=uuid.UUID(report["brand_id"]),
|
||||
change_type=report["change_type"],
|
||||
change_details=report["change_details"],
|
||||
baseline=report["baseline"],
|
||||
current=report["current"],
|
||||
recommendations=report["recommendations"],
|
||||
)
|
||||
|
||||
|
||||
@router.put("/{record_id}/status", response_model=MonitoringRecordResponse)
|
||||
async def update_monitoring_status(
|
||||
record_id: uuid.UUID,
|
||||
status_update: MonitoringStatusUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
valid_statuses = {"active", "paused", "completed"}
|
||||
if status_update.status not in valid_statuses:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"无效的状态值,支持: {', '.join(valid_statuses)}",
|
||||
)
|
||||
|
||||
stmt = select(MonitoringRecord).where(MonitoringRecord.id == record_id)
|
||||
result = await db.execute(stmt)
|
||||
record = result.scalar_one_or_none()
|
||||
|
||||
if not record:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="监测记录不存在",
|
||||
)
|
||||
|
||||
await _get_brand_with_access(record.brand_id, db, current_user)
|
||||
|
||||
record.status = status_update.status
|
||||
await db.commit()
|
||||
await db.refresh(record)
|
||||
|
||||
return _record_to_response(record)
|
||||
|
||||
|
||||
@router.post("/{record_id}/check", response_model=MonitoringRecordResponse)
|
||||
async def trigger_manual_check(
|
||||
record_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
stmt = select(MonitoringRecord).where(MonitoringRecord.id == record_id)
|
||||
result = await db.execute(stmt)
|
||||
record = result.scalar_one_or_none()
|
||||
|
||||
if not record:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="监测记录不存在",
|
||||
)
|
||||
|
||||
await _get_brand_with_access(record.brand_id, db, current_user)
|
||||
|
||||
service = MonitorService()
|
||||
updated_record = await service.check_and_compare(db, record_id)
|
||||
if not updated_record:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="检测执行失败",
|
||||
)
|
||||
|
||||
return _record_to_response(updated_record)
|
||||
|
|
@ -20,15 +20,24 @@ from app.database import get_db
|
|||
from app.models.user import User
|
||||
from app.models.brand import Brand
|
||||
from app.models.competitor import Competitor
|
||||
from app.models.citation_record import CitationRecord
|
||||
from app.models.query import Query as QueryModel
|
||||
from app.services.scoring_service import ScoringService
|
||||
from app.schemas.brand import BrandCreate, BrandResponse
|
||||
from app.services.diagnosis.data_collector import DataCollectorService
|
||||
from app.services.diagnosis.geo_diagnosis import GEODiagnosisService
|
||||
from app.utils.health import get_health_level, get_health_level_label
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/onboarding", tags=["onboarding"])
|
||||
|
||||
_FREE_TIER_DIMENSIONS = {"内容可提取性", "E-E-A-T信号", "引用就绪度"}
|
||||
|
||||
|
||||
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
|
||||
if isinstance(value, uuid.UUID):
|
||||
return value
|
||||
return uuid.UUID(str(value))
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Request / Response schemas
|
||||
|
|
@ -60,23 +69,43 @@ class CompetitorRecommendationSimpleResponse(BaseModel):
|
|||
recommendations: list[CompetitorRecommendationSimple]
|
||||
|
||||
|
||||
class HealthDimensionItem(BaseModel):
|
||||
name: str
|
||||
score: float
|
||||
max_score: float
|
||||
percentage: float
|
||||
status: str
|
||||
|
||||
|
||||
class HealthRecommendationItem(BaseModel):
|
||||
priority: str
|
||||
dimension: str
|
||||
title: str
|
||||
description: str
|
||||
|
||||
|
||||
class HealthReportResponse(BaseModel):
|
||||
"""初始健康评分报告"""
|
||||
brand_id: str
|
||||
brand_name: str
|
||||
overall_score: float
|
||||
health_level: str
|
||||
health_level_label: str
|
||||
platform_scores: dict
|
||||
strengths: list[str]
|
||||
weaknesses: list[str]
|
||||
competitor_scores: list[dict]
|
||||
dimensions: list[HealthDimensionItem] = []
|
||||
recommendations: list[HealthRecommendationItem] = []
|
||||
is_full_report: bool = False
|
||||
|
||||
|
||||
class ActionSuggestion(BaseModel):
|
||||
"""行动建议项"""
|
||||
title: str
|
||||
description: str
|
||||
priority: str # high / medium / low
|
||||
action_type: str # e.g. coverage, keyword, sentiment, platform
|
||||
priority: str
|
||||
action_type: str
|
||||
is_paid_action: bool = False
|
||||
action_button_text: str = ""
|
||||
|
||||
|
||||
class ActionSuggestionsResponse(BaseModel):
|
||||
|
|
@ -105,7 +134,7 @@ async def get_onboarding_status(
|
|||
- completed=True 且 brand_id 有值 → 已完成
|
||||
- completed=False, current_step=1 → 需要创建品牌
|
||||
"""
|
||||
stmt = select(Brand).where(Brand.user_id == current_user.id)
|
||||
stmt = select(Brand).where(Brand.user_id == _to_uuid(current_user.id))
|
||||
result = await db.execute(stmt)
|
||||
brand = result.scalar_one_or_none()
|
||||
|
||||
|
|
@ -144,7 +173,7 @@ async def create_onboarding_brand(
|
|||
)
|
||||
|
||||
brand = Brand(
|
||||
user_id=current_user.id,
|
||||
user_id=_to_uuid(current_user.id),
|
||||
name=full_brand_data.name,
|
||||
aliases=full_brand_data.aliases,
|
||||
website=full_brand_data.website,
|
||||
|
|
@ -156,6 +185,38 @@ async def create_onboarding_brand(
|
|||
await db.commit()
|
||||
await db.refresh(brand)
|
||||
|
||||
# 自动创建默认查询词(检查 max_queries 限制)
|
||||
try:
|
||||
current_query_count_stmt = select(func.count()).select_from(QueryModel).where(
|
||||
QueryModel.user_id == current_user.id
|
||||
)
|
||||
current_query_count_result = await db.execute(current_query_count_stmt)
|
||||
current_query_count = current_query_count_result.scalar_one()
|
||||
|
||||
max_queries = getattr(current_user, "max_queries", 3) # 默认免费版 3 个
|
||||
|
||||
if current_query_count < max_queries:
|
||||
default_query = QueryModel(
|
||||
user_id=current_user.id,
|
||||
keyword=f"{brand.name} 推荐",
|
||||
target_brand=brand.name,
|
||||
brand_aliases=brand.aliases or [],
|
||||
platforms=brand.platforms or ["wenxin", "kimi"],
|
||||
frequency=brand.frequency or "weekly",
|
||||
status="active",
|
||||
)
|
||||
db.add(default_query)
|
||||
await db.commit()
|
||||
await db.refresh(default_query)
|
||||
logger.info(f"Auto-created default query for brand '{brand.name}'")
|
||||
else:
|
||||
logger.info(
|
||||
f"Skipped auto-creating default query for brand '{brand.name}': "
|
||||
f"query limit reached ({current_query_count}/{max_queries})"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to auto-create default query for brand '{brand.name}': {e}")
|
||||
|
||||
return brand
|
||||
|
||||
|
||||
|
|
@ -171,8 +232,7 @@ async def get_onboarding_competitor_recommendations(
|
|||
复用 brands/competitors 中的推荐逻辑,
|
||||
支持 LLM 智能推荐和规则推荐两种模式。
|
||||
"""
|
||||
# 验证品牌归属
|
||||
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == current_user.id)
|
||||
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == _to_uuid(current_user.id))
|
||||
result = await db.execute(stmt)
|
||||
brand = result.scalar_one_or_none()
|
||||
|
||||
|
|
@ -182,7 +242,6 @@ async def get_onboarding_competitor_recommendations(
|
|||
detail="品牌不存在",
|
||||
)
|
||||
|
||||
# 获取已有竞品名称(排除)
|
||||
existing_stmt = select(Competitor.name).where(Competitor.brand_id == brand_id)
|
||||
existing_result = await db.execute(existing_stmt)
|
||||
existing_names = [row[0] for row in existing_result.all()]
|
||||
|
|
@ -228,14 +287,8 @@ async def get_onboarding_health_report(
|
|||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取品牌初始健康评分报告。
|
||||
|
||||
基于 citation_records 表统计品牌的引用数据,
|
||||
如果没有引用数据则返回初始化状态(overall_score: 0)。
|
||||
"""
|
||||
# 验证品牌归属
|
||||
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == current_user.id)
|
||||
user_uuid = _to_uuid(current_user.id)
|
||||
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == user_uuid)
|
||||
result = await db.execute(stmt)
|
||||
brand = result.scalar_one_or_none()
|
||||
|
||||
|
|
@ -245,7 +298,55 @@ async def get_onboarding_health_report(
|
|||
detail="品牌不存在",
|
||||
)
|
||||
|
||||
# 查询与品牌关联的 queries
|
||||
user_plan = getattr(current_user, "plan", None) or "free"
|
||||
is_paid = user_plan not in ("free", None)
|
||||
|
||||
try:
|
||||
collector = DataCollectorService(db)
|
||||
collection = await collector.collect(
|
||||
brand_name=brand.name,
|
||||
brand_aliases=brand.aliases or [],
|
||||
website=brand.website,
|
||||
industry=brand.industry,
|
||||
)
|
||||
|
||||
geo_service = GEODiagnosisService()
|
||||
diagnosis_result = geo_service.diagnose(collection.diagnosis_input)
|
||||
except Exception as e:
|
||||
logger.error(f"Onboarding健康报告采集失败: brand_id={brand_id}, error={e}", exc_info=True)
|
||||
return HealthReportResponse(
|
||||
brand_id=str(brand.id),
|
||||
brand_name=brand.name,
|
||||
overall_score=0.0,
|
||||
health_level="danger",
|
||||
health_level_label=get_health_level_label("danger"),
|
||||
platform_scores={},
|
||||
strengths=["品牌已创建,等待数据采集"],
|
||||
weaknesses=["数据采集失败,请稍后重试"],
|
||||
competitor_scores=[],
|
||||
dimensions=[
|
||||
HealthDimensionItem(name=d, score=0.0, max_score=0.0, percentage=0.0, status="fail")
|
||||
for d in sorted(_FREE_TIER_DIMENSIONS)
|
||||
],
|
||||
recommendations=[],
|
||||
is_full_report=is_paid,
|
||||
)
|
||||
|
||||
result_dict = diagnosis_result.to_dict()
|
||||
all_dimensions = result_dict.get("dimensions", [])
|
||||
all_recommendations = result_dict.get("recommendations", [])
|
||||
|
||||
if is_paid:
|
||||
filtered_dimensions = all_dimensions
|
||||
filtered_recommendations = all_recommendations
|
||||
else:
|
||||
filtered_dimensions = [d for d in all_dimensions if d.get("name") in _FREE_TIER_DIMENSIONS]
|
||||
filtered_recommendations = [r for r in all_recommendations if r.get("priority") == "P0"]
|
||||
|
||||
overall_score = round(diagnosis_result.overall_score, 2)
|
||||
health_level = get_health_level(overall_score)
|
||||
|
||||
platform_scores: dict[str, float] = {}
|
||||
queries_stmt = select(QueryModel).where(
|
||||
QueryModel.user_id == current_user.id,
|
||||
QueryModel.target_brand == brand.name,
|
||||
|
|
@ -253,138 +354,89 @@ async def get_onboarding_health_report(
|
|||
queries_result = await db.execute(queries_stmt)
|
||||
queries = list(queries_result.scalars().all())
|
||||
|
||||
# 没有查询数据 → 返回初始化状态
|
||||
if not queries:
|
||||
return HealthReportResponse(
|
||||
brand_id=str(brand.id),
|
||||
brand_name=brand.name,
|
||||
overall_score=0.0,
|
||||
platform_scores={},
|
||||
strengths=["品牌已创建,等待数据采集"],
|
||||
weaknesses=["尚无AI平台引用数据,需等待查询执行"],
|
||||
competitor_scores=[],
|
||||
if queries:
|
||||
from app.models.citation_record import CitationRecord
|
||||
query_ids = [q.id for q in queries]
|
||||
citations_stmt = select(CitationRecord).where(
|
||||
CitationRecord.query_id.in_(query_ids),
|
||||
)
|
||||
citations_result = await db.execute(citations_stmt)
|
||||
citations = list(citations_result.scalars().all())
|
||||
|
||||
query_ids = [q.id for q in queries]
|
||||
platforms_seen: dict[str, dict] = {}
|
||||
for c in citations:
|
||||
p = c.platform or "unknown"
|
||||
if p not in platforms_seen:
|
||||
platforms_seen[p] = {"total": 0, "cited": 0}
|
||||
platforms_seen[p]["total"] += 1
|
||||
if c.cited:
|
||||
platforms_seen[p]["cited"] += 1
|
||||
|
||||
# 获取引用记录
|
||||
citations_stmt = select(CitationRecord).where(
|
||||
CitationRecord.query_id.in_(query_ids),
|
||||
)
|
||||
citations_result = await db.execute(citations_stmt)
|
||||
citations = list(citations_result.scalars().all())
|
||||
for p, data in platforms_seen.items():
|
||||
rate = (data["cited"] / data["total"] * 100) if data["total"] > 0 else 0.0
|
||||
platform_scores[p] = round(rate, 2)
|
||||
|
||||
total = len(citations)
|
||||
cited = [c for c in citations if c.cited]
|
||||
|
||||
# 计算各平台评分
|
||||
platform_scores: dict[str, float] = {}
|
||||
platforms_seen: dict[str, dict] = {} # {platform: {total, cited}}
|
||||
|
||||
for c in citations:
|
||||
p = c.platform or "unknown"
|
||||
if p not in platforms_seen:
|
||||
platforms_seen[p] = {"total": 0, "cited": 0}
|
||||
platforms_seen[p]["total"] += 1
|
||||
if c.cited:
|
||||
platforms_seen[p]["cited"] += 1
|
||||
|
||||
for p, data in platforms_seen.items():
|
||||
rate = (data["cited"] / data["total"] * 100) if data["total"] > 0 else 0.0
|
||||
platform_scores[p] = round(rate, 2)
|
||||
|
||||
# 使用 ScoringService 计算 overall_score
|
||||
scoring_service = ScoringService()
|
||||
sentiment_counts = {"positive": 0, "neutral": 0, "negative": 0}
|
||||
for c in cited:
|
||||
sentiment = c.sentiment or "neutral"
|
||||
if sentiment in sentiment_counts:
|
||||
sentiment_counts[sentiment] += 1
|
||||
|
||||
from app.schemas.scoring import CitationResult
|
||||
citation_results = [
|
||||
CitationResult(
|
||||
cited=c.cited,
|
||||
position=c.citation_position,
|
||||
citation_text=c.citation_text,
|
||||
sentiment=c.sentiment or "neutral",
|
||||
confidence=c.confidence or 0.0,
|
||||
)
|
||||
for c in cited
|
||||
]
|
||||
positions = [c.citation_position for c in cited if c.cited]
|
||||
|
||||
# 获取竞品信息
|
||||
competitor_stmt = select(Competitor).where(Competitor.brand_id == brand_id)
|
||||
competitor_result = await db.execute(competitor_stmt)
|
||||
competitors = list(competitor_result.scalars().all())
|
||||
competitor_names = [c.name for c in competitors]
|
||||
competitor_mentions: dict[str, int] = {}
|
||||
for comp_name in competitor_names:
|
||||
count = sum(
|
||||
1 for c in citations
|
||||
if c.cited and c.competitor_brands and comp_name in c.competitor_brands
|
||||
)
|
||||
if count > 0:
|
||||
competitor_mentions[comp_name] = count
|
||||
|
||||
v2_result = scoring_service.calculate_v2(
|
||||
mentioned_count=len(cited),
|
||||
total_queries=total,
|
||||
positions=positions,
|
||||
sentiment_counts=sentiment_counts,
|
||||
citations=citation_results,
|
||||
brand_mentions=len(cited),
|
||||
competitor_mentions=competitor_mentions,
|
||||
)
|
||||
|
||||
# 生成 strengths/weaknesses
|
||||
strengths = []
|
||||
weaknesses = []
|
||||
|
||||
if total == 0:
|
||||
strengths.append("品牌已创建")
|
||||
weaknesses.append("尚无引用数据")
|
||||
else:
|
||||
mention_rate = len(cited) / total * 100 if total > 0 else 0
|
||||
if mention_rate >= 50:
|
||||
strengths.append(f"提及率较高 ({round(mention_rate, 1)}%)")
|
||||
else:
|
||||
weaknesses.append(f"提及率偏低 ({round(mention_rate, 1)}%)")
|
||||
for d in filtered_dimensions:
|
||||
pct = d.get("percentage", 0)
|
||||
if pct >= 60:
|
||||
strengths.append(f"{d.get('name', '')}表现良好 ({round(pct, 1)}%)")
|
||||
elif pct > 0:
|
||||
weaknesses.append(f"{d.get('name', '')}有待提升 ({round(pct, 1)}%)")
|
||||
|
||||
for p, score in platform_scores.items():
|
||||
if score >= 60:
|
||||
strengths.append(f"{p} 平台表现良好 ({score}%)")
|
||||
elif score > 0:
|
||||
weaknesses.append(f"{p} 平台覆盖率不足 ({score}%)")
|
||||
if not strengths:
|
||||
strengths.append("已有初步诊断数据")
|
||||
if not weaknesses:
|
||||
weaknesses.append("暂无明显短板")
|
||||
|
||||
if sentiment_counts["positive"] > sentiment_counts["negative"]:
|
||||
strengths.append("情感倾向正面")
|
||||
elif sentiment_counts["negative"] > sentiment_counts["positive"]:
|
||||
weaknesses.append("情感倾向偏负面")
|
||||
competitor_stmt = select(Competitor).where(Competitor.brand_id == brand_id)
|
||||
competitor_result = await db.execute(competitor_stmt)
|
||||
competitors = list(competitor_result.scalars().all())
|
||||
|
||||
if not strengths:
|
||||
strengths.append("已有初步引用数据")
|
||||
if not weaknesses:
|
||||
weaknesses.append("暂无明显短板")
|
||||
|
||||
# 竞品评分
|
||||
competitor_scores = []
|
||||
for comp_name, mentions in competitor_mentions.items():
|
||||
comp_score = round(mentions / total * 100, 2) if total > 0 else 0.0
|
||||
for comp in competitors:
|
||||
competitor_scores.append({
|
||||
"name": comp_name,
|
||||
"score": comp_score,
|
||||
"name": comp.name,
|
||||
"score": 0.0,
|
||||
"is_leading": False,
|
||||
})
|
||||
|
||||
dimension_items = [
|
||||
HealthDimensionItem(
|
||||
name=d.get("name", ""),
|
||||
score=round(d.get("score", 0.0), 2),
|
||||
max_score=d.get("max_score", 0.0),
|
||||
percentage=round(d.get("percentage", 0.0), 2),
|
||||
status=d.get("status", "fail"),
|
||||
)
|
||||
for d in filtered_dimensions
|
||||
]
|
||||
|
||||
recommendation_items = [
|
||||
HealthRecommendationItem(
|
||||
priority=r.get("priority", "P0"),
|
||||
dimension=r.get("dimension", ""),
|
||||
title=r.get("title", ""),
|
||||
description=r.get("description", ""),
|
||||
)
|
||||
for r in filtered_recommendations
|
||||
]
|
||||
|
||||
return HealthReportResponse(
|
||||
brand_id=str(brand.id),
|
||||
brand_name=brand.name,
|
||||
overall_score=round(v2_result.overall_score, 2),
|
||||
overall_score=overall_score,
|
||||
health_level=health_level,
|
||||
health_level_label=get_health_level_label(health_level),
|
||||
platform_scores=platform_scores,
|
||||
strengths=strengths,
|
||||
weaknesses=weaknesses,
|
||||
competitor_scores=competitor_scores,
|
||||
dimensions=dimension_items,
|
||||
recommendations=recommendation_items,
|
||||
is_full_report=is_paid,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -394,21 +446,21 @@ async def get_onboarding_action_suggestions(
|
|||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
根据健康报告生成行动建议(基于规则引擎,不需要 LLM)。
|
||||
"""
|
||||
# 先获取健康报告数据(复用逻辑)
|
||||
report = await get_onboarding_health_report(brand_id, current_user, db)
|
||||
|
||||
user_plan = getattr(current_user, "plan", None) or "free"
|
||||
is_paid = user_plan not in ("free", None)
|
||||
|
||||
suggestions = []
|
||||
|
||||
# 规则引擎:基于评分和平台数据生成建议
|
||||
if report.overall_score < 20:
|
||||
suggestions.append(ActionSuggestion(
|
||||
title="提升 AI 平台覆盖率",
|
||||
description=f"当前综合评分仅 {report.overall_score},品牌在AI搜索中几乎未被提及。建议增加查询词覆盖面,让AI平台更频繁地引用品牌。",
|
||||
priority="high",
|
||||
action_type="coverage",
|
||||
is_paid_action=False,
|
||||
action_button_text="设置查询词",
|
||||
))
|
||||
|
||||
if report.overall_score < 50:
|
||||
|
|
@ -417,9 +469,10 @@ async def get_onboarding_action_suggestions(
|
|||
description="品牌在关键查询词下的提及率偏低,建议调整查询关键词策略,聚焦行业核心术语。",
|
||||
priority="high",
|
||||
action_type="keyword",
|
||||
is_paid_action=False,
|
||||
action_button_text="优化关键词",
|
||||
))
|
||||
|
||||
# 平台维度建议
|
||||
for platform, score in report.platform_scores.items():
|
||||
if score < 30:
|
||||
suggestions.append(ActionSuggestion(
|
||||
|
|
@ -427,28 +480,32 @@ async def get_onboarding_action_suggestions(
|
|||
description=f"品牌在 {platform} 平台的引用率仅为 {score}%,需要针对性优化该平台的内容策略。",
|
||||
priority="medium",
|
||||
action_type="platform",
|
||||
is_paid_action=False,
|
||||
action_button_text=f"优化{platform}",
|
||||
))
|
||||
|
||||
# 情感维度建议
|
||||
if "情感倾向偏负面" in report.weaknesses:
|
||||
suggestions.append(ActionSuggestion(
|
||||
title="改善品牌情感倾向",
|
||||
description="AI平台对品牌的情感评价偏负面,建议发布正面品牌内容、优化品牌描述以改善情感得分。",
|
||||
priority="medium",
|
||||
action_type="sentiment",
|
||||
))
|
||||
for dim in report.dimensions:
|
||||
if dim.percentage < 40 and dim.name not in _FREE_TIER_DIMENSIONS:
|
||||
suggestions.append(ActionSuggestion(
|
||||
title=f"提升{dim.name}得分",
|
||||
description=f"{dim.name}当前得分仅 {round(dim.percentage, 1)}%,解锁详细诊断和优化方案。",
|
||||
priority="medium",
|
||||
action_type="dimension",
|
||||
is_paid_action=True,
|
||||
action_button_text="升级解锁",
|
||||
))
|
||||
|
||||
# 竞品对比建议
|
||||
for comp in report.competitor_scores:
|
||||
if comp["score"] > report.overall_score:
|
||||
if comp.get("score", 0) > report.overall_score:
|
||||
suggestions.append(ActionSuggestion(
|
||||
title=f"应对竞品 {comp['name']} 威胁",
|
||||
description=f"竞品 {comp['name']} 评分 ({comp['score']}) 高于本品牌 ({report.overall_score}),建议分析竞品优势领域并制定差异化策略。",
|
||||
priority="high",
|
||||
action_type="competitive",
|
||||
is_paid_action=True,
|
||||
action_button_text="查看竞品分析",
|
||||
))
|
||||
|
||||
# 如果没有引用数据,给出基础建议
|
||||
if report.overall_score == 0:
|
||||
suggestions = [
|
||||
ActionSuggestion(
|
||||
|
|
@ -456,28 +513,45 @@ async def get_onboarding_action_suggestions(
|
|||
description="品牌尚无查询数据,建议首先设置与品牌最相关的核心查询词,让系统开始数据采集。",
|
||||
priority="high",
|
||||
action_type="keyword",
|
||||
is_paid_action=False,
|
||||
action_button_text="设置查询词",
|
||||
),
|
||||
ActionSuggestion(
|
||||
title="添加竞品对比",
|
||||
description="添加主要竞品以便进行对比分析,了解品牌在市场中的定位。",
|
||||
priority="medium",
|
||||
action_type="coverage",
|
||||
is_paid_action=False,
|
||||
action_button_text="添加竞品",
|
||||
),
|
||||
ActionSuggestion(
|
||||
title="完善品牌信息",
|
||||
description="补充品牌别名、网站、行业等详细信息,有助于提升AI平台识别率。",
|
||||
title="解锁完整6维度诊断",
|
||||
description="免费版仅展示3个核心维度,升级后可查看完整6维度诊断报告和深度优化方案。",
|
||||
priority="medium",
|
||||
action_type="brand_info",
|
||||
action_type="upgrade",
|
||||
is_paid_action=True,
|
||||
action_button_text="升级Pro",
|
||||
),
|
||||
]
|
||||
|
||||
# 确保至少有1条建议
|
||||
if not suggestions:
|
||||
suggestions.append(ActionSuggestion(
|
||||
title="持续监测品牌表现",
|
||||
description="品牌表现良好,建议持续监测并保持当前策略。",
|
||||
priority="low",
|
||||
action_type="monitor",
|
||||
is_paid_action=False,
|
||||
action_button_text="查看Dashboard",
|
||||
))
|
||||
|
||||
if not is_paid:
|
||||
suggestions.append(ActionSuggestion(
|
||||
title="升级Pro解锁完整诊断",
|
||||
description="免费版仅展示3个核心维度和P0建议。升级Pro可获取完整6维度诊断、深度竞品分析和AI优化方案。",
|
||||
priority="low",
|
||||
action_type="upgrade",
|
||||
is_paid_action=True,
|
||||
action_button_text="升级Pro",
|
||||
))
|
||||
|
||||
return ActionSuggestionsResponse(suggestions=suggestions)
|
||||
|
|
@ -496,8 +570,7 @@ async def complete_onboarding(
|
|||
User 模型当前没有 onboarding_completed 专用字段,
|
||||
品牌的创建即代表 onboarding 完成。
|
||||
"""
|
||||
# 验证品牌归属
|
||||
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == current_user.id)
|
||||
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == _to_uuid(current_user.id))
|
||||
result = await db.execute(stmt)
|
||||
brand = result.scalar_one_or_none()
|
||||
|
||||
|
|
@ -507,8 +580,6 @@ async def complete_onboarding(
|
|||
detail="品牌不存在",
|
||||
)
|
||||
|
||||
# 品牌已创建即代表 onboarding 完成,无需额外字段更新
|
||||
# 后续如需专用字段,可通过 alembic 迁移添加 user.onboarding_completed
|
||||
logger.info(f"User {current_user.id} completed onboarding with brand {brand_id}")
|
||||
|
||||
return OnboardingCompleteResponse(success=True)
|
||||
|
|
@ -0,0 +1,274 @@
|
|||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.database import get_db
|
||||
from app.models.payment_order import PaymentOrder as PaymentOrderModel
|
||||
from app.models.user import User
|
||||
from app.services.payment import get_payment_gateway
|
||||
from app.services.subscription import PLANS, subscribe
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/payments", tags=["支付"])
|
||||
|
||||
|
||||
class CreateOrderRequest(BaseModel):
|
||||
plan: str
|
||||
payment_provider: str = "wechat"
|
||||
|
||||
|
||||
class CreateOrderResponse(BaseModel):
|
||||
order_id: str
|
||||
pay_url: str
|
||||
amount: float
|
||||
currency: str = "CNY"
|
||||
status: str = "pending"
|
||||
|
||||
|
||||
class OrderStatusResponse(BaseModel):
|
||||
order_id: str
|
||||
status: str
|
||||
plan: str
|
||||
amount: float
|
||||
payment_provider: str
|
||||
payment_id: str | None = None
|
||||
created_at: str | None = None
|
||||
paid_at: str | None = None
|
||||
|
||||
|
||||
class RefundRequest(BaseModel):
|
||||
reason: str = ""
|
||||
|
||||
|
||||
@router.post("/orders", response_model=CreateOrderResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_payment_order(
|
||||
request: CreateOrderRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
plan_data = PLANS.get(request.plan)
|
||||
if plan_data is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail=f"无效的套餐: {request.plan}",
|
||||
)
|
||||
|
||||
if plan_data["price"] == 0:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="免费套餐无需支付",
|
||||
)
|
||||
|
||||
order_id = uuid.uuid4()
|
||||
amount = plan_data["price"]
|
||||
|
||||
gateway = get_payment_gateway(request.payment_provider)
|
||||
payment_order = await gateway.create_order(
|
||||
order_id=str(order_id),
|
||||
amount=amount,
|
||||
description=f"GEO平台{plan_data['name']}订阅",
|
||||
user_id=current_user.id,
|
||||
plan=request.plan,
|
||||
)
|
||||
|
||||
db_order = PaymentOrderModel(
|
||||
id=order_id,
|
||||
user_id=current_user.id,
|
||||
plan=request.plan,
|
||||
amount=amount,
|
||||
payment_provider=request.payment_provider,
|
||||
status="pending",
|
||||
pay_url=payment_order.pay_url,
|
||||
)
|
||||
db.add(db_order)
|
||||
await db.commit()
|
||||
|
||||
return CreateOrderResponse(
|
||||
order_id=str(order_id),
|
||||
pay_url=payment_order.pay_url,
|
||||
amount=amount,
|
||||
status="pending",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/callback/wechat")
|
||||
async def wechat_pay_callback(request: Request):
|
||||
body = await request.form()
|
||||
request_data = dict(body)
|
||||
|
||||
gateway = get_payment_gateway("wechat")
|
||||
callback = await gateway.verify_callback(request_data)
|
||||
|
||||
return await _handle_payment_callback(request_data, callback, "wechat")
|
||||
|
||||
|
||||
@router.post("/callback/alipay")
|
||||
async def alipay_callback(request: Request):
|
||||
body = await request.form()
|
||||
request_data = dict(body)
|
||||
|
||||
gateway = get_payment_gateway("alipay")
|
||||
callback = await gateway.verify_callback(request_data)
|
||||
|
||||
return await _handle_payment_callback(request_data, callback, "alipay")
|
||||
|
||||
|
||||
async def _handle_payment_callback(request_data: dict, callback, provider: str):
|
||||
from app.database import AsyncSessionLocal
|
||||
|
||||
async with AsyncSessionLocal() as db:
|
||||
try:
|
||||
result = await _process_callback(db, callback, provider)
|
||||
await db.commit()
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"[PaymentCallback] 处理回调异常: {e}", exc_info=True)
|
||||
await db.rollback()
|
||||
if provider == "wechat":
|
||||
return _wechat_fail_response()
|
||||
return "fail"
|
||||
|
||||
|
||||
async def _process_callback(db: AsyncSession, callback, provider: str):
|
||||
stmt = select(PaymentOrderModel).where(
|
||||
PaymentOrderModel.id == uuid.UUID(callback.order_id)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
order = result.scalar_one_or_none()
|
||||
|
||||
if order is None:
|
||||
logger.warning(f"[PaymentCallback] 订单不存在: order_id={callback.order_id}")
|
||||
if provider == "wechat":
|
||||
return _wechat_fail_response()
|
||||
return "fail"
|
||||
|
||||
if callback.status == "success":
|
||||
order.status = "paid"
|
||||
order.payment_id = callback.payment_id
|
||||
order.callback_data = callback.raw_data
|
||||
order.paid_at = datetime.now(timezone.utc)
|
||||
|
||||
await subscribe(db, order.user_id, order.plan)
|
||||
|
||||
logger.info(
|
||||
f"[PaymentCallback] 支付成功: order_id={callback.order_id}, "
|
||||
f"plan={order.plan}, provider={provider}"
|
||||
)
|
||||
else:
|
||||
order.status = "failed"
|
||||
order.callback_data = callback.raw_data
|
||||
|
||||
if provider == "wechat":
|
||||
return _wechat_success_response()
|
||||
return "success"
|
||||
|
||||
|
||||
def _wechat_success_response():
|
||||
from fastapi.responses import Response
|
||||
return Response(content="<xml><return_code><![CDATA[SUCCESS]]></return_code></xml>", media_type="application/xml")
|
||||
|
||||
|
||||
def _wechat_fail_response():
|
||||
from fastapi.responses import Response
|
||||
return Response(content="<xml><return_code><![CDATA[FAIL]]></return_code></xml>", media_type="application/xml")
|
||||
|
||||
|
||||
@router.get("/orders/{order_id}", response_model=OrderStatusResponse)
|
||||
async def query_order_status(
|
||||
order_id: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
try:
|
||||
oid = uuid.UUID(order_id)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="无效的订单ID",
|
||||
)
|
||||
|
||||
stmt = select(PaymentOrderModel).where(PaymentOrderModel.id == oid)
|
||||
result = await db.execute(stmt)
|
||||
order = result.scalar_one_or_none()
|
||||
|
||||
if order is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="订单不存在",
|
||||
)
|
||||
|
||||
if order.user_id != current_user.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="无权查看此订单",
|
||||
)
|
||||
|
||||
return OrderStatusResponse(
|
||||
order_id=str(order.id),
|
||||
status=order.status,
|
||||
plan=order.plan,
|
||||
amount=order.amount,
|
||||
payment_provider=order.payment_provider,
|
||||
payment_id=order.payment_id,
|
||||
created_at=order.created_at.isoformat() if order.created_at else None,
|
||||
paid_at=order.paid_at.isoformat() if order.paid_at else None,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/refund/{order_id}")
|
||||
async def refund_order(
|
||||
order_id: str,
|
||||
body: RefundRequest = RefundRequest(),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
user_plan = getattr(current_user, "plan", "free") or "free"
|
||||
if user_plan != "enterprise":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="仅企业管理员可执行退款操作",
|
||||
)
|
||||
|
||||
try:
|
||||
oid = uuid.UUID(order_id)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="无效的订单ID",
|
||||
)
|
||||
|
||||
stmt = select(PaymentOrderModel).where(PaymentOrderModel.id == oid)
|
||||
result = await db.execute(stmt)
|
||||
order = result.scalar_one_or_none()
|
||||
|
||||
if order is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="订单不存在",
|
||||
)
|
||||
|
||||
if order.status != "paid":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="仅已支付订单可退款",
|
||||
)
|
||||
|
||||
gateway = get_payment_gateway(order.payment_provider)
|
||||
success = await gateway.refund(order_id, order.amount, body.reason)
|
||||
|
||||
if success:
|
||||
order.status = "refunded"
|
||||
await db.commit()
|
||||
return {"message": "退款成功", "order_id": order_id}
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="退款失败",
|
||||
)
|
||||
|
|
@ -1,11 +1,5 @@
|
|||
"""
|
||||
平台健康检查API - 验证各AI平台适配器状态
|
||||
|
||||
端点: GET /api/platforms/health
|
||||
返回: 各平台适配器配置状态和健康信息
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
|
@ -18,7 +12,6 @@ router = APIRouter(prefix="/platforms", tags=["platforms"])
|
|||
|
||||
|
||||
class PlatformHealthStatus:
|
||||
"""平台健康状态"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -37,22 +30,38 @@ class PlatformHealthStatus:
|
|||
self.message = message
|
||||
|
||||
|
||||
_PLATFORM_URLS = {
|
||||
"kimi": "https://kimi.moonshot.cn",
|
||||
"wenxin": "https://yiyan.baidu.com",
|
||||
"doubao": "https://www.doubao.com/",
|
||||
}
|
||||
|
||||
|
||||
def _check_api_key_health(
|
||||
platform_name: str,
|
||||
env_key_name: str,
|
||||
url: str,
|
||||
) -> PlatformHealthStatus:
|
||||
api_key = os.getenv(env_key_name, "")
|
||||
api_key_set = bool(api_key and api_key.strip())
|
||||
configured = api_key_set
|
||||
|
||||
return PlatformHealthStatus(
|
||||
name=platform_name,
|
||||
url=url,
|
||||
configured=configured,
|
||||
api_key_set=api_key_set,
|
||||
status="configured" if configured else "not_configured",
|
||||
message="API Key已配置" if configured else "API Key未配置",
|
||||
)
|
||||
|
||||
|
||||
def check_kimi_health() -> PlatformHealthStatus:
|
||||
"""检查Kimi平台健康状态"""
|
||||
try:
|
||||
from app.workers.platforms.kimi import KimiAdapter
|
||||
|
||||
adapter = KimiAdapter()
|
||||
api_key_set = bool(adapter.api_key and adapter.api_key.strip())
|
||||
configured = adapter.is_configured
|
||||
|
||||
return PlatformHealthStatus(
|
||||
name="kimi",
|
||||
url=adapter.platform_url,
|
||||
configured=configured,
|
||||
api_key_set=api_key_set,
|
||||
status="configured" if configured else "not_configured",
|
||||
message="API Key已配置" if configured else "API Key未配置",
|
||||
return _check_api_key_health(
|
||||
platform_name="kimi",
|
||||
env_key_name="MOONSHOT_API_KEY",
|
||||
url=_PLATFORM_URLS["kimi"],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Kimi健康检查失败: {e}")
|
||||
|
|
@ -65,18 +74,16 @@ def check_kimi_health() -> PlatformHealthStatus:
|
|||
|
||||
|
||||
def check_wenxin_health() -> PlatformHealthStatus:
|
||||
"""检查文心平台健康状态"""
|
||||
try:
|
||||
from app.workers.platforms.wenxin import WenxinAdapter
|
||||
|
||||
adapter = WenxinAdapter()
|
||||
api_key_set = bool(adapter.api_key and adapter.api_key.strip())
|
||||
secret_key_set = bool(adapter.secret_key and adapter.secret_key.strip())
|
||||
configured = adapter.is_configured
|
||||
api_key = os.getenv("BAIDU_QIANFAN_API_KEY", "")
|
||||
secret_key = os.getenv("BAIDU_QIANFAN_SECRET_KEY", "")
|
||||
api_key_set = bool(api_key and api_key.strip())
|
||||
secret_key_set = bool(secret_key and secret_key.strip())
|
||||
configured = api_key_set and secret_key_set
|
||||
|
||||
return PlatformHealthStatus(
|
||||
name="wenxin",
|
||||
url=adapter.platform_url,
|
||||
url=_PLATFORM_URLS["wenxin"],
|
||||
configured=configured,
|
||||
api_key_set=api_key_set,
|
||||
status="configured" if configured else "not_configured",
|
||||
|
|
@ -93,21 +100,11 @@ def check_wenxin_health() -> PlatformHealthStatus:
|
|||
|
||||
|
||||
def check_doubao_health() -> PlatformHealthStatus:
|
||||
"""检查豆包平台健康状态"""
|
||||
try:
|
||||
from app.workers.platforms.doubao import DoubaoAdapter
|
||||
|
||||
adapter = DoubaoAdapter()
|
||||
api_key_set = bool(adapter.api_key and adapter.api_key.strip())
|
||||
configured = adapter.is_configured
|
||||
|
||||
return PlatformHealthStatus(
|
||||
name="doubao",
|
||||
url=adapter.platform_url,
|
||||
configured=configured,
|
||||
api_key_set=api_key_set,
|
||||
status="configured" if configured else "not_configured",
|
||||
message="API Key已配置" if configured else "API Key未配置",
|
||||
return _check_api_key_health(
|
||||
platform_name="doubao",
|
||||
env_key_name="DOUBAO_API_KEY",
|
||||
url=_PLATFORM_URLS["doubao"],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"豆包健康检查失败: {e}")
|
||||
|
|
@ -120,7 +117,6 @@ def check_doubao_health() -> PlatformHealthStatus:
|
|||
|
||||
|
||||
def check_all_platforms() -> dict:
|
||||
"""检查所有平台健康状态"""
|
||||
platforms = [
|
||||
check_kimi_health(),
|
||||
check_wenxin_health(),
|
||||
|
|
@ -136,29 +132,12 @@ def check_all_platforms() -> dict:
|
|||
|
||||
@router.get("/health")
|
||||
async def get_platform_health():
|
||||
"""
|
||||
获取所有AI平台适配器的健康状态
|
||||
|
||||
返回每个平台的:
|
||||
- name: 平台名称
|
||||
- configured: 是否已配置
|
||||
- url: 平台URL
|
||||
- api_key_set: API Key是否已设置
|
||||
- status: 健康状态 (configured / not_configured / error)
|
||||
- message: 状态消息
|
||||
"""
|
||||
health_info = check_all_platforms()
|
||||
return health_info
|
||||
|
||||
|
||||
@router.get("/health/{platform_name}")
|
||||
async def get_platform_health_by_name(platform_name: str):
|
||||
"""
|
||||
获取指定平台适配器的健康状态
|
||||
|
||||
Args:
|
||||
platform_name: 平台名称 (kimi / wenxin / doubao)
|
||||
"""
|
||||
if platform_name == "kimi":
|
||||
result = vars(check_kimi_health())
|
||||
elif platform_name == "wenxin":
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from app.database import get_db
|
|||
from app.models.user import User
|
||||
from app.schemas.citation import RunNowResponse
|
||||
from app.schemas.query import QueryCreate, QueryListResponse, QueryResponse, QueryUpdate
|
||||
from app.services.citation import trigger_query_now
|
||||
from app.services.citation.citation import trigger_query_now
|
||||
from app.services.query import create_query, delete_query, get_queries, get_query, update_query
|
||||
|
||||
router = APIRouter()
|
||||
|
|
|
|||
|
|
@ -1,23 +1,90 @@
|
|||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from starlette.responses import Response
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.database import get_db
|
||||
from app.models.brand import Brand
|
||||
from app.models.user import User
|
||||
from app.services.citation import export_citations_csv, export_citations_pdf
|
||||
from app.schemas.scoring import CitationResult
|
||||
from app.services.citation.citation import export_citations_csv, export_citations_pdf
|
||||
from app.services.scoring.scoring_service import ScoringService, ScoringResultV2
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
async def _compute_v2_scores(
|
||||
db: AsyncSession,
|
||||
user_id: uuid.UUID,
|
||||
brand_id: uuid.UUID,
|
||||
) -> ScoringResultV2 | None:
|
||||
try:
|
||||
from app.api.scoring import (
|
||||
_get_citations_for_brand,
|
||||
_analyze_sentiments_for_citations,
|
||||
)
|
||||
|
||||
total_queries, brand_citations, _, competitor_mentions = (
|
||||
await _get_citations_for_brand(db, user_id, brand_id)
|
||||
)
|
||||
|
||||
if total_queries == 0:
|
||||
return None
|
||||
|
||||
brand_stmt = select(Brand).where(
|
||||
Brand.id == brand_id, Brand.user_id == user_id
|
||||
)
|
||||
brand_result = await db.execute(brand_stmt)
|
||||
brand = brand_result.scalar_one_or_none()
|
||||
if not brand:
|
||||
return None
|
||||
|
||||
sentiment_counts = await _analyze_sentiments_for_citations(
|
||||
brand_name=brand.name,
|
||||
brand_citations=brand_citations,
|
||||
)
|
||||
|
||||
citation_results = [
|
||||
CitationResult(
|
||||
cited=c.cited,
|
||||
position=c.citation_position,
|
||||
citation_text=c.citation_text,
|
||||
sentiment=c.sentiment or "neutral",
|
||||
confidence=c.confidence or 0.0,
|
||||
)
|
||||
for c in brand_citations
|
||||
]
|
||||
|
||||
positions = [c.citation_position for c in brand_citations if c.cited]
|
||||
|
||||
scoring_service = ScoringService()
|
||||
return scoring_service.calculate_v2(
|
||||
mentioned_count=len(brand_citations),
|
||||
total_queries=total_queries,
|
||||
positions=positions,
|
||||
sentiment_counts=sentiment_counts,
|
||||
citations=citation_results,
|
||||
brand_mentions=len(brand_citations),
|
||||
competitor_mentions=competitor_mentions,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("V2 scoring failed for brand %s", brand_id, exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/export/csv")
|
||||
async def export_report(
|
||||
query_id: uuid.UUID = Query(...),
|
||||
brand_id: Optional[uuid.UUID] = Query(None),
|
||||
format: str = Query("csv"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
|
|
@ -29,7 +96,13 @@ async def export_report(
|
|||
)
|
||||
|
||||
try:
|
||||
csv_content = await export_citations_csv(db, current_user.id, query_id)
|
||||
v2_result = None
|
||||
if brand_id is not None:
|
||||
v2_result = await _compute_v2_scores(db, current_user.id, brand_id)
|
||||
|
||||
csv_content = await export_citations_csv(
|
||||
db, current_user.id, query_id, v2_result=v2_result
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
|
|
@ -51,11 +124,18 @@ async def export_report(
|
|||
@router.get("/export/pdf")
|
||||
async def export_pdf(
|
||||
query_id: Optional[uuid.UUID] = None,
|
||||
brand_id: Optional[uuid.UUID] = Query(None),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
try:
|
||||
pdf_bytes = await export_citations_pdf(db, current_user.id, query_id)
|
||||
v2_result = None
|
||||
if brand_id is not None:
|
||||
v2_result = await _compute_v2_scores(db, current_user.id, brand_id)
|
||||
|
||||
pdf_bytes = await export_citations_pdf(
|
||||
db, current_user.id, query_id, v2_result=v2_result
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,248 @@
|
|||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.database import get_db
|
||||
from app.models.user import User
|
||||
from app.models.brand import Brand
|
||||
from app.models.schema_suggestion import SchemaSuggestion
|
||||
from app.schemas.schema_suggestion import (
|
||||
SchemaAdviseRequest,
|
||||
SchemaSuggestionResponse,
|
||||
SchemaSuggestionList,
|
||||
SchemaValidationResult,
|
||||
SchemaStatusUpdateRequest,
|
||||
)
|
||||
from app.services.schema.schema_advisor_service import SchemaAdvisorService
|
||||
from app.services.scoring.scoring_service import ScoringService
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
async def _get_brand_with_access(
|
||||
brand_id: uuid.UUID,
|
||||
db: AsyncSession,
|
||||
current_user: User,
|
||||
) -> Brand:
|
||||
stmt = select(Brand).where(
|
||||
Brand.id == brand_id,
|
||||
Brand.user_id == current_user.id,
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
brand = result.scalar_one_or_none()
|
||||
if not brand:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="品牌不存在",
|
||||
)
|
||||
return brand
|
||||
|
||||
|
||||
async def _get_brand_diagnosis_data(
|
||||
db: AsyncSession,
|
||||
user_id: uuid.UUID,
|
||||
brand: Brand,
|
||||
) -> dict:
|
||||
from app.models.query import Query as QueryModel
|
||||
from app.models.citation_record import CitationRecord
|
||||
from app.models.competitor import Competitor
|
||||
from app.schemas.scoring import CitationResult
|
||||
from app.services.analysis.sentiment_service import get_sentiment_service
|
||||
|
||||
queries_stmt = select(QueryModel).where(
|
||||
QueryModel.user_id == user_id,
|
||||
QueryModel.target_brand == brand.name,
|
||||
)
|
||||
queries_result = await db.execute(queries_stmt)
|
||||
queries = list(queries_result.scalars().all())
|
||||
|
||||
if not queries:
|
||||
scoring_service = ScoringService()
|
||||
empty_result = scoring_service.calculate_v2(
|
||||
mentioned_count=0,
|
||||
total_queries=0,
|
||||
positions=[],
|
||||
sentiment_counts={"positive": 0, "neutral": 0, "negative": 0},
|
||||
citations=[],
|
||||
brand_mentions=0,
|
||||
competitor_mentions={},
|
||||
)
|
||||
return empty_result.to_dict()
|
||||
|
||||
query_ids = [q.id for q in queries]
|
||||
|
||||
citations_stmt = select(CitationRecord).where(
|
||||
CitationRecord.query_id.in_(query_ids),
|
||||
)
|
||||
citations_result = await db.execute(citations_stmt)
|
||||
all_citations = list(citations_result.scalars().all())
|
||||
|
||||
total_queries = len(all_citations)
|
||||
brand_citations = [c for c in all_citations if c.cited]
|
||||
|
||||
competitor_stmt = select(Competitor).where(Competitor.brand_id == brand.id)
|
||||
competitor_result = await db.execute(competitor_stmt)
|
||||
competitors = list(competitor_result.scalars().all())
|
||||
competitor_names = [c.name for c in competitors]
|
||||
|
||||
competitor_mentions: dict[str, int] = {}
|
||||
for comp_name in competitor_names:
|
||||
count = sum(
|
||||
1 for c in all_citations
|
||||
if c.cited and c.competitor_brands
|
||||
and comp_name in c.competitor_brands
|
||||
)
|
||||
if count > 0:
|
||||
competitor_mentions[comp_name] = count
|
||||
|
||||
sentiment_service = get_sentiment_service()
|
||||
sentiment_counts = {"positive": 0, "neutral": 0, "negative": 0}
|
||||
for citation in brand_citations:
|
||||
if citation.sentiment and citation.sentiment in ("positive", "neutral", "negative"):
|
||||
sentiment_counts[citation.sentiment] += 1
|
||||
else:
|
||||
content = citation.raw_response or citation.citation_text or ""
|
||||
if content.strip():
|
||||
try:
|
||||
result = await sentiment_service.analyze(
|
||||
brand_name=brand.name,
|
||||
content=content,
|
||||
)
|
||||
sentiment_counts[result.sentiment] += 1
|
||||
except Exception:
|
||||
sentiment_counts["neutral"] += 1
|
||||
else:
|
||||
sentiment_counts["neutral"] += 1
|
||||
|
||||
citation_results = [
|
||||
CitationResult(
|
||||
cited=c.cited,
|
||||
position=c.citation_position,
|
||||
citation_text=c.citation_text,
|
||||
sentiment="neutral",
|
||||
confidence=c.confidence or 0.0,
|
||||
)
|
||||
for c in brand_citations
|
||||
]
|
||||
|
||||
positions = [c.citation_position for c in brand_citations if c.cited]
|
||||
|
||||
scoring_service = ScoringService()
|
||||
v2_result = scoring_service.calculate_v2(
|
||||
mentioned_count=len(brand_citations),
|
||||
total_queries=total_queries,
|
||||
positions=positions,
|
||||
sentiment_counts=sentiment_counts,
|
||||
citations=citation_results,
|
||||
brand_mentions=len(brand_citations),
|
||||
competitor_mentions=competitor_mentions,
|
||||
)
|
||||
|
||||
return v2_result.to_dict()
|
||||
|
||||
|
||||
@router.post("/advise", response_model=SchemaSuggestionList)
|
||||
async def generate_schema_advise(
|
||||
request: SchemaAdviseRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
brand = await _get_brand_with_access(request.brand_id, db, current_user)
|
||||
|
||||
diagnosis_data = await _get_brand_diagnosis_data(db, current_user.id, brand)
|
||||
|
||||
brand_info = {
|
||||
"name": brand.name,
|
||||
"website": brand.website or "",
|
||||
"industry": brand.industry or "",
|
||||
}
|
||||
|
||||
service = SchemaAdvisorService()
|
||||
suggestions = await service.generate_suggestions(
|
||||
db=db,
|
||||
brand_id=brand.id,
|
||||
diagnosis_data=diagnosis_data,
|
||||
brand_info=brand_info,
|
||||
target_url=request.target_url,
|
||||
focus_dimensions=request.focus_dimensions,
|
||||
)
|
||||
|
||||
return SchemaSuggestionList(
|
||||
suggestions=[SchemaSuggestionResponse.model_validate(s) for s in suggestions],
|
||||
total=len(suggestions),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/brand/{brand_id}", response_model=SchemaSuggestionList)
|
||||
async def get_brand_schema_suggestions(
|
||||
brand_id: uuid.UUID,
|
||||
status_filter: str | None = Query(None, alias="status", description="按状态筛选"),
|
||||
schema_type: str | None = Query(None, description="按Schema类型筛选"),
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(20, ge=1, le=100),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
await _get_brand_with_access(brand_id, db, current_user)
|
||||
|
||||
service = SchemaAdvisorService()
|
||||
suggestions, total = await service.get_suggestions(
|
||||
db=db,
|
||||
brand_id=brand_id,
|
||||
status_filter=status_filter,
|
||||
schema_type=schema_type,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
return SchemaSuggestionList(
|
||||
suggestions=[SchemaSuggestionResponse.model_validate(s) for s in suggestions],
|
||||
total=total,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{suggestion_id}", response_model=SchemaSuggestionResponse)
|
||||
async def get_schema_suggestion_detail(
|
||||
suggestion_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
service = SchemaAdvisorService()
|
||||
suggestion = await service.get_suggestion_by_id(db, suggestion_id)
|
||||
if not suggestion:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="建议不存在",
|
||||
)
|
||||
brand = await _get_brand_with_access(suggestion.brand_id, db, current_user)
|
||||
return SchemaSuggestionResponse.model_validate(suggestion)
|
||||
|
||||
|
||||
@router.put("/{suggestion_id}/status", response_model=SchemaSuggestionResponse)
|
||||
async def update_schema_suggestion_status(
|
||||
suggestion_id: uuid.UUID,
|
||||
status_update: SchemaStatusUpdateRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
valid_statuses = {"pending", "applied", "dismissed"}
|
||||
if status_update.status not in valid_statuses:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"无效的状态值,支持: {', '.join(valid_statuses)}",
|
||||
)
|
||||
|
||||
service = SchemaAdvisorService()
|
||||
suggestion = await service.get_suggestion_by_id(db, suggestion_id)
|
||||
if not suggestion:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="建议不存在",
|
||||
)
|
||||
await _get_brand_with_access(suggestion.brand_id, db, current_user)
|
||||
|
||||
updated = await service.update_status(db, suggestion_id, status_update.status)
|
||||
return SchemaSuggestionResponse.model_validate(updated)
|
||||
|
|
@ -27,8 +27,8 @@ from app.schemas.scoring import (
|
|||
DimensionCompareItem,
|
||||
CitationResult,
|
||||
)
|
||||
from app.services.scoring_service import ScoringService, get_health_level
|
||||
from app.services.sentiment_service import get_sentiment_service
|
||||
from app.services.scoring.scoring_service import ScoringService, get_health_level
|
||||
from app.services.analysis.sentiment_service import get_sentiment_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -446,7 +446,7 @@ async def get_brand_score(
|
|||
|
||||
# 异步触发告警检测(不影响主流程)
|
||||
try:
|
||||
from app.services.alert_engine import AlertEngine
|
||||
from app.services.alert.alert_engine import AlertEngine
|
||||
alert_engine = AlertEngine(db)
|
||||
|
||||
# 获取当前已有提及的平台集合
|
||||
|
|
|
|||
|
|
@ -0,0 +1,388 @@
|
|||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.database import get_db
|
||||
from app.models.user import User
|
||||
from app.models.brand import Brand
|
||||
from app.models.geo_plan import GeoPlan, GeoPlanAction
|
||||
from app.schemas.geo_plan import (
|
||||
GeoPlanGenerateRequest,
|
||||
GeoPlanResponse,
|
||||
GeoPlanListResponse,
|
||||
GeoPlanActionResponse,
|
||||
GeoPlanActionUpdateStatus,
|
||||
GeoPlanActionExecuteResponse,
|
||||
)
|
||||
from app.services.scoring.brand_scoring_data_service import get_brand_scoring_data_service
|
||||
from app.services.strategy.geo_plan_generator import generate_geo_plan
|
||||
from app.services.content.content_generation_service import ContentGenerationService
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
async def _get_brand_with_access(
|
||||
brand_id: uuid.UUID,
|
||||
db: AsyncSession,
|
||||
current_user: User,
|
||||
) -> Brand:
|
||||
stmt = select(Brand).where(
|
||||
Brand.id == brand_id,
|
||||
Brand.user_id == current_user.id,
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
brand = result.scalar_one_or_none()
|
||||
|
||||
if not brand:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="品牌不存在",
|
||||
)
|
||||
return brand
|
||||
|
||||
|
||||
async def _get_brand_scoring_data(
|
||||
db: AsyncSession,
|
||||
user_id: uuid.UUID,
|
||||
brand: Brand,
|
||||
) -> tuple:
|
||||
scoring_data_service = get_brand_scoring_data_service()
|
||||
scoring_data = await scoring_data_service.get_brand_scoring_data(db, user_id, brand)
|
||||
return (
|
||||
scoring_data.v2_result,
|
||||
scoring_data.competitor_data,
|
||||
scoring_data.sentiment_counts,
|
||||
scoring_data.platform_scores,
|
||||
scoring_data.total_queries,
|
||||
scoring_data.mentioned_count,
|
||||
)
|
||||
|
||||
|
||||
def _plan_to_response(plan: GeoPlan) -> GeoPlanResponse:
|
||||
actions = [
|
||||
GeoPlanActionResponse(
|
||||
id=action.id,
|
||||
plan_id=action.plan_id,
|
||||
action_type=action.action_type,
|
||||
title=action.title,
|
||||
description=action.description,
|
||||
reason=action.reason,
|
||||
priority=action.priority,
|
||||
status=action.status,
|
||||
target_keyword=action.target_keyword,
|
||||
target_platform=action.target_platform,
|
||||
content_style=action.content_style,
|
||||
estimated_impact=action.estimated_impact,
|
||||
difficulty=action.difficulty,
|
||||
execution_params=action.execution_params,
|
||||
sort_order=action.sort_order,
|
||||
completed_at=action.completed_at,
|
||||
created_at=action.created_at,
|
||||
)
|
||||
for action in sorted(plan.actions, key=lambda a: a.sort_order)
|
||||
]
|
||||
return GeoPlanResponse(
|
||||
id=plan.id,
|
||||
brand_id=plan.brand_id,
|
||||
title=plan.title,
|
||||
status=plan.status,
|
||||
diagnosis_score=plan.diagnosis_score,
|
||||
target_score=plan.target_score,
|
||||
estimated_weeks=plan.estimated_weeks,
|
||||
plan_data=plan.plan_data,
|
||||
source=plan.source,
|
||||
actions=actions,
|
||||
created_at=plan.created_at,
|
||||
updated_at=plan.updated_at,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/generate", response_model=GeoPlanResponse)
|
||||
async def generate_geo_plan_endpoint(
|
||||
request: GeoPlanGenerateRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
brand = await _get_brand_with_access(request.brand_id, db, current_user)
|
||||
|
||||
(
|
||||
v2_result,
|
||||
competitor_data,
|
||||
sentiment_data,
|
||||
platform_scores,
|
||||
total_queries,
|
||||
mentioned_count,
|
||||
) = await _get_brand_scoring_data(db, current_user.id, brand)
|
||||
|
||||
target_score = request.target_score or 75
|
||||
|
||||
plan_data = await generate_geo_plan(
|
||||
brand_name=brand.name,
|
||||
scoring_result=v2_result,
|
||||
target_score=target_score,
|
||||
total_queries=total_queries,
|
||||
platform_scores=platform_scores,
|
||||
competitor_data=competitor_data,
|
||||
)
|
||||
|
||||
from app.config import settings
|
||||
source = "llm" if (settings.ENABLE_LLM and settings.DEEPSEEK_API_KEY) else "rule"
|
||||
|
||||
organization_id = current_user.id
|
||||
org_stmt = select(func.count()).select_from(
|
||||
select(1).where(True).subquery()
|
||||
)
|
||||
|
||||
db_plan = GeoPlan(
|
||||
organization_id=organization_id,
|
||||
brand_id=brand.id,
|
||||
title=plan_data.title,
|
||||
status="draft",
|
||||
diagnosis_score=int(round(v2_result.overall_score)),
|
||||
target_score=target_score,
|
||||
estimated_weeks=plan_data.estimated_weeks,
|
||||
plan_data={
|
||||
"weekly_plan": plan_data.weekly_plan,
|
||||
},
|
||||
source=source,
|
||||
created_by=current_user.id,
|
||||
)
|
||||
db.add(db_plan)
|
||||
await db.flush()
|
||||
|
||||
for idx, action_item in enumerate(plan_data.actions):
|
||||
db_action = GeoPlanAction(
|
||||
plan_id=db_plan.id,
|
||||
action_type=action_item.action_type,
|
||||
title=action_item.title,
|
||||
description=action_item.description,
|
||||
reason=action_item.reason,
|
||||
priority=action_item.priority,
|
||||
status="pending",
|
||||
target_keyword=action_item.target_keyword,
|
||||
target_platform=action_item.target_platform,
|
||||
content_style=action_item.content_style,
|
||||
estimated_impact=action_item.estimated_impact,
|
||||
difficulty=action_item.difficulty,
|
||||
execution_params=action_item.execution_params,
|
||||
sort_order=idx,
|
||||
)
|
||||
db.add(db_action)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(db_plan)
|
||||
|
||||
stmt = (
|
||||
select(GeoPlan)
|
||||
.options(selectinload(GeoPlanAction.plan))
|
||||
.where(GeoPlan.id == db_plan.id)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
db_plan = result.scalar_one()
|
||||
|
||||
action_stmt = select(GeoPlanAction).where(
|
||||
GeoPlanAction.plan_id == db_plan.id
|
||||
).order_by(GeoPlanAction.sort_order)
|
||||
action_result = await db.execute(action_stmt)
|
||||
db_plan.actions = list(action_result.scalars().all())
|
||||
|
||||
return _plan_to_response(db_plan)
|
||||
|
||||
|
||||
@router.get("/brand/{brand_id}", response_model=GeoPlanListResponse)
|
||||
async def get_brand_plans(
|
||||
brand_id: uuid.UUID,
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(20, ge=1, le=100),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
await _get_brand_with_access(brand_id, db, current_user)
|
||||
|
||||
count_stmt = select(func.count()).select_from(GeoPlan).where(
|
||||
GeoPlan.brand_id == brand_id,
|
||||
)
|
||||
count_result = await db.execute(count_stmt)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
stmt = (
|
||||
select(GeoPlan)
|
||||
.where(GeoPlan.brand_id == brand_id)
|
||||
.order_by(GeoPlan.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
plans = list(result.scalars().all())
|
||||
|
||||
plan_responses = []
|
||||
for plan in plans:
|
||||
action_stmt = select(GeoPlanAction).where(
|
||||
GeoPlanAction.plan_id == plan.id
|
||||
).order_by(GeoPlanAction.sort_order)
|
||||
action_result = await db.execute(action_stmt)
|
||||
plan.actions = list(action_result.scalars().all())
|
||||
plan_responses.append(_plan_to_response(plan))
|
||||
|
||||
return GeoPlanListResponse(plans=plan_responses, total=total)
|
||||
|
||||
|
||||
@router.get("/{plan_id}", response_model=GeoPlanResponse)
|
||||
async def get_plan_detail(
|
||||
plan_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
stmt = select(GeoPlan).where(GeoPlan.id == plan_id)
|
||||
result = await db.execute(stmt)
|
||||
plan = result.scalar_one_or_none()
|
||||
|
||||
if not plan:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="方案不存在",
|
||||
)
|
||||
|
||||
brand = await _get_brand_with_access(plan.brand_id, db, current_user)
|
||||
|
||||
action_stmt = select(GeoPlanAction).where(
|
||||
GeoPlanAction.plan_id == plan.id
|
||||
).order_by(GeoPlanAction.sort_order)
|
||||
action_result = await db.execute(action_stmt)
|
||||
plan.actions = list(action_result.scalars().all())
|
||||
|
||||
return _plan_to_response(plan)
|
||||
|
||||
|
||||
@router.put("/actions/{action_id}/status", response_model=GeoPlanActionResponse)
|
||||
async def update_action_status(
|
||||
action_id: uuid.UUID,
|
||||
status_update: GeoPlanActionUpdateStatus,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
valid_statuses = {"pending", "in_progress", "completed", "skipped"}
|
||||
if status_update.status not in valid_statuses:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"无效的状态值,支持: {', '.join(valid_statuses)}",
|
||||
)
|
||||
|
||||
stmt = select(GeoPlanAction).where(GeoPlanAction.id == action_id)
|
||||
result = await db.execute(stmt)
|
||||
action = result.scalar_one_or_none()
|
||||
|
||||
if not action:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="行动项不存在",
|
||||
)
|
||||
|
||||
plan_stmt = select(GeoPlan).where(GeoPlan.id == action.plan_id)
|
||||
plan_result = await db.execute(plan_stmt)
|
||||
plan = plan_result.scalar_one()
|
||||
|
||||
await _get_brand_with_access(plan.brand_id, db, current_user)
|
||||
|
||||
action.status = status_update.status
|
||||
if status_update.status == "completed":
|
||||
action.completed_at = datetime.now()
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(action)
|
||||
|
||||
return GeoPlanActionResponse(
|
||||
id=action.id,
|
||||
plan_id=action.plan_id,
|
||||
action_type=action.action_type,
|
||||
title=action.title,
|
||||
description=action.description,
|
||||
reason=action.reason,
|
||||
priority=action.priority,
|
||||
status=action.status,
|
||||
target_keyword=action.target_keyword,
|
||||
target_platform=action.target_platform,
|
||||
content_style=action.content_style,
|
||||
estimated_impact=action.estimated_impact,
|
||||
difficulty=action.difficulty,
|
||||
execution_params=action.execution_params,
|
||||
sort_order=action.sort_order,
|
||||
completed_at=action.completed_at,
|
||||
created_at=action.created_at,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/actions/{action_id}/execute", response_model=GeoPlanActionExecuteResponse)
|
||||
async def execute_action(
|
||||
action_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
stmt = select(GeoPlanAction).where(GeoPlanAction.id == action_id)
|
||||
result = await db.execute(stmt)
|
||||
action = result.scalar_one_or_none()
|
||||
|
||||
if not action:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="行动项不存在",
|
||||
)
|
||||
|
||||
plan_stmt = select(GeoPlan).where(GeoPlan.id == action.plan_id)
|
||||
plan_result = await db.execute(plan_stmt)
|
||||
plan = plan_result.scalar_one()
|
||||
|
||||
brand = await _get_brand_with_access(plan.brand_id, db, current_user)
|
||||
|
||||
if action.action_type not in ("content_creation", "content_optimization"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"行动类型 '{action.action_type}' 不支持一键执行,仅支持 content_creation 和 content_optimization",
|
||||
)
|
||||
|
||||
params = action.execution_params or {}
|
||||
keyword = params.get("keyword", action.target_keyword or brand.name)
|
||||
platform = params.get("platform", action.target_platform or "通用")
|
||||
style = params.get("style", action.content_style or "专业严谨")
|
||||
word_count = params.get("word_count", 2000)
|
||||
knowledge_base_ids = params.get("knowledge_base_ids")
|
||||
|
||||
content_service = ContentGenerationService()
|
||||
|
||||
try:
|
||||
gen_result = await content_service.generate_content(
|
||||
keyword=keyword,
|
||||
brand_name=brand.name,
|
||||
platform=platform,
|
||||
content_style=style,
|
||||
word_count=word_count,
|
||||
knowledge_base_ids=knowledge_base_ids,
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
org_id=str(plan.organization_id),
|
||||
run_deai=True,
|
||||
run_geo=True,
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"内容生成失败: {str(e)}",
|
||||
)
|
||||
|
||||
action.status = "completed"
|
||||
action.completed_at = datetime.now()
|
||||
await db.commit()
|
||||
await db.refresh(action)
|
||||
|
||||
content_id = gen_result.get("content_id")
|
||||
|
||||
return GeoPlanActionExecuteResponse(
|
||||
action_id=action.id,
|
||||
content_id=content_id,
|
||||
message="内容生成成功" if content_id else "内容生成完成(未持久化)",
|
||||
)
|
||||
|
|
@ -22,9 +22,9 @@ from app.schemas.suggestion import (
|
|||
SuggestionHistoryResponse,
|
||||
)
|
||||
from app.schemas.scoring import CitationResult
|
||||
from app.services.scoring_service import ScoringService
|
||||
from app.services.sentiment_service import get_sentiment_service
|
||||
from app.services.optimization_advisor import (
|
||||
from app.services.scoring.scoring_service import ScoringService
|
||||
from app.services.analysis.sentiment_service import get_sentiment_service
|
||||
from app.services.advisor.optimization_advisor import (
|
||||
generate_suggestions,
|
||||
build_context_from_scoring_result,
|
||||
)
|
||||
|
|
@ -163,7 +163,7 @@ async def _get_brand_scoring_data(
|
|||
)
|
||||
|
||||
# 计算平台评分
|
||||
from app.api.dashboard import REQUIRED_PLATFORMS
|
||||
from app.services.scoring.brand_scoring_data_service import REQUIRED_PLATFORMS
|
||||
platform_scores: dict[str, float] = {}
|
||||
for platform in REQUIRED_PLATFORMS:
|
||||
platform_citations = [c for c in all_citations if c.platform == platform]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,124 @@
|
|||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.database import get_db
|
||||
from app.models.user import User
|
||||
from app.models.brand import Brand
|
||||
from app.models.trend_insight import TrendInsight
|
||||
from app.schemas.trend_insight import (
|
||||
TrendInsightRequest,
|
||||
TrendInsightResponse,
|
||||
TrendInsightList,
|
||||
TrendSummary,
|
||||
)
|
||||
from app.services.trend.trend_analyzer_service import TrendAnalyzerService
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
async def _get_brand_with_access(
|
||||
brand_id: uuid.UUID,
|
||||
db: AsyncSession,
|
||||
current_user: User,
|
||||
) -> Brand:
|
||||
stmt = select(Brand).where(
|
||||
Brand.id == brand_id,
|
||||
Brand.user_id == current_user.id,
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
brand = result.scalar_one_or_none()
|
||||
|
||||
if not brand:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="品牌不存在",
|
||||
)
|
||||
return brand
|
||||
|
||||
|
||||
@router.post("/insight", response_model=TrendInsightResponse)
|
||||
async def create_trend_insight(
|
||||
request: TrendInsightRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
await _get_brand_with_access(request.brand_id, db, current_user)
|
||||
|
||||
service = TrendAnalyzerService(db)
|
||||
result = await service.analyze_trends(
|
||||
brand_id=request.brand_id,
|
||||
days=request.period_days,
|
||||
platforms=request.platforms,
|
||||
keywords=request.keywords,
|
||||
)
|
||||
|
||||
if result.get("status") == "insufficient_data":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=result.get("message", "数据不足"),
|
||||
)
|
||||
|
||||
insight_id = uuid.UUID(result["insight_id"])
|
||||
insight = await service.get_insight_by_id(insight_id)
|
||||
if insight is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="洞察创建失败",
|
||||
)
|
||||
return insight
|
||||
|
||||
|
||||
@router.get("/brand/{brand_id}", response_model=TrendInsightList)
|
||||
async def list_trend_insights(
|
||||
brand_id: uuid.UUID,
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(20, ge=1, le=100),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
await _get_brand_with_access(brand_id, db, current_user)
|
||||
|
||||
service = TrendAnalyzerService(db)
|
||||
items, total = await service.get_insights(
|
||||
brand_id=brand_id,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
)
|
||||
return TrendInsightList(items=items, total=total)
|
||||
|
||||
|
||||
@router.get("/brand/{brand_id}/summary", response_model=TrendSummary)
|
||||
async def get_trend_summary(
|
||||
brand_id: uuid.UUID,
|
||||
period_days: int = Query(30, ge=7, le=365),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
await _get_brand_with_access(brand_id, db, current_user)
|
||||
|
||||
service = TrendAnalyzerService(db)
|
||||
summary = await service.get_summary(
|
||||
brand_id=brand_id,
|
||||
days=period_days,
|
||||
)
|
||||
return TrendSummary(**summary)
|
||||
|
||||
|
||||
@router.get("/{insight_id}", response_model=TrendInsightResponse)
|
||||
async def get_trend_insight_detail(
|
||||
insight_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
service = TrendAnalyzerService(db)
|
||||
insight = await service.get_insight_by_id(insight_id)
|
||||
if insight is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="洞察不存在",
|
||||
)
|
||||
return insight
|
||||
|
|
@ -64,6 +64,40 @@ class Settings(BaseSettings):
|
|||
# AI平台API调用频率限制(每分钟请求数)
|
||||
API_RATE_LIMIT_RPM: int = 10
|
||||
|
||||
# Payment Gateway Configuration
|
||||
WECHAT_PAY_MCH_ID: str = ""
|
||||
WECHAT_PAY_API_KEY: str = ""
|
||||
WECHAT_PAY_APP_ID: str = ""
|
||||
WECHAT_PAY_CERT_PATH: str = ""
|
||||
WECHAT_PAY_NOTIFY_URL: str = ""
|
||||
ALIPAY_APP_ID: str = ""
|
||||
ALIPAY_PRIVATE_KEY_PATH: str = ""
|
||||
ALIPAY_PUBLIC_KEY_PATH: str = ""
|
||||
ALIPAY_NOTIFY_URL: str = ""
|
||||
PAYMENT_MODE: str = "mock"
|
||||
|
||||
ZHIHU_CLIENT_ID: str = ""
|
||||
ZHIHU_CLIENT_SECRET: str = ""
|
||||
ZHIHU_ACCESS_TOKEN: str = ""
|
||||
TOUTIAO_APP_ID: str = ""
|
||||
TOUTIAO_APP_SECRET: str = ""
|
||||
TOUTIAO_ACCESS_TOKEN: str = ""
|
||||
WECHAT_OFFICIAL_APP_ID: str = ""
|
||||
WECHAT_OFFICIAL_APP_SECRET: str = ""
|
||||
SMTP_HOST: str = ""
|
||||
SMTP_PORT: int = 587
|
||||
SMTP_USER: str = ""
|
||||
SMTP_PASSWORD: str = ""
|
||||
SMTP_FROM_EMAIL: str = "noreply@geo-platform.com"
|
||||
SMTP_FROM_NAME: str = "GEO平台"
|
||||
EMAIL_MODE: str = "mock"
|
||||
SENDGRID_API_KEY: str = ""
|
||||
ALIYUN_MAIL_ACCESS_KEY: str = ""
|
||||
ALIYUN_MAIL_ACCESS_SECRET: str = ""
|
||||
ALIYUN_MAIL_REGION: str = "cn-hangzhou"
|
||||
|
||||
DISTRIBUTION_MODE: str = "mock"
|
||||
|
||||
@field_validator("JWT_SECRET")
|
||||
@classmethod
|
||||
def validate_jwt_secret(cls, v: str) -> str:
|
||||
|
|
|
|||
|
|
@ -43,14 +43,21 @@ from app.api.ai_engines import router as ai_engines_router
|
|||
from app.api.detection import router as detection_router
|
||||
from app.api.api_keys import router as api_keys_router
|
||||
from app.api.usage import router as usage_router
|
||||
from app.api.strategy import router as strategy_router
|
||||
from app.api.competitor_analysis import router as competitor_analysis_router
|
||||
from app.api.trends import router as trends_router
|
||||
from app.api.schema_advisor import router as schema_advisor_router
|
||||
from app.api.monitoring import router as monitoring_router
|
||||
from app.api.health_score import router as health_score_router
|
||||
from app.api.payments import router as payments_router
|
||||
from app.api.attribution import router as attribution_router
|
||||
from app.config import settings
|
||||
from app.database import engine, Base
|
||||
from app.schemas.common import ErrorResponse, ErrorCode
|
||||
from app.middleware.rate_limit import RateLimitMiddleware
|
||||
from app.middleware.logging_middleware import RequestLoggingMiddleware
|
||||
from app.middleware.request_id import RequestIdMiddleware
|
||||
from app.middleware.metrics import MetricsMiddleware
|
||||
from app.monitoring.middleware import MonitoringMiddleware
|
||||
from app.middleware.metrics import MetricsMiddleware, MonitoringMiddleware
|
||||
from app.database import get_db
|
||||
from app.workers.scheduler import query_scheduler
|
||||
|
||||
|
|
@ -59,7 +66,13 @@ from app.workers.scheduler import query_scheduler
|
|||
async def lifespan(app: FastAPI):
|
||||
import app.models
|
||||
|
||||
import app.monitoring
|
||||
import app.middleware.prometheus_metrics
|
||||
from app.middleware.prometheus_metrics import SERVICE_INFO
|
||||
import os
|
||||
SERVICE_INFO.info({
|
||||
"version": "1.0.0",
|
||||
"environment": os.getenv("ENVIRONMENT", "development"),
|
||||
})
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.execute(text("SELECT 1"))
|
||||
|
|
@ -120,10 +133,14 @@ _allow_origins = [origin.strip() for origin in settings.CORS_ORIGINS.split(",")
|
|||
if not _allow_origins:
|
||||
_allow_origins = ["http://localhost:3000"]
|
||||
|
||||
import os
|
||||
|
||||
_is_dev = os.getenv("ENVIRONMENT", "development") == "development"
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=_allow_origins,
|
||||
allow_credentials=True,
|
||||
allow_origins=_allow_origins if not _is_dev else ["*"],
|
||||
allow_credentials=not _is_dev,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
|
@ -174,6 +191,14 @@ app.include_router(ai_engines_router, prefix="/api/v1/ai-engines", tags=["AI引
|
|||
app.include_router(detection_router, prefix="/api/v1/detection", tags=["定时检测任务"])
|
||||
app.include_router(api_keys_router, prefix="/api/v1/api-keys", tags=["API Key管理"])
|
||||
app.include_router(usage_router, prefix="/api/v1/usage", tags=["用量追踪"])
|
||||
app.include_router(strategy_router, prefix="/api/v1/strategy", tags=["GEO方案"])
|
||||
app.include_router(competitor_analysis_router, prefix="/api/v1/competitor", tags=["竞品分析"])
|
||||
app.include_router(schema_advisor_router, prefix="/api/v1/schema", tags=["Schema建议"])
|
||||
app.include_router(trends_router, prefix="/api/v1/trends", tags=["趋势洞察"])
|
||||
app.include_router(monitoring_router, prefix="/api/v1/monitoring", tags=["效果追踪"])
|
||||
app.include_router(health_score_router, prefix="/api/v1/public", tags=["公开API"])
|
||||
app.include_router(payments_router)
|
||||
app.include_router(attribution_router, prefix="/api/v1/attribution", tags=["效果归因"])
|
||||
|
||||
|
||||
@app.get("/health", tags=["可观测性"])
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import time
|
|||
from contextlib import asynccontextmanager
|
||||
from typing import Optional
|
||||
|
||||
from app.monitoring.metrics import (
|
||||
from app.middleware.prometheus_metrics import (
|
||||
AGENT_EXECUTIONS_TOTAL,
|
||||
AGENT_EXECUTION_DURATION_SECONDS,
|
||||
AGENT_RUNNING_TASKS,
|
||||
|
|
@ -2,7 +2,7 @@
|
|||
import time
|
||||
from typing import Optional
|
||||
|
||||
from app.monitoring.metrics import (
|
||||
from app.middleware.prometheus_metrics import (
|
||||
LLM_REQUESTS_TOTAL,
|
||||
LLM_REQUEST_DURATION_SECONDS,
|
||||
LLM_TOKENS_TOTAL,
|
||||
|
|
@ -1,9 +1,19 @@
|
|||
"""请求指标收集中间件:计时、慢请求告警、响应时间响应头。"""
|
||||
"""请求指标收集中间件:计时、慢请求告警、响应时间响应头、Prometheus指标收集。
|
||||
|
||||
合并自原 middleware/metrics.py(MetricsMiddleware)和 monitoring/middleware.py(MonitoringMiddleware)。
|
||||
"""
|
||||
import time
|
||||
import logging
|
||||
from typing import Callable
|
||||
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
from app.middleware.prometheus_metrics import (
|
||||
API_REQUESTS_TOTAL,
|
||||
API_REQUEST_DURATION_SECONDS,
|
||||
API_REQUESTS_IN_PROGRESS,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("geo.metrics")
|
||||
|
||||
|
|
@ -11,14 +21,14 @@ logger = logging.getLogger("geo.metrics")
|
|||
SLOW_REQUEST_THRESHOLD = 1.0
|
||||
|
||||
# 跳过指标收集的路径前缀(健康检查等高频低价值路径)
|
||||
_SKIP_PATHS = {"/health", "/ready", "/docs", "/openapi.json", "/favicon.ico"}
|
||||
_SKIP_PATHS = {"/health", "/ready", "/docs", "/openapi.json", "/favicon.ico", "/metrics"}
|
||||
|
||||
|
||||
class MetricsMiddleware(BaseHTTPMiddleware):
|
||||
"""记录每个 HTTP 请求的耗时,并:
|
||||
- 在响应头写入 X-Response-Time
|
||||
- 对超过阈值的慢请求输出 WARNING 日志(携带结构化字段)
|
||||
- 预留 Sentry / Prometheus 集成点(TODO 注释标注)
|
||||
- 预留 Sentry 集成点(TODO 注释标注)
|
||||
"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
|
|
@ -51,10 +61,82 @@ class MetricsMiddleware(BaseHTTPMiddleware):
|
|||
else:
|
||||
logger.debug("Request completed", extra=log_extra)
|
||||
|
||||
# TODO: 集成 Prometheus Counter/Histogram
|
||||
# metrics_registry.http_request_duration.observe(duration, labels={...})
|
||||
|
||||
# TODO: 集成 Sentry 性能监控
|
||||
# if sentry_sdk: sentry_sdk.set_measurement("response_time_ms", duration_ms)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class MonitoringMiddleware(BaseHTTPMiddleware):
|
||||
"""API监控中间件 — 收集 Prometheus 指标。
|
||||
|
||||
- 记录请求总数、耗时分布、活跃请求数
|
||||
- 自动规范化端点标签(替换路径中的ID参数)
|
||||
"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
# 跳过排除路径
|
||||
if request.url.path in _SKIP_PATHS:
|
||||
return await call_next(request)
|
||||
|
||||
# 提取端点标识(用于指标标签)
|
||||
endpoint = self._get_endpoint_label(request)
|
||||
|
||||
# 增加活跃请求计数
|
||||
API_REQUESTS_IN_PROGRESS.labels(
|
||||
method=request.method,
|
||||
endpoint=endpoint
|
||||
).inc()
|
||||
|
||||
# 记录开始时间
|
||||
start_time = time.perf_counter()
|
||||
|
||||
try:
|
||||
# 执行请求
|
||||
response = await call_next(request)
|
||||
status_code = response.status_code
|
||||
except Exception as e:
|
||||
status_code = 500
|
||||
raise
|
||||
finally:
|
||||
# 计算耗时
|
||||
duration = time.perf_counter() - start_time
|
||||
|
||||
# 记录指标
|
||||
API_REQUESTS_TOTAL.labels(
|
||||
method=request.method,
|
||||
endpoint=endpoint,
|
||||
status=str(status_code)
|
||||
).inc()
|
||||
|
||||
API_REQUEST_DURATION_SECONDS.labels(
|
||||
method=request.method,
|
||||
endpoint=endpoint
|
||||
).observe(duration)
|
||||
|
||||
# 减少活跃请求计数
|
||||
API_REQUESTS_IN_PROGRESS.labels(
|
||||
method=request.method,
|
||||
endpoint=endpoint
|
||||
).dec()
|
||||
|
||||
return response
|
||||
|
||||
def _get_endpoint_label(self, request: Request) -> str:
|
||||
"""提取端点标签"""
|
||||
path = request.url.path
|
||||
|
||||
# 规范化路径(替换ID等参数)
|
||||
parts = path.strip("/").split("/")
|
||||
|
||||
# 处理常见模式:/api/v1/resources/{id}
|
||||
if len(parts) >= 4 and parts[0] == "api":
|
||||
resource = parts[2] if len(parts) > 2 else "unknown"
|
||||
action = parts[3] if len(parts) > 3 else "list"
|
||||
|
||||
# 映射到规范标签
|
||||
if action.isdigit():
|
||||
return f"{resource}_detail"
|
||||
return f"{resource}_{action}"
|
||||
|
||||
return "other"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,57 @@
|
|||
from fastapi import Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.database import get_db
|
||||
from app.models.user import User
|
||||
from app.services.subscription import PLANS
|
||||
|
||||
|
||||
class SubscriptionEnforcement:
|
||||
@staticmethod
|
||||
def require_plan(*allowed_plans: str):
|
||||
async def _check(current_user: User = Depends(get_current_user)):
|
||||
user_plan = getattr(current_user, "plan", "free") or "free"
|
||||
if user_plan not in allowed_plans:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={
|
||||
"message": f"此功能需要 {allowed_plans[0]} 及以上套餐",
|
||||
"required_plan": allowed_plans[0],
|
||||
"current_plan": user_plan,
|
||||
"upgrade_url": "/api/v1/subscriptions/plans",
|
||||
},
|
||||
)
|
||||
return current_user
|
||||
return _check
|
||||
|
||||
@staticmethod
|
||||
def check_quota(resource: str):
|
||||
async def _check(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
user_plan = getattr(current_user, "plan", "free") or "free"
|
||||
plan_config = PLANS.get(user_plan, PLANS["free"])
|
||||
|
||||
if resource == "queries":
|
||||
limit = plan_config.get("max_queries", 3)
|
||||
current_usage = getattr(current_user, "max_queries", limit) or limit
|
||||
remaining = max(0, limit - current_usage)
|
||||
elif resource == "brands":
|
||||
limit = plan_config.get("max_brands", 1)
|
||||
remaining = limit if limit == -1 else max(0, limit)
|
||||
elif resource == "alerts":
|
||||
limit = plan_config.get("max_alerts_per_month", 0)
|
||||
remaining = limit if limit == -1 else max(0, limit)
|
||||
else:
|
||||
remaining = 0
|
||||
|
||||
return {
|
||||
"user_id": current_user.id,
|
||||
"plan": user_plan,
|
||||
"resource": resource,
|
||||
"remaining": remaining,
|
||||
"unlimited": remaining == -1,
|
||||
}
|
||||
return _check
|
||||
|
|
@ -33,6 +33,14 @@ from app.models.alert import Alert
|
|||
from app.models.alert_setting import AlertSetting
|
||||
from app.models.detection_task import DetectionTask
|
||||
from app.models.usage_record import UsageRecord
|
||||
from app.models.geo_plan import GeoPlan, GeoPlanAction
|
||||
from app.models.trend_insight import TrendInsight
|
||||
from app.models.competitor_insight import CompetitorInsight
|
||||
from app.models.schema_suggestion import SchemaSuggestion
|
||||
from app.models.monitoring import MonitoringRecord, ContentBaseline
|
||||
from app.models.diagnosis_record import DiagnosisRecord
|
||||
from app.models.payment_order import PaymentOrder
|
||||
from app.models.attribution_record import AttributionRecord
|
||||
|
||||
__all__ = [
|
||||
"User",
|
||||
|
|
@ -76,4 +84,14 @@ __all__ = [
|
|||
"AlertSetting",
|
||||
"DetectionTask",
|
||||
"UsageRecord",
|
||||
"GeoPlan",
|
||||
"GeoPlanAction",
|
||||
"CompetitorInsight",
|
||||
"SchemaSuggestion",
|
||||
"TrendInsight",
|
||||
"MonitoringRecord",
|
||||
"ContentBaseline",
|
||||
"DiagnosisRecord",
|
||||
"PaymentOrder",
|
||||
"AttributionRecord",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -74,8 +74,8 @@ class AgentConfig(Base):
|
|||
onupdate=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
updated_by: Mapped[uuid.UUID | None] = mapped_column(
|
||||
Uuid(as_uuid=True),
|
||||
updated_by: Mapped[str | None] = mapped_column(
|
||||
String(36),
|
||||
ForeignKey("users.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,65 @@
|
|||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import String, Float, Integer, ForeignKey, Index, func, Text
|
||||
from sqlalchemy import Uuid
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.database import Base, JSONType
|
||||
|
||||
|
||||
class AttributionRecord(Base):
|
||||
__tablename__ = "attribution_records"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
Uuid(as_uuid=True),
|
||||
primary_key=True,
|
||||
default=uuid.uuid4,
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
Text,
|
||||
ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
brand_id: Mapped[uuid.UUID] = mapped_column(
|
||||
Uuid(as_uuid=True),
|
||||
ForeignKey("brands.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
content_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
Uuid(as_uuid=True),
|
||||
ForeignKey("contents.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
baseline_score: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
current_score: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
score_delta: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
attribution_window_days: Mapped[int] = mapped_column(
|
||||
Integer, server_default="28", nullable=False,
|
||||
)
|
||||
published_at: Mapped[datetime | None] = mapped_column(nullable=True)
|
||||
window_end_at: Mapped[datetime | None] = mapped_column(nullable=True)
|
||||
status: Mapped[str] = mapped_column(
|
||||
String(20), server_default="tracking", nullable=False,
|
||||
)
|
||||
attributed_dimensions: Mapped[dict | None] = mapped_column(JSONType, nullable=True)
|
||||
roi_percentage: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
brand: Mapped["Brand"] = relationship("Brand")
|
||||
content: Mapped["Content | None"] = relationship("Content")
|
||||
|
||||
__table_args__ = (
|
||||
Index("idx_attribution_records_brand_id", "brand_id"),
|
||||
Index("idx_attribution_records_user_id", "user_id"),
|
||||
Index("idx_attribution_records_status", "status"),
|
||||
Index("idx_attribution_records_content_id", "content_id"),
|
||||
)
|
||||
|
|
@ -62,7 +62,12 @@ class Brand(Base):
|
|||
"Suggestion", back_populates="brand", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
schema_suggestions: Mapped[list["SchemaSuggestion"]] = relationship(
|
||||
"SchemaSuggestion", back_populates="brand", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
# Import at bottom to avoid circular import
|
||||
from app.models.competitor import Competitor # noqa: E402, F401
|
||||
from app.models.suggestion import Suggestion # noqa: E402, F401
|
||||
from app.models.schema_suggestion import SchemaSuggestion # noqa: E402, F401
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from sqlalchemy import Uuid, JSON
|
|||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.database import Base
|
||||
from app.utils.text import sanitize_raw_response
|
||||
|
||||
|
||||
class CitationRecord(Base):
|
||||
|
|
@ -75,3 +76,40 @@ class CitationRecord(Base):
|
|||
Index("idx_citation_records_queried_at", "queried_at"),
|
||||
Index("idx_citation_records_platform", "platform"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_citation_result(
|
||||
cls,
|
||||
query_id: uuid.UUID,
|
||||
platform: str,
|
||||
result: dict,
|
||||
) -> "CitationRecord":
|
||||
"""从引用检测结果字典创建 CitationRecord 实例
|
||||
|
||||
统一处理字段映射、默认值和 raw_response / ai_response_text 的清理。
|
||||
|
||||
Args:
|
||||
query_id: 关联的查询 ID
|
||||
platform: 平台名称
|
||||
result: 引用检测结果字典
|
||||
|
||||
Returns:
|
||||
CitationRecord 实例(未持久化)
|
||||
"""
|
||||
return cls(
|
||||
query_id=query_id,
|
||||
platform=platform,
|
||||
cited=result.get("cited", False),
|
||||
citation_position=result.get("position"),
|
||||
citation_text=result.get("citation_text"),
|
||||
competitor_brands=result.get("competitor_brands", []),
|
||||
raw_response=sanitize_raw_response(result.get("raw_response", "")),
|
||||
confidence=result.get("confidence"),
|
||||
match_type=result.get("match_type"),
|
||||
# 引用源分析字段
|
||||
data_source=result.get("data_source"),
|
||||
source_urls=result.get("source_urls"),
|
||||
source_titles=result.get("source_titles"),
|
||||
citation_contexts=result.get("citation_contexts"),
|
||||
ai_response_text=sanitize_raw_response(result.get("ai_response_text", "")),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,64 @@
|
|||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import String, Float, Integer, ForeignKey, Index, func
|
||||
from sqlalchemy import Uuid
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from sqlalchemy.types import TypeDecorator, JSON
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class JSONType(TypeDecorator):
|
||||
impl = JSON
|
||||
cache_ok = True
|
||||
|
||||
def load_dialect_impl(self, dialect):
|
||||
if dialect.name == "postgresql":
|
||||
return dialect.type_descriptor(JSONB())
|
||||
return dialect.type_descriptor(JSON())
|
||||
|
||||
|
||||
class CompetitorInsight(Base):
|
||||
__tablename__ = "competitor_insights"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
Uuid(as_uuid=True),
|
||||
primary_key=True,
|
||||
default=uuid.uuid4,
|
||||
)
|
||||
brand_id: Mapped[uuid.UUID] = mapped_column(
|
||||
Uuid(as_uuid=True),
|
||||
ForeignKey("brands.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
competitor_name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
analysis_type: Mapped[str] = mapped_column(
|
||||
String(50), nullable=False,
|
||||
)
|
||||
insight_data: Mapped[dict | None] = mapped_column(JSONType, nullable=True)
|
||||
citation_count_brand: Mapped[int] = mapped_column(Integer, default=0, nullable=False)
|
||||
citation_count_competitor: Mapped[int] = mapped_column(Integer, default=0, nullable=False)
|
||||
sentiment_brand: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
sentiment_competitor: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
platform_breakdown: Mapped[dict | None] = mapped_column(JSONType, nullable=True)
|
||||
gap_analysis: Mapped[dict | None] = mapped_column(JSONType, nullable=True)
|
||||
opportunity_areas: Mapped[dict | None] = mapped_column(JSONType, nullable=True)
|
||||
recommendations: Mapped[dict | None] = mapped_column(JSONType, nullable=True)
|
||||
confidence: Mapped[str] = mapped_column(String(20), default="medium", nullable=False)
|
||||
period_days: Mapped[int] = mapped_column(Integer, default=30, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("idx_competitor_insights_brand_id", "brand_id"),
|
||||
Index("idx_competitor_insights_analysis_type", "analysis_type"),
|
||||
)
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import String, Uuid, JSON, Float, Text, ForeignKey, Index, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class DiagnosisRecord(Base):
|
||||
__tablename__ = "diagnosis_records"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
brand_id: Mapped[uuid.UUID] = mapped_column(
|
||||
Uuid(as_uuid=True),
|
||||
ForeignKey("brands.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
user_id: Mapped[uuid.UUID] = mapped_column(Uuid(as_uuid=True), nullable=False)
|
||||
diagnosis_type: Mapped[str] = mapped_column(
|
||||
String(20), default="geo", nullable=False
|
||||
)
|
||||
status: Mapped[str] = mapped_column(String(20), default="pending", nullable=False)
|
||||
overall_score: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
result_json: Mapped[dict | None] = mapped_column(JSON, nullable=True)
|
||||
error_message: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
collection_metadata: Mapped[dict | None] = mapped_column(JSON, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
server_default=func.now(), nullable=False
|
||||
)
|
||||
completed_at: Mapped[datetime | None] = mapped_column(nullable=True)
|
||||
|
||||
__table_args__ = (
|
||||
Index("idx_diagnosis_records_brand_id", "brand_id"),
|
||||
Index("idx_diagnosis_records_user_id", "user_id"),
|
||||
Index("idx_diagnosis_records_status", "status"),
|
||||
Index("idx_diagnosis_records_created_at", "created_at"),
|
||||
)
|
||||
|
|
@ -0,0 +1,112 @@
|
|||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import String, Integer, Text, ForeignKey, Index, func
|
||||
from sqlalchemy import Uuid
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.database import Base, JSONType
|
||||
|
||||
|
||||
class GeoPlan(Base):
|
||||
__tablename__ = "geo_plans"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
Uuid(as_uuid=True),
|
||||
primary_key=True,
|
||||
default=uuid.uuid4,
|
||||
)
|
||||
organization_id: Mapped[uuid.UUID] = mapped_column(
|
||||
Uuid(as_uuid=True),
|
||||
ForeignKey("organizations.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
brand_id: Mapped[uuid.UUID] = mapped_column(
|
||||
Uuid(as_uuid=True),
|
||||
ForeignKey("brands.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
title: Mapped[str] = mapped_column(String(500), nullable=False)
|
||||
status: Mapped[str] = mapped_column(
|
||||
String(20), server_default="draft", nullable=False,
|
||||
)
|
||||
diagnosis_score: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
target_score: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
estimated_weeks: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
plan_data: Mapped[dict | None] = mapped_column(JSONType, nullable=True)
|
||||
source: Mapped[str] = mapped_column(String(20), nullable=False, default="rule")
|
||||
created_by: Mapped[str | None] = mapped_column(
|
||||
String(36),
|
||||
ForeignKey("users.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
brand: Mapped["Brand"] = relationship("Brand")
|
||||
creator: Mapped["User | None"] = relationship(
|
||||
"User", foreign_keys=[created_by]
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("idx_geo_plans_brand_id", "brand_id"),
|
||||
Index("idx_geo_plans_status", "status"),
|
||||
Index("idx_geo_plans_organization_id", "organization_id"),
|
||||
)
|
||||
|
||||
|
||||
class GeoPlanAction(Base):
|
||||
__tablename__ = "geo_plan_actions"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
Uuid(as_uuid=True),
|
||||
primary_key=True,
|
||||
default=uuid.uuid4,
|
||||
)
|
||||
plan_id: Mapped[uuid.UUID] = mapped_column(
|
||||
Uuid(as_uuid=True),
|
||||
ForeignKey("geo_plans.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
action_type: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
title: Mapped[str] = mapped_column(String(500), nullable=False)
|
||||
description: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
reason: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
priority: Mapped[str] = mapped_column(String(10), nullable=False)
|
||||
status: Mapped[str] = mapped_column(
|
||||
String(20), server_default="pending", nullable=False,
|
||||
)
|
||||
target_keyword: Mapped[str | None] = mapped_column(String(200), nullable=True)
|
||||
target_platform: Mapped[str | None] = mapped_column(String(50), nullable=True)
|
||||
content_style: Mapped[str | None] = mapped_column(String(50), nullable=True)
|
||||
estimated_impact: Mapped[str | None] = mapped_column(String(500), nullable=True)
|
||||
difficulty: Mapped[str] = mapped_column(String(10), nullable=False)
|
||||
execution_params: Mapped[dict | None] = mapped_column(JSONType, nullable=True)
|
||||
sort_order: Mapped[int] = mapped_column(
|
||||
Integer, server_default="0", nullable=False,
|
||||
)
|
||||
completed_at: Mapped[datetime | None] = mapped_column(nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
plan: Mapped["GeoPlan"] = relationship("GeoPlan")
|
||||
|
||||
__table_args__ = (
|
||||
Index("idx_geo_plan_actions_plan_id", "plan_id"),
|
||||
Index("idx_geo_plan_actions_status", "status"),
|
||||
Index("idx_geo_plan_actions_priority", "priority"),
|
||||
)
|
||||
|
|
@ -36,8 +36,8 @@ class KnowledgeBase(Base):
|
|||
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
document_count: Mapped[int] = mapped_column(Integer, server_default="0", nullable=False)
|
||||
status: Mapped[str] = mapped_column(String(20), server_default="active", nullable=False)
|
||||
created_by: Mapped[uuid.UUID | None] = mapped_column(
|
||||
Uuid(as_uuid=True),
|
||||
created_by: Mapped[str | None] = mapped_column(
|
||||
String(36),
|
||||
ForeignKey("users.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,100 @@
|
|||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import String, Integer, Float, ForeignKey, Index, func
|
||||
from sqlalchemy import Uuid
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.database import Base, JSONType
|
||||
|
||||
|
||||
class MonitoringRecord(Base):
|
||||
__tablename__ = "monitoring_records"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
Uuid(as_uuid=True),
|
||||
primary_key=True,
|
||||
default=uuid.uuid4,
|
||||
)
|
||||
brand_id: Mapped[uuid.UUID] = mapped_column(
|
||||
Uuid(as_uuid=True),
|
||||
ForeignKey("brands.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
content_id: Mapped[str | None] = mapped_column(String(36), nullable=True)
|
||||
query_keywords: Mapped[str | None] = mapped_column(String(500), nullable=True)
|
||||
platform: Mapped[str | None] = mapped_column(String(50), nullable=True)
|
||||
baseline_citation_count: Mapped[int] = mapped_column(
|
||||
Integer, server_default="0", nullable=False,
|
||||
)
|
||||
baseline_sentiment: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
baseline_rank: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
current_citation_count: Mapped[int] = mapped_column(
|
||||
Integer, server_default="0", nullable=False,
|
||||
)
|
||||
current_sentiment: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
current_rank: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
change_type: Mapped[str | None] = mapped_column(String(20), nullable=True)
|
||||
change_details: Mapped[dict | None] = mapped_column(JSONType, nullable=True)
|
||||
check_interval_hours: Mapped[int] = mapped_column(
|
||||
Integer, server_default="24", nullable=False,
|
||||
)
|
||||
last_checked_at: Mapped[datetime | None] = mapped_column(nullable=True)
|
||||
next_check_at: Mapped[datetime | None] = mapped_column(nullable=True)
|
||||
status: Mapped[str] = mapped_column(
|
||||
String(20), server_default="active", nullable=False,
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
brand: Mapped["Brand"] = relationship("Brand")
|
||||
baselines: Mapped[list["ContentBaseline"]] = relationship(
|
||||
"ContentBaseline", back_populates="monitoring_record", cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("idx_monitoring_records_brand_id", "brand_id"),
|
||||
Index("idx_monitoring_records_status", "status"),
|
||||
Index("idx_monitoring_records_next_check_at", "next_check_at"),
|
||||
)
|
||||
|
||||
|
||||
class ContentBaseline(Base):
|
||||
__tablename__ = "content_baselines"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
Uuid(as_uuid=True),
|
||||
primary_key=True,
|
||||
default=uuid.uuid4,
|
||||
)
|
||||
monitoring_record_id: Mapped[uuid.UUID] = mapped_column(
|
||||
Uuid(as_uuid=True),
|
||||
ForeignKey("monitoring_records.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
brand_name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
keyword: Mapped[str] = mapped_column(String(200), nullable=False)
|
||||
platform: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
citation_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
sentiment_score: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
rank_position: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
snapshot_data: Mapped[dict | None] = mapped_column(JSONType, nullable=True)
|
||||
recorded_at: Mapped[datetime] = mapped_column(
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
monitoring_record: Mapped["MonitoringRecord"] = relationship(
|
||||
"MonitoringRecord", back_populates="baselines",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("idx_content_baselines_monitoring_record_id", "monitoring_record_id"),
|
||||
)
|
||||
|
|
@ -0,0 +1,108 @@
|
|||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import String, Integer, Float, ForeignKey, Index, func
|
||||
from sqlalchemy import Uuid
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.database import Base, JSONType
|
||||
|
||||
|
||||
class MonitoringRecord(Base):
|
||||
__tablename__ = "monitoring_records"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
Uuid(as_uuid=True),
|
||||
primary_key=True,
|
||||
default=uuid.uuid4,
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
String(36),
|
||||
ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
brand_id: Mapped[uuid.UUID] = mapped_column(
|
||||
Uuid(as_uuid=True),
|
||||
ForeignKey("brands.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
content_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
Uuid(as_uuid=True),
|
||||
ForeignKey("contents.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
query_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
Uuid(as_uuid=True),
|
||||
ForeignKey("queries.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
task_type: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
status: Mapped[str] = mapped_column(
|
||||
String(20), server_default="pending", nullable=False,
|
||||
)
|
||||
baseline_data: Mapped[dict | None] = mapped_column(JSONType, nullable=True)
|
||||
current_data: Mapped[dict | None] = mapped_column(JSONType, nullable=True)
|
||||
change_report: Mapped[dict | None] = mapped_column(JSONType, nullable=True)
|
||||
interval_hours: Mapped[int] = mapped_column(
|
||||
Integer, server_default="24", nullable=False,
|
||||
)
|
||||
last_checked_at: Mapped[datetime | None] = mapped_column(nullable=True)
|
||||
next_check_at: Mapped[datetime | None] = mapped_column(nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
brand: Mapped["Brand"] = relationship("Brand")
|
||||
user: Mapped["User"] = relationship("User")
|
||||
|
||||
__table_args__ = (
|
||||
Index("idx_monitoring_records_user_id", "user_id"),
|
||||
Index("idx_monitoring_records_brand_id", "brand_id"),
|
||||
Index("idx_monitoring_records_status", "status"),
|
||||
Index("idx_monitoring_records_next_check_at", "next_check_at"),
|
||||
)
|
||||
|
||||
|
||||
class ContentBaseline(Base):
|
||||
__tablename__ = "content_baselines"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
Uuid(as_uuid=True),
|
||||
primary_key=True,
|
||||
default=uuid.uuid4,
|
||||
)
|
||||
brand_id: Mapped[uuid.UUID] = mapped_column(
|
||||
Uuid(as_uuid=True),
|
||||
ForeignKey("brands.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
content_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
Uuid(as_uuid=True),
|
||||
ForeignKey("contents.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
citation_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
positive_ratio: Mapped[float] = mapped_column(Float, nullable=False, default=0.0)
|
||||
avg_rank: Mapped[float] = mapped_column(Float, nullable=False, default=0.0)
|
||||
platform_data: Mapped[dict | None] = mapped_column(JSONType, nullable=True)
|
||||
recorded_at: Mapped[datetime] = mapped_column(
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
brand: Mapped["Brand"] = relationship("Brand")
|
||||
|
||||
__table_args__ = (
|
||||
Index("idx_content_baselines_brand_id", "brand_id"),
|
||||
Index("idx_content_baselines_content_id", "content_id"),
|
||||
)
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import String, Float, DateTime, ForeignKey, func, Uuid
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.database import Base, JSONType
|
||||
|
||||
|
||||
class PaymentOrder(Base):
|
||||
__tablename__ = "payment_orders"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
Uuid(as_uuid=True),
|
||||
primary_key=True,
|
||||
default=uuid.uuid4,
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
String(36),
|
||||
ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
plan: Mapped[str] = mapped_column(String(20), nullable=False)
|
||||
amount: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
currency: Mapped[str] = mapped_column(String(10), default="CNY")
|
||||
payment_provider: Mapped[str] = mapped_column(String(20), nullable=False)
|
||||
payment_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
status: Mapped[str] = mapped_column(String(20), default="pending")
|
||||
pay_url: Mapped[str | None] = mapped_column(String(1024), nullable=True)
|
||||
callback_data: Mapped[dict | None] = mapped_column(JSONType, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
paid_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
|
|
@ -0,0 +1,74 @@
|
|||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import String, Text, DateTime, ForeignKey, Index, func, Float
|
||||
from sqlalchemy import Uuid
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.database import Base, JSONType
|
||||
|
||||
|
||||
class SchemaSuggestion(Base):
|
||||
__tablename__ = "schema_suggestions"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
Uuid(as_uuid=True),
|
||||
primary_key=True,
|
||||
default=uuid.uuid4,
|
||||
)
|
||||
brand_id: Mapped[uuid.UUID] = mapped_column(
|
||||
Uuid(as_uuid=True),
|
||||
ForeignKey("brands.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
schema_type: Mapped[str] = mapped_column(
|
||||
String(50), nullable=False,
|
||||
)
|
||||
target_url: Mapped[str | None] = mapped_column(
|
||||
String(500), nullable=True,
|
||||
)
|
||||
json_ld_template: Mapped[dict] = mapped_column(
|
||||
JSONType, nullable=False, default=dict,
|
||||
)
|
||||
json_ld_filled: Mapped[dict | None] = mapped_column(
|
||||
JSONType, nullable=True,
|
||||
)
|
||||
priority: Mapped[str] = mapped_column(
|
||||
String(20), nullable=False, default="medium",
|
||||
)
|
||||
status: Mapped[str] = mapped_column(
|
||||
String(20), nullable=False, default="pending",
|
||||
)
|
||||
diagnosis_dimensions: Mapped[dict | None] = mapped_column(
|
||||
JSONType, nullable=True,
|
||||
)
|
||||
implementation_difficulty: Mapped[str] = mapped_column(
|
||||
String(20), nullable=False, default="medium",
|
||||
)
|
||||
estimated_impact: Mapped[str | None] = mapped_column(
|
||||
Text, nullable=True,
|
||||
)
|
||||
validation_errors: Mapped[dict | None] = mapped_column(
|
||||
JSONType, nullable=True,
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
brand: Mapped["Brand"] = relationship("Brand", back_populates="schema_suggestions")
|
||||
|
||||
__table_args__ = (
|
||||
Index("idx_schema_suggestions_brand_id", "brand_id"),
|
||||
Index("idx_schema_suggestions_status", "status"),
|
||||
Index("idx_schema_suggestions_schema_type", "schema_type"),
|
||||
Index("idx_schema_suggestions_brand_status", "brand_id", "status"),
|
||||
)
|
||||
|
||||
|
||||
from app.models.brand import Brand # noqa: E402, F401
|
||||
|
|
@ -0,0 +1,54 @@
|
|||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import String, Integer, Float, ForeignKey, Index, func, Text
|
||||
from sqlalchemy import Uuid, JSON
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class TrendInsight(Base):
|
||||
__tablename__ = "trend_insights"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
Uuid(as_uuid=True),
|
||||
primary_key=True,
|
||||
default=uuid.uuid4,
|
||||
)
|
||||
brand_id: Mapped[uuid.UUID] = mapped_column(
|
||||
Uuid(as_uuid=True),
|
||||
ForeignKey("brands.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
trend_type: Mapped[str] = mapped_column(String(20), nullable=False)
|
||||
keyword: Mapped[str | None] = mapped_column(String(200), nullable=True)
|
||||
platform: Mapped[str | None] = mapped_column(String(50), nullable=True)
|
||||
period_start: Mapped[datetime] = mapped_column(nullable=False)
|
||||
period_end: Mapped[datetime] = mapped_column(nullable=False)
|
||||
data_points: Mapped[list | None] = mapped_column(JSON, nullable=True)
|
||||
change_rate: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
absolute_change: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
sentiment_trend: Mapped[dict | None] = mapped_column(JSON, nullable=True)
|
||||
cause_analysis: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
recommendations: Mapped[list | None] = mapped_column(JSON, nullable=True)
|
||||
confidence: Mapped[float] = mapped_column(Float, nullable=False, default=0.5)
|
||||
severity: Mapped[str] = mapped_column(
|
||||
String(20), nullable=False, server_default="info",
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("idx_trend_insights_brand_id", "brand_id"),
|
||||
Index("idx_trend_insights_trend_type", "trend_type"),
|
||||
Index("idx_trend_insights_created_at", "created_at"),
|
||||
Index("idx_trend_insights_period_start", "period_start"),
|
||||
)
|
||||
|
|
@ -31,6 +31,8 @@ class User(Base):
|
|||
lockedUntil: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
organization_id: Mapped[uuid.UUID | None] = mapped_column(Uuid(as_uuid=True), ForeignKey("organizations.id", ondelete="SET NULL"), nullable=True)
|
||||
role: Mapped[str] = mapped_column(String(20), server_default="owner", nullable=False)
|
||||
plan: Mapped[str] = mapped_column(String(20), server_default="free", nullable=False)
|
||||
max_queries: Mapped[int] = mapped_column(Integer, server_default="5", nullable=False)
|
||||
|
||||
queries: Mapped[list["Query"]] = relationship(
|
||||
"Query", back_populates="user", viewonly=True,
|
||||
|
|
|
|||
|
|
@ -1,13 +0,0 @@
|
|||
"""监控模块"""
|
||||
import os
|
||||
|
||||
from app.monitoring.metrics import *
|
||||
from app.monitoring.middleware import MonitoringMiddleware
|
||||
from app.monitoring.agent_hooks import agent_execution_context, record_agent_execution
|
||||
from app.monitoring.llm_metrics import get_llm_metrics, LLMMetricsWrapper
|
||||
|
||||
# 设置服务信息
|
||||
SERVICE_INFO.info({
|
||||
"version": "1.0.0",
|
||||
"environment": os.getenv("ENVIRONMENT", "development"),
|
||||
})
|
||||
|
|
@ -1,86 +0,0 @@
|
|||
"""监控中间件"""
|
||||
import time
|
||||
from typing import Callable
|
||||
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from app.monitoring.metrics import (
|
||||
API_REQUESTS_TOTAL,
|
||||
API_REQUEST_DURATION_SECONDS,
|
||||
API_REQUESTS_IN_PROGRESS,
|
||||
)
|
||||
|
||||
# 需要排除的路径(不记录指标)
|
||||
EXCLUDED_PATHS = {"/health", "/ready", "/metrics", "/docs", "/openapi.json"}
|
||||
|
||||
|
||||
class MonitoringMiddleware(BaseHTTPMiddleware):
|
||||
"""API监控中间件"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
# 跳过排除路径
|
||||
if request.url.path in EXCLUDED_PATHS:
|
||||
return await call_next(request)
|
||||
|
||||
# 提取端点标识(用于指标标签)
|
||||
endpoint = self._get_endpoint_label(request)
|
||||
|
||||
# 增加活跃请求计数
|
||||
API_REQUESTS_IN_PROGRESS.labels(
|
||||
method=request.method,
|
||||
endpoint=endpoint
|
||||
).inc()
|
||||
|
||||
# 记录开始时间
|
||||
start_time = time.perf_counter()
|
||||
|
||||
try:
|
||||
# 执行请求
|
||||
response = await call_next(request)
|
||||
status_code = response.status_code
|
||||
except Exception as e:
|
||||
status_code = 500
|
||||
raise
|
||||
finally:
|
||||
# 计算耗时
|
||||
duration = time.perf_counter() - start_time
|
||||
|
||||
# 记录指标
|
||||
API_REQUESTS_TOTAL.labels(
|
||||
method=request.method,
|
||||
endpoint=endpoint,
|
||||
status=str(status_code)
|
||||
).inc()
|
||||
|
||||
API_REQUEST_DURATION_SECONDS.labels(
|
||||
method=request.method,
|
||||
endpoint=endpoint
|
||||
).observe(duration)
|
||||
|
||||
# 减少活跃请求计数
|
||||
API_REQUESTS_IN_PROGRESS.labels(
|
||||
method=request.method,
|
||||
endpoint=endpoint
|
||||
).dec()
|
||||
|
||||
return response
|
||||
|
||||
def _get_endpoint_label(self, request: Request) -> str:
|
||||
"""提取端点标签"""
|
||||
path = request.url.path
|
||||
|
||||
# 规范化路径(替换ID等参数)
|
||||
parts = path.strip("/").split("/")
|
||||
|
||||
# 处理常见模式:/api/v1/resources/{id}
|
||||
if len(parts) >= 4 and parts[0] == "api":
|
||||
resource = parts[2] if len(parts) > 2 else "unknown"
|
||||
action = parts[3] if len(parts) > 3 else "list"
|
||||
|
||||
# 映射到规范标签
|
||||
if action.isdigit():
|
||||
return f"{resource}_detail"
|
||||
return f"{resource}_{action}"
|
||||
|
||||
return "other"
|
||||
|
|
@ -1,7 +1,31 @@
|
|||
from app.repositories.api_key_repository import APIKeyRepository
|
||||
from app.repositories.usage_repository import UsageRepository
|
||||
from app.repositories.brand_repository import BrandRepository
|
||||
from app.repositories.query_repository import QueryRepository
|
||||
from app.repositories.citation_repository import CitationRepository
|
||||
from app.repositories.content_repository import ContentRepository
|
||||
from app.repositories.knowledge_repository import KnowledgeRepository
|
||||
from app.repositories.alert_repository import AlertRepository
|
||||
from app.repositories.subscription_repository import SubscriptionRepository
|
||||
from app.repositories.organization_repository import OrganizationRepository
|
||||
from app.repositories.user_repository import UserRepository
|
||||
from app.repositories.detection_task_repository import DetectionTaskRepository
|
||||
from app.repositories.suggestion_repository import SuggestionRepository
|
||||
from app.repositories.competitor_repository import CompetitorRepository
|
||||
|
||||
__all__ = [
|
||||
"APIKeyRepository",
|
||||
"UsageRepository",
|
||||
"BrandRepository",
|
||||
"QueryRepository",
|
||||
"CitationRepository",
|
||||
"ContentRepository",
|
||||
"KnowledgeRepository",
|
||||
"AlertRepository",
|
||||
"SubscriptionRepository",
|
||||
"OrganizationRepository",
|
||||
"UserRepository",
|
||||
"DetectionTaskRepository",
|
||||
"SuggestionRepository",
|
||||
"CompetitorRepository",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,77 @@
|
|||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select, func, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.alert import Alert
|
||||
|
||||
|
||||
class AlertRepository:
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
async def get_by_id(self, id: uuid.UUID) -> Optional[Alert]:
|
||||
result = await self.session.execute(
|
||||
select(Alert).where(Alert.id == id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def list_by_user(
|
||||
self, user_id: uuid.UUID, *, skip: int = 0, limit: int = 100
|
||||
) -> list[Alert]:
|
||||
result = await self.session.execute(
|
||||
select(Alert)
|
||||
.where(Alert.user_id == user_id)
|
||||
.order_by(Alert.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def count_by_user(self, user_id: uuid.UUID) -> int:
|
||||
result = await self.session.execute(
|
||||
select(func.count()).select_from(Alert).where(Alert.user_id == user_id)
|
||||
)
|
||||
return result.scalar_one()
|
||||
|
||||
async def get_unread_by_user(self, user_id: uuid.UUID) -> list[Alert]:
|
||||
result = await self.session.execute(
|
||||
select(Alert).where(
|
||||
Alert.user_id == user_id,
|
||||
Alert.is_read == False,
|
||||
).order_by(Alert.created_at.desc())
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def mark_as_read(self, alert_id: uuid.UUID) -> bool:
|
||||
instance = await self.get_by_id(alert_id)
|
||||
if instance is None:
|
||||
return False
|
||||
instance.is_read = True
|
||||
await self.session.flush()
|
||||
return True
|
||||
|
||||
async def create(self, **kwargs) -> Alert:
|
||||
instance = Alert(**kwargs)
|
||||
self.session.add(instance)
|
||||
await self.session.flush()
|
||||
return instance
|
||||
|
||||
async def update(self, id: uuid.UUID, **kwargs) -> Optional[Alert]:
|
||||
instance = await self.get_by_id(id)
|
||||
if instance is None:
|
||||
return None
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(instance, key):
|
||||
setattr(instance, key, value)
|
||||
await self.session.flush()
|
||||
return instance
|
||||
|
||||
async def delete(self, id: uuid.UUID) -> bool:
|
||||
instance = await self.get_by_id(id)
|
||||
if instance is None:
|
||||
return False
|
||||
await self.session.delete(instance)
|
||||
await self.session.flush()
|
||||
return True
|
||||
|
|
@ -0,0 +1,71 @@
|
|||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.brand import Brand
|
||||
|
||||
|
||||
class BrandRepository:
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
async def get_by_id(self, id: uuid.UUID) -> Optional[Brand]:
|
||||
result = await self.session.execute(
|
||||
select(Brand).where(Brand.id == id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def list_by_user(
|
||||
self, user_id: uuid.UUID, *, skip: int = 0, limit: int = 100
|
||||
) -> list[Brand]:
|
||||
result = await self.session.execute(
|
||||
select(Brand)
|
||||
.where(Brand.user_id == user_id)
|
||||
.order_by(Brand.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def count_by_user(self, user_id: uuid.UUID) -> int:
|
||||
result = await self.session.execute(
|
||||
select(func.count()).select_from(Brand).where(Brand.user_id == user_id)
|
||||
)
|
||||
return result.scalar_one()
|
||||
|
||||
async def get_by_name_and_organization(
|
||||
self, name: str, organization_id: uuid.UUID
|
||||
) -> Optional[Brand]:
|
||||
result = await self.session.execute(
|
||||
select(Brand).where(
|
||||
Brand.name == name,
|
||||
Brand.user_id == organization_id,
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def create(self, **kwargs) -> Brand:
|
||||
instance = Brand(**kwargs)
|
||||
self.session.add(instance)
|
||||
await self.session.flush()
|
||||
return instance
|
||||
|
||||
async def update(self, id: uuid.UUID, **kwargs) -> Optional[Brand]:
|
||||
instance = await self.get_by_id(id)
|
||||
if instance is None:
|
||||
return None
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(instance, key):
|
||||
setattr(instance, key, value)
|
||||
await self.session.flush()
|
||||
return instance
|
||||
|
||||
async def delete(self, id: uuid.UUID) -> bool:
|
||||
instance = await self.get_by_id(id)
|
||||
if instance is None:
|
||||
return False
|
||||
await self.session.delete(instance)
|
||||
await self.session.flush()
|
||||
return True
|
||||
|
|
@ -0,0 +1,84 @@
|
|||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.citation_record import CitationRecord
|
||||
|
||||
|
||||
class CitationRepository:
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
async def get_by_id(self, id: uuid.UUID) -> Optional[CitationRecord]:
|
||||
result = await self.session.execute(
|
||||
select(CitationRecord).where(CitationRecord.id == id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def list_by_query(
|
||||
self, query_id: uuid.UUID, *, skip: int = 0, limit: int = 100
|
||||
) -> list[CitationRecord]:
|
||||
result = await self.session.execute(
|
||||
select(CitationRecord)
|
||||
.where(CitationRecord.query_id == query_id)
|
||||
.order_by(CitationRecord.queried_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def count_by_query(self, query_id: uuid.UUID) -> int:
|
||||
result = await self.session.execute(
|
||||
select(func.count()).select_from(CitationRecord).where(
|
||||
CitationRecord.query_id == query_id
|
||||
)
|
||||
)
|
||||
return result.scalar_one()
|
||||
|
||||
async def get_by_query_and_platform(
|
||||
self, query_id: uuid.UUID, platform: str
|
||||
) -> Optional[CitationRecord]:
|
||||
result = await self.session.execute(
|
||||
select(CitationRecord).where(
|
||||
CitationRecord.query_id == query_id,
|
||||
CitationRecord.platform == platform,
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def count_cited_by_brand(self, brand_name: str) -> int:
|
||||
result = await self.session.execute(
|
||||
select(func.count())
|
||||
.select_from(CitationRecord)
|
||||
.join(CitationRecord.query)
|
||||
.where(
|
||||
CitationRecord.cited == True,
|
||||
)
|
||||
)
|
||||
return result.scalar_one()
|
||||
|
||||
async def create(self, **kwargs) -> CitationRecord:
|
||||
instance = CitationRecord(**kwargs)
|
||||
self.session.add(instance)
|
||||
await self.session.flush()
|
||||
return instance
|
||||
|
||||
async def update(self, id: uuid.UUID, **kwargs) -> Optional[CitationRecord]:
|
||||
instance = await self.get_by_id(id)
|
||||
if instance is None:
|
||||
return None
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(instance, key):
|
||||
setattr(instance, key, value)
|
||||
await self.session.flush()
|
||||
return instance
|
||||
|
||||
async def delete(self, id: uuid.UUID) -> bool:
|
||||
instance = await self.get_by_id(id)
|
||||
if instance is None:
|
||||
return False
|
||||
await self.session.delete(instance)
|
||||
await self.session.flush()
|
||||
return True
|
||||
|
|
@ -0,0 +1,71 @@
|
|||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.competitor import Competitor
|
||||
|
||||
|
||||
class CompetitorRepository:
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
async def get_by_id(self, id: uuid.UUID) -> Optional[Competitor]:
|
||||
result = await self.session.execute(
|
||||
select(Competitor).where(Competitor.id == id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def list_by_brand(
|
||||
self, brand_id: uuid.UUID, *, skip: int = 0, limit: int = 100
|
||||
) -> list[Competitor]:
|
||||
result = await self.session.execute(
|
||||
select(Competitor)
|
||||
.where(Competitor.brand_id == brand_id)
|
||||
.order_by(Competitor.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def count_by_brand(self, brand_id: uuid.UUID) -> int:
|
||||
result = await self.session.execute(
|
||||
select(func.count()).select_from(Competitor).where(
|
||||
Competitor.brand_id == brand_id
|
||||
)
|
||||
)
|
||||
return result.scalar_one()
|
||||
|
||||
async def get_by_brand(self, brand_name: str) -> list[Competitor]:
|
||||
from app.models.brand import Brand
|
||||
result = await self.session.execute(
|
||||
select(Competitor)
|
||||
.join(Brand, Competitor.brand_id == Brand.id)
|
||||
.where(Brand.name == brand_name)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def create(self, **kwargs) -> Competitor:
|
||||
instance = Competitor(**kwargs)
|
||||
self.session.add(instance)
|
||||
await self.session.flush()
|
||||
return instance
|
||||
|
||||
async def update(self, id: uuid.UUID, **kwargs) -> Optional[Competitor]:
|
||||
instance = await self.get_by_id(id)
|
||||
if instance is None:
|
||||
return None
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(instance, key):
|
||||
setattr(instance, key, value)
|
||||
await self.session.flush()
|
||||
return instance
|
||||
|
||||
async def delete(self, id: uuid.UUID) -> bool:
|
||||
instance = await self.get_by_id(id)
|
||||
if instance is None:
|
||||
return False
|
||||
await self.session.delete(instance)
|
||||
await self.session.flush()
|
||||
return True
|
||||
|
|
@ -0,0 +1,68 @@
|
|||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.content import Content
|
||||
|
||||
|
||||
class ContentRepository:
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
async def get_by_id(self, id: uuid.UUID) -> Optional[Content]:
|
||||
result = await self.session.execute(
|
||||
select(Content).where(Content.id == id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def list_by_organization(
|
||||
self, organization_id: uuid.UUID, *, skip: int = 0, limit: int = 100
|
||||
) -> list[Content]:
|
||||
result = await self.session.execute(
|
||||
select(Content)
|
||||
.where(Content.organization_id == organization_id)
|
||||
.order_by(Content.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def count_by_organization(self, organization_id: uuid.UUID) -> int:
|
||||
result = await self.session.execute(
|
||||
select(func.count()).select_from(Content).where(
|
||||
Content.organization_id == organization_id
|
||||
)
|
||||
)
|
||||
return result.scalar_one()
|
||||
|
||||
async def get_by_status(self, status: str) -> list[Content]:
|
||||
result = await self.session.execute(
|
||||
select(Content).where(Content.status == status)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def create(self, **kwargs) -> Content:
|
||||
instance = Content(**kwargs)
|
||||
self.session.add(instance)
|
||||
await self.session.flush()
|
||||
return instance
|
||||
|
||||
async def update(self, id: uuid.UUID, **kwargs) -> Optional[Content]:
|
||||
instance = await self.get_by_id(id)
|
||||
if instance is None:
|
||||
return None
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(instance, key):
|
||||
setattr(instance, key, value)
|
||||
await self.session.flush()
|
||||
return instance
|
||||
|
||||
async def delete(self, id: uuid.UUID) -> bool:
|
||||
instance = await self.get_by_id(id)
|
||||
if instance is None:
|
||||
return False
|
||||
await self.session.delete(instance)
|
||||
await self.session.flush()
|
||||
return True
|
||||
|
|
@ -0,0 +1,68 @@
|
|||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.detection_task import DetectionTask
|
||||
|
||||
|
||||
class DetectionTaskRepository:
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
async def get_by_id(self, id: uuid.UUID) -> Optional[DetectionTask]:
|
||||
result = await self.session.execute(
|
||||
select(DetectionTask).where(DetectionTask.id == id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def list_by_user(
|
||||
self, user_id: uuid.UUID, *, skip: int = 0, limit: int = 100
|
||||
) -> list[DetectionTask]:
|
||||
result = await self.session.execute(
|
||||
select(DetectionTask)
|
||||
.where(DetectionTask.user_id == user_id)
|
||||
.order_by(DetectionTask.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def count_by_user(self, user_id: uuid.UUID) -> int:
|
||||
result = await self.session.execute(
|
||||
select(func.count()).select_from(DetectionTask).where(
|
||||
DetectionTask.user_id == user_id
|
||||
)
|
||||
)
|
||||
return result.scalar_one()
|
||||
|
||||
async def get_active_tasks(self) -> list[DetectionTask]:
|
||||
result = await self.session.execute(
|
||||
select(DetectionTask).where(DetectionTask.is_active == True)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def create(self, **kwargs) -> DetectionTask:
|
||||
instance = DetectionTask(**kwargs)
|
||||
self.session.add(instance)
|
||||
await self.session.flush()
|
||||
return instance
|
||||
|
||||
async def update(self, id: uuid.UUID, **kwargs) -> Optional[DetectionTask]:
|
||||
instance = await self.get_by_id(id)
|
||||
if instance is None:
|
||||
return None
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(instance, key):
|
||||
setattr(instance, key, value)
|
||||
await self.session.flush()
|
||||
return instance
|
||||
|
||||
async def delete(self, id: uuid.UUID) -> bool:
|
||||
instance = await self.get_by_id(id)
|
||||
if instance is None:
|
||||
return False
|
||||
await self.session.delete(instance)
|
||||
await self.session.flush()
|
||||
return True
|
||||
|
|
@ -0,0 +1,70 @@
|
|||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.knowledge import KnowledgeBase
|
||||
|
||||
|
||||
class KnowledgeRepository:
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
async def get_by_id(self, id: uuid.UUID) -> Optional[KnowledgeBase]:
|
||||
result = await self.session.execute(
|
||||
select(KnowledgeBase).where(KnowledgeBase.id == id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def list_by_organization(
|
||||
self, organization_id: uuid.UUID, *, skip: int = 0, limit: int = 100
|
||||
) -> list[KnowledgeBase]:
|
||||
result = await self.session.execute(
|
||||
select(KnowledgeBase)
|
||||
.where(KnowledgeBase.organization_id == organization_id)
|
||||
.order_by(KnowledgeBase.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def count_by_organization(self, organization_id: uuid.UUID) -> int:
|
||||
result = await self.session.execute(
|
||||
select(func.count()).select_from(KnowledgeBase).where(
|
||||
KnowledgeBase.organization_id == organization_id
|
||||
)
|
||||
)
|
||||
return result.scalar_one()
|
||||
|
||||
async def get_by_organization(self, organization_id: uuid.UUID) -> list[KnowledgeBase]:
|
||||
result = await self.session.execute(
|
||||
select(KnowledgeBase).where(
|
||||
KnowledgeBase.organization_id == organization_id
|
||||
)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def create(self, **kwargs) -> KnowledgeBase:
|
||||
instance = KnowledgeBase(**kwargs)
|
||||
self.session.add(instance)
|
||||
await self.session.flush()
|
||||
return instance
|
||||
|
||||
async def update(self, id: uuid.UUID, **kwargs) -> Optional[KnowledgeBase]:
|
||||
instance = await self.get_by_id(id)
|
||||
if instance is None:
|
||||
return None
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(instance, key):
|
||||
setattr(instance, key, value)
|
||||
await self.session.flush()
|
||||
return instance
|
||||
|
||||
async def delete(self, id: uuid.UUID) -> bool:
|
||||
instance = await self.get_by_id(id)
|
||||
if instance is None:
|
||||
return False
|
||||
await self.session.delete(instance)
|
||||
await self.session.flush()
|
||||
return True
|
||||
|
|
@ -0,0 +1,75 @@
|
|||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.organization import Organization
|
||||
|
||||
|
||||
class OrganizationRepository:
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
async def get_by_id(self, id: uuid.UUID) -> Optional[Organization]:
|
||||
result = await self.session.execute(
|
||||
select(Organization).where(Organization.id == id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_by_slug(self, slug: str) -> Optional[Organization]:
|
||||
result = await self.session.execute(
|
||||
select(Organization).where(Organization.slug == slug)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def list_all(self, *, skip: int = 0, limit: int = 100) -> list[Organization]:
|
||||
result = await self.session.execute(
|
||||
select(Organization)
|
||||
.order_by(Organization.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def count_all(self) -> int:
|
||||
result = await self.session.execute(
|
||||
select(func.count()).select_from(Organization)
|
||||
)
|
||||
return result.scalar_one()
|
||||
|
||||
async def get_by_owner(self, user_id: str) -> list[Organization]:
|
||||
from app.models.organization import OrgMember
|
||||
result = await self.session.execute(
|
||||
select(Organization)
|
||||
.join(OrgMember, OrgMember.organization_id == Organization.id)
|
||||
.where(
|
||||
OrgMember.user_id == user_id,
|
||||
OrgMember.role == "owner",
|
||||
)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def create(self, **kwargs) -> Organization:
|
||||
instance = Organization(**kwargs)
|
||||
self.session.add(instance)
|
||||
await self.session.flush()
|
||||
return instance
|
||||
|
||||
async def update(self, id: uuid.UUID, **kwargs) -> Optional[Organization]:
|
||||
instance = await self.get_by_id(id)
|
||||
if instance is None:
|
||||
return None
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(instance, key):
|
||||
setattr(instance, key, value)
|
||||
await self.session.flush()
|
||||
return instance
|
||||
|
||||
async def delete(self, id: uuid.UUID) -> bool:
|
||||
instance = await self.get_by_id(id)
|
||||
if instance is None:
|
||||
return False
|
||||
await self.session.delete(instance)
|
||||
await self.session.flush()
|
||||
return True
|
||||
|
|
@ -0,0 +1,72 @@
|
|||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.query import Query
|
||||
|
||||
|
||||
class QueryRepository:
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
async def get_by_id(self, id: uuid.UUID) -> Optional[Query]:
|
||||
result = await self.session.execute(
|
||||
select(Query).where(Query.id == id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def list_by_user(
|
||||
self, user_id: str, *, skip: int = 0, limit: int = 100
|
||||
) -> list[Query]:
|
||||
result = await self.session.execute(
|
||||
select(Query)
|
||||
.where(Query.user_id == user_id)
|
||||
.order_by(Query.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def count_by_user(self, user_id: str) -> int:
|
||||
result = await self.session.execute(
|
||||
select(func.count()).select_from(Query).where(Query.user_id == user_id)
|
||||
)
|
||||
return result.scalar_one()
|
||||
|
||||
async def get_by_brand(self, brand_name: str) -> list[Query]:
|
||||
result = await self.session.execute(
|
||||
select(Query).where(Query.target_brand == brand_name)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_active_queries(self) -> list[Query]:
|
||||
result = await self.session.execute(
|
||||
select(Query).where(Query.status == "active")
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def create(self, **kwargs) -> Query:
|
||||
instance = Query(**kwargs)
|
||||
self.session.add(instance)
|
||||
await self.session.flush()
|
||||
return instance
|
||||
|
||||
async def update(self, id: uuid.UUID, **kwargs) -> Optional[Query]:
|
||||
instance = await self.get_by_id(id)
|
||||
if instance is None:
|
||||
return None
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(instance, key):
|
||||
setattr(instance, key, value)
|
||||
await self.session.flush()
|
||||
return instance
|
||||
|
||||
async def delete(self, id: uuid.UUID) -> bool:
|
||||
instance = await self.get_by_id(id)
|
||||
if instance is None:
|
||||
return False
|
||||
await self.session.delete(instance)
|
||||
await self.session.flush()
|
||||
return True
|
||||
|
|
@ -0,0 +1,70 @@
|
|||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.subscription import Subscription
|
||||
|
||||
|
||||
class SubscriptionRepository:
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
async def get_by_id(self, id: uuid.UUID) -> Optional[Subscription]:
|
||||
result = await self.session.execute(
|
||||
select(Subscription).where(Subscription.id == id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def list_by_user(
|
||||
self, user_id: str, *, skip: int = 0, limit: int = 100
|
||||
) -> list[Subscription]:
|
||||
result = await self.session.execute(
|
||||
select(Subscription)
|
||||
.where(Subscription.user_id == user_id)
|
||||
.order_by(Subscription.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def count_by_user(self, user_id: str) -> int:
|
||||
result = await self.session.execute(
|
||||
select(func.count()).select_from(Subscription).where(
|
||||
Subscription.user_id == user_id
|
||||
)
|
||||
)
|
||||
return result.scalar_one()
|
||||
|
||||
async def get_by_organization(self, organization_id: uuid.UUID) -> list[Subscription]:
|
||||
result = await self.session.execute(
|
||||
select(Subscription).where(
|
||||
Subscription.user_id == organization_id
|
||||
)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def create(self, **kwargs) -> Subscription:
|
||||
instance = Subscription(**kwargs)
|
||||
self.session.add(instance)
|
||||
await self.session.flush()
|
||||
return instance
|
||||
|
||||
async def update(self, id: uuid.UUID, **kwargs) -> Optional[Subscription]:
|
||||
instance = await self.get_by_id(id)
|
||||
if instance is None:
|
||||
return None
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(instance, key):
|
||||
setattr(instance, key, value)
|
||||
await self.session.flush()
|
||||
return instance
|
||||
|
||||
async def delete(self, id: uuid.UUID) -> bool:
|
||||
instance = await self.get_by_id(id)
|
||||
if instance is None:
|
||||
return False
|
||||
await self.session.delete(instance)
|
||||
await self.session.flush()
|
||||
return True
|
||||
|
|
@ -0,0 +1,71 @@
|
|||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.suggestion import Suggestion
|
||||
|
||||
|
||||
class SuggestionRepository:
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
async def get_by_id(self, id: uuid.UUID) -> Optional[Suggestion]:
|
||||
result = await self.session.execute(
|
||||
select(Suggestion).where(Suggestion.id == id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def list_by_brand(
|
||||
self, brand_id: uuid.UUID, *, skip: int = 0, limit: int = 100
|
||||
) -> list[Suggestion]:
|
||||
result = await self.session.execute(
|
||||
select(Suggestion)
|
||||
.where(Suggestion.brand_id == brand_id)
|
||||
.order_by(Suggestion.generated_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def count_by_brand(self, brand_id: uuid.UUID) -> int:
|
||||
result = await self.session.execute(
|
||||
select(func.count()).select_from(Suggestion).where(
|
||||
Suggestion.brand_id == brand_id
|
||||
)
|
||||
)
|
||||
return result.scalar_one()
|
||||
|
||||
async def get_by_brand(self, brand_name: str) -> list[Suggestion]:
|
||||
from app.models.brand import Brand
|
||||
result = await self.session.execute(
|
||||
select(Suggestion)
|
||||
.join(Brand, Suggestion.brand_id == Brand.id)
|
||||
.where(Brand.name == brand_name)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def create(self, **kwargs) -> Suggestion:
|
||||
instance = Suggestion(**kwargs)
|
||||
self.session.add(instance)
|
||||
await self.session.flush()
|
||||
return instance
|
||||
|
||||
async def update(self, id: uuid.UUID, **kwargs) -> Optional[Suggestion]:
|
||||
instance = await self.get_by_id(id)
|
||||
if instance is None:
|
||||
return None
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(instance, key):
|
||||
setattr(instance, key, value)
|
||||
await self.session.flush()
|
||||
return instance
|
||||
|
||||
async def delete(self, id: uuid.UUID) -> bool:
|
||||
instance = await self.get_by_id(id)
|
||||
if instance is None:
|
||||
return False
|
||||
await self.session.delete(instance)
|
||||
await self.session.flush()
|
||||
return True
|
||||
|
|
@ -0,0 +1,62 @@
|
|||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
class UserRepository:
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
async def get_by_id(self, id: str) -> Optional[User]:
|
||||
result = await self.session.execute(
|
||||
select(User).where(User.id == id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_by_email(self, email: str) -> Optional[User]:
|
||||
result = await self.session.execute(
|
||||
select(User).where(User.email == email)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def list_all(self, *, skip: int = 0, limit: int = 100) -> list[User]:
|
||||
result = await self.session.execute(
|
||||
select(User)
|
||||
.order_by(User.createdAt.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def count_all(self) -> int:
|
||||
result = await self.session.execute(
|
||||
select(func.count()).select_from(User)
|
||||
)
|
||||
return result.scalar_one()
|
||||
|
||||
async def create(self, **kwargs) -> User:
|
||||
instance = User(**kwargs)
|
||||
self.session.add(instance)
|
||||
await self.session.flush()
|
||||
return instance
|
||||
|
||||
async def update(self, id: str, **kwargs) -> Optional[User]:
|
||||
instance = await self.get_by_id(id)
|
||||
if instance is None:
|
||||
return None
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(instance, key):
|
||||
setattr(instance, key, value)
|
||||
await self.session.flush()
|
||||
return instance
|
||||
|
||||
async def delete(self, id: str) -> bool:
|
||||
instance = await self.get_by_id(id)
|
||||
if instance is None:
|
||||
return False
|
||||
await self.session.delete(instance)
|
||||
await self.session.flush()
|
||||
return True
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class CompetitorAnalysisRequest(BaseModel):
|
||||
brand_id: uuid.UUID = Field(..., description="品牌ID")
|
||||
analysis_types: list[str] | None = Field(
|
||||
None,
|
||||
description="分析类型列表: citation_gap/content_strategy/platform_coverage/query_overlap/differentiation",
|
||||
)
|
||||
period_days: int | None = Field(30, description="分析周期天数")
|
||||
|
||||
|
||||
class CompetitorInsightResponse(BaseModel):
|
||||
id: uuid.UUID = Field(..., description="洞察ID")
|
||||
brand_id: uuid.UUID = Field(..., description="品牌ID")
|
||||
competitor_name: str = Field(..., description="竞品名称")
|
||||
analysis_type: str = Field(..., description="分析类型")
|
||||
insight_data: dict | None = Field(None, description="洞察数据")
|
||||
citation_count_brand: int = Field(0, description="品牌引用次数")
|
||||
citation_count_competitor: int = Field(0, description="竞品引用次数")
|
||||
sentiment_brand: float | None = Field(None, description="品牌情感分数")
|
||||
sentiment_competitor: float | None = Field(None, description="竞品情感分数")
|
||||
platform_breakdown: dict | None = Field(None, description="平台分布")
|
||||
gap_analysis: dict | None = Field(None, description="差距分析")
|
||||
opportunity_areas: dict | None = Field(None, description="机会领域")
|
||||
recommendations: dict | None = Field(None, description="策略建议")
|
||||
confidence: str = Field("medium", description="置信度: high/medium/low")
|
||||
period_days: int = Field(30, description="分析周期天数")
|
||||
created_at: datetime = Field(..., description="创建时间")
|
||||
updated_at: datetime = Field(..., description="更新时间")
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class CompetitorInsightList(BaseModel):
|
||||
items: list[CompetitorInsightResponse] = Field(default_factory=list, description="洞察列表")
|
||||
total: int = Field(0, description="总数")
|
||||
|
||||
|
||||
class CompetitorGapSummary(BaseModel):
|
||||
brand_name: str = Field(..., description="品牌名称")
|
||||
competitor_name: str = Field(..., description="竞品名称")
|
||||
gap_dimensions: list[dict] = Field(default_factory=list, description="差距维度列表")
|
||||
overall_gap_score: float = Field(0.0, description="综合差距分数(0-100)")
|
||||
|
|
@ -0,0 +1,71 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class GEODiagnosisTriggerRequest(BaseModel):
|
||||
force_refresh: bool = Field(default=False)
|
||||
|
||||
|
||||
class GEODiagnosisTaskResponse(BaseModel):
|
||||
task_id: str
|
||||
brand_id: str
|
||||
status: str
|
||||
|
||||
|
||||
class GEODimensionItemResponse(BaseModel):
|
||||
name: str
|
||||
status: str
|
||||
description: str
|
||||
suggestion: str
|
||||
score: float
|
||||
max_score: float
|
||||
|
||||
|
||||
class GEODimensionResponse(BaseModel):
|
||||
name: str
|
||||
score: float
|
||||
max_score: float
|
||||
percentage: float
|
||||
status: str
|
||||
items: list[GEODimensionItemResponse]
|
||||
detail: dict
|
||||
|
||||
|
||||
class GEORecommendationResponse(BaseModel):
|
||||
priority: str
|
||||
dimension: str
|
||||
title: str
|
||||
description: str
|
||||
impact: str
|
||||
effort: str
|
||||
|
||||
|
||||
class GEODiagnosisResponse(BaseModel):
|
||||
overall_score: float
|
||||
health_level: str
|
||||
health_level_label: str
|
||||
dimensions: list[GEODimensionResponse]
|
||||
recommendations: list[GEORecommendationResponse]
|
||||
is_full_report: bool = False
|
||||
|
||||
|
||||
class GEODiagnosisResultResponse(BaseModel):
|
||||
task_id: str
|
||||
brand_id: str
|
||||
status: str
|
||||
result: GEODiagnosisResponse | None = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class GEODiagnosisHistoryItem(BaseModel):
|
||||
task_id: str
|
||||
overall_score: float
|
||||
health_level: str
|
||||
created_at: str
|
||||
completed_at: str | None = None
|
||||
|
||||
|
||||
class GEODiagnosisHistoryResponse(BaseModel):
|
||||
brand_id: str
|
||||
history: list[GEODiagnosisHistoryItem]
|
||||
|
|
@ -0,0 +1,69 @@
|
|||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class GeoPlanGenerateRequest(BaseModel):
|
||||
brand_id: uuid.UUID = Field(..., description="品牌ID")
|
||||
target_score: int | None = Field(75, description="目标评分")
|
||||
|
||||
|
||||
class GeoPlanActionResponse(BaseModel):
|
||||
id: uuid.UUID = Field(..., description="行动项ID")
|
||||
plan_id: uuid.UUID = Field(..., description="方案ID")
|
||||
action_type: str = Field(..., description="行动类型")
|
||||
title: str = Field(..., description="行动标题")
|
||||
description: str = Field(..., description="详细描述")
|
||||
reason: str = Field(..., description="基于诊断数据的原因")
|
||||
priority: str = Field(..., description="优先级: high/medium/low")
|
||||
status: str = Field(..., description="状态: pending/in_progress/completed/skipped")
|
||||
target_keyword: str | None = Field(None, description="预填关键词")
|
||||
target_platform: str | None = Field(None, description="预填平台")
|
||||
content_style: str | None = Field(None, description="预填风格")
|
||||
estimated_impact: str | None = Field(None, description="预期效果")
|
||||
difficulty: str = Field(..., description="难度: easy/medium/hard")
|
||||
execution_params: dict | None = Field(None, description="一键执行参数")
|
||||
sort_order: int = Field(0, description="排序序号")
|
||||
completed_at: datetime | None = Field(None, description="完成时间")
|
||||
created_at: datetime = Field(..., description="创建时间")
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class GeoPlanResponse(BaseModel):
|
||||
id: uuid.UUID = Field(..., description="方案ID")
|
||||
brand_id: uuid.UUID = Field(..., description="品牌ID")
|
||||
title: str = Field(..., description="方案标题")
|
||||
status: str = Field(..., description="状态: draft/active/completed/archived")
|
||||
diagnosis_score: int = Field(..., description="诊断评分")
|
||||
target_score: int = Field(..., description="目标评分")
|
||||
estimated_weeks: int = Field(..., description="预计周数")
|
||||
plan_data: dict | None = Field(None, description="方案详细数据")
|
||||
source: str = Field(..., description="生成来源: rule/llm")
|
||||
actions: list[GeoPlanActionResponse] = Field(
|
||||
default_factory=list, description="行动项列表"
|
||||
)
|
||||
created_at: datetime = Field(..., description="创建时间")
|
||||
updated_at: datetime = Field(..., description="更新时间")
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class GeoPlanListResponse(BaseModel):
|
||||
plans: list[GeoPlanResponse] = Field(default_factory=list, description="方案列表")
|
||||
total: int = Field(..., description="总数")
|
||||
|
||||
|
||||
class GeoPlanActionUpdateStatus(BaseModel):
|
||||
status: str = Field(
|
||||
..., description="新状态: pending/in_progress/completed/skipped"
|
||||
)
|
||||
|
||||
|
||||
class GeoPlanActionExecuteResponse(BaseModel):
|
||||
action_id: uuid.UUID = Field(..., description="行动项ID")
|
||||
content_id: str | None = Field(None, description="生成的内容ID")
|
||||
message: str = Field(..., description="执行结果消息")
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class HealthScoreDimension(BaseModel):
|
||||
name: str
|
||||
score: float
|
||||
max_score: float
|
||||
percentage: float
|
||||
status: str
|
||||
|
||||
|
||||
class HealthScoreRecommendation(BaseModel):
|
||||
priority: str
|
||||
dimension: str
|
||||
title: str
|
||||
description: str
|
||||
|
||||
|
||||
class HealthScoreResponse(BaseModel):
|
||||
brand_name: str
|
||||
overall_score: float
|
||||
health_level: str
|
||||
health_level_label: str
|
||||
dimensions: list[HealthScoreDimension]
|
||||
recommendations: list[HealthScoreRecommendation]
|
||||
is_full_report: bool = False
|
||||
cached: bool = False
|
||||
|
||||
|
||||
class HealthScoreRequest(BaseModel):
|
||||
brand: str
|
||||
competitors: list[str] = Field(default_factory=list)
|
||||
|
|
@ -0,0 +1,72 @@
|
|||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class MonitoringRecordCreate(BaseModel):
|
||||
brand_id: uuid.UUID = Field(..., description="品牌ID")
|
||||
content_id: str | None = Field(None, description="内容ID")
|
||||
query_keywords: str | None = Field(None, description="查询关键词")
|
||||
platform: str | None = Field(None, description="平台")
|
||||
check_interval_hours: int = Field(24, description="检测间隔(小时)")
|
||||
|
||||
|
||||
class MonitoringRecordResponse(BaseModel):
|
||||
id: uuid.UUID = Field(..., description="记录ID")
|
||||
brand_id: uuid.UUID = Field(..., description="品牌ID")
|
||||
content_id: str | None = Field(None, description="内容ID")
|
||||
query_keywords: str | None = Field(None, description="查询关键词")
|
||||
platform: str | None = Field(None, description="平台")
|
||||
baseline_citation_count: int = Field(0, description="基线引用量")
|
||||
baseline_sentiment: float | None = Field(None, description="基线情感分数")
|
||||
baseline_rank: int | None = Field(None, description="基线排名")
|
||||
current_citation_count: int = Field(0, description="当前引用量")
|
||||
current_sentiment: float | None = Field(None, description="当前情感分数")
|
||||
current_rank: int | None = Field(None, description="当前排名")
|
||||
change_type: str | None = Field(None, description="变化类型: positive/negative/neutral")
|
||||
change_details: dict | None = Field(None, description="变化详情")
|
||||
check_interval_hours: int = Field(24, description="检测间隔(小时)")
|
||||
last_checked_at: datetime | None = Field(None, description="上次检测时间")
|
||||
next_check_at: datetime | None = Field(None, description="下次检测时间")
|
||||
status: str = Field("active", description="状态: active/paused/completed")
|
||||
created_at: datetime = Field(..., description="创建时间")
|
||||
updated_at: datetime = Field(..., description="更新时间")
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class MonitoringRecordList(BaseModel):
|
||||
records: list[MonitoringRecordResponse] = Field(default_factory=list, description="监测记录列表")
|
||||
total: int = Field(..., description="总数")
|
||||
|
||||
|
||||
class MonitoringChangeReport(BaseModel):
|
||||
monitoring_record_id: uuid.UUID = Field(..., description="监测记录ID")
|
||||
brand_id: uuid.UUID = Field(..., description="品牌ID")
|
||||
change_type: str = Field(..., description="变化类型: positive/negative/neutral")
|
||||
change_details: dict | None = Field(None, description="变化详情")
|
||||
baseline: dict = Field(default_factory=dict, description="基线数据")
|
||||
current: dict = Field(default_factory=dict, description="当前数据")
|
||||
recommendations: list[str] = Field(default_factory=list, description="建议")
|
||||
|
||||
|
||||
class ContentBaselineResponse(BaseModel):
|
||||
id: uuid.UUID = Field(..., description="基线ID")
|
||||
monitoring_record_id: uuid.UUID = Field(..., description="监测记录ID")
|
||||
brand_name: str = Field(..., description="品牌名称")
|
||||
keyword: str = Field(..., description="关键词")
|
||||
platform: str = Field(..., description="平台")
|
||||
citation_count: int = Field(0, description="引用量")
|
||||
sentiment_score: float | None = Field(None, description="情感分数")
|
||||
rank_position: int | None = Field(None, description="排名位置")
|
||||
snapshot_data: dict | None = Field(None, description="快照数据")
|
||||
recorded_at: datetime = Field(..., description="记录时间")
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class MonitoringStatusUpdate(BaseModel):
|
||||
status: str = Field(..., description="新状态: active/paused/completed")
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SchemaAdviseRequest(BaseModel):
|
||||
brand_id: uuid.UUID = Field(..., description="品牌ID")
|
||||
target_url: str | None = Field(None, description="目标页面URL")
|
||||
focus_dimensions: list[str] | None = Field(None, description="聚焦的诊断维度")
|
||||
|
||||
|
||||
class SchemaSuggestionResponse(BaseModel):
|
||||
id: uuid.UUID = Field(..., description="建议ID")
|
||||
brand_id: uuid.UUID = Field(..., description="品牌ID")
|
||||
schema_type: str = Field(..., description="Schema类型: Organization/Product/FAQPage/Article/LocalBusiness")
|
||||
target_url: str | None = Field(None, description="目标页面URL")
|
||||
json_ld_template: dict = Field(..., description="JSON-LD模板")
|
||||
json_ld_filled: dict | None = Field(None, description="填充后的JSON-LD")
|
||||
priority: str = Field(default="medium", description="优先级: high/medium/low")
|
||||
status: str = Field(default="pending", description="状态: pending/applied/dismissed")
|
||||
diagnosis_dimensions: dict | None = Field(None, description="诊断维度数据")
|
||||
implementation_difficulty: str = Field(default="medium", description="实施难度: easy/medium/hard")
|
||||
estimated_impact: str | None = Field(None, description="预期影响描述")
|
||||
validation_errors: dict | None = Field(None, description="验证错误信息")
|
||||
created_at: datetime = Field(..., description="创建时间")
|
||||
updated_at: datetime = Field(..., description="更新时间")
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class SchemaSuggestionList(BaseModel):
|
||||
suggestions: list[SchemaSuggestionResponse] = Field(default_factory=list, description="建议列表")
|
||||
total: int = Field(..., description="总数")
|
||||
|
||||
|
||||
class SchemaValidationResult(BaseModel):
|
||||
is_valid: bool = Field(..., description="是否有效")
|
||||
errors: list[str] = Field(default_factory=list, description="错误列表")
|
||||
warnings: list[str] = Field(default_factory=list, description="警告列表")
|
||||
|
||||
|
||||
class SchemaStatusUpdateRequest(BaseModel):
|
||||
status: str = Field(..., description="新状态: pending/applied/dismissed")
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class TrendInsightRequest(BaseModel):
|
||||
brand_id: uuid.UUID = Field(..., description="品牌ID")
|
||||
period_days: int = Field(30, ge=7, le=365, description="分析周期天数")
|
||||
platforms: list[str] | None = Field(None, description="筛选平台列表")
|
||||
keywords: list[str] | None = Field(None, description="筛选关键词列表")
|
||||
|
||||
|
||||
class TrendInsightResponse(BaseModel):
|
||||
id: uuid.UUID
|
||||
brand_id: uuid.UUID
|
||||
trend_type: str
|
||||
keyword: str | None
|
||||
platform: str | None
|
||||
period_start: datetime
|
||||
period_end: datetime
|
||||
data_points: list | None
|
||||
change_rate: float | None
|
||||
absolute_change: int | None
|
||||
sentiment_trend: dict | None
|
||||
cause_analysis: str | None
|
||||
recommendations: list | None
|
||||
confidence: float
|
||||
severity: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class TrendInsightList(BaseModel):
|
||||
items: list[TrendInsightResponse]
|
||||
total: int
|
||||
|
||||
|
||||
class TrendSummary(BaseModel):
|
||||
brand_id: uuid.UUID
|
||||
period_days: int
|
||||
rising_count: int = 0
|
||||
declining_count: int = 0
|
||||
hotspot_count: int = 0
|
||||
top_keywords: list[str] = Field(default_factory=list)
|
||||
platform_highlights: dict = Field(default_factory=dict)
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
from .optimization_advisor import (
|
||||
generate_suggestions,
|
||||
generate_rule_based_suggestions,
|
||||
generate_llm_suggestions,
|
||||
build_context_from_scoring_result,
|
||||
SuggestionItem,
|
||||
BrandAnalysisContext,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"generate_suggestions",
|
||||
"generate_rule_based_suggestions",
|
||||
"generate_llm_suggestions",
|
||||
"build_context_from_scoring_result",
|
||||
"SuggestionItem",
|
||||
"BrandAnalysisContext",
|
||||
]
|
||||
|
|
@ -25,7 +25,8 @@ from dataclasses import dataclass, field
|
|||
from typing import Any
|
||||
|
||||
from app.config import settings
|
||||
from app.services.scoring_service import ScoringResultV2
|
||||
from app.services.scoring.scoring_service import ScoringResultV2
|
||||
from app.utils.json_extractor import extract_json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -525,7 +526,7 @@ async def generate_llm_suggestions(
|
|||
raise ValueError("LLM返回空响应")
|
||||
|
||||
# 提取JSON
|
||||
json_str = _extract_json(content)
|
||||
json_str = extract_json(content)
|
||||
result = json.loads(json_str)
|
||||
|
||||
# 解析建议
|
||||
|
|
@ -573,32 +574,6 @@ async def generate_llm_suggestions(
|
|||
return generate_rule_based_suggestions(ctx)
|
||||
|
||||
|
||||
def _extract_json(text: str) -> str:
|
||||
"""从文本中提取JSON"""
|
||||
import re
|
||||
|
||||
# 尝试直接解析
|
||||
try:
|
||||
json.loads(text)
|
||||
return text
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 尝试从代码块中提取
|
||||
json_pattern = r'```(?:json)?\s*([\s\S]*?)\s*```'
|
||||
match = re.search(json_pattern, text)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
|
||||
# 尝试找到第一个{到最后一个}之间的内容
|
||||
first_brace = text.find('{')
|
||||
last_brace = text.rfind('}')
|
||||
if first_brace != -1 and last_brace != -1 and last_brace > first_brace:
|
||||
return text[first_brace:last_brace + 1]
|
||||
|
||||
raise ValueError(f"无法从响应中提取JSON: {text[:200]}")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 主入口:生成优化建议
|
||||
# ============================================================
|
||||
|
|
@ -6,6 +6,10 @@ from .doubao import DoubaoAdapter
|
|||
from .gemini import GeminiAdapter
|
||||
from .kimi import KimiAdapter
|
||||
from .perplexity import PerplexityAdapter
|
||||
from .platform_bridge import (
|
||||
execute_single_platform,
|
||||
query_platform_raw,
|
||||
)
|
||||
from .qwen import QwenAdapter
|
||||
from .wenxin import WenxinAdapter
|
||||
from .yuanbao import YuanbaoAdapter
|
||||
|
|
@ -25,4 +29,6 @@ __all__ = [
|
|||
"QwenAdapter",
|
||||
"GeminiAdapter",
|
||||
"BatchQueryService",
|
||||
"execute_single_platform",
|
||||
"query_platform_raw",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,295 @@
|
|||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from urllib.parse import quote
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import httpx
|
||||
|
||||
from .base import AIEngineAdapter, AIQueryResult, EngineType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_PLATFORM_NAME_MAP: dict[str, EngineType] = {
|
||||
"wenxin": EngineType.WENXIN,
|
||||
"kimi": EngineType.KIMI,
|
||||
"doubao": EngineType.DOUBAO,
|
||||
"tongyi": EngineType.QWEN,
|
||||
"deepseek": EngineType.DEEPSEEK,
|
||||
"chatgpt": EngineType.CHATGPT,
|
||||
"perplexity": EngineType.PERPLEXITY,
|
||||
"gemini": EngineType.GEMINI,
|
||||
"yuanbao": EngineType.YUANBAO,
|
||||
}
|
||||
|
||||
_SEARCH_ONLY_PLATFORMS = {"qingyan", "tiangong", "xinghuo"}
|
||||
|
||||
|
||||
def get_engine_type_for_platform(platform_name: str) -> EngineType | None:
|
||||
return _PLATFORM_NAME_MAP.get(platform_name)
|
||||
|
||||
|
||||
def is_search_only_platform(platform_name: str) -> bool:
|
||||
return platform_name in _SEARCH_ONLY_PLATFORMS
|
||||
|
||||
|
||||
async def search_wikipedia(keyword: str, max_chars: int = 2000) -> str:
|
||||
search_url = "https://zh.wikipedia.org/w/api.php"
|
||||
headers = {
|
||||
"User-Agent": "GEO-Citation-Bot/1.0 (contact@example.com)",
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
search_resp = await client.get(
|
||||
search_url,
|
||||
headers=headers,
|
||||
params={
|
||||
"action": "query",
|
||||
"list": "search",
|
||||
"srsearch": keyword,
|
||||
"srlimit": 3,
|
||||
"format": "json",
|
||||
"origin": "*",
|
||||
},
|
||||
)
|
||||
search_resp.raise_for_status()
|
||||
search_data = search_resp.json()
|
||||
|
||||
search_results = search_data.get("query", {}).get("search", [])
|
||||
if not search_results:
|
||||
return ""
|
||||
|
||||
title = search_results[0]["title"]
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
extract_resp = await client.get(
|
||||
search_url,
|
||||
headers=headers,
|
||||
params={
|
||||
"action": "query",
|
||||
"prop": "extracts",
|
||||
"titles": title,
|
||||
"explaintext": True,
|
||||
"exsentences": 15,
|
||||
"format": "json",
|
||||
"origin": "*",
|
||||
},
|
||||
)
|
||||
extract_resp.raise_for_status()
|
||||
extract_data = extract_resp.json()
|
||||
|
||||
pages = extract_data.get("query", {}).get("pages", {})
|
||||
for page in pages.values():
|
||||
extract = page.get("extract", "")
|
||||
if extract:
|
||||
extract = re.sub(r'\[\d+\]', '', extract)
|
||||
extract = re.sub(r'\s+', ' ', extract).strip()
|
||||
return extract[:max_chars]
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def _strip_html(raw: str) -> str:
|
||||
raw = raw.replace(" ", " ")
|
||||
raw = raw.replace(""", '"')
|
||||
raw = raw.replace("&", "&")
|
||||
raw = raw.replace("<", "<")
|
||||
raw = raw.replace(">", ">")
|
||||
raw = raw.replace("'", "'")
|
||||
text = re.sub(r"<[^>]+>", "", raw)
|
||||
text = re.sub(r"\s+", " ", text).strip()
|
||||
return text
|
||||
|
||||
|
||||
async def search_duckduckgo(query: str, max_results: int = 5) -> str:
|
||||
url = f"https://html.duckduckgo.com/html/?q={quote(query)}"
|
||||
headers = {
|
||||
"User-Agent": (
|
||||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
||||
"AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36"
|
||||
),
|
||||
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
|
||||
"Accept-Language": "zh-CN,zh;q=0.9,en-US;q=0.8,en;q=0.7",
|
||||
}
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30, follow_redirects=True) as client:
|
||||
resp = await client.get(url, headers=headers)
|
||||
resp.raise_for_status()
|
||||
html = resp.text
|
||||
|
||||
if "web-result" not in html and "result__snippet" not in html and "result__title" not in html:
|
||||
raise RuntimeError("DuckDuckGo 返回了非结果页面")
|
||||
|
||||
results: list[str] = []
|
||||
|
||||
result_blocks = re.findall(
|
||||
r'<div class="result[^"]*"[^>]*>.*?<h[^>]*class="result__title"[^>]*>.*?<a[^>]*>(.*?)</a>.*?</h[^>]*>.*?<a[^>]*class="result__snippet"[^>]*>(.*?)</a>.*?</div>',
|
||||
html,
|
||||
re.DOTALL | re.IGNORECASE,
|
||||
)
|
||||
if result_blocks:
|
||||
for title_raw, snippet_raw in result_blocks[:max_results]:
|
||||
title = _strip_html(title_raw)
|
||||
snippet = _strip_html(snippet_raw)
|
||||
if title or snippet:
|
||||
results.append(f"{title}\n{snippet}")
|
||||
|
||||
if not results:
|
||||
snippets = re.findall(
|
||||
r'<a[^>]*class="result__snippet"[^>]*>(.*?)</a>', html, re.DOTALL | re.IGNORECASE
|
||||
)
|
||||
titles = re.findall(
|
||||
r'<h[^>]*class="result__title"[^>]*>.*?<a[^>]*>(.*?)</a>.*?</h[^>]*>',
|
||||
html,
|
||||
re.DOTALL | re.IGNORECASE,
|
||||
)
|
||||
for i in range(min(len(titles), len(snippets), max_results)):
|
||||
title = _strip_html(titles[i])
|
||||
snippet = _strip_html(snippets[i])
|
||||
if title or snippet:
|
||||
results.append(f"{title}\n{snippet}")
|
||||
|
||||
if results:
|
||||
return "\n\n".join(results)
|
||||
|
||||
raise RuntimeError("DuckDuckGo 未解析到结果")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"DuckDuckGo 搜索失败: {e},回退到 Wikipedia")
|
||||
wiki_text = await search_wikipedia(query, max_chars=2000)
|
||||
if wiki_text:
|
||||
return wiki_text
|
||||
raise RuntimeError(f"所有搜索源均失败: {e}")
|
||||
|
||||
|
||||
async def fetch_search_content(platform_name: str, keyword: str) -> str:
|
||||
logger.info(f"[{platform_name}] 搜索查询: {keyword}")
|
||||
return await search_duckduckgo(keyword, max_results=5)
|
||||
|
||||
|
||||
class SearchOnlyAdapter(AIEngineAdapter):
|
||||
def __init__(self, platform_name: str, **kwargs):
|
||||
self._platform_name = platform_name
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def get_engine_type(self) -> EngineType:
|
||||
return EngineType.DEEPSEEK
|
||||
|
||||
def _get_env_key(self) -> str | None:
|
||||
return ""
|
||||
|
||||
async def query(
|
||||
self,
|
||||
query: str,
|
||||
brand_name: str,
|
||||
competitor_names: list[str] | None = None,
|
||||
) -> AIQueryResult:
|
||||
start_time = time.perf_counter()
|
||||
content = await fetch_search_content(self._platform_name, query)
|
||||
elapsed_ms = int((time.perf_counter() - start_time) * 1000)
|
||||
|
||||
has_brand, has_comp, brand_ctx, comp_ctx = self._detect_citations(
|
||||
content, brand_name, competitor_names
|
||||
)
|
||||
|
||||
return AIQueryResult(
|
||||
engine_type=self.get_engine_type(),
|
||||
query=query,
|
||||
raw_response=content,
|
||||
citations=[],
|
||||
has_brand_citation=has_brand,
|
||||
has_competitor_citation=has_comp,
|
||||
brand_context=brand_ctx,
|
||||
competitor_contexts=comp_ctx,
|
||||
response_time_ms=elapsed_ms,
|
||||
timestamp=datetime.now(UTC),
|
||||
metadata={"platform_name": self._platform_name, "mode": "search_only"},
|
||||
)
|
||||
|
||||
|
||||
async def query_platform_raw(
|
||||
platform_name: str,
|
||||
keyword: str,
|
||||
brand_name: str = "",
|
||||
competitor_names: list[str] | None = None,
|
||||
) -> str:
|
||||
from .batch_query import _build_adapters
|
||||
|
||||
if is_search_only_platform(platform_name):
|
||||
content = await fetch_search_content(platform_name, keyword)
|
||||
return f"[data_source: search_engine]\n{content}"
|
||||
|
||||
engine_type = get_engine_type_for_platform(platform_name)
|
||||
if engine_type is None:
|
||||
raise ValueError(f"不支持的平台: {platform_name}")
|
||||
|
||||
adapters = _build_adapters()
|
||||
adapter = adapters.get(engine_type.value)
|
||||
if adapter is None:
|
||||
raise ValueError(f"平台 {platform_name} 适配器未注册")
|
||||
|
||||
result = await adapter.query(keyword, brand_name, competitor_names)
|
||||
return f"[data_source: ai_platform]\n{result.raw_response}"
|
||||
|
||||
|
||||
_SUPPORTED_PLATFORMS = {
|
||||
"wenxin", "kimi", "doubao", "tongyi",
|
||||
"qingyan", "tiangong", "xinghuo",
|
||||
}
|
||||
|
||||
|
||||
async def execute_single_platform(
|
||||
keyword: str,
|
||||
platform: str,
|
||||
target_brand: str,
|
||||
brand_aliases: list,
|
||||
) -> dict:
|
||||
if platform not in _SUPPORTED_PLATFORMS:
|
||||
raise ValueError(f"不支持的平台: {platform}")
|
||||
|
||||
from app.workers.citation_extractor import analyze_citations
|
||||
|
||||
search_keyword = f"{keyword} {target_brand}"
|
||||
raw_response = await query_platform_raw(
|
||||
platform_name=platform,
|
||||
keyword=search_keyword,
|
||||
brand_name=target_brand,
|
||||
)
|
||||
|
||||
citation_analysis = analyze_citations(raw_response)
|
||||
|
||||
from app.workers.citation_engine import BrandMatcher, CompetitorDetector
|
||||
|
||||
matcher = BrandMatcher(target_brand=target_brand, brand_aliases=brand_aliases)
|
||||
match_result = matcher.match(citation_analysis.clean_response)
|
||||
|
||||
competitor_detector = CompetitorDetector()
|
||||
competitor_brands = competitor_detector.detect(
|
||||
citation_analysis.clean_response, target_brand
|
||||
)
|
||||
|
||||
source_urls = [
|
||||
c.source_url for c in citation_analysis.citations if c.source_url
|
||||
]
|
||||
source_titles = [
|
||||
c.source_title for c in citation_analysis.citations if c.source_title
|
||||
]
|
||||
citation_contexts = [
|
||||
c.citation_context for c in citation_analysis.citations if c.citation_context
|
||||
]
|
||||
|
||||
return {
|
||||
"cited": match_result["cited"],
|
||||
"confidence": match_result["confidence"],
|
||||
"match_type": match_result["match_type"],
|
||||
"position": match_result["position"],
|
||||
"citation_text": match_result["citation_text"],
|
||||
"competitor_brands": competitor_brands,
|
||||
"raw_response": raw_response,
|
||||
"data_source": citation_analysis.data_source,
|
||||
"source_urls": source_urls,
|
||||
"source_titles": source_titles,
|
||||
"citation_contexts": citation_contexts,
|
||||
"ai_response_text": citation_analysis.clean_response,
|
||||
}
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
from .alert_engine import (
|
||||
AlertEngine,
|
||||
AlertContext,
|
||||
DEFAULT_ALERT_CONFIGS,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AlertEngine",
|
||||
"AlertContext",
|
||||
"DEFAULT_ALERT_CONFIGS",
|
||||
]
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
from .sentiment_service import (
|
||||
SentimentAnalysisService,
|
||||
SentimentResult,
|
||||
SentimentCache,
|
||||
get_sentiment_service,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"SentimentAnalysisService",
|
||||
"SentimentResult",
|
||||
"SentimentCache",
|
||||
"get_sentiment_service",
|
||||
]
|
||||
|
|
@ -5,11 +5,11 @@ import asyncio
|
|||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from app.config import settings
|
||||
from app.utils.json_extractor import extract_json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -276,31 +276,11 @@ class SentimentAnalysisService:
|
|||
raise RuntimeError("API返回空响应")
|
||||
|
||||
# 提取JSON
|
||||
json_str = self._extract_json(content)
|
||||
return json.loads(json_str)
|
||||
|
||||
def _extract_json(self, text: str) -> str:
|
||||
"""从文本中提取JSON"""
|
||||
# 尝试直接解析
|
||||
try:
|
||||
json.loads(text)
|
||||
return text
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 尝试从代码块中提取
|
||||
json_pattern = r"```(?:json)?\s*([\s\S]*?)\s*```"
|
||||
match = re.search(json_pattern, text)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
|
||||
# 尝试找到第一个{到最后一个}之间的内容
|
||||
first_brace = text.find("{")
|
||||
last_brace = text.rfind("}")
|
||||
if first_brace != -1 and last_brace != -1 and last_brace > first_brace:
|
||||
return text[first_brace : last_brace + 1]
|
||||
|
||||
raise RuntimeError(f"无法从响应中提取JSON: {text[:200]}")
|
||||
json_str = extract_json(content)
|
||||
except ValueError as e:
|
||||
raise RuntimeError(str(e)) from e
|
||||
return json.loads(json_str)
|
||||
|
||||
def _parse_response(self, response: dict) -> SentimentResult:
|
||||
"""解析API响应"""
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
from app.services.attribution.attribution_engine import AttributionEngine
|
||||
from app.services.attribution.roi_calculator import ROICalculator
|
||||
|
||||
__all__ = ["AttributionEngine", "ROICalculator"]
|
||||
|
|
@ -0,0 +1,150 @@
|
|||
import logging
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.attribution_record import AttributionRecord
|
||||
from app.models.diagnosis_record import DiagnosisRecord
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AttributionEngine:
|
||||
async def start_tracking(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
brand_id: uuid.UUID,
|
||||
content_id: uuid.UUID | None,
|
||||
user_id: str,
|
||||
) -> AttributionRecord:
|
||||
baseline_score = await self._get_latest_score(db, brand_id)
|
||||
|
||||
now = datetime.now(UTC)
|
||||
record = AttributionRecord(
|
||||
user_id=user_id,
|
||||
brand_id=brand_id,
|
||||
content_id=content_id,
|
||||
baseline_score=baseline_score,
|
||||
published_at=now,
|
||||
window_end_at=now + timedelta(days=28),
|
||||
status="tracking",
|
||||
)
|
||||
db.add(record)
|
||||
await db.commit()
|
||||
await db.refresh(record)
|
||||
return record
|
||||
|
||||
async def check_attribution(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
record_id: uuid.UUID,
|
||||
) -> AttributionRecord:
|
||||
stmt = select(AttributionRecord).where(AttributionRecord.id == record_id)
|
||||
result = await db.execute(stmt)
|
||||
record = result.scalar_one_or_none()
|
||||
if not record:
|
||||
raise ValueError(f"AttributionRecord {record_id} not found")
|
||||
|
||||
current_score = await self._get_latest_score(db, record.brand_id)
|
||||
record.current_score = current_score
|
||||
record.score_delta = round(current_score - record.baseline_score, 2)
|
||||
|
||||
baseline_dims = await self._get_latest_dimensions(db, record.brand_id, record.published_at)
|
||||
current_dims = await self._get_latest_dimensions(db, record.brand_id, None)
|
||||
if baseline_dims and current_dims:
|
||||
record.attributed_dimensions = self._compute_dimension_deltas(
|
||||
baseline_dims, current_dims
|
||||
)
|
||||
|
||||
now = datetime.now(UTC)
|
||||
if record.window_end_at:
|
||||
window_end = record.window_end_at
|
||||
if window_end.tzinfo is None:
|
||||
window_end = window_end.replace(tzinfo=UTC)
|
||||
if now >= window_end:
|
||||
record.status = "completed"
|
||||
elif record.score_delta and record.score_delta > 0:
|
||||
record.status = "tracking"
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(record)
|
||||
return record
|
||||
|
||||
async def get_brand_attribution_summary(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
brand_id: uuid.UUID,
|
||||
) -> dict:
|
||||
stmt = (
|
||||
select(AttributionRecord)
|
||||
.where(AttributionRecord.brand_id == brand_id)
|
||||
.order_by(AttributionRecord.created_at.desc())
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
records = result.scalars().all()
|
||||
|
||||
total_delta = sum(r.score_delta or 0 for r in records)
|
||||
tracking_count = sum(1 for r in records if r.status == "tracking")
|
||||
completed_count = sum(1 for r in records if r.status == "completed")
|
||||
positive_count = sum(1 for r in records if (r.score_delta or 0) > 0)
|
||||
|
||||
return {
|
||||
"brand_id": str(brand_id),
|
||||
"records": records,
|
||||
"total_score_delta": round(total_delta, 2),
|
||||
"tracking_count": tracking_count,
|
||||
"completed_count": completed_count,
|
||||
"positive_count": positive_count,
|
||||
}
|
||||
|
||||
async def _get_latest_score(self, db: AsyncSession, brand_id: uuid.UUID) -> float:
|
||||
stmt = (
|
||||
select(DiagnosisRecord)
|
||||
.where(
|
||||
DiagnosisRecord.brand_id == brand_id,
|
||||
DiagnosisRecord.status == "completed",
|
||||
)
|
||||
.order_by(DiagnosisRecord.completed_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
record = result.scalar_one_or_none()
|
||||
if record and record.overall_score is not None:
|
||||
return float(record.overall_score)
|
||||
logger.warning("No completed DiagnosisRecord for brand %s, using 0 as baseline", brand_id)
|
||||
return 0.0
|
||||
|
||||
async def _get_latest_dimensions(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
brand_id: uuid.UUID,
|
||||
before: datetime | None,
|
||||
) -> dict | None:
|
||||
stmt = (
|
||||
select(DiagnosisRecord)
|
||||
.where(
|
||||
DiagnosisRecord.brand_id == brand_id,
|
||||
DiagnosisRecord.status == "completed",
|
||||
)
|
||||
.order_by(DiagnosisRecord.completed_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
if before:
|
||||
stmt = stmt.where(DiagnosisRecord.completed_at < before)
|
||||
result = await db.execute(stmt)
|
||||
record = result.scalar_one_or_none()
|
||||
if record and record.result_json:
|
||||
return record.result_json.get("dimensions")
|
||||
return None
|
||||
|
||||
def _compute_dimension_deltas(self, before_dims: list, after_dims: list) -> dict:
|
||||
before_map = {d.get("name"): d.get("score", 0) for d in before_dims}
|
||||
after_map = {d.get("name"): d.get("score", 0) for d in after_dims}
|
||||
deltas = {}
|
||||
for name in after_map:
|
||||
b = before_map.get(name, 0)
|
||||
a = after_map[name]
|
||||
deltas[name] = {"before": b, "after": a, "delta": round(a - b, 2)}
|
||||
return deltas
|
||||
|
|
@ -0,0 +1,59 @@
|
|||
from app.models.attribution_record import AttributionRecord
|
||||
|
||||
|
||||
class ROICalculator:
|
||||
INDUSTRY_AVG_CITATION_VALUE = 50.0
|
||||
|
||||
def calculate_roi(
|
||||
self,
|
||||
subscription_cost: float,
|
||||
score_delta: float,
|
||||
attribution_records: list[AttributionRecord],
|
||||
) -> dict:
|
||||
value_generated = score_delta * self.INDUSTRY_AVG_CITATION_VALUE
|
||||
if subscription_cost > 0:
|
||||
roi_percentage = round(
|
||||
(value_generated - subscription_cost) / subscription_cost * 100, 2
|
||||
)
|
||||
else:
|
||||
roi_percentage = 0.0
|
||||
break_even_delta = self.estimate_break_even(subscription_cost)
|
||||
return {
|
||||
"roi_percentage": roi_percentage,
|
||||
"value_generated": round(value_generated, 2),
|
||||
"cost": subscription_cost,
|
||||
"break_even_delta": round(break_even_delta, 2),
|
||||
}
|
||||
|
||||
def generate_ab_comparison(
|
||||
self,
|
||||
before_score: float,
|
||||
after_score: float,
|
||||
before_dimensions: dict,
|
||||
after_dimensions: dict,
|
||||
) -> dict:
|
||||
overall_delta = round(after_score - before_score, 2)
|
||||
dimensions = []
|
||||
all_names = set(list(before_dimensions.keys()) + list(after_dimensions.keys()))
|
||||
for name in all_names:
|
||||
b = before_dimensions.get(name, {}).get("score", 0)
|
||||
a = after_dimensions.get(name, {}).get("score", 0)
|
||||
delta = round(a - b, 2)
|
||||
dimensions.append({
|
||||
"name": name,
|
||||
"before": b,
|
||||
"after": a,
|
||||
"delta": delta,
|
||||
"improved": delta > 0,
|
||||
})
|
||||
return {
|
||||
"overall_before": before_score,
|
||||
"overall_after": after_score,
|
||||
"overall_delta": overall_delta,
|
||||
"dimensions": dimensions,
|
||||
}
|
||||
|
||||
def estimate_break_even(self, subscription_cost: float) -> float:
|
||||
if self.INDUSTRY_AVG_CITATION_VALUE == 0:
|
||||
return 0.0
|
||||
return subscription_cost / self.INDUSTRY_AVG_CITATION_VALUE
|
||||
|
|
@ -71,6 +71,8 @@ async def register_user(db: AsyncSession, user_data: UserRegister) -> User:
|
|||
email=user_data.email,
|
||||
password=hash_password(user_data.password),
|
||||
username=user_data.name,
|
||||
plan="free",
|
||||
max_queries=5,
|
||||
)
|
||||
db.add(user)
|
||||
await db.commit()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,34 @@
|
|||
from .citation import (
|
||||
get_citations,
|
||||
get_citation_stats,
|
||||
trigger_query_now,
|
||||
export_citations_pdf,
|
||||
export_citations_csv,
|
||||
PLATFORM_NAMES,
|
||||
)
|
||||
|
||||
from .citation_pattern import (
|
||||
CitationPatternEngine,
|
||||
CitationPattern,
|
||||
PatternAnalysisReport,
|
||||
ContentStructureAnalyzer,
|
||||
AuthoritySignalAnalyzer,
|
||||
CitationFormatAnalyzer,
|
||||
EnginePreferenceAnalyzer,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"get_citations",
|
||||
"get_citation_stats",
|
||||
"trigger_query_now",
|
||||
"export_citations_pdf",
|
||||
"export_citations_csv",
|
||||
"PLATFORM_NAMES",
|
||||
"CitationPatternEngine",
|
||||
"CitationPattern",
|
||||
"PatternAnalysisReport",
|
||||
"ContentStructureAnalyzer",
|
||||
"AuthoritySignalAnalyzer",
|
||||
"CitationFormatAnalyzer",
|
||||
"EnginePreferenceAnalyzer",
|
||||
]
|
||||
|
|
@ -1,9 +1,15 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import csv
|
||||
import io
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.services.scoring.scoring_service import ScoringResultV2
|
||||
|
||||
from sqlalchemy import func, select, and_, cast, Integer
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
|
@ -13,7 +19,7 @@ from app.database import AsyncSessionLocal
|
|||
from app.models.citation_record import CitationRecord
|
||||
from app.models.query import Query
|
||||
from app.models.query_task import QueryTask
|
||||
from app.workers.citation_engine import CitationEngine
|
||||
from app.services.ai_engine.platform_bridge import execute_single_platform as _execute_single_platform_bridge
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -304,8 +310,10 @@ async def _execute_query_tasks(
|
|||
brand_aliases: list,
|
||||
user_id: uuid.UUID | None = None,
|
||||
):
|
||||
"""后台执行查询任务"""
|
||||
engine = CitationEngine()
|
||||
"""后台执行查询任务 — 通过 Agent 框架执行,失败时回退到直接引擎"""
|
||||
from app.agent_framework.agents.citation_detector import CitationDetectorAgent
|
||||
|
||||
agent = CitationDetectorAgent()
|
||||
try:
|
||||
async with AsyncSessionLocal() as db:
|
||||
# 验证 query 归属该用户
|
||||
|
|
@ -330,7 +338,8 @@ async def _execute_query_tasks(
|
|||
task.error_message = None
|
||||
await db.commit()
|
||||
|
||||
citation_result = await engine.execute_single_platform(
|
||||
citation_result = await _execute_single_platform_via_agent(
|
||||
agent=agent,
|
||||
keyword=keyword,
|
||||
platform=task.platform,
|
||||
target_brand=target_brand,
|
||||
|
|
@ -338,16 +347,10 @@ async def _execute_query_tasks(
|
|||
)
|
||||
|
||||
if citation_result:
|
||||
record = CitationRecord(
|
||||
record = CitationRecord.from_citation_result(
|
||||
query_id=query_id,
|
||||
platform=task.platform,
|
||||
cited=citation_result.get("cited", False),
|
||||
citation_position=citation_result.get("position"),
|
||||
citation_text=citation_result.get("citation_text"),
|
||||
competitor_brands=citation_result.get("competitor_brands", []),
|
||||
raw_response=citation_result.get("raw_response", ""),
|
||||
confidence=citation_result.get("confidence"),
|
||||
match_type=citation_result.get("match_type"),
|
||||
result=citation_result,
|
||||
)
|
||||
db.add(record)
|
||||
|
||||
|
|
@ -366,7 +369,34 @@ async def _execute_query_tasks(
|
|||
except Exception as e:
|
||||
logger.error(f"查询引擎执行失败: {e}")
|
||||
finally:
|
||||
await engine.close()
|
||||
await agent.close()
|
||||
|
||||
|
||||
async def _execute_single_platform_via_agent(
|
||||
agent,
|
||||
keyword: str,
|
||||
platform: str,
|
||||
target_brand: str,
|
||||
brand_aliases: list,
|
||||
) -> dict:
|
||||
try:
|
||||
return await agent.execute_single_platform_compat(
|
||||
keyword=keyword,
|
||||
platform=platform,
|
||||
target_brand=target_brand,
|
||||
brand_aliases=brand_aliases,
|
||||
)
|
||||
except Exception as agent_err:
|
||||
logger.warning(
|
||||
f"Agent 框架执行单平台检测失败 ({platform}): {agent_err},"
|
||||
"回退到直接引擎"
|
||||
)
|
||||
return await _execute_single_platform_bridge(
|
||||
keyword=keyword,
|
||||
platform=platform,
|
||||
target_brand=target_brand,
|
||||
brand_aliases=brand_aliases,
|
||||
)
|
||||
|
||||
|
||||
PLATFORM_NAMES = {
|
||||
|
|
@ -386,6 +416,7 @@ async def export_citations_pdf(
|
|||
db: AsyncSession,
|
||||
user_id: uuid.UUID,
|
||||
query_id: uuid.UUID | None = None,
|
||||
v2_result: ScoringResultV2 | None = None,
|
||||
) -> bytes:
|
||||
"""生成PDF格式报告"""
|
||||
import os
|
||||
|
|
@ -505,6 +536,20 @@ async def export_citations_pdf(
|
|||
pdf.cell(col_widths[i], 7, d, border=1)
|
||||
pdf.ln()
|
||||
|
||||
if v2_result is not None:
|
||||
pdf.add_page()
|
||||
pdf.set_font_size(16)
|
||||
pdf.cell(0, 12, "四、V2 品牌可见性评分", new_x="LMARGIN", new_y="NEXT")
|
||||
pdf.set_font_size(11)
|
||||
pdf.cell(0, 8, f"综合评分: {v2_result.overall_score:.2f}/100", new_x="LMARGIN", new_y="NEXT")
|
||||
pdf.cell(0, 8, f"健康等级: {v2_result.health_level}", new_x="LMARGIN", new_y="NEXT")
|
||||
pdf.ln(5)
|
||||
pdf.cell(0, 8, f"提及率: {v2_result.mention_rate.score:.2f}/{v2_result.mention_rate.max_score:.0f} ({v2_result.mention_rate.percentage:.1f}%)", new_x="LMARGIN", new_y="NEXT")
|
||||
pdf.cell(0, 8, f"推荐排名: {v2_result.recommendation_rank.score:.2f}/{v2_result.recommendation_rank.max_score:.0f} ({v2_result.recommendation_rank.percentage:.1f}%)", new_x="LMARGIN", new_y="NEXT")
|
||||
pdf.cell(0, 8, f"情感倾向: {v2_result.sentiment_score.score:.2f}/{v2_result.sentiment_score.max_score:.0f} ({v2_result.sentiment_score.percentage:.1f}%)", new_x="LMARGIN", new_y="NEXT")
|
||||
pdf.cell(0, 8, f"引用质量: {v2_result.citation_quality.score:.2f}/{v2_result.citation_quality.max_score:.0f} ({v2_result.citation_quality.percentage:.1f}%)", new_x="LMARGIN", new_y="NEXT")
|
||||
pdf.cell(0, 8, f"竞品对比: {v2_result.competitive_position.score:.2f}/{v2_result.competitive_position.max_score:.0f} ({v2_result.competitive_position.percentage:.1f}%)", new_x="LMARGIN", new_y="NEXT")
|
||||
|
||||
return pdf.output()
|
||||
|
||||
|
||||
|
|
@ -512,6 +557,7 @@ async def export_citations_csv(
|
|||
db: AsyncSession,
|
||||
user_id: uuid.UUID,
|
||||
query_id: uuid.UUID,
|
||||
v2_result: ScoringResultV2 | None = None,
|
||||
) -> str:
|
||||
query = await _verify_query_ownership(db, query_id, user_id)
|
||||
if query is None:
|
||||
|
|
@ -527,7 +573,7 @@ async def export_citations_csv(
|
|||
|
||||
output = io.StringIO()
|
||||
writer = csv.writer(output)
|
||||
writer.writerow([
|
||||
headers = [
|
||||
"查询关键词",
|
||||
"目标品牌",
|
||||
"查询日期",
|
||||
|
|
@ -538,7 +584,18 @@ async def export_citations_csv(
|
|||
"匹配置信度",
|
||||
"匹配类型",
|
||||
"竞争品牌",
|
||||
])
|
||||
]
|
||||
if v2_result is not None:
|
||||
headers.extend([
|
||||
"overall_score",
|
||||
"health_level",
|
||||
"mention_rate",
|
||||
"recommendation_rank",
|
||||
"sentiment_score",
|
||||
"citation_quality",
|
||||
"competitive_position",
|
||||
])
|
||||
writer.writerow(headers)
|
||||
|
||||
total_queries = len(records)
|
||||
total_citations = 0
|
||||
|
|
@ -570,7 +627,7 @@ async def export_citations_csv(
|
|||
if record.confidence is not None:
|
||||
confidence_str = f"{record.confidence:.2f}"
|
||||
|
||||
writer.writerow([
|
||||
row = [
|
||||
query.keyword,
|
||||
query.target_brand,
|
||||
date_str,
|
||||
|
|
@ -581,7 +638,18 @@ async def export_citations_csv(
|
|||
confidence_str,
|
||||
match_type_display,
|
||||
", ".join(record.competitor_brands) if record.competitor_brands else "",
|
||||
])
|
||||
]
|
||||
if v2_result is not None:
|
||||
row.extend([
|
||||
round(v2_result.overall_score, 2),
|
||||
v2_result.health_level,
|
||||
round(v2_result.mention_rate.score, 2),
|
||||
round(v2_result.recommendation_rank.score, 2),
|
||||
round(v2_result.sentiment_score.score, 2),
|
||||
round(v2_result.citation_quality.score, 2),
|
||||
round(v2_result.competitive_position.score, 2),
|
||||
])
|
||||
writer.writerow(row)
|
||||
|
||||
# 汇总统计
|
||||
writer.writerow([])
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue