feat: P0-P2功能实现 + GEO workflow分析与规划
P0 紧急修复: - 实现诊断分析页面(SEO+GEO 5+6维度) - 修复E2E测试: dashboard标题统一为'品牌健康中心' - 修复建议模块API路径不一致 - 修复告警模块HTTP方法不匹配(POST→PATCH) P1 功能实现: - 实现监测优化页面(告警列表+配置) - 实现组织管理页面(成员/角色/邀请) - 实现SEO诊断5维度后端服务(63测试) - 实现GEO诊断6维度后端服务(40测试) - 实现诊断API端点(TDD, 6测试) - 前端诊断页面连接真实API P2 功能实现: - 实现告警设置API端点(TDD, 11测试) - 实现套餐额度预警服务(TDD, 37测试) - 实现邮件通知服务(TDD, 30测试) GEO Workflow分析: - 10步闭环流程设计(替代原7步) - 7个缺失节点技术方案 - 4阶段12周开发计划 - 完整技术架构设计
This commit is contained in:
parent
cbedb09383
commit
65e2f3c380
|
|
@ -0,0 +1,16 @@
|
||||||
|
# CodeGraph data files
|
||||||
|
# These are local to each machine and should not be committed
|
||||||
|
|
||||||
|
# Database
|
||||||
|
*.db
|
||||||
|
*.db-wal
|
||||||
|
*.db-shm
|
||||||
|
|
||||||
|
# Cache
|
||||||
|
cache/
|
||||||
|
|
||||||
|
# Logs
|
||||||
|
*.log
|
||||||
|
|
||||||
|
# Hook markers
|
||||||
|
.dirty
|
||||||
|
|
@ -36,6 +36,12 @@ ZHIPU_API_KEY=
|
||||||
# 通义千问 (可选)
|
# 通义千问 (可选)
|
||||||
TONGYI_API_KEY=
|
TONGYI_API_KEY=
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# 阿里云百炼(图片生成)
|
||||||
|
# ============================================================
|
||||||
|
# 万相-文生图V1 API Key
|
||||||
|
ALIYUN_DASHSCOPE_API_KEY=
|
||||||
|
|
||||||
# ============================================================
|
# ============================================================
|
||||||
# LLM Provider 配置
|
# LLM Provider 配置
|
||||||
# ============================================================
|
# ============================================================
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,6 @@
|
||||||
|
DATABASE_URL=sqlite+aiosqlite:///./test.db
|
||||||
|
REDIS_URL=redis://localhost:6379
|
||||||
|
ENVIRONMENT=testing
|
||||||
|
LOG_LEVEL=info
|
||||||
|
SECRET_KEY=test-secret-key-for-testing-only
|
||||||
|
CORS_ORIGINS=http://localhost:3000
|
||||||
|
|
@ -0,0 +1,143 @@
|
||||||
|
"""Add knowledge graph tables
|
||||||
|
|
||||||
|
Revision ID: f7a8b9c0de56
|
||||||
|
Revises: e5f7g9h1cd45
|
||||||
|
Create Date: 2026-05-24 12:00:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
import enum
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "f7a8b9c0de56"
|
||||||
|
down_revision: Union[str, None] = "e5f7g9h1cd45"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ------------------------------------------------------------------ #
|
||||||
|
# 1. 创建实体类型枚举
|
||||||
|
# ------------------------------------------------------------------ #
|
||||||
|
entity_type_enum = enum.Enum(
|
||||||
|
'EntityType',
|
||||||
|
[
|
||||||
|
'ORGANIZATION', 'PRODUCT', 'PERSON', 'LOCATION',
|
||||||
|
'TECHNOLOGY', 'BRAND', 'EVENT', 'CONCEPT', 'OTHER'
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------ #
|
||||||
|
# 2. 创建关系类型枚举
|
||||||
|
# ------------------------------------------------------------------ #
|
||||||
|
relation_type_enum = enum.Enum(
|
||||||
|
'RelationType',
|
||||||
|
[
|
||||||
|
'COMPETES_WITH', 'PARTNERS_WITH', 'ACQUIRES', 'SUBSIDIARY_OF',
|
||||||
|
'PRODUCES', 'USES_TECHNOLOGY', 'PART_OF',
|
||||||
|
'LOCATED_IN', 'FOUNDED_IN',
|
||||||
|
'CEO_OF', 'FOUNDER_OF',
|
||||||
|
'RELATED_TO', 'MENTIONED_IN', 'ALSO_KNOWN_AS'
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------ #
|
||||||
|
# 3. knowledge_entities 表
|
||||||
|
# ------------------------------------------------------------------ #
|
||||||
|
op.create_table(
|
||||||
|
"knowledge_entities",
|
||||||
|
sa.Column(
|
||||||
|
"id",
|
||||||
|
postgresql.UUID(as_uuid=True),
|
||||||
|
primary_key=True,
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"knowledge_base_id",
|
||||||
|
postgresql.UUID(as_uuid=True),
|
||||||
|
sa.ForeignKey("knowledge_bases.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("name", sa.String(500), nullable=False),
|
||||||
|
sa.Column("entity_type", sa.Enum(entity_type_enum, name="entitytype"), nullable=False),
|
||||||
|
sa.Column("description", sa.Text, nullable=True),
|
||||||
|
sa.Column("properties", postgresql.JSONB(astext_type=sa.Text()), nullable=True, server_default="{}"),
|
||||||
|
sa.Column(
|
||||||
|
"source_chunk_id",
|
||||||
|
postgresql.UUID(as_uuid=True),
|
||||||
|
sa.ForeignKey("knowledge_chunks.id", ondelete="SET NULL"),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
sa.Column("confidence", sa.String(20), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"updated_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
op.create_index("ix_entities_name", "knowledge_entities", ["name"])
|
||||||
|
op.create_index("ix_entities_kb_name", "knowledge_entities", ["knowledge_base_id", "name"])
|
||||||
|
op.create_index("ix_entities_kb_type", "knowledge_entities", ["knowledge_base_id", "entity_type"])
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------ #
|
||||||
|
# 4. knowledge_relations 表
|
||||||
|
# ------------------------------------------------------------------ #
|
||||||
|
op.create_table(
|
||||||
|
"knowledge_relations",
|
||||||
|
sa.Column(
|
||||||
|
"id",
|
||||||
|
postgresql.UUID(as_uuid=True),
|
||||||
|
primary_key=True,
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"source_entity_id",
|
||||||
|
postgresql.UUID(as_uuid=True),
|
||||||
|
sa.ForeignKey("knowledge_entities.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"target_entity_id",
|
||||||
|
postgresql.UUID(as_uuid=True),
|
||||||
|
sa.ForeignKey("knowledge_entities.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("relation_type", sa.Enum(relation_type_enum, name="relationtype"), nullable=False),
|
||||||
|
sa.Column("properties", postgresql.JSONB(astext_type=sa.Text()), nullable=True, server_default="{}"),
|
||||||
|
sa.Column(
|
||||||
|
"source_chunk_id",
|
||||||
|
postgresql.UUID(as_uuid=True),
|
||||||
|
sa.ForeignKey("knowledge_chunks.id", ondelete="SET NULL"),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
sa.Column("confidence", sa.String(20), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
op.create_index("ix_relations_source", "knowledge_relations", ["source_entity_id"])
|
||||||
|
op.create_index("ix_relations_target", "knowledge_relations", ["target_entity_id"])
|
||||||
|
op.create_index("ix_relations_type", "knowledge_relations", ["relation_type"])
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# 删除表(注意外键约束会自动处理)
|
||||||
|
op.drop_table("knowledge_relations")
|
||||||
|
op.drop_table("knowledge_entities")
|
||||||
|
# 删除枚举类型
|
||||||
|
op.execute("DROP TYPE IF EXISTS relationtype")
|
||||||
|
op.execute("DROP TYPE IF EXISTS entitytype")
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
"""Alerts API endpoints - 告警通知接口"""
|
"""Alerts API endpoints - 告警通知接口"""
|
||||||
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||||
|
|
@ -9,6 +10,7 @@ from app.api.deps import get_current_user
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.models.alert import Alert
|
from app.models.alert import Alert
|
||||||
from app.models.alert_setting import AlertSetting
|
from app.models.alert_setting import AlertSetting
|
||||||
|
from app.models.brand import Brand
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.schemas.alert import (
|
from app.schemas.alert import (
|
||||||
AlertResponse,
|
AlertResponse,
|
||||||
|
|
@ -24,9 +26,35 @@ from app.schemas.alert import (
|
||||||
)
|
)
|
||||||
from app.services.alert_engine import AlertEngine
|
from app.services.alert_engine import AlertEngine
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
async def verify_brand_ownership(
|
||||||
|
brand_id: uuid.UUID,
|
||||||
|
current_user: User,
|
||||||
|
db: AsyncSession,
|
||||||
|
) -> Brand:
|
||||||
|
"""验证品牌属于当前用户"""
|
||||||
|
stmt = select(Brand).where(
|
||||||
|
and_(
|
||||||
|
Brand.id == brand_id,
|
||||||
|
Brand.user_id == current_user.id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
brand = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not brand:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"品牌 {brand_id} 不存在或不属于当前用户",
|
||||||
|
)
|
||||||
|
|
||||||
|
return brand
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
# ============================================================
|
||||||
# 告警接口
|
# 告警接口
|
||||||
# ============================================================
|
# ============================================================
|
||||||
|
|
@ -207,21 +235,7 @@ async def update_alert_settings(
|
||||||
|
|
||||||
for item in data.settings:
|
for item in data.settings:
|
||||||
# 验证品牌属于当前用户
|
# 验证品牌属于当前用户
|
||||||
from app.models.brand import Brand
|
await verify_brand_ownership(item.brand_id, current_user, db)
|
||||||
brand_stmt = select(Brand).where(
|
|
||||||
and_(
|
|
||||||
Brand.id == item.brand_id,
|
|
||||||
Brand.user_id == current_user.id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
brand_result = await db.execute(brand_stmt)
|
|
||||||
brand = brand_result.scalar_one_or_none()
|
|
||||||
|
|
||||||
if not brand:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail=f"品牌 {item.brand_id} 不存在或不属于当前用户",
|
|
||||||
)
|
|
||||||
|
|
||||||
# 查找现有设置
|
# 查找现有设置
|
||||||
existing_stmt = select(AlertSetting).where(
|
existing_stmt = select(AlertSetting).where(
|
||||||
|
|
@ -257,6 +271,7 @@ async def update_alert_settings(
|
||||||
for setting in updated_settings:
|
for setting in updated_settings:
|
||||||
await db.refresh(setting)
|
await db.refresh(setting)
|
||||||
|
|
||||||
|
logger.info(f"批量更新告警设置: user={current_user.id}, count={len(updated_settings)}")
|
||||||
return {"items": updated_settings, "total": len(updated_settings)}
|
return {"items": updated_settings, "total": len(updated_settings)}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -292,3 +307,74 @@ async def update_single_setting(
|
||||||
await db.refresh(setting)
|
await db.refresh(setting)
|
||||||
|
|
||||||
return setting
|
return setting
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/settings", response_model=AlertSettingResponse, status_code=status.HTTP_201_CREATED)
|
||||||
|
async def create_alert_setting(
|
||||||
|
data: AlertSettingCreate,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""创建告警设置"""
|
||||||
|
# 验证品牌属于当前用户
|
||||||
|
await verify_brand_ownership(data.brand_id, current_user, db)
|
||||||
|
|
||||||
|
# 检查是否已存在相同品牌+类型的设置
|
||||||
|
existing_stmt = select(AlertSetting).where(
|
||||||
|
and_(
|
||||||
|
AlertSetting.brand_id == data.brand_id,
|
||||||
|
AlertSetting.alert_type == data.alert_type,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
existing_result = await db.execute(existing_stmt)
|
||||||
|
existing = existing_result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if existing:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"品牌 {data.brand_id} 的告警类型 {data.alert_type} 已存在",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建新设置
|
||||||
|
setting = AlertSetting(
|
||||||
|
brand_id=data.brand_id,
|
||||||
|
user_id=current_user.id,
|
||||||
|
alert_type=data.alert_type,
|
||||||
|
enabled=data.enabled,
|
||||||
|
threshold=data.threshold,
|
||||||
|
)
|
||||||
|
db.add(setting)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(setting)
|
||||||
|
|
||||||
|
logger.info(f"创建告警设置: user={current_user.id}, brand={data.brand_id}, type={data.alert_type}")
|
||||||
|
return setting
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/settings/{setting_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
async def delete_alert_setting(
|
||||||
|
setting_id: uuid.UUID,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""删除告警设置"""
|
||||||
|
stmt = select(AlertSetting).where(
|
||||||
|
and_(
|
||||||
|
AlertSetting.id == setting_id,
|
||||||
|
AlertSetting.user_id == current_user.id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
setting = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not setting:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="告警设置不存在",
|
||||||
|
)
|
||||||
|
|
||||||
|
await db.delete(setting)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
logger.info(f"删除告警设置: user={current_user.id}, setting={setting_id}")
|
||||||
|
return None
|
||||||
|
|
|
||||||
|
|
@ -237,4 +237,165 @@ async def generate_topics(
|
||||||
|
|
||||||
return {"status": "success", "topics": topics}
|
return {"status": "success", "topics": topics}
|
||||||
except LLMError as e:
|
except LLMError as e:
|
||||||
raise HTTPException(status_code=502, detail=str(e))
|
raise HTTPException(status_code=502, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 母题库接口 ====================
|
||||||
|
|
||||||
|
class TopicGenerateRequest(BaseModel):
|
||||||
|
"""母题生成请求"""
|
||||||
|
params: dict # 母题模板参数
|
||||||
|
platform: str = "通用"
|
||||||
|
style: str = "专业严谨"
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/topics")
|
||||||
|
async def list_topics():
|
||||||
|
"""获取所有母题库列表"""
|
||||||
|
from app.services.content.topic_templates import list_topic_templates
|
||||||
|
templates = list_topic_templates()
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"id": t.id,
|
||||||
|
"name": t.name,
|
||||||
|
"description": t.description,
|
||||||
|
"icon": t.icon,
|
||||||
|
"recommended_platforms": t.recommended_platforms,
|
||||||
|
"word_count_range": list(t.word_count_range),
|
||||||
|
"required_params": t.required_params,
|
||||||
|
"optional_params": t.optional_params,
|
||||||
|
}
|
||||||
|
for t in templates
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/topics/{topic_id}")
|
||||||
|
async def get_topic(topic_id: str):
|
||||||
|
"""获取母题详情"""
|
||||||
|
from app.services.content.topic_templates import get_topic_template
|
||||||
|
template = get_topic_template(topic_id)
|
||||||
|
if not template:
|
||||||
|
raise HTTPException(status_code=404, detail="Topic not found")
|
||||||
|
return {
|
||||||
|
"id": template.id,
|
||||||
|
"name": template.name,
|
||||||
|
"description": template.description,
|
||||||
|
"icon": template.icon,
|
||||||
|
"prompt_template": template.prompt_template,
|
||||||
|
"seo_tips": template.seo_tips,
|
||||||
|
"recommended_platforms": template.recommended_platforms,
|
||||||
|
"word_count_range": list(template.word_count_range),
|
||||||
|
"required_params": template.required_params,
|
||||||
|
"optional_params": template.optional_params,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/topics/{topic_id}/generate")
|
||||||
|
async def generate_with_topic(
|
||||||
|
topic_id: str,
|
||||||
|
request: TopicGenerateRequest,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""使用母题生成内容"""
|
||||||
|
from app.services.content.topic_templates import get_topic_template, render_topic_prompt
|
||||||
|
from app.services.llm import LLMError, LLMFactory
|
||||||
|
from app.agent_framework.prompts import DEAI_TEMPLATE, GEO_OPTIMIZER_TEMPLATE
|
||||||
|
|
||||||
|
template = get_topic_template(topic_id)
|
||||||
|
if not template:
|
||||||
|
raise HTTPException(status_code=404, detail="Topic not found")
|
||||||
|
|
||||||
|
# 验证必填参数
|
||||||
|
for param in template.required_params:
|
||||||
|
if param not in request.params:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Missing required parameter: {param}"
|
||||||
|
)
|
||||||
|
|
||||||
|
org_id = getattr(current_user, "organization_id", None)
|
||||||
|
if not org_id:
|
||||||
|
raise HTTPException(status_code=403, detail="用户未关联组织")
|
||||||
|
|
||||||
|
try:
|
||||||
|
provider = LLMFactory.get_default()
|
||||||
|
|
||||||
|
# 渲染Prompt
|
||||||
|
prompt = render_topic_prompt(topic_id, request.params)
|
||||||
|
|
||||||
|
# 调用内容生成
|
||||||
|
response = await provider.chat(
|
||||||
|
[{"role": "user", "content": prompt}],
|
||||||
|
temperature=0.7,
|
||||||
|
max_tokens=4000
|
||||||
|
)
|
||||||
|
content = response.content
|
||||||
|
|
||||||
|
# 去AI化处理
|
||||||
|
deai_variables = {
|
||||||
|
"original_content": content,
|
||||||
|
"target_style": "自然流畅",
|
||||||
|
"preserve_structure": "是",
|
||||||
|
}
|
||||||
|
messages = DEAI_TEMPLATE.render(deai_variables)
|
||||||
|
response = await provider.chat(messages, temperature=0.9, max_tokens=len(content) * 2)
|
||||||
|
content = response.content
|
||||||
|
|
||||||
|
# GEO优化
|
||||||
|
geo_variables = {
|
||||||
|
"original_content": content,
|
||||||
|
"target_keywords": request.params.get("keywords", ""),
|
||||||
|
"target_platform": request.platform,
|
||||||
|
"optimization_level": "moderate",
|
||||||
|
}
|
||||||
|
messages = GEO_OPTIMIZER_TEMPLATE.render(geo_variables)
|
||||||
|
response = await provider.chat(messages, temperature=0.5, max_tokens=len(content) * 2)
|
||||||
|
optimized = response.content
|
||||||
|
|
||||||
|
# 存入数据库
|
||||||
|
content_obj = Content(
|
||||||
|
organization_id=org_id,
|
||||||
|
title=request.params.get("product_name") or request.params.get("topic") or topic_id,
|
||||||
|
content_type="article",
|
||||||
|
body=optimized,
|
||||||
|
status="draft",
|
||||||
|
target_platforms=[request.platform],
|
||||||
|
keywords=[request.params.get("keywords", "")],
|
||||||
|
extra_metadata={
|
||||||
|
"original_content": content,
|
||||||
|
"topic_id": topic_id,
|
||||||
|
"topic_name": template.name,
|
||||||
|
"brand_name": request.params.get("brand_name", ""),
|
||||||
|
"content_style": request.style,
|
||||||
|
},
|
||||||
|
created_by=current_user.id,
|
||||||
|
current_version=1,
|
||||||
|
)
|
||||||
|
db.add(content_obj)
|
||||||
|
await db.flush()
|
||||||
|
|
||||||
|
version = ContentVersion(
|
||||||
|
content_id=content_obj.id,
|
||||||
|
version_number=1,
|
||||||
|
title=content_obj.title,
|
||||||
|
body=optimized,
|
||||||
|
change_summary="母题库自动生成",
|
||||||
|
created_by=current_user.id,
|
||||||
|
)
|
||||||
|
db.add(version)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(content_obj)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"topic_id": topic_id,
|
||||||
|
"content": content,
|
||||||
|
"optimized_content": optimized,
|
||||||
|
"content_id": str(content_obj.id),
|
||||||
|
"seo_tips": template.seo_tips,
|
||||||
|
}
|
||||||
|
|
||||||
|
except LLMError as e:
|
||||||
|
raise HTTPException(status_code=502, detail=f"LLM调用失败: {str(e)}")
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"内容生成异常: {str(e)}")
|
||||||
|
|
@ -0,0 +1,150 @@
|
||||||
|
"""诊断API端点 - 提供SEO和GEO诊断功能"""
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.api.deps import get_current_user
|
||||||
|
from app.database import get_db
|
||||||
|
from app.models.user import User
|
||||||
|
from app.models.brand import Brand
|
||||||
|
from app.services.seo_diagnosis import SEODiagnosisService
|
||||||
|
from app.services.geo_diagnosis import GEODiagnosisService, GEODiagnosisInput
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/seo/{brand_id}")
|
||||||
|
async def get_seo_diagnosis(
|
||||||
|
brand_id: uuid.UUID,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取品牌的SEO诊断结果
|
||||||
|
|
||||||
|
返回5维度SEO诊断:
|
||||||
|
- 技术SEO (25分)
|
||||||
|
- 页面SEO (20分)
|
||||||
|
- 内容质量 (20分)
|
||||||
|
- 外链分析 (15分)
|
||||||
|
- 用户体验 (20分)
|
||||||
|
"""
|
||||||
|
brand = await _get_brand_or_404(brand_id, current_user, db)
|
||||||
|
|
||||||
|
try:
|
||||||
|
service = SEODiagnosisService()
|
||||||
|
result = service.diagnose()
|
||||||
|
|
||||||
|
logger.info(f"SEO诊断完成: brand_id={brand_id}, brand={brand.name}, score={result.overall_score}")
|
||||||
|
|
||||||
|
return result.to_dict()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"SEO诊断失败: brand_id={brand_id}, error={str(e)}", exc_info=True)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="SEO诊断服务异常,请稍后重试",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/geo/{brand_id}")
|
||||||
|
async def get_geo_diagnosis(
|
||||||
|
brand_id: uuid.UUID,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取品牌的GEO诊断结果
|
||||||
|
|
||||||
|
返回6维度GEO诊断:
|
||||||
|
- 内容可提取性 (20分)
|
||||||
|
- 实体清晰度 (15分)
|
||||||
|
- E-E-A-T信号 (20分)
|
||||||
|
- Schema标记 (15分)
|
||||||
|
- 主题权威 (15分)
|
||||||
|
- 引用就绪度 (15分)
|
||||||
|
"""
|
||||||
|
brand = await _get_brand_or_404(brand_id, current_user, db)
|
||||||
|
|
||||||
|
try:
|
||||||
|
input_data = GEODiagnosisInput()
|
||||||
|
service = GEODiagnosisService()
|
||||||
|
result = service.diagnose(input_data)
|
||||||
|
|
||||||
|
logger.info(f"GEO诊断完成: brand_id={brand_id}, brand={brand.name}, score={result.overall_score}")
|
||||||
|
|
||||||
|
return result.to_dict()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"GEO诊断失败: brand_id={brand_id}, error={str(e)}", exc_info=True)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="GEO诊断服务异常,请稍后重试",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/combined/{brand_id}")
|
||||||
|
async def get_combined_diagnosis(
|
||||||
|
brand_id: uuid.UUID,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取品牌的综合诊断结果
|
||||||
|
|
||||||
|
结合SEO和GEO诊断,返回综合评分和详细诊断结果
|
||||||
|
"""
|
||||||
|
brand = await _get_brand_or_404(brand_id, current_user, db)
|
||||||
|
|
||||||
|
try:
|
||||||
|
seo_service = SEODiagnosisService()
|
||||||
|
seo_result = seo_service.diagnose()
|
||||||
|
|
||||||
|
geo_service = GEODiagnosisService()
|
||||||
|
geo_result = geo_service.diagnose(GEODiagnosisInput())
|
||||||
|
|
||||||
|
combined_score = round((seo_result.overall_score + geo_result.overall_score) / 2, 2)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"综合诊断完成: brand_id={brand_id}, brand={brand.name}, "
|
||||||
|
f"seo_score={seo_result.overall_score}, "
|
||||||
|
f"geo_score={geo_result.overall_score}, "
|
||||||
|
f"combined_score={combined_score}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"seo_score": seo_result.overall_score,
|
||||||
|
"geo_score": geo_result.overall_score,
|
||||||
|
"combined_score": combined_score,
|
||||||
|
"seo_diagnosis": seo_result.to_dict(),
|
||||||
|
"geo_diagnosis": geo_result.to_dict(),
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"综合诊断失败: brand_id={brand_id}, error={str(e)}", exc_info=True)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="综合诊断服务异常,请稍后重试",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_brand_or_404(
|
||||||
|
brand_id: uuid.UUID,
|
||||||
|
current_user: User,
|
||||||
|
db: AsyncSession,
|
||||||
|
) -> Brand:
|
||||||
|
"""获取品牌或抛出404异常"""
|
||||||
|
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == current_user.id)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
brand = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not brand:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="品牌不存在",
|
||||||
|
)
|
||||||
|
|
||||||
|
return brand
|
||||||
|
|
@ -0,0 +1,182 @@
|
||||||
|
"""图片生成 API 路由"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, status
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from app.services.image_generator import (
|
||||||
|
IMAGE_STYLES,
|
||||||
|
LAYOUT_OPTIONS,
|
||||||
|
ImageGenerator,
|
||||||
|
ImageGenerationError,
|
||||||
|
PLATFORM_IMAGE_SPECS,
|
||||||
|
)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/image", tags=["图片生成"])
|
||||||
|
|
||||||
|
|
||||||
|
class GenerateCoverRequest(BaseModel):
|
||||||
|
"""生成封面图请求"""
|
||||||
|
title: str = Field(..., description="文章标题")
|
||||||
|
platform: str = Field(..., description="目标平台")
|
||||||
|
image_type: str = Field(default="cover", description="图片类型: cover/inline")
|
||||||
|
style: str = Field(default="modern", description="风格选项")
|
||||||
|
layout: str = Field(default="centered", description="排版选项")
|
||||||
|
custom_prompt: Optional[str] = Field(default=None, description="自定义提示词")
|
||||||
|
|
||||||
|
|
||||||
|
class ImageResultResponse(BaseModel):
|
||||||
|
"""图片生成结果响应"""
|
||||||
|
url: str
|
||||||
|
width: int
|
||||||
|
height: int
|
||||||
|
prompt: str
|
||||||
|
platform: str
|
||||||
|
task_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class ImageSpecsResponse(BaseModel):
|
||||||
|
"""平台图片规格响应"""
|
||||||
|
platform: str
|
||||||
|
specs: dict
|
||||||
|
|
||||||
|
|
||||||
|
class StyleOption(BaseModel):
|
||||||
|
"""风格选项"""
|
||||||
|
value: str
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class LayoutOption(BaseModel):
|
||||||
|
"""排版选项"""
|
||||||
|
value: str
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class ImageConfigResponse(BaseModel):
|
||||||
|
"""图片生成配置响应"""
|
||||||
|
platforms: list[str]
|
||||||
|
styles: list[StyleOption]
|
||||||
|
layouts: list[LayoutOption]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/generate-cover", response_model=ImageResultResponse)
|
||||||
|
async def generate_cover(request: GenerateCoverRequest):
|
||||||
|
"""生成封面图
|
||||||
|
|
||||||
|
基于阿里云百炼(万相-文生图V1)生成封面图,自动适配目标平台的尺寸规格。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: 生成请求参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ImageResultResponse: 包含图片URL和元数据
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: 当平台不支持或API调用失败时
|
||||||
|
"""
|
||||||
|
# 验证平台是否支持
|
||||||
|
supported_platforms = ImageGenerator.get_supported_platforms()
|
||||||
|
if request.platform not in supported_platforms:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"不支持的平台: {request.platform},支持的平台: {', '.join(supported_platforms)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 验证风格选项
|
||||||
|
if request.style not in IMAGE_STYLES:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"不支持的风格: {request.style},支持的风格: {', '.join(IMAGE_STYLES.keys())}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 验证排版选项
|
||||||
|
if request.layout not in LAYOUT_OPTIONS:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"不支持的排版: {request.layout},支持的排版: {', '.join(LAYOUT_OPTIONS.keys())}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 验证图片类型
|
||||||
|
if request.image_type not in ("cover", "inline"):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="图片类型只支持: cover, inline",
|
||||||
|
)
|
||||||
|
|
||||||
|
generator = ImageGenerator()
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await generator.generate_cover(
|
||||||
|
title=request.title,
|
||||||
|
platform=request.platform,
|
||||||
|
image_type=request.image_type,
|
||||||
|
style=request.style,
|
||||||
|
layout=request.layout,
|
||||||
|
custom_prompt=request.custom_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ImageResultResponse(
|
||||||
|
url=result.url,
|
||||||
|
width=result.width,
|
||||||
|
height=result.height,
|
||||||
|
prompt=result.prompt,
|
||||||
|
platform=result.platform,
|
||||||
|
task_id=result.task_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
except ImageGenerationError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"图片生成失败: {str(e)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/platforms", response_model=list[str])
|
||||||
|
async def get_supported_platforms():
|
||||||
|
"""获取支持的平台列表"""
|
||||||
|
return ImageGenerator.get_supported_platforms()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/platforms/{platform}/specs", response_model=ImageSpecsResponse)
|
||||||
|
async def get_platform_specs(platform: str):
|
||||||
|
"""获取指定平台的图片规格
|
||||||
|
|
||||||
|
Args:
|
||||||
|
platform: 平台标识
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ImageSpecsResponse: 平台图片规格
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: 当平台不支持时
|
||||||
|
"""
|
||||||
|
specs = ImageGenerator.get_platform_specs(platform)
|
||||||
|
if specs is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"不支持的平台: {platform}",
|
||||||
|
)
|
||||||
|
|
||||||
|
return ImageSpecsResponse(platform=platform, specs=specs)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/config", response_model=ImageConfigResponse)
|
||||||
|
async def get_image_config():
|
||||||
|
"""获取图片生成配置(风格、排版选项等)"""
|
||||||
|
styles = [
|
||||||
|
StyleOption(value=key, name=value["name"])
|
||||||
|
for key, value in IMAGE_STYLES.items()
|
||||||
|
]
|
||||||
|
|
||||||
|
layouts = [
|
||||||
|
LayoutOption(value=key, name=value["name"])
|
||||||
|
for key, value in LAYOUT_OPTIONS.items()
|
||||||
|
]
|
||||||
|
|
||||||
|
return ImageConfigResponse(
|
||||||
|
platforms=ImageGenerator.get_supported_platforms(),
|
||||||
|
styles=styles,
|
||||||
|
layouts=layouts,
|
||||||
|
)
|
||||||
|
|
@ -29,10 +29,15 @@ from app.schemas.knowledge import (
|
||||||
KnowledgeBaseCreate,
|
KnowledgeBaseCreate,
|
||||||
KnowledgeBaseResponse,
|
KnowledgeBaseResponse,
|
||||||
KnowledgeSearchRequest,
|
KnowledgeSearchRequest,
|
||||||
|
RetrieveRequest,
|
||||||
SearchResponse,
|
SearchResponse,
|
||||||
SearchResultItem,
|
SearchResultItem,
|
||||||
|
UpdateDocumentRequest,
|
||||||
)
|
)
|
||||||
from app.services.knowledge import MockEmbedder, RAGService
|
from app.services.knowledge import MockEmbedder, RAGService
|
||||||
|
from app.services.knowledge.enhanced_rag import EnhancedRAG
|
||||||
|
from app.services.knowledge.incremental_index import IncrementalIndexService
|
||||||
|
from app.services.knowledge.chunker import ChunkerFactory
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
@ -499,3 +504,151 @@ async def knowledge_search(
|
||||||
total=len(items),
|
total=len(items),
|
||||||
latency_ms=latency_ms,
|
latency_ms=latency_ms,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/bases/{kb_id}/chunks/preview")
|
||||||
|
async def preview_chunks(
|
||||||
|
kb_id: uuid.UUID,
|
||||||
|
text: str,
|
||||||
|
strategy: str = "recursive",
|
||||||
|
chunk_size: int = 500,
|
||||||
|
):
|
||||||
|
"""预览分块效果"""
|
||||||
|
chunker = ChunkerFactory.create(strategy)
|
||||||
|
|
||||||
|
# 临时修改chunk_size
|
||||||
|
if strategy == "recursive":
|
||||||
|
chunker.STRATEGY.chunk_size = chunk_size
|
||||||
|
elif strategy == "semantic":
|
||||||
|
chunker.STRATEGY.chunk_size = chunk_size * 1.5 # 语义块可以更大
|
||||||
|
elif strategy == "fixed":
|
||||||
|
chunker.STRATEGY.chunk_size = chunk_size
|
||||||
|
|
||||||
|
preview = chunker.preview(text, max_chunks=10)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"strategy": strategy,
|
||||||
|
"chunk_count": len(preview),
|
||||||
|
"preview": preview,
|
||||||
|
"strategies": [
|
||||||
|
{
|
||||||
|
"name": s.name,
|
||||||
|
"description": s.description,
|
||||||
|
"recommended_size": s.chunk_size,
|
||||||
|
}
|
||||||
|
for s in ChunkerFactory.list_strategies()
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 增量索引 API
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.post("/bases/{kb_id}/documents/{doc_id}/reindex")
|
||||||
|
async def reindex_document(
|
||||||
|
kb_id: uuid.UUID,
|
||||||
|
doc_id: uuid.UUID,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""重新索引单个文档"""
|
||||||
|
org_id = current_user.organization_id
|
||||||
|
if not org_id:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found")
|
||||||
|
|
||||||
|
await _get_kb(db, kb_id, org_id)
|
||||||
|
|
||||||
|
index_service = IncrementalIndexService(_rag_service)
|
||||||
|
result = await index_service.add_document(
|
||||||
|
db, str(kb_id), str(doc_id)
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/bases/{kb_id}/documents/{doc_id}/update")
|
||||||
|
async def update_document_content(
|
||||||
|
kb_id: uuid.UUID,
|
||||||
|
doc_id: uuid.UUID,
|
||||||
|
request: UpdateDocumentRequest,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""更新文档内容(增量)"""
|
||||||
|
org_id = current_user.organization_id
|
||||||
|
if not org_id:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found")
|
||||||
|
|
||||||
|
await _get_kb(db, kb_id, org_id)
|
||||||
|
|
||||||
|
index_service = IncrementalIndexService(_rag_service)
|
||||||
|
result = await index_service.update_document(
|
||||||
|
db, str(doc_id), request.content
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/bases/{kb_id}/documents/{doc_id}")
|
||||||
|
async def delete_document_incremental(
|
||||||
|
kb_id: uuid.UUID,
|
||||||
|
doc_id: uuid.UUID,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""删除文档"""
|
||||||
|
org_id = current_user.organization_id
|
||||||
|
if not org_id:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found")
|
||||||
|
|
||||||
|
await _get_kb(db, kb_id, org_id)
|
||||||
|
|
||||||
|
index_service = IncrementalIndexService(_rag_service)
|
||||||
|
result = await index_service.delete_document(db, str(doc_id))
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/bases/{kb_id}/rebuild")
|
||||||
|
async def rebuild_knowledge_base(
|
||||||
|
kb_id: uuid.UUID,
|
||||||
|
force: bool = False,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""重建知识库索引"""
|
||||||
|
org_id = current_user.organization_id
|
||||||
|
if not org_id:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found")
|
||||||
|
|
||||||
|
await _get_kb(db, kb_id, org_id)
|
||||||
|
|
||||||
|
index_service = IncrementalIndexService(_rag_service)
|
||||||
|
result = await index_service.rebuild_knowledge_base(
|
||||||
|
db, str(kb_id), force
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/bases/{kb_id}/retrieve")
|
||||||
|
async def enhanced_retrieve(
|
||||||
|
kb_id: uuid.UUID,
|
||||||
|
request: RetrieveRequest,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""增强检索(支持重排序和压缩)"""
|
||||||
|
org_id = current_user.organization_id
|
||||||
|
if not org_id:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found")
|
||||||
|
|
||||||
|
await _get_kb(db, kb_id, org_id)
|
||||||
|
|
||||||
|
enhanced_rag = EnhancedRAG(_rag_service, _rag_service.embedder)
|
||||||
|
results = await enhanced_rag.retrieve_with_rerank(
|
||||||
|
db,
|
||||||
|
request.query,
|
||||||
|
[str(kb_id)],
|
||||||
|
top_k=request.top_k or 5,
|
||||||
|
use_rerank=request.use_rerank,
|
||||||
|
use_compression=request.use_compression,
|
||||||
|
)
|
||||||
|
return {"results": results, "query": request.query}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,115 @@
|
||||||
|
"""知识图谱API"""
|
||||||
|
from typing import Optional
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.api.deps import get_db, get_current_user
|
||||||
|
from app.models.user import User
|
||||||
|
from app.services.knowledge.graph_builder import GraphBuilder
|
||||||
|
from app.services.knowledge.graph_query import GraphQuery
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/knowledge-bases", tags=["知识图谱"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{kb_id}/graph/build")
|
||||||
|
async def build_graph(
|
||||||
|
kb_id: UUID,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
从知识库构建知识图谱
|
||||||
|
|
||||||
|
对知识库中的所有Chunks执行实体和关系抽取
|
||||||
|
"""
|
||||||
|
# TODO: 实现批量构建
|
||||||
|
# 目前先实现单个Chunk的构建
|
||||||
|
return {"message": "Use /graph/build-chunk to build from specific chunk"}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{kb_id}/graph/build-chunk/{chunk_id}")
|
||||||
|
async def build_graph_from_chunk(
|
||||||
|
kb_id: UUID,
|
||||||
|
chunk_id: UUID,
|
||||||
|
context: Optional[str] = None,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""从单个Chunk构建图谱"""
|
||||||
|
builder = GraphBuilder()
|
||||||
|
|
||||||
|
try:
|
||||||
|
stats = await builder.build_from_chunk(db, str(chunk_id), context)
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"stats": stats,
|
||||||
|
}
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{kb_id}/graph/statistics")
|
||||||
|
async def get_graph_statistics(
|
||||||
|
kb_id: UUID,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""获取图谱统计信息"""
|
||||||
|
query = GraphQuery()
|
||||||
|
stats = await query.get_statistics(db, str(kb_id))
|
||||||
|
return stats
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{kb_id}/graph/entities/search")
|
||||||
|
async def search_entities(
|
||||||
|
kb_id: UUID,
|
||||||
|
q: str,
|
||||||
|
entity_type: Optional[str] = None,
|
||||||
|
limit: int = 20,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""搜索实体"""
|
||||||
|
query = GraphQuery()
|
||||||
|
entities = await query.search_entities(
|
||||||
|
db, str(kb_id), q, entity_type, limit
|
||||||
|
)
|
||||||
|
return {"entities": entities}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{kb_id}/graph/entities/{entity_id}")
|
||||||
|
async def get_entity(
|
||||||
|
kb_id: UUID,
|
||||||
|
entity_id: UUID,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""获取实体详情"""
|
||||||
|
query = GraphQuery()
|
||||||
|
entity = await query.get_entity(db, str(entity_id))
|
||||||
|
|
||||||
|
if not entity:
|
||||||
|
raise HTTPException(status_code=404, detail="Entity not found")
|
||||||
|
|
||||||
|
# 获取邻居
|
||||||
|
neighbors = await query.get_entity_neighbors(db, str(entity_id))
|
||||||
|
entity["neighbors"] = neighbors
|
||||||
|
|
||||||
|
return entity
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{kb_id}/graph/path")
|
||||||
|
async def find_path(
|
||||||
|
kb_id: UUID,
|
||||||
|
source: str,
|
||||||
|
target: str,
|
||||||
|
max_hops: int = 3,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""查找两个实体之间的路径"""
|
||||||
|
query = GraphQuery()
|
||||||
|
path = await query.get_entity_path(db, source, target, max_hops)
|
||||||
|
return {"path": path, "hops": len(path)}
|
||||||
|
|
@ -1,9 +1,23 @@
|
||||||
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
|
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
|
||||||
from sqlalchemy.orm import declarative_base
|
from sqlalchemy.orm import declarative_base
|
||||||
from sqlalchemy import text
|
from sqlalchemy import text, JSON
|
||||||
|
from sqlalchemy.types import TypeDecorator
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
|
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
|
|
||||||
|
|
||||||
|
class JSONType(TypeDecorator):
|
||||||
|
"""A JSON type that uses JSONB on PostgreSQL and JSON on other databases."""
|
||||||
|
impl = JSON
|
||||||
|
cache_ok = True
|
||||||
|
|
||||||
|
def load_dialect_impl(self, dialect):
|
||||||
|
if dialect.name == "postgresql":
|
||||||
|
return dialect.type_descriptor(JSONB())
|
||||||
|
return dialect.type_descriptor(JSON())
|
||||||
|
|
||||||
|
|
||||||
engine = create_async_engine(
|
engine = create_async_engine(
|
||||||
settings.DATABASE_URL,
|
settings.DATABASE_URL,
|
||||||
pool_size=10, # 连接池大小
|
pool_size=10, # 连接池大小
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,8 @@ from datetime import datetime, timezone
|
||||||
|
|
||||||
from fastapi import FastAPI, HTTPException, Request, Depends
|
from fastapi import FastAPI, HTTPException, Request, Depends
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse, Response
|
||||||
|
from prometheus_client import generate_latest, CONTENT_TYPE_LATEST
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
|
@ -31,9 +32,12 @@ from app.api.subscriptions import router as subscription_router
|
||||||
from app.api.alerts import router as alerts_router
|
from app.api.alerts import router as alerts_router
|
||||||
from app.api.dashboard import router as dashboard_router
|
from app.api.dashboard import router as dashboard_router
|
||||||
from app.api.brands import router as brands_router
|
from app.api.brands import router as brands_router
|
||||||
|
from app.api.diagnosis import router as diagnosis_router
|
||||||
from app.api.onboarding import router as onboarding_router
|
from app.api.onboarding import router as onboarding_router
|
||||||
from app.api.platforms import router as platforms_router
|
from app.api.platforms import router as platforms_router
|
||||||
from app.api.platform_rules import router as platform_rules_router
|
from app.api.platform_rules import router as platform_rules_router
|
||||||
|
from app.api.image import router as image_router
|
||||||
|
from app.api.knowledge_graph import router as knowledge_graph_router
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
from app.database import engine, Base
|
from app.database import engine, Base
|
||||||
from app.schemas.common import ErrorResponse, ErrorCode
|
from app.schemas.common import ErrorResponse, ErrorCode
|
||||||
|
|
@ -41,6 +45,7 @@ from app.middleware.rate_limit import RateLimitMiddleware
|
||||||
from app.middleware.logging_middleware import RequestLoggingMiddleware
|
from app.middleware.logging_middleware import RequestLoggingMiddleware
|
||||||
from app.middleware.request_id import RequestIdMiddleware
|
from app.middleware.request_id import RequestIdMiddleware
|
||||||
from app.middleware.metrics import MetricsMiddleware
|
from app.middleware.metrics import MetricsMiddleware
|
||||||
|
from app.monitoring.middleware import MonitoringMiddleware
|
||||||
from app.database import get_db
|
from app.database import get_db
|
||||||
from app.workers.scheduler import query_scheduler
|
from app.workers.scheduler import query_scheduler
|
||||||
|
|
||||||
|
|
@ -49,6 +54,9 @@ from app.workers.scheduler import query_scheduler
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
import app.models
|
import app.models
|
||||||
|
|
||||||
|
# 初始化监控模块
|
||||||
|
import app.monitoring
|
||||||
|
|
||||||
async with engine.begin() as conn:
|
async with engine.begin() as conn:
|
||||||
await conn.run_sync(Base.metadata.create_all)
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
|
||||||
|
|
@ -131,6 +139,7 @@ async def add_security_headers(request, call_next):
|
||||||
app.add_middleware(RequestLoggingMiddleware)
|
app.add_middleware(RequestLoggingMiddleware)
|
||||||
app.add_middleware(RateLimitMiddleware)
|
app.add_middleware(RateLimitMiddleware)
|
||||||
app.add_middleware(MetricsMiddleware)
|
app.add_middleware(MetricsMiddleware)
|
||||||
|
app.add_middleware(MonitoringMiddleware)
|
||||||
app.add_middleware(RequestIdMiddleware)
|
app.add_middleware(RequestIdMiddleware)
|
||||||
|
|
||||||
app.include_router(auth_router, prefix="/api/v1/auth", tags=["认证"])
|
app.include_router(auth_router, prefix="/api/v1/auth", tags=["认证"])
|
||||||
|
|
@ -150,9 +159,12 @@ app.include_router(analytics_router, prefix="/api/v1/analytics", tags=["监测
|
||||||
app.include_router(alerts_router, prefix="/api/v1/alerts", tags=["告警通知"])
|
app.include_router(alerts_router, prefix="/api/v1/alerts", tags=["告警通知"])
|
||||||
app.include_router(dashboard_router, prefix="/api/v1/dashboard", tags=["仪表盘"])
|
app.include_router(dashboard_router, prefix="/api/v1/dashboard", tags=["仪表盘"])
|
||||||
app.include_router(brands_router, prefix="/api/v1/brands", tags=["品牌管理"])
|
app.include_router(brands_router, prefix="/api/v1/brands", tags=["品牌管理"])
|
||||||
|
app.include_router(diagnosis_router, prefix="/api/v1/diagnosis", tags=["诊断服务"])
|
||||||
app.include_router(onboarding_router, prefix="/api/v1")
|
app.include_router(onboarding_router, prefix="/api/v1")
|
||||||
app.include_router(platforms_router, prefix="/api/v1")
|
app.include_router(platforms_router, prefix="/api/v1")
|
||||||
app.include_router(platform_rules_router)
|
app.include_router(platform_rules_router)
|
||||||
|
app.include_router(image_router, prefix="/api/v1")
|
||||||
|
app.include_router(knowledge_graph_router, prefix="/api/v1/knowledge-bases")
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health", tags=["可观测性"])
|
@app.get("/health", tags=["可观测性"])
|
||||||
|
|
@ -203,3 +215,90 @@ async def readiness_check(db: AsyncSession = Depends(get_db)):
|
||||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/metrics", tags=["可观测性"])
|
||||||
|
async def metrics():
|
||||||
|
"""Prometheus指标端点"""
|
||||||
|
return Response(
|
||||||
|
content=generate_latest(),
|
||||||
|
media_type=CONTENT_TYPE_LATEST
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---- 详细健康检查端点 ----
|
||||||
|
from app.services.health_checker import HealthChecker
|
||||||
|
from app.services.app_state import app_state
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health/detailed", tags=["可观测性"])
|
||||||
|
async def detailed_health(
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
详细健康检查
|
||||||
|
|
||||||
|
返回所有依赖组件的健康状态:
|
||||||
|
- database: 数据库连接
|
||||||
|
- redis: Redis缓存
|
||||||
|
- llm_providers: LLM服务提供商
|
||||||
|
- storage: 文件存储
|
||||||
|
|
||||||
|
状态:
|
||||||
|
- healthy: 所有组件正常
|
||||||
|
- degraded: 部分组件异常,但仍可服务
|
||||||
|
- unhealthy: 核心组件异常
|
||||||
|
"""
|
||||||
|
checker = HealthChecker(db, settings.REDIS_URL)
|
||||||
|
health_result = await checker.check_all()
|
||||||
|
|
||||||
|
# 添加应用信息
|
||||||
|
health_result["app"] = app_state.get_info()
|
||||||
|
|
||||||
|
return health_result
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health/liveness", tags=["可观测性"])
|
||||||
|
async def liveness():
|
||||||
|
"""
|
||||||
|
存活探针
|
||||||
|
|
||||||
|
用于Kubernetes livenessProbe
|
||||||
|
只要应用进程存活就返回200
|
||||||
|
"""
|
||||||
|
return {"status": "alive"}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health/readiness", tags=["可观测性"])
|
||||||
|
async def readiness(db: AsyncSession = Depends(get_db)):
|
||||||
|
"""
|
||||||
|
就绪探针
|
||||||
|
|
||||||
|
用于Kubernetes readinessProbe
|
||||||
|
检查核心依赖是否就绪
|
||||||
|
"""
|
||||||
|
checker = HealthChecker(db, settings.REDIS_URL)
|
||||||
|
|
||||||
|
# 只检查核心依赖:数据库和Redis
|
||||||
|
db_result = await checker.check_database()
|
||||||
|
redis_result = await checker.check_redis()
|
||||||
|
|
||||||
|
if db_result.healthy and redis_result.healthy:
|
||||||
|
return {
|
||||||
|
"status": "ready",
|
||||||
|
"checks": {
|
||||||
|
"database": db_result.healthy,
|
||||||
|
"redis": redis_result.healthy,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=503,
|
||||||
|
detail={
|
||||||
|
"status": "not_ready",
|
||||||
|
"checks": {
|
||||||
|
"database": db_result.healthy,
|
||||||
|
"redis": redis_result.healthy,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,12 @@ from app.models.knowledge import (
|
||||||
KnowledgeChunk,
|
KnowledgeChunk,
|
||||||
KnowledgeSearchLog,
|
KnowledgeSearchLog,
|
||||||
)
|
)
|
||||||
|
from app.models.knowledge_graph import (
|
||||||
|
KnowledgeEntity,
|
||||||
|
KnowledgeRelation,
|
||||||
|
EntityType,
|
||||||
|
RelationType,
|
||||||
|
)
|
||||||
from app.models.analytics import PublishRecord, ContentMetrics, OptimizationInsight
|
from app.models.analytics import PublishRecord, ContentMetrics, OptimizationInsight
|
||||||
from app.models.distribution import DistributionSchedule
|
from app.models.distribution import DistributionSchedule
|
||||||
# 缺失的模型导入 - 重构后遗留
|
# 缺失的模型导入 - 重构后遗留
|
||||||
|
|
@ -48,6 +54,10 @@ __all__ = [
|
||||||
"KnowledgeDocument",
|
"KnowledgeDocument",
|
||||||
"KnowledgeChunk",
|
"KnowledgeChunk",
|
||||||
"KnowledgeSearchLog",
|
"KnowledgeSearchLog",
|
||||||
|
"KnowledgeEntity",
|
||||||
|
"KnowledgeRelation",
|
||||||
|
"EntityType",
|
||||||
|
"RelationType",
|
||||||
"PublishRecord",
|
"PublishRecord",
|
||||||
"ContentMetrics",
|
"ContentMetrics",
|
||||||
"OptimizationInsight",
|
"OptimizationInsight",
|
||||||
|
|
|
||||||
|
|
@ -3,10 +3,9 @@ from datetime import datetime
|
||||||
|
|
||||||
from sqlalchemy import String, Integer, ForeignKey, Index, func, Text
|
from sqlalchemy import String, Integer, ForeignKey, Index, func, Text
|
||||||
from sqlalchemy import Uuid
|
from sqlalchemy import Uuid
|
||||||
from sqlalchemy.dialects.postgresql import JSONB
|
|
||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
from app.database import Base
|
from app.database import Base, JSONType
|
||||||
|
|
||||||
|
|
||||||
class AgentRegistry(Base):
|
class AgentRegistry(Base):
|
||||||
|
|
@ -24,7 +23,7 @@ class AgentRegistry(Base):
|
||||||
version: Mapped[str | None] = mapped_column(String(20), nullable=True)
|
version: Mapped[str | None] = mapped_column(String(20), nullable=True)
|
||||||
endpoint: Mapped[str | None] = mapped_column(String(500), nullable=True)
|
endpoint: Mapped[str | None] = mapped_column(String(500), nullable=True)
|
||||||
status: Mapped[str] = mapped_column(String(20), server_default="offline", nullable=False)
|
status: Mapped[str] = mapped_column(String(20), server_default="offline", nullable=False)
|
||||||
capabilities: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
capabilities: Mapped[dict | None] = mapped_column(JSONType, nullable=True)
|
||||||
last_heartbeat: Mapped[datetime | None] = mapped_column(nullable=True)
|
last_heartbeat: Mapped[datetime | None] = mapped_column(nullable=True)
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
server_default=func.now(),
|
server_default=func.now(),
|
||||||
|
|
@ -68,7 +67,7 @@ class AgentConfig(Base):
|
||||||
nullable=False,
|
nullable=False,
|
||||||
)
|
)
|
||||||
config_key: Mapped[str] = mapped_column(String(100), nullable=False)
|
config_key: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||||
config_value: Mapped[dict] = mapped_column(JSONB, nullable=False)
|
config_value: Mapped[dict] = mapped_column(JSONType, nullable=False)
|
||||||
description: Mapped[str | None] = mapped_column(String(500), nullable=True)
|
description: Mapped[str | None] = mapped_column(String(500), nullable=True)
|
||||||
updated_at: Mapped[datetime] = mapped_column(
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
server_default=func.now(),
|
server_default=func.now(),
|
||||||
|
|
@ -111,8 +110,8 @@ class AgentTask(Base):
|
||||||
task_type: Mapped[str] = mapped_column(String(50), nullable=False)
|
task_type: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||||
status: Mapped[str] = mapped_column(String(20), server_default="pending", nullable=False)
|
status: Mapped[str] = mapped_column(String(20), server_default="pending", nullable=False)
|
||||||
priority: Mapped[int] = mapped_column(Integer, server_default="0", nullable=False)
|
priority: Mapped[int] = mapped_column(Integer, server_default="0", nullable=False)
|
||||||
input_data: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
input_data: Mapped[dict | None] = mapped_column(JSONType, nullable=True)
|
||||||
output_data: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
output_data: Mapped[dict | None] = mapped_column(JSONType, nullable=True)
|
||||||
error_message: Mapped[str | None] = mapped_column(Text, nullable=True)
|
error_message: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
created_by: Mapped[uuid.UUID | None] = mapped_column(
|
created_by: Mapped[uuid.UUID | None] = mapped_column(
|
||||||
Uuid(as_uuid=True),
|
Uuid(as_uuid=True),
|
||||||
|
|
@ -184,7 +183,7 @@ class AgentTaskLog(Base):
|
||||||
)
|
)
|
||||||
log_level: Mapped[str] = mapped_column(String(10), nullable=False)
|
log_level: Mapped[str] = mapped_column(String(10), nullable=False)
|
||||||
message: Mapped[str] = mapped_column(Text, nullable=False)
|
message: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
extra_metadata: Mapped[dict | None] = mapped_column("metadata", JSONB, nullable=True)
|
extra_metadata: Mapped[dict | None] = mapped_column("metadata", JSONType, nullable=True)
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
server_default=func.now(),
|
server_default=func.now(),
|
||||||
nullable=False,
|
nullable=False,
|
||||||
|
|
|
||||||
|
|
@ -3,10 +3,9 @@ from datetime import datetime
|
||||||
|
|
||||||
from sqlalchemy import String, Integer, Boolean, ForeignKey, Index, func, Text
|
from sqlalchemy import String, Integer, Boolean, ForeignKey, Index, func, Text
|
||||||
from sqlalchemy import Uuid
|
from sqlalchemy import Uuid
|
||||||
from sqlalchemy.dialects.postgresql import JSONB
|
|
||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
from app.database import Base
|
from app.database import Base, JSONType
|
||||||
|
|
||||||
|
|
||||||
class BrandKnowledge(Base):
|
class BrandKnowledge(Base):
|
||||||
|
|
@ -25,7 +24,7 @@ class BrandKnowledge(Base):
|
||||||
category: Mapped[str] = mapped_column(String(50), nullable=False)
|
category: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||||
title: Mapped[str] = mapped_column(String(200), nullable=False)
|
title: Mapped[str] = mapped_column(String(200), nullable=False)
|
||||||
content: Mapped[str] = mapped_column(Text, nullable=False)
|
content: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
extra_metadata: Mapped[dict | None] = mapped_column("metadata", JSONB, nullable=True)
|
extra_metadata: Mapped[dict | None] = mapped_column("metadata", JSONType, nullable=True)
|
||||||
is_active: Mapped[bool] = mapped_column(Boolean, server_default="true", nullable=False)
|
is_active: Mapped[bool] = mapped_column(Boolean, server_default="true", nullable=False)
|
||||||
created_by: Mapped[uuid.UUID | None] = mapped_column(
|
created_by: Mapped[uuid.UUID | None] = mapped_column(
|
||||||
Uuid(as_uuid=True),
|
Uuid(as_uuid=True),
|
||||||
|
|
|
||||||
|
|
@ -3,10 +3,9 @@ from datetime import datetime
|
||||||
|
|
||||||
from sqlalchemy import String, Integer, ForeignKey, Index, func, Text
|
from sqlalchemy import String, Integer, ForeignKey, Index, func, Text
|
||||||
from sqlalchemy import Uuid
|
from sqlalchemy import Uuid
|
||||||
from sqlalchemy.dialects.postgresql import JSONB
|
|
||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
from app.database import Base
|
from app.database import Base, JSONType
|
||||||
|
|
||||||
|
|
||||||
class Content(Base):
|
class Content(Base):
|
||||||
|
|
@ -31,9 +30,9 @@ class Content(Base):
|
||||||
content_type: Mapped[str] = mapped_column(String(50), nullable=False)
|
content_type: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||||
body: Mapped[str | None] = mapped_column(Text, nullable=True)
|
body: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
status: Mapped[str] = mapped_column(String(20), server_default="draft", nullable=False)
|
status: Mapped[str] = mapped_column(String(20), server_default="draft", nullable=False)
|
||||||
target_platforms: Mapped[list | None] = mapped_column(JSONB, nullable=True)
|
target_platforms: Mapped[list | None] = mapped_column(JSONType, nullable=True)
|
||||||
keywords: Mapped[list | None] = mapped_column(JSONB, nullable=True)
|
keywords: Mapped[list | None] = mapped_column(JSONType, nullable=True)
|
||||||
extra_metadata: Mapped[dict | None] = mapped_column("metadata", JSONB, nullable=True)
|
extra_metadata: Mapped[dict | None] = mapped_column("metadata", JSONType, nullable=True)
|
||||||
created_by: Mapped[uuid.UUID | None] = mapped_column(
|
created_by: Mapped[uuid.UUID | None] = mapped_column(
|
||||||
Uuid(as_uuid=True),
|
Uuid(as_uuid=True),
|
||||||
ForeignKey("users.id", ondelete="SET NULL"),
|
ForeignKey("users.id", ondelete="SET NULL"),
|
||||||
|
|
|
||||||
|
|
@ -4,10 +4,9 @@ from datetime import datetime
|
||||||
|
|
||||||
from sqlalchemy import String, Integer, ForeignKey, Index, func, Text
|
from sqlalchemy import String, Integer, ForeignKey, Index, func, Text
|
||||||
from sqlalchemy import Uuid
|
from sqlalchemy import Uuid
|
||||||
from sqlalchemy.dialects.postgresql import JSONB
|
|
||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
from app.database import Base
|
from app.database import Base, JSONType
|
||||||
|
|
||||||
|
|
||||||
class DistributionSchedule(Base):
|
class DistributionSchedule(Base):
|
||||||
|
|
@ -29,9 +28,9 @@ class DistributionSchedule(Base):
|
||||||
ForeignKey("contents.id", ondelete="SET NULL"),
|
ForeignKey("contents.id", ondelete="SET NULL"),
|
||||||
nullable=True,
|
nullable=True,
|
||||||
)
|
)
|
||||||
platforms: Mapped[list | None] = mapped_column(JSONB, nullable=True)
|
platforms: Mapped[list | None] = mapped_column(JSONType, nullable=True)
|
||||||
"""[{platform, platform_name, scheduled_time, status}]"""
|
"""[{platform, platform_name, scheduled_time, status}]"""
|
||||||
tips: Mapped[list | None] = mapped_column(JSONB, nullable=True)
|
tips: Mapped[list | None] = mapped_column(JSONType, nullable=True)
|
||||||
status: Mapped[str] = mapped_column(String(20), server_default="pending", nullable=False)
|
status: Mapped[str] = mapped_column(String(20), server_default="pending", nullable=False)
|
||||||
created_by: Mapped[uuid.UUID | None] = mapped_column(
|
created_by: Mapped[uuid.UUID | None] = mapped_column(
|
||||||
Uuid(as_uuid=True),
|
Uuid(as_uuid=True),
|
||||||
|
|
|
||||||
|
|
@ -3,10 +3,9 @@ from datetime import datetime
|
||||||
|
|
||||||
from sqlalchemy import String, Integer, ForeignKey, Index, func, Text
|
from sqlalchemy import String, Integer, ForeignKey, Index, func, Text
|
||||||
from sqlalchemy import Uuid
|
from sqlalchemy import Uuid
|
||||||
from sqlalchemy.dialects.postgresql import JSONB
|
|
||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
from app.database import Base
|
from app.database import Base, JSONType
|
||||||
|
|
||||||
# pgvector Vector type - imported conditionally
|
# pgvector Vector type - imported conditionally
|
||||||
try:
|
try:
|
||||||
|
|
@ -92,7 +91,7 @@ class KnowledgeDocument(Base):
|
||||||
status: Mapped[str] = mapped_column(String(20), server_default="processing", nullable=False) # "processing" / "ready" / "failed"
|
status: Mapped[str] = mapped_column(String(20), server_default="processing", nullable=False) # "processing" / "ready" / "failed"
|
||||||
error_message: Mapped[str | None] = mapped_column(Text, nullable=True)
|
error_message: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
# mapped_column("metadata") to avoid SQLAlchemy reserved keyword conflict
|
# mapped_column("metadata") to avoid SQLAlchemy reserved keyword conflict
|
||||||
extra_metadata: Mapped[dict | None] = mapped_column("metadata", JSONB, nullable=True)
|
extra_metadata: Mapped[dict | None] = mapped_column("metadata", JSONType, nullable=True)
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
server_default=func.now(),
|
server_default=func.now(),
|
||||||
nullable=False,
|
nullable=False,
|
||||||
|
|
@ -152,7 +151,7 @@ class KnowledgeChunk(Base):
|
||||||
chunk_index: Mapped[int] = mapped_column(Integer, nullable=False)
|
chunk_index: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
token_count: Mapped[int] = mapped_column(Integer, server_default="0", nullable=False)
|
token_count: Mapped[int] = mapped_column(Integer, server_default="0", nullable=False)
|
||||||
# mapped_column("metadata") to avoid SQLAlchemy reserved keyword conflict
|
# mapped_column("metadata") to avoid SQLAlchemy reserved keyword conflict
|
||||||
extra_metadata: Mapped[dict | None] = mapped_column("metadata", JSONB, nullable=True)
|
extra_metadata: Mapped[dict | None] = mapped_column("metadata", JSONType, nullable=True)
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
server_default=func.now(),
|
server_default=func.now(),
|
||||||
nullable=False,
|
nullable=False,
|
||||||
|
|
@ -189,7 +188,7 @@ class KnowledgeSearchLog(Base):
|
||||||
nullable=True,
|
nullable=True,
|
||||||
)
|
)
|
||||||
query: Mapped[str] = mapped_column(Text, nullable=False)
|
query: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
knowledge_base_ids: Mapped[list | None] = mapped_column(JSONB, nullable=True)
|
knowledge_base_ids: Mapped[list | None] = mapped_column(JSONType, nullable=True)
|
||||||
results_count: Mapped[int] = mapped_column(Integer, server_default="0", nullable=False)
|
results_count: Mapped[int] = mapped_column(Integer, server_default="0", nullable=False)
|
||||||
latency_ms: Mapped[int] = mapped_column(Integer, server_default="0", nullable=False)
|
latency_ms: Mapped[int] = mapped_column(Integer, server_default="0", nullable=False)
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,175 @@
|
||||||
|
"""知识图谱数据模型"""
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
import enum
|
||||||
|
|
||||||
|
from sqlalchemy import String, Text, DateTime, ForeignKey, Index, Enum, func
|
||||||
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
|
from app.database import Base, JSONType
|
||||||
|
|
||||||
|
|
||||||
|
class EntityType(str, enum.Enum):
|
||||||
|
"""实体类型"""
|
||||||
|
ORGANIZATION = "ORGANIZATION" # 组织/公司
|
||||||
|
PRODUCT = "PRODUCT" # 产品
|
||||||
|
PERSON = "PERSON" # 人物
|
||||||
|
LOCATION = "LOCATION" # 地点
|
||||||
|
TECHNOLOGY = "TECHNOLOGY" # 技术
|
||||||
|
BRAND = "BRAND" # 品牌
|
||||||
|
EVENT = "EVENT" # 事件
|
||||||
|
CONCEPT = "CONCEPT" # 概念
|
||||||
|
OTHER = "OTHER" # 其他
|
||||||
|
|
||||||
|
|
||||||
|
class RelationType(str, enum.Enum):
|
||||||
|
"""关系类型"""
|
||||||
|
# 组织关系
|
||||||
|
COMPETES_WITH = "COMPETES_WITH" # 竞争对手
|
||||||
|
PARTNERS_WITH = "PARTNERS_WITH" # 合作伙伴
|
||||||
|
ACQUIRES = "ACQUIRES" # 收购
|
||||||
|
SUBSIDIARY_OF = "SUBSIDIARY_OF" # 子公司
|
||||||
|
|
||||||
|
# 产品关系
|
||||||
|
PRODUCES = "PRODUCES" # 生产
|
||||||
|
USES_TECHNOLOGY = "USES_TECHNOLOGY" # 使用技术
|
||||||
|
PART_OF = "PART_OF" # 属于(产品线)
|
||||||
|
|
||||||
|
# 地点关系
|
||||||
|
LOCATED_IN = "LOCATED_IN" # 位于
|
||||||
|
FOUNDED_IN = "FOUNDED_IN" # 成立于
|
||||||
|
|
||||||
|
# 人物关系
|
||||||
|
CEO_OF = "CEO_OF" # CEO
|
||||||
|
FOUNDER_OF = "FOUNDER_OF" # 创始人
|
||||||
|
|
||||||
|
# 通用关系
|
||||||
|
RELATED_TO = "RELATED_TO" # 相关
|
||||||
|
MENTIONED_IN = "MENTIONED_IN" # 提及于
|
||||||
|
ALSO_KNOWN_AS = "ALSO_KNOWN_AS" # 又名
|
||||||
|
|
||||||
|
|
||||||
|
class KnowledgeEntity(Base):
|
||||||
|
"""知识实体"""
|
||||||
|
__tablename__ = "knowledge_entities"
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
UUID(as_uuid=True),
|
||||||
|
primary_key=True,
|
||||||
|
default=uuid.uuid4,
|
||||||
|
)
|
||||||
|
knowledge_base_id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
UUID(as_uuid=True),
|
||||||
|
ForeignKey("knowledge_bases.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 实体信息
|
||||||
|
name: Mapped[str] = mapped_column(String(500), nullable=False, index=True)
|
||||||
|
entity_type: Mapped[EntityType] = mapped_column(Enum(EntityType), nullable=False, index=True)
|
||||||
|
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
|
||||||
|
# 扩展属性(JSON)
|
||||||
|
properties: Mapped[dict | None] = mapped_column(JSONType, default=dict)
|
||||||
|
|
||||||
|
# 来源信息
|
||||||
|
source_chunk_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||||
|
UUID(as_uuid=True),
|
||||||
|
ForeignKey("knowledge_chunks.id", ondelete="SET NULL"),
|
||||||
|
nullable=True,
|
||||||
|
)
|
||||||
|
confidence: Mapped[str | None] = mapped_column(String(20), nullable=True) # 置信度描述:high/medium/low
|
||||||
|
|
||||||
|
# 元数据
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
server_default=func.now(),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
server_default=func.now(),
|
||||||
|
onupdate=func.now(),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 索引
|
||||||
|
__table_args__ = (
|
||||||
|
Index("ix_entities_kb_name", "knowledge_base_id", "name"),
|
||||||
|
Index("ix_entities_kb_type", "knowledge_base_id", "entity_type"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 关系
|
||||||
|
outgoing_relations: Mapped[list["KnowledgeRelation"]] = relationship(
|
||||||
|
"KnowledgeRelation",
|
||||||
|
foreign_keys="KnowledgeRelation.source_entity_id",
|
||||||
|
back_populates="source_entity",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
)
|
||||||
|
incoming_relations: Mapped[list["KnowledgeRelation"]] = relationship(
|
||||||
|
"KnowledgeRelation",
|
||||||
|
foreign_keys="KnowledgeRelation.target_entity_id",
|
||||||
|
back_populates="target_entity",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class KnowledgeRelation(Base):
|
||||||
|
"""知识关系"""
|
||||||
|
__tablename__ = "knowledge_relations"
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
UUID(as_uuid=True),
|
||||||
|
primary_key=True,
|
||||||
|
default=uuid.uuid4,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 关系两端
|
||||||
|
source_entity_id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
UUID(as_uuid=True),
|
||||||
|
ForeignKey("knowledge_entities.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
target_entity_id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
UUID(as_uuid=True),
|
||||||
|
ForeignKey("knowledge_entities.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 关系信息
|
||||||
|
relation_type: Mapped[RelationType] = mapped_column(Enum(RelationType), nullable=False, index=True)
|
||||||
|
|
||||||
|
# 扩展属性
|
||||||
|
properties: Mapped[dict | None] = mapped_column(JSONType, default=dict)
|
||||||
|
|
||||||
|
# 来源信息
|
||||||
|
source_chunk_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||||
|
UUID(as_uuid=True),
|
||||||
|
ForeignKey("knowledge_chunks.id", ondelete="SET NULL"),
|
||||||
|
nullable=True,
|
||||||
|
)
|
||||||
|
confidence: Mapped[str | None] = mapped_column(String(20), nullable=True)
|
||||||
|
|
||||||
|
# 元数据
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
server_default=func.now(),
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 关系
|
||||||
|
source_entity: Mapped["KnowledgeEntity"] = relationship(
|
||||||
|
"KnowledgeEntity",
|
||||||
|
foreign_keys=[source_entity_id],
|
||||||
|
back_populates="outgoing_relations",
|
||||||
|
)
|
||||||
|
target_entity: Mapped["KnowledgeEntity"] = relationship(
|
||||||
|
"KnowledgeEntity",
|
||||||
|
foreign_keys=[target_entity_id],
|
||||||
|
back_populates="incoming_relations",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 索引
|
||||||
|
__table_args__ = (
|
||||||
|
Index("ix_relations_source", "source_entity_id"),
|
||||||
|
Index("ix_relations_target", "target_entity_id"),
|
||||||
|
Index("ix_relations_type", "relation_type"),
|
||||||
|
)
|
||||||
|
|
@ -3,10 +3,9 @@ from datetime import datetime
|
||||||
|
|
||||||
from sqlalchemy import String, Integer, ForeignKey, Index, func, Text
|
from sqlalchemy import String, Integer, ForeignKey, Index, func, Text
|
||||||
from sqlalchemy import Uuid
|
from sqlalchemy import Uuid
|
||||||
from sqlalchemy.dialects.postgresql import JSONB
|
|
||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
from app.database import Base
|
from app.database import Base, JSONType
|
||||||
|
|
||||||
|
|
||||||
class LifecycleProject(Base):
|
class LifecycleProject(Base):
|
||||||
|
|
@ -23,7 +22,7 @@ class LifecycleProject(Base):
|
||||||
nullable=False,
|
nullable=False,
|
||||||
)
|
)
|
||||||
brand_name: Mapped[str] = mapped_column(String(100), nullable=False)
|
brand_name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||||
brand_aliases: Mapped[list] = mapped_column(JSONB, server_default="[]", nullable=False)
|
brand_aliases: Mapped[list] = mapped_column(JSONType, server_default="[]", nullable=False)
|
||||||
current_stage: Mapped[int] = mapped_column(Integer, server_default="1", nullable=False)
|
current_stage: Mapped[int] = mapped_column(Integer, server_default="1", nullable=False)
|
||||||
status: Mapped[str] = mapped_column(String(20), server_default="active", nullable=False)
|
status: Mapped[str] = mapped_column(String(20), server_default="active", nullable=False)
|
||||||
created_by: Mapped[uuid.UUID] = mapped_column(
|
created_by: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
|
@ -77,7 +76,7 @@ class ProjectStage(Base):
|
||||||
started_at: Mapped[datetime | None] = mapped_column(nullable=True)
|
started_at: Mapped[datetime | None] = mapped_column(nullable=True)
|
||||||
completed_at: Mapped[datetime | None] = mapped_column(nullable=True)
|
completed_at: Mapped[datetime | None] = mapped_column(nullable=True)
|
||||||
notes: Mapped[str | None] = mapped_column(Text, nullable=True)
|
notes: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
metrics: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
metrics: Mapped[dict | None] = mapped_column(JSONType, nullable=True)
|
||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
project: Mapped["LifecycleProject"] = relationship(
|
project: Mapped["LifecycleProject"] = relationship(
|
||||||
|
|
|
||||||
|
|
@ -3,10 +3,9 @@ from datetime import datetime
|
||||||
|
|
||||||
from sqlalchemy import String, Boolean, ForeignKey, Index, func, Text
|
from sqlalchemy import String, Boolean, ForeignKey, Index, func, Text
|
||||||
from sqlalchemy import Uuid
|
from sqlalchemy import Uuid
|
||||||
from sqlalchemy.dialects.postgresql import JSONB
|
|
||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
from app.database import Base
|
from app.database import Base, JSONType
|
||||||
|
|
||||||
|
|
||||||
class PlatformRule(Base):
|
class PlatformRule(Base):
|
||||||
|
|
@ -21,7 +20,7 @@ class PlatformRule(Base):
|
||||||
rule_category: Mapped[str] = mapped_column(String(50), nullable=False)
|
rule_category: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||||
rule_name: Mapped[str] = mapped_column(String(200), nullable=False)
|
rule_name: Mapped[str] = mapped_column(String(200), nullable=False)
|
||||||
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
check_criteria: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
check_criteria: Mapped[dict | None] = mapped_column(JSONType, nullable=True)
|
||||||
severity: Mapped[str] = mapped_column(String(20), nullable=False)
|
severity: Mapped[str] = mapped_column(String(20), nullable=False)
|
||||||
is_active: Mapped[bool] = mapped_column(Boolean, server_default="true", nullable=False)
|
is_active: Mapped[bool] = mapped_column(Boolean, server_default="true", nullable=False)
|
||||||
updated_at: Mapped[datetime] = mapped_column(
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,13 @@
|
||||||
|
"""监控模块"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
from app.monitoring.metrics import *
|
||||||
|
from app.monitoring.middleware import MonitoringMiddleware
|
||||||
|
from app.monitoring.agent_hooks import agent_execution_context, record_agent_execution
|
||||||
|
from app.monitoring.llm_metrics import get_llm_metrics, LLMMetricsWrapper
|
||||||
|
|
||||||
|
# 设置服务信息
|
||||||
|
SERVICE_INFO.info({
|
||||||
|
"version": "1.0.0",
|
||||||
|
"environment": os.getenv("ENVIRONMENT", "development"),
|
||||||
|
})
|
||||||
|
|
@ -0,0 +1,57 @@
|
||||||
|
"""Agent执行指标钩子"""
|
||||||
|
import time
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from app.monitoring.metrics import (
|
||||||
|
AGENT_EXECUTIONS_TOTAL,
|
||||||
|
AGENT_EXECUTION_DURATION_SECONDS,
|
||||||
|
AGENT_RUNNING_TASKS,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def agent_execution_context(agent_name: str):
|
||||||
|
"""Agent执行上下文管理器 - 自动记录指标"""
|
||||||
|
# 增加运行任务计数
|
||||||
|
AGENT_RUNNING_TASKS.labels(agent_name=agent_name).inc()
|
||||||
|
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
status = "success"
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
except Exception as e:
|
||||||
|
status = "failure"
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
# 记录执行时间和状态
|
||||||
|
duration = time.perf_counter() - start_time
|
||||||
|
|
||||||
|
AGENT_EXECUTIONS_TOTAL.labels(
|
||||||
|
agent_name=agent_name,
|
||||||
|
status=status
|
||||||
|
).inc()
|
||||||
|
|
||||||
|
AGENT_EXECUTION_DURATION_SECONDS.labels(
|
||||||
|
agent_name=agent_name
|
||||||
|
).observe(duration)
|
||||||
|
|
||||||
|
# 减少运行任务计数
|
||||||
|
AGENT_RUNNING_TASKS.labels(agent_name=agent_name).dec()
|
||||||
|
|
||||||
|
|
||||||
|
def record_agent_execution(
|
||||||
|
agent_name: str,
|
||||||
|
status: str,
|
||||||
|
duration: float
|
||||||
|
):
|
||||||
|
"""手动记录Agent执行指标"""
|
||||||
|
AGENT_EXECUTIONS_TOTAL.labels(
|
||||||
|
agent_name=agent_name,
|
||||||
|
status=status
|
||||||
|
).inc()
|
||||||
|
|
||||||
|
AGENT_EXECUTION_DURATION_SECONDS.labels(
|
||||||
|
agent_name=agent_name
|
||||||
|
).observe(duration)
|
||||||
|
|
@ -0,0 +1,102 @@
|
||||||
|
"""LLM调用指标包装"""
|
||||||
|
import time
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from app.monitoring.metrics import (
|
||||||
|
LLM_REQUESTS_TOTAL,
|
||||||
|
LLM_REQUEST_DURATION_SECONDS,
|
||||||
|
LLM_TOKENS_TOTAL,
|
||||||
|
LLM_COST_ESTIMATED,
|
||||||
|
)
|
||||||
|
|
||||||
|
# LLM成本估算(USD/token)
|
||||||
|
LLM_COST_PER_TOKEN = {
|
||||||
|
# OpenAI
|
||||||
|
("openai", "gpt-4o"): {"prompt": 0.000005, "completion": 0.000015},
|
||||||
|
("openai", "gpt-4o-mini"): {"prompt": 0.00000015, "completion": 0.0000006},
|
||||||
|
("openai", "gpt-4-turbo"): {"prompt": 0.00001, "completion": 0.00003},
|
||||||
|
# DeepSeek
|
||||||
|
("deepseek", "deepseek-chat"): {"prompt": 0.00000014, "completion": 0.00000028},
|
||||||
|
("deepseek", "deepseek-coder"): {"prompt": 0.00000014, "completion": 0.00000028},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class LLMMetricsWrapper:
|
||||||
|
"""LLM调用指标包装器"""
|
||||||
|
|
||||||
|
def __init__(self, provider: str, model: str):
|
||||||
|
self.provider = provider
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
def record_request(
|
||||||
|
self,
|
||||||
|
status: str,
|
||||||
|
duration: float,
|
||||||
|
prompt_tokens: Optional[int] = None,
|
||||||
|
completion_tokens: Optional[int] = None,
|
||||||
|
):
|
||||||
|
"""记录LLM请求指标"""
|
||||||
|
# 记录请求数和耗时
|
||||||
|
LLM_REQUESTS_TOTAL.labels(
|
||||||
|
provider=self.provider,
|
||||||
|
model=self.model,
|
||||||
|
status=status
|
||||||
|
).inc()
|
||||||
|
|
||||||
|
LLM_REQUEST_DURATION_SECONDS.labels(
|
||||||
|
provider=self.provider,
|
||||||
|
model=self.model
|
||||||
|
).observe(duration)
|
||||||
|
|
||||||
|
# 记录Token消耗
|
||||||
|
if prompt_tokens is not None:
|
||||||
|
LLM_TOKENS_TOTAL.labels(
|
||||||
|
provider=self.provider,
|
||||||
|
model=self.model,
|
||||||
|
token_type="prompt"
|
||||||
|
).inc(prompt_tokens)
|
||||||
|
|
||||||
|
if completion_tokens is not None:
|
||||||
|
LLM_TOKENS_TOTAL.labels(
|
||||||
|
provider=self.provider,
|
||||||
|
model=self.model,
|
||||||
|
token_type="completion"
|
||||||
|
).inc(completion_tokens)
|
||||||
|
|
||||||
|
# 估算成本
|
||||||
|
cost = self._estimate_cost(prompt_tokens, completion_tokens)
|
||||||
|
if cost > 0:
|
||||||
|
LLM_COST_ESTIMATED.labels(
|
||||||
|
provider=self.provider,
|
||||||
|
model=self.model
|
||||||
|
).inc(cost)
|
||||||
|
|
||||||
|
def _estimate_cost(
|
||||||
|
self,
|
||||||
|
prompt_tokens: Optional[int],
|
||||||
|
completion_tokens: Optional[int]
|
||||||
|
) -> float:
|
||||||
|
"""估算请求成本"""
|
||||||
|
cost_info = LLM_COST_PER_TOKEN.get((self.provider, self.model))
|
||||||
|
if not cost_info:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
total = 0.0
|
||||||
|
if prompt_tokens:
|
||||||
|
total += prompt_tokens * cost_info["prompt"]
|
||||||
|
if completion_tokens:
|
||||||
|
total += completion_tokens * cost_info["completion"]
|
||||||
|
|
||||||
|
return total
|
||||||
|
|
||||||
|
|
||||||
|
# 全局LLM指标记录器
|
||||||
|
_llm_metrics_cache: dict[str, LLMMetricsWrapper] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def get_llm_metrics(provider: str, model: str) -> LLMMetricsWrapper:
|
||||||
|
"""获取LLM指标包装器(带缓存)"""
|
||||||
|
key = f"{provider}:{model}"
|
||||||
|
if key not in _llm_metrics_cache:
|
||||||
|
_llm_metrics_cache[key] = LLMMetricsWrapper(provider, model)
|
||||||
|
return _llm_metrics_cache[key]
|
||||||
|
|
@ -0,0 +1,97 @@
|
||||||
|
"""Prometheus指标定义"""
|
||||||
|
from prometheus_client import Counter, Histogram, Gauge, Info
|
||||||
|
|
||||||
|
# API层指标
|
||||||
|
API_REQUESTS_TOTAL = Counter(
|
||||||
|
"geo_api_requests_total",
|
||||||
|
"Total API requests",
|
||||||
|
["method", "endpoint", "status"]
|
||||||
|
)
|
||||||
|
|
||||||
|
API_REQUEST_DURATION_SECONDS = Histogram(
|
||||||
|
"geo_api_request_duration_seconds",
|
||||||
|
"API request duration in seconds",
|
||||||
|
["method", "endpoint"],
|
||||||
|
buckets=(0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0)
|
||||||
|
)
|
||||||
|
|
||||||
|
API_REQUESTS_IN_PROGRESS = Gauge(
|
||||||
|
"geo_api_requests_in_progress",
|
||||||
|
"Number of requests currently being processed",
|
||||||
|
["method", "endpoint"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Agent层指标
|
||||||
|
AGENT_EXECUTIONS_TOTAL = Counter(
|
||||||
|
"geo_agent_executions_total",
|
||||||
|
"Total agent executions",
|
||||||
|
["agent_name", "status"]
|
||||||
|
)
|
||||||
|
|
||||||
|
AGENT_EXECUTION_DURATION_SECONDS = Histogram(
|
||||||
|
"geo_agent_execution_duration_seconds",
|
||||||
|
"Agent execution duration in seconds",
|
||||||
|
["agent_name"],
|
||||||
|
buckets=(0.1, 0.5, 1.0, 5.0, 10.0, 30.0, 60.0, 120.0)
|
||||||
|
)
|
||||||
|
|
||||||
|
AGENT_RUNNING_TASKS = Gauge(
|
||||||
|
"geo_agent_running_tasks",
|
||||||
|
"Number of tasks currently running",
|
||||||
|
["agent_name"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# LLM层指标
|
||||||
|
LLM_REQUESTS_TOTAL = Counter(
|
||||||
|
"geo_llm_requests_total",
|
||||||
|
"Total LLM requests",
|
||||||
|
["provider", "model", "status"]
|
||||||
|
)
|
||||||
|
|
||||||
|
LLM_REQUEST_DURATION_SECONDS = Histogram(
|
||||||
|
"geo_llm_request_duration_seconds",
|
||||||
|
"LLM request duration in seconds",
|
||||||
|
["provider", "model"],
|
||||||
|
buckets=(0.1, 0.5, 1.0, 2.0, 5.0, 10.0, 30.0, 60.0)
|
||||||
|
)
|
||||||
|
|
||||||
|
LLM_TOKENS_TOTAL = Counter(
|
||||||
|
"geo_llm_tokens_total",
|
||||||
|
"Total LLM tokens used",
|
||||||
|
["provider", "model", "token_type"]
|
||||||
|
)
|
||||||
|
|
||||||
|
LLM_COST_ESTIMATED = Gauge(
|
||||||
|
"geo_llm_cost_estimated",
|
||||||
|
"Estimated LLM cost in USD",
|
||||||
|
["provider", "model"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# 业务层指标
|
||||||
|
BRAND_COUNT = Gauge(
|
||||||
|
"geo_brands_total",
|
||||||
|
"Total number of brands"
|
||||||
|
)
|
||||||
|
|
||||||
|
QUERY_COUNT_TOTAL = Counter(
|
||||||
|
"geo_queries_total",
|
||||||
|
"Total number of queries executed",
|
||||||
|
["platform", "status"]
|
||||||
|
)
|
||||||
|
|
||||||
|
CONTENT_GENERATED_TOTAL = Counter(
|
||||||
|
"geo_content_generated_total",
|
||||||
|
"Total content generated"
|
||||||
|
)
|
||||||
|
|
||||||
|
CITATION_DETECTED_TOTAL = Counter(
|
||||||
|
"geo_citations_detected_total",
|
||||||
|
"Total citations detected",
|
||||||
|
["platform"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# 系统信息
|
||||||
|
SERVICE_INFO = Info(
|
||||||
|
"geo_service",
|
||||||
|
"GEO service information"
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,86 @@
|
||||||
|
"""监控中间件"""
|
||||||
|
import time
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
from fastapi import Request, Response
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
|
||||||
|
from app.monitoring.metrics import (
|
||||||
|
API_REQUESTS_TOTAL,
|
||||||
|
API_REQUEST_DURATION_SECONDS,
|
||||||
|
API_REQUESTS_IN_PROGRESS,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 需要排除的路径(不记录指标)
|
||||||
|
EXCLUDED_PATHS = {"/health", "/ready", "/metrics", "/docs", "/openapi.json"}
|
||||||
|
|
||||||
|
|
||||||
|
class MonitoringMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""API监控中间件"""
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||||
|
# 跳过排除路径
|
||||||
|
if request.url.path in EXCLUDED_PATHS:
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
# 提取端点标识(用于指标标签)
|
||||||
|
endpoint = self._get_endpoint_label(request)
|
||||||
|
|
||||||
|
# 增加活跃请求计数
|
||||||
|
API_REQUESTS_IN_PROGRESS.labels(
|
||||||
|
method=request.method,
|
||||||
|
endpoint=endpoint
|
||||||
|
).inc()
|
||||||
|
|
||||||
|
# 记录开始时间
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 执行请求
|
||||||
|
response = await call_next(request)
|
||||||
|
status_code = response.status_code
|
||||||
|
except Exception as e:
|
||||||
|
status_code = 500
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
# 计算耗时
|
||||||
|
duration = time.perf_counter() - start_time
|
||||||
|
|
||||||
|
# 记录指标
|
||||||
|
API_REQUESTS_TOTAL.labels(
|
||||||
|
method=request.method,
|
||||||
|
endpoint=endpoint,
|
||||||
|
status=str(status_code)
|
||||||
|
).inc()
|
||||||
|
|
||||||
|
API_REQUEST_DURATION_SECONDS.labels(
|
||||||
|
method=request.method,
|
||||||
|
endpoint=endpoint
|
||||||
|
).observe(duration)
|
||||||
|
|
||||||
|
# 减少活跃请求计数
|
||||||
|
API_REQUESTS_IN_PROGRESS.labels(
|
||||||
|
method=request.method,
|
||||||
|
endpoint=endpoint
|
||||||
|
).dec()
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
def _get_endpoint_label(self, request: Request) -> str:
|
||||||
|
"""提取端点标签"""
|
||||||
|
path = request.url.path
|
||||||
|
|
||||||
|
# 规范化路径(替换ID等参数)
|
||||||
|
parts = path.strip("/").split("/")
|
||||||
|
|
||||||
|
# 处理常见模式:/api/v1/resources/{id}
|
||||||
|
if len(parts) >= 4 and parts[0] == "api":
|
||||||
|
resource = parts[2] if len(parts) > 2 else "unknown"
|
||||||
|
action = parts[3] if len(parts) > 3 else "list"
|
||||||
|
|
||||||
|
# 映射到规范标签
|
||||||
|
if action.isdigit():
|
||||||
|
return f"{resource}_detail"
|
||||||
|
return f"{resource}_{action}"
|
||||||
|
|
||||||
|
return "other"
|
||||||
|
|
@ -74,3 +74,27 @@ class ChunkPreview(BaseModel):
|
||||||
token_count: int
|
token_count: int
|
||||||
|
|
||||||
model_config = {"from_attributes": True}
|
model_config = {"from_attributes": True}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- 增量索引 Schemas ----------
|
||||||
|
|
||||||
|
class UpdateDocumentRequest(BaseModel):
|
||||||
|
"""更新文档内容请求"""
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
class RetrieveRequest(BaseModel):
|
||||||
|
"""增强检索请求"""
|
||||||
|
query: str = Field(..., min_length=1)
|
||||||
|
top_k: Optional[int] = Field(default=5, ge=1, le=50)
|
||||||
|
use_rerank: Optional[bool] = Field(default=True)
|
||||||
|
use_compression: Optional[bool] = Field(default=False)
|
||||||
|
|
||||||
|
|
||||||
|
class RebuildResponse(BaseModel):
|
||||||
|
"""重建索引响应"""
|
||||||
|
total: int
|
||||||
|
processed: int
|
||||||
|
skipped: int
|
||||||
|
failed: int
|
||||||
|
errors: list[dict]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,54 @@
|
||||||
|
"""应用状态管理"""
|
||||||
|
import os
|
||||||
|
import platform
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
class AppState:
|
||||||
|
"""应用状态"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.start_time = time.time()
|
||||||
|
self.platform = platform.system()
|
||||||
|
self.python_version = platform.python_version()
|
||||||
|
self.version = os.getenv("APP_VERSION", "1.0.0")
|
||||||
|
self.environment = os.getenv("ENVIRONMENT", "development")
|
||||||
|
|
||||||
|
def get_uptime_seconds(self) -> float:
|
||||||
|
"""获取运行时间(秒)"""
|
||||||
|
return time.time() - self.start_time
|
||||||
|
|
||||||
|
def get_uptime_formatted(self) -> str:
|
||||||
|
"""获取格式化的运行时间"""
|
||||||
|
seconds = self.get_uptime_seconds()
|
||||||
|
|
||||||
|
days = int(seconds // 86400)
|
||||||
|
hours = int((seconds % 86400) // 3600)
|
||||||
|
minutes = int((seconds % 3600) // 60)
|
||||||
|
secs = int(seconds % 60)
|
||||||
|
|
||||||
|
parts = []
|
||||||
|
if days > 0:
|
||||||
|
parts.append(f"{days}天")
|
||||||
|
if hours > 0:
|
||||||
|
parts.append(f"{hours}小时")
|
||||||
|
if minutes > 0:
|
||||||
|
parts.append(f"{minutes}分钟")
|
||||||
|
parts.append(f"{secs}秒")
|
||||||
|
|
||||||
|
return "".join(parts)
|
||||||
|
|
||||||
|
def get_info(self) -> dict:
|
||||||
|
"""获取应用信息"""
|
||||||
|
return {
|
||||||
|
"version": self.version,
|
||||||
|
"environment": self.environment,
|
||||||
|
"platform": self.platform,
|
||||||
|
"python_version": self.python_version,
|
||||||
|
"uptime_seconds": round(self.get_uptime_seconds(), 2),
|
||||||
|
"uptime_formatted": self.get_uptime_formatted(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# 全局实例
|
||||||
|
app_state = AppState()
|
||||||
|
|
@ -0,0 +1,362 @@
|
||||||
|
"""GEO内容8大母题库定义"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TopicTemplate:
|
||||||
|
"""母题模板"""
|
||||||
|
id: str # product_comparison
|
||||||
|
name: str # 产品对比
|
||||||
|
description: str # 描述
|
||||||
|
icon: str # emoji图标
|
||||||
|
prompt_template: str # Prompt模板
|
||||||
|
seo_tips: list[str] # SEO技巧
|
||||||
|
recommended_platforms: list[str] # 推荐平台
|
||||||
|
word_count_range: tuple[int, int] # 推荐字数范围
|
||||||
|
required_params: list[str] # 必填参数
|
||||||
|
optional_params: list[str] # 可选参数
|
||||||
|
|
||||||
|
TOPIC_TEMPLATES = {
|
||||||
|
# 1. 产品对比
|
||||||
|
"product_comparison": TopicTemplate(
|
||||||
|
id="product_comparison",
|
||||||
|
name="产品对比",
|
||||||
|
description="品牌与竞品的功能、价格、体验对比",
|
||||||
|
icon="⚖️",
|
||||||
|
prompt_template="""
|
||||||
|
请基于以下信息生成一篇产品对比文章:
|
||||||
|
|
||||||
|
【你的品牌】{brand_name}
|
||||||
|
【对比品牌】{competitor_name}
|
||||||
|
【对比维度】{comparison_dimensions}
|
||||||
|
【目标关键词】{keywords}
|
||||||
|
【内容风格】{content_style}
|
||||||
|
|
||||||
|
要求:
|
||||||
|
1. 客观中立,不刻意贬低竞品
|
||||||
|
2. 突出自身优势但有理有据
|
||||||
|
3. 适合AI平台引用(结构清晰、数据支撑)
|
||||||
|
4. 包含对比表格或图表
|
||||||
|
5. 字数:约{word_count}字
|
||||||
|
|
||||||
|
文章结构:
|
||||||
|
1. 引言:简要介绍两个产品
|
||||||
|
2. 外观设计对比
|
||||||
|
3. 功能特性对比
|
||||||
|
4. 性价比对比
|
||||||
|
5. 适用场景分析
|
||||||
|
6. 结论:总结优劣,给出建议
|
||||||
|
""",
|
||||||
|
seo_tips=[
|
||||||
|
"标题包含品牌名+对比+关键词",
|
||||||
|
"使用对比表格增强可读性",
|
||||||
|
"自然嵌入目标关键词",
|
||||||
|
],
|
||||||
|
recommended_platforms=["知乎", "百家号", "公众号"],
|
||||||
|
word_count_range=(1500, 3000),
|
||||||
|
required_params=["brand_name", "competitor_name", "comparison_dimensions", "keywords"],
|
||||||
|
optional_params=["content_style", "word_count"],
|
||||||
|
),
|
||||||
|
|
||||||
|
# 2. 使用指南
|
||||||
|
"how_to_guide": TopicTemplate(
|
||||||
|
id="how_to_guide",
|
||||||
|
name="使用指南",
|
||||||
|
description="产品使用方法教程,帮助用户解决问题",
|
||||||
|
icon="📖",
|
||||||
|
prompt_template="""
|
||||||
|
请基于以下信息生成一篇使用指南文章:
|
||||||
|
|
||||||
|
【产品名称】{product_name}
|
||||||
|
【核心功能】{core_features}
|
||||||
|
【目标用户】{target_audience}
|
||||||
|
【目标关键词】{keywords}
|
||||||
|
【内容风格】{content_style}
|
||||||
|
|
||||||
|
要求:
|
||||||
|
1. 步骤清晰,操作性强
|
||||||
|
2. 图文并茂(描述需要的图片类型)
|
||||||
|
3. 适合搜索引擎收录
|
||||||
|
4. 适合AI平台引用
|
||||||
|
5. 字数:约{word_count}字
|
||||||
|
|
||||||
|
文章结构:
|
||||||
|
1. 引言:介绍产品价值
|
||||||
|
2. 基础设置/准备工作
|
||||||
|
3. 核心功能使用步骤(分点详述)
|
||||||
|
4. 进阶技巧/常见问题
|
||||||
|
5. 总结与进阶学习建议
|
||||||
|
""",
|
||||||
|
seo_tips=[
|
||||||
|
"标题包含产品名+使用方法/教程",
|
||||||
|
"使用有序列表标注步骤",
|
||||||
|
"包含FAQ解答常见问题",
|
||||||
|
],
|
||||||
|
recommended_platforms=["知乎", "小红书", "简书"],
|
||||||
|
word_count_range=(1000, 2500),
|
||||||
|
required_params=["product_name", "core_features", "keywords"],
|
||||||
|
optional_params=["target_audience", "content_style", "word_count"],
|
||||||
|
),
|
||||||
|
|
||||||
|
# 3. 行业趋势
|
||||||
|
"industry_trends": TopicTemplate(
|
||||||
|
id="industry_trends",
|
||||||
|
name="行业趋势",
|
||||||
|
description="行业动态、发展方向和未来预测分析",
|
||||||
|
icon="📈",
|
||||||
|
prompt_template="""
|
||||||
|
请基于以下信息生成一篇行业趋势分析文章:
|
||||||
|
|
||||||
|
【行业名称】{industry_name}
|
||||||
|
【品牌视角】{brand_perspective}
|
||||||
|
【分析维度】{analysis_dimensions}
|
||||||
|
【目标关键词】{keywords}
|
||||||
|
【内容风格】{content_style}
|
||||||
|
|
||||||
|
要求:
|
||||||
|
1. 数据支撑,引用权威来源
|
||||||
|
2. 趋势分析有理有据
|
||||||
|
3. 适合AI平台引用(观点鲜明、数据翔实)
|
||||||
|
4. 字数:约{word_count}字
|
||||||
|
|
||||||
|
文章结构:
|
||||||
|
1. 引言:行业现状概述
|
||||||
|
2. 核心趋势分析(3-5个趋势)
|
||||||
|
3. 趋势背后的驱动因素
|
||||||
|
4. 品牌如何把握趋势
|
||||||
|
5. 展望与建议
|
||||||
|
""",
|
||||||
|
seo_tips=[
|
||||||
|
"标题包含行业名+趋势/预测",
|
||||||
|
"引用数据增加可信度",
|
||||||
|
"使用时间线展示演变",
|
||||||
|
],
|
||||||
|
recommended_platforms=["知乎", "百家号", "公众号"],
|
||||||
|
word_count_range=(2000, 4000),
|
||||||
|
required_params=["industry_name", "brand_perspective", "keywords"],
|
||||||
|
optional_params=["analysis_dimensions", "content_style", "word_count"],
|
||||||
|
),
|
||||||
|
|
||||||
|
# 4. 专家观点
|
||||||
|
"expert_opinion": TopicTemplate(
|
||||||
|
id="expert_opinion",
|
||||||
|
name="专家观点",
|
||||||
|
description="行业专家见解和专业分析",
|
||||||
|
icon="💡",
|
||||||
|
prompt_template="""
|
||||||
|
请基于以下信息生成一篇专家观点文章:
|
||||||
|
|
||||||
|
【主题】{topic}
|
||||||
|
【专家身份】{expert_identity}
|
||||||
|
【核心观点】{core_opinion}
|
||||||
|
【目标关键词】{keywords}
|
||||||
|
【内容风格】{content_style}
|
||||||
|
|
||||||
|
要求:
|
||||||
|
1. 观点鲜明,论证充分
|
||||||
|
2. 结合实际案例
|
||||||
|
3. 适合AI平台引用
|
||||||
|
4. 字数:约{word_count}字
|
||||||
|
|
||||||
|
文章结构:
|
||||||
|
1. 引言:抛出核心观点
|
||||||
|
2. 观点详细阐述
|
||||||
|
3. 案例支撑
|
||||||
|
4. 对行业的影响
|
||||||
|
5. 结论与建议
|
||||||
|
""",
|
||||||
|
seo_tips=[
|
||||||
|
"标题体现专家身份+观点",
|
||||||
|
"使用引用格式突出核心观点",
|
||||||
|
"增加互动性问题",
|
||||||
|
],
|
||||||
|
recommended_platforms=["知乎", "公众号", "微博"],
|
||||||
|
word_count_range=(1500, 3000),
|
||||||
|
required_params=["topic", "expert_identity", "core_opinion", "keywords"],
|
||||||
|
optional_params=["content_style", "word_count"],
|
||||||
|
),
|
||||||
|
|
||||||
|
# 5. 案例研究
|
||||||
|
"case_study": TopicTemplate(
|
||||||
|
id="case_study",
|
||||||
|
name="案例研究",
|
||||||
|
description="成功案例深度分析",
|
||||||
|
icon="🏆",
|
||||||
|
prompt_template="""
|
||||||
|
请基于以下信息生成一篇案例研究文章:
|
||||||
|
|
||||||
|
【案例名称】{case_name}
|
||||||
|
【企业背景】{company_background}
|
||||||
|
【核心成果】{core_results}
|
||||||
|
【成功要素】{success_factors}
|
||||||
|
【目标关键词】{keywords}
|
||||||
|
【内容风格】{content_style}
|
||||||
|
|
||||||
|
要求:
|
||||||
|
1. 故事性强,有代入感
|
||||||
|
2. 数据支撑成果
|
||||||
|
3. 可复制的方法论
|
||||||
|
4. 适合AI平台引用
|
||||||
|
5. 字数:约{word_count}字
|
||||||
|
|
||||||
|
文章结构:
|
||||||
|
1. 案例背景介绍
|
||||||
|
2. 面临的挑战
|
||||||
|
3. 解决方案详述
|
||||||
|
4. 成果数据展示
|
||||||
|
5. 成功经验总结
|
||||||
|
6. 可复制建议
|
||||||
|
""",
|
||||||
|
seo_tips=[
|
||||||
|
"标题包含行业+案例+成果",
|
||||||
|
"使用数据图表展示成果",
|
||||||
|
"突出可复制的关键步骤",
|
||||||
|
],
|
||||||
|
recommended_platforms=["知乎", "百家号", "公众号"],
|
||||||
|
word_count_range=(2000, 4000),
|
||||||
|
required_params=["case_name", "company_background", "core_results", "keywords"],
|
||||||
|
optional_params=["success_factors", "content_style", "word_count"],
|
||||||
|
),
|
||||||
|
|
||||||
|
# 6. 数据报告
|
||||||
|
"data_report": TopicTemplate(
|
||||||
|
id="data_report",
|
||||||
|
name="数据报告",
|
||||||
|
description="行业数据和分析报告",
|
||||||
|
icon="📊",
|
||||||
|
prompt_template="""
|
||||||
|
请基于以下信息生成一篇数据报告文章:
|
||||||
|
|
||||||
|
【报告主题】{report_topic}
|
||||||
|
【数据范围】{data_scope}
|
||||||
|
【核心发现】{core_findings}
|
||||||
|
【目标关键词】{keywords}
|
||||||
|
【内容风格】{content_style}
|
||||||
|
|
||||||
|
要求:
|
||||||
|
1. 数据翔实,来源可靠
|
||||||
|
2. 图表丰富,直观易懂
|
||||||
|
3. 深度分析,洞察本质
|
||||||
|
4. 适合AI平台引用
|
||||||
|
5. 字数:约{word_count}字
|
||||||
|
|
||||||
|
文章结构:
|
||||||
|
1. 研究背景与方法
|
||||||
|
2. 核心数据发现
|
||||||
|
3. 深度分析
|
||||||
|
4. 趋势解读
|
||||||
|
5. 建议与展望
|
||||||
|
""",
|
||||||
|
seo_tips=[
|
||||||
|
"标题包含数据+报告+关键词",
|
||||||
|
"使用图表增强说服力",
|
||||||
|
"包含数据来源说明",
|
||||||
|
],
|
||||||
|
recommended_platforms=["知乎", "百家号", "公众号"],
|
||||||
|
word_count_range=(3000, 5000),
|
||||||
|
required_params=["report_topic", "data_scope", "core_findings", "keywords"],
|
||||||
|
optional_params=["content_style", "word_count"],
|
||||||
|
),
|
||||||
|
|
||||||
|
# 7. 问题解答
|
||||||
|
"faq": TopicTemplate(
|
||||||
|
id="faq",
|
||||||
|
name="问题解答",
|
||||||
|
description="常见问题解答",
|
||||||
|
icon="❓",
|
||||||
|
prompt_template="""
|
||||||
|
请基于以下信息生成一篇FAQ文章:
|
||||||
|
|
||||||
|
【主题】{topic}
|
||||||
|
【目标用户】{target_audience}
|
||||||
|
【问题列表】{questions}
|
||||||
|
【目标关键词】{keywords}
|
||||||
|
【内容风格】{content_style}
|
||||||
|
|
||||||
|
要求:
|
||||||
|
1. 问题覆盖面广,击中痛点
|
||||||
|
2. 回答简洁明了,实操性强
|
||||||
|
3. 适合搜索引擎长尾词
|
||||||
|
4. 适合AI平台引用
|
||||||
|
5. 字数:约{word_count}字
|
||||||
|
|
||||||
|
文章结构:
|
||||||
|
1. 引言:为什么需要了解这些
|
||||||
|
2. 核心问题解答(每个问题独立成段)
|
||||||
|
3. 延伸问题
|
||||||
|
4. 总结:下一步建议
|
||||||
|
""",
|
||||||
|
seo_tips=[
|
||||||
|
"标题包含主题+常见问题/FAQ",
|
||||||
|
"每个问题使用H2标签",
|
||||||
|
"覆盖长尾搜索词",
|
||||||
|
],
|
||||||
|
recommended_platforms=["知乎", "简书", "小程序"],
|
||||||
|
word_count_range=(1000, 2000),
|
||||||
|
required_params=["topic", "questions", "keywords"],
|
||||||
|
optional_params=["target_audience", "content_style", "word_count"],
|
||||||
|
),
|
||||||
|
|
||||||
|
# 8. 评测报告
|
||||||
|
"review": TopicTemplate(
|
||||||
|
id="review",
|
||||||
|
name="评测报告",
|
||||||
|
description="产品深度评测:优缺点全面分析",
|
||||||
|
icon="🔬",
|
||||||
|
prompt_template="""
|
||||||
|
请基于以下信息生成一篇评测报告:
|
||||||
|
|
||||||
|
【产品名称】{product_name}
|
||||||
|
【评测维度】{review_dimensions}
|
||||||
|
【竞品参照】{competitor_reference}
|
||||||
|
【目标关键词】{keywords}
|
||||||
|
【内容风格】{content_style}
|
||||||
|
|
||||||
|
要求:
|
||||||
|
1. 客观公正,优点不夸大
|
||||||
|
2. 缺点实事求是
|
||||||
|
3. 有具体数据和场景支撑
|
||||||
|
4. 适合AI平台引用
|
||||||
|
5. 字数:约{word_count}字
|
||||||
|
|
||||||
|
文章结构:
|
||||||
|
1. 评测背景与方法
|
||||||
|
2. 外观设计评价
|
||||||
|
3. 核心功能评测
|
||||||
|
4. 性能测试数据
|
||||||
|
5. 优缺点总结
|
||||||
|
6. 适合人群分析
|
||||||
|
7. 购买建议
|
||||||
|
""",
|
||||||
|
seo_tips=[
|
||||||
|
"标题包含产品名+评测/测评",
|
||||||
|
"使用评分表格直观展示",
|
||||||
|
"包含具体测试数据",
|
||||||
|
],
|
||||||
|
recommended_platforms=["知乎", "小红书", "百家号"],
|
||||||
|
word_count_range=(2000, 4000),
|
||||||
|
required_params=["product_name", "review_dimensions", "keywords"],
|
||||||
|
optional_params=["competitor_reference", "content_style", "word_count"],
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_topic_template(topic_id: str) -> Optional[TopicTemplate]:
|
||||||
|
"""获取母题模板"""
|
||||||
|
return TOPIC_TEMPLATES.get(topic_id)
|
||||||
|
|
||||||
|
def list_topic_templates() -> list[TopicTemplate]:
|
||||||
|
"""列出所有母题模板"""
|
||||||
|
return list(TOPIC_TEMPLATES.values())
|
||||||
|
|
||||||
|
def render_topic_prompt(topic_id: str, params: dict) -> str:
|
||||||
|
"""渲染母题Prompt"""
|
||||||
|
template = get_topic_template(topic_id)
|
||||||
|
if not template:
|
||||||
|
raise ValueError(f"Unknown topic: {topic_id}")
|
||||||
|
|
||||||
|
# 设置默认参数
|
||||||
|
params.setdefault("content_style", "专业严谨")
|
||||||
|
params.setdefault("word_count", 2000)
|
||||||
|
|
||||||
|
return template.prompt_template.format(**params)
|
||||||
|
|
@ -106,6 +106,11 @@ PLATFORM_RULES: dict[str, dict] = {
|
||||||
"不得使用AI水文(内容需有信息增量)",
|
"不得使用AI水文(内容需有信息增量)",
|
||||||
"专业背书内容需有据可查",
|
"专业背书内容需有据可查",
|
||||||
],
|
],
|
||||||
|
"image_rules": {
|
||||||
|
"cover": {"width": 690, "height": 280, "ratio": "2.5:1"},
|
||||||
|
"inline": {"width": 500, "height": 375, "ratio": "4:3"},
|
||||||
|
"max_size_mb": 5,
|
||||||
|
},
|
||||||
"seo_tips": [
|
"seo_tips": [
|
||||||
"回答开头直接给出结论",
|
"回答开头直接给出结论",
|
||||||
"使用数据和案例支撑",
|
"使用数据和案例支撑",
|
||||||
|
|
@ -186,6 +191,11 @@ PLATFORM_RULES: dict[str, dict] = {
|
||||||
"正文不含外部链接(仅支持公众号链接和小程序)",
|
"正文不含外部链接(仅支持公众号链接和小程序)",
|
||||||
"不得包含未经授权的商标/品牌名称",
|
"不得包含未经授权的商标/品牌名称",
|
||||||
],
|
],
|
||||||
|
"image_rules": {
|
||||||
|
"cover": {"width": 900, "height": 383, "ratio": "2.35:1"},
|
||||||
|
"inline": {"width": 800, "height": 600, "ratio": "4:3"},
|
||||||
|
"max_size_mb": 10,
|
||||||
|
},
|
||||||
"seo_tips": [
|
"seo_tips": [
|
||||||
"首段包含核心关键词",
|
"首段包含核心关键词",
|
||||||
"使用小标题分段(适配搜一搜)",
|
"使用小标题分段(适配搜一搜)",
|
||||||
|
|
@ -262,6 +272,11 @@ PLATFORM_RULES: dict[str, dict] = {
|
||||||
"正文需含至少1张配图",
|
"正文需含至少1张配图",
|
||||||
"不得搬运/洗稿",
|
"不得搬运/洗稿",
|
||||||
],
|
],
|
||||||
|
"image_rules": {
|
||||||
|
"cover": {"width": 600, "height": 400, "ratio": "3:2"},
|
||||||
|
"inline": {"width": 800, "height": 600, "ratio": "4:3"},
|
||||||
|
"max_size_mb": 5,
|
||||||
|
},
|
||||||
"seo_tips": [
|
"seo_tips": [
|
||||||
"标题包含百度搜索热词",
|
"标题包含百度搜索热词",
|
||||||
"文章结构化(H2小标题)",
|
"文章结构化(H2小标题)",
|
||||||
|
|
@ -341,6 +356,11 @@ PLATFORM_RULES: dict[str, dict] = {
|
||||||
"首发原创优先推荐",
|
"首发原创优先推荐",
|
||||||
"配图清晰不模糊",
|
"配图清晰不模糊",
|
||||||
],
|
],
|
||||||
|
"image_rules": {
|
||||||
|
"cover": {"width": 1024, "height": 678, "ratio": "1.5:1"},
|
||||||
|
"inline": {"width": 800, "height": 600, "ratio": "4:3"},
|
||||||
|
"max_size_mb": 5,
|
||||||
|
},
|
||||||
"seo_tips": [
|
"seo_tips": [
|
||||||
"标题含核心关键词",
|
"标题含核心关键词",
|
||||||
"文章1500字以上推荐更高",
|
"文章1500字以上推荐更高",
|
||||||
|
|
@ -417,6 +437,11 @@ PLATFORM_RULES: dict[str, dict] = {
|
||||||
"话题标签有助于曝光",
|
"话题标签有助于曝光",
|
||||||
"配图有助于转发",
|
"配图有助于转发",
|
||||||
],
|
],
|
||||||
|
"image_rules": {
|
||||||
|
"cover": {"width": 980, "height": 560, "ratio": "1.75:1"},
|
||||||
|
"inline": {"width": 800, "height": 600, "ratio": "4:3"},
|
||||||
|
"max_size_mb": 5,
|
||||||
|
},
|
||||||
"seo_tips": [
|
"seo_tips": [
|
||||||
"热门话题可增加曝光",
|
"热门话题可增加曝光",
|
||||||
"短句更易阅读",
|
"短句更易阅读",
|
||||||
|
|
@ -493,6 +518,11 @@ PLATFORM_RULES: dict[str, dict] = {
|
||||||
"不得出现其他平台引流信息",
|
"不得出现其他平台引流信息",
|
||||||
"图片不含水印",
|
"图片不含水印",
|
||||||
],
|
],
|
||||||
|
"image_rules": {
|
||||||
|
"cover": {"width": 1080, "height": 1080, "ratio": "1:1"},
|
||||||
|
"inline": {"width": 1242, "height": 1660, "ratio": "3:4"},
|
||||||
|
"max_size_mb": 10,
|
||||||
|
},
|
||||||
"seo_tips": [
|
"seo_tips": [
|
||||||
"标题含数字更吸引点击",
|
"标题含数字更吸引点击",
|
||||||
"正文用短句+emoji分段",
|
"正文用短句+emoji分段",
|
||||||
|
|
@ -572,6 +602,11 @@ PLATFORM_RULES: dict[str, dict] = {
|
||||||
"封面和标题很重要",
|
"封面和标题很重要",
|
||||||
"互动有助于推荐",
|
"互动有助于推荐",
|
||||||
],
|
],
|
||||||
|
"image_rules": {
|
||||||
|
"cover": {"width": 1920, "height": 1080, "ratio": "16:9"},
|
||||||
|
"inline": {"width": 800, "height": 600, "ratio": "4:3"},
|
||||||
|
"max_size_mb": 5,
|
||||||
|
},
|
||||||
"seo_tips": [
|
"seo_tips": [
|
||||||
"标题包含关键词",
|
"标题包含关键词",
|
||||||
"封面吸引人",
|
"封面吸引人",
|
||||||
|
|
@ -646,6 +681,11 @@ PLATFORM_RULES: dict[str, dict] = {
|
||||||
"文艺风格更受欢迎",
|
"文艺风格更受欢迎",
|
||||||
"配图有助于阅读",
|
"配图有助于阅读",
|
||||||
],
|
],
|
||||||
|
"image_rules": {
|
||||||
|
"cover": {"width": 800, "height": 600, "ratio": "4:3"},
|
||||||
|
"inline": {"width": 800, "height": 600, "ratio": "4:3"},
|
||||||
|
"max_size_mb": 5,
|
||||||
|
},
|
||||||
"seo_tips": [
|
"seo_tips": [
|
||||||
"标题包含关键词",
|
"标题包含关键词",
|
||||||
"合理使用专题",
|
"合理使用专题",
|
||||||
|
|
@ -721,6 +761,11 @@ PLATFORM_RULES: dict[str, dict] = {
|
||||||
"鼓励原创技术文章",
|
"鼓励原创技术文章",
|
||||||
"禁止低质量搬运",
|
"禁止低质量搬运",
|
||||||
],
|
],
|
||||||
|
"image_rules": {
|
||||||
|
"cover": {"width": 1024, "height": 768, "ratio": "4:3"},
|
||||||
|
"inline": {"width": 800, "height": 600, "ratio": "4:3"},
|
||||||
|
"max_size_mb": 5,
|
||||||
|
},
|
||||||
"seo_tips": [
|
"seo_tips": [
|
||||||
"标题包含技术关键词",
|
"标题包含技术关键词",
|
||||||
"代码块有助于阅读",
|
"代码块有助于阅读",
|
||||||
|
|
@ -796,6 +841,11 @@ PLATFORM_RULES: dict[str, dict] = {
|
||||||
"话题标签2-5个",
|
"话题标签2-5个",
|
||||||
"文案简短有吸引力",
|
"文案简短有吸引力",
|
||||||
],
|
],
|
||||||
|
"image_rules": {
|
||||||
|
"cover": {"width": 1080, "height": 1920, "ratio": "9:16"},
|
||||||
|
"inline": {"width": 1080, "height": 1920, "ratio": "9:16"},
|
||||||
|
"max_size_mb": 10,
|
||||||
|
},
|
||||||
"seo_tips": [
|
"seo_tips": [
|
||||||
"前3秒决定完播率",
|
"前3秒决定完播率",
|
||||||
"标题含热点关键词",
|
"标题含热点关键词",
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,382 @@
|
||||||
|
"""
|
||||||
|
邮件通知服务
|
||||||
|
|
||||||
|
支持发送告警通知、额度预警等邮件。
|
||||||
|
|
||||||
|
功能:
|
||||||
|
- 邮件模板引擎: 变量替换渲染邮件内容
|
||||||
|
- 邮件内容生成: 告警通知、额度预警邮件生成
|
||||||
|
- 邮件发送: 支持真实SMTP和模拟模式
|
||||||
|
- 邮件队列管理: 批量添加和发送
|
||||||
|
- 错误处理和重试: 自动重试机制
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
import smtplib
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from email.mime.multipart import MIMEMultipart
|
||||||
|
from email.mime.text import MIMEText
|
||||||
|
from email.mime.base import MIMEBase
|
||||||
|
from email import encoders
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
EMAIL_TEMPLATES = {
|
||||||
|
"alert_notification": {
|
||||||
|
"subject": "[GEO平台] 告警通知:{alert_type}",
|
||||||
|
"body_html": """
|
||||||
|
<h2>告警通知</h2>
|
||||||
|
<p>品牌:<strong>{brand_name}</strong></p>
|
||||||
|
<p>告警类型:{alert_type}</p>
|
||||||
|
<p>严重程度:{severity}</p>
|
||||||
|
<p>详情:{description}</p>
|
||||||
|
<p>时间:{timestamp}</p>
|
||||||
|
""",
|
||||||
|
"body_text": "告警通知 - 品牌:{brand_name}, 类型:{alert_type}, severity:{severity}"
|
||||||
|
},
|
||||||
|
"quota_warning": {
|
||||||
|
"subject": "[GEO平台] 额度预警:{quota_type}",
|
||||||
|
"body_html": """
|
||||||
|
<h2>额度预警</h2>
|
||||||
|
<p>您的{quota_type}使用量已达到{usage_percentage}%</p>
|
||||||
|
<p>已使用:{used} / 总额度:{limit}</p>
|
||||||
|
<p>建议操作:{recommended_action}</p>
|
||||||
|
""",
|
||||||
|
"body_text": "额度预警 - {quota_type}使用量:{usage_percentage}%"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EmailMessage:
|
||||||
|
to: str
|
||||||
|
subject: str
|
||||||
|
body_html: str
|
||||||
|
body_text: str
|
||||||
|
attachments: list[dict] = field(default_factory=list)
|
||||||
|
metadata: dict = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EmailSendResult:
|
||||||
|
success: bool
|
||||||
|
message_id: str | None
|
||||||
|
error: str | None
|
||||||
|
retry_count: int
|
||||||
|
|
||||||
|
|
||||||
|
class EmailService:
|
||||||
|
"""邮件通知服务
|
||||||
|
|
||||||
|
提供邮件模板渲染、内容生成、发送(支持模拟模式)、队列管理等功能。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
simulate_mode: bool = True,
|
||||||
|
smtp_host: str = "localhost",
|
||||||
|
smtp_port: int = 587,
|
||||||
|
smtp_user: str = "",
|
||||||
|
smtp_password: str = "",
|
||||||
|
max_retries: int = 3,
|
||||||
|
):
|
||||||
|
self.simulate_mode = simulate_mode
|
||||||
|
self.smtp_host = smtp_host
|
||||||
|
self.smtp_port = smtp_port
|
||||||
|
self.smtp_user = smtp_user
|
||||||
|
self.smtp_password = smtp_password
|
||||||
|
self.max_retries = max_retries
|
||||||
|
self._queue: list[EmailMessage] = []
|
||||||
|
|
||||||
|
def validate_email(self, email: str) -> bool:
|
||||||
|
"""验证邮箱地址格式
|
||||||
|
|
||||||
|
Args:
|
||||||
|
email: 邮箱地址
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否为有效邮箱地址
|
||||||
|
"""
|
||||||
|
if not email:
|
||||||
|
return False
|
||||||
|
pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
|
||||||
|
return bool(re.match(pattern, email))
|
||||||
|
|
||||||
|
def _safe_format(self, template: str, variables: dict[str, Any]) -> str:
|
||||||
|
"""安全格式化模板,缺失变量时保留原占位符
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template: 模板字符串
|
||||||
|
variables: 变量字典
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
格式化后的字符串
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return template.format(**variables)
|
||||||
|
except KeyError:
|
||||||
|
result = template
|
||||||
|
for key, value in variables.items():
|
||||||
|
result = result.replace("{" + key + "}", str(value))
|
||||||
|
return result
|
||||||
|
|
||||||
|
def render_template(
|
||||||
|
self,
|
||||||
|
template_name: str,
|
||||||
|
to: str,
|
||||||
|
variables: dict[str, Any],
|
||||||
|
) -> EmailMessage:
|
||||||
|
"""渲染邮件模板
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template_name: 模板名称
|
||||||
|
to: 收件人邮箱
|
||||||
|
variables: 模板变量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
EmailMessage邮件消息对象
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 模板不存在
|
||||||
|
"""
|
||||||
|
if template_name not in EMAIL_TEMPLATES:
|
||||||
|
raise ValueError(f"模板不存在: {template_name}")
|
||||||
|
|
||||||
|
template = EMAIL_TEMPLATES[template_name]
|
||||||
|
|
||||||
|
subject = self._safe_format(template["subject"], variables)
|
||||||
|
body_html = self._safe_format(template["body_html"], variables)
|
||||||
|
body_text = self._safe_format(template["body_text"], variables)
|
||||||
|
|
||||||
|
return EmailMessage(
|
||||||
|
to=to,
|
||||||
|
subject=subject,
|
||||||
|
body_html=body_html,
|
||||||
|
body_text=body_text,
|
||||||
|
metadata=variables,
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate_alert_email(
|
||||||
|
self,
|
||||||
|
to: str,
|
||||||
|
alert_type: str,
|
||||||
|
brand_name: str,
|
||||||
|
severity: str,
|
||||||
|
description: str,
|
||||||
|
timestamp: str,
|
||||||
|
) -> EmailMessage:
|
||||||
|
"""生成告警通知邮件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
to: 收件人邮箱
|
||||||
|
alert_type: 告警类型
|
||||||
|
brand_name: 品牌名称
|
||||||
|
severity: 严重程度
|
||||||
|
description: 告警详情
|
||||||
|
timestamp: 时间戳
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
EmailMessage邮件消息对象
|
||||||
|
"""
|
||||||
|
variables = {
|
||||||
|
"alert_type": alert_type,
|
||||||
|
"brand_name": brand_name,
|
||||||
|
"severity": severity,
|
||||||
|
"description": description,
|
||||||
|
"timestamp": timestamp,
|
||||||
|
}
|
||||||
|
return self.render_template("alert_notification", to, variables)
|
||||||
|
|
||||||
|
def generate_quota_warning_email(
|
||||||
|
self,
|
||||||
|
to: str,
|
||||||
|
quota_type: str,
|
||||||
|
usage_percentage: int,
|
||||||
|
used: int,
|
||||||
|
limit: int,
|
||||||
|
recommended_action: str,
|
||||||
|
) -> EmailMessage:
|
||||||
|
"""生成额度预警邮件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
to: 收件人邮箱
|
||||||
|
quota_type: 额度类型
|
||||||
|
usage_percentage: 使用百分比
|
||||||
|
used: 已使用量
|
||||||
|
limit: 总额度
|
||||||
|
recommended_action: 建议操作
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
EmailMessage邮件消息对象
|
||||||
|
"""
|
||||||
|
variables = {
|
||||||
|
"quota_type": quota_type,
|
||||||
|
"usage_percentage": usage_percentage,
|
||||||
|
"used": used,
|
||||||
|
"limit": limit,
|
||||||
|
"recommended_action": recommended_action,
|
||||||
|
}
|
||||||
|
return self.render_template("quota_warning", to, variables)
|
||||||
|
|
||||||
|
def add_attachment(self, msg: EmailMessage, filename: str, content: bytes) -> None:
|
||||||
|
"""添加附件到邮件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
msg: 邮件消息对象
|
||||||
|
filename: 附件文件名
|
||||||
|
content: 附件内容
|
||||||
|
"""
|
||||||
|
msg.attachments.append({
|
||||||
|
"filename": filename,
|
||||||
|
"content": content,
|
||||||
|
})
|
||||||
|
|
||||||
|
def _create_mime_message(self, msg: EmailMessage) -> MIMEMultipart:
|
||||||
|
"""创建MIME邮件对象
|
||||||
|
|
||||||
|
Args:
|
||||||
|
msg: EmailMessage对象
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MIMEMultipart MIME邮件对象
|
||||||
|
"""
|
||||||
|
mime_msg = MIMEMultipart()
|
||||||
|
mime_msg["From"] = self.smtp_user
|
||||||
|
mime_msg["To"] = msg.to
|
||||||
|
mime_msg["Subject"] = msg.subject
|
||||||
|
|
||||||
|
mime_msg.attach(MIMEText(msg.body_html, "html", "utf-8"))
|
||||||
|
mime_msg.attach(MIMEText(msg.body_text, "plain", "utf-8"))
|
||||||
|
|
||||||
|
for attachment in msg.attachments:
|
||||||
|
part = MIMEBase("application", "octet-stream")
|
||||||
|
part.set_payload(attachment["content"])
|
||||||
|
encoders.encode_base64(part)
|
||||||
|
part.add_header(
|
||||||
|
"Content-Disposition",
|
||||||
|
f'attachment; filename="{attachment["filename"]}"',
|
||||||
|
)
|
||||||
|
mime_msg.attach(part)
|
||||||
|
|
||||||
|
return mime_msg
|
||||||
|
|
||||||
|
def send_email(self, msg: EmailMessage) -> EmailSendResult:
|
||||||
|
"""发送邮件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
msg: EmailMessage邮件消息对象
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
EmailSendResult发送结果对象
|
||||||
|
"""
|
||||||
|
if not self.validate_email(msg.to):
|
||||||
|
logger.warning(f"无效的邮箱地址: {msg.to}")
|
||||||
|
return EmailSendResult(
|
||||||
|
success=False,
|
||||||
|
message_id=None,
|
||||||
|
error=f"无效的邮箱地址: {msg.to}",
|
||||||
|
retry_count=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.simulate_mode:
|
||||||
|
message_id = f"sim_{uuid.uuid4().hex[:8]}"
|
||||||
|
logger.info(f"[模拟模式] 邮件发送成功: {msg.to}, ID: {message_id}")
|
||||||
|
return EmailSendResult(
|
||||||
|
success=True,
|
||||||
|
message_id=message_id,
|
||||||
|
error=None,
|
||||||
|
retry_count=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
retry_count = 0
|
||||||
|
last_error = None
|
||||||
|
|
||||||
|
while retry_count <= self.max_retries:
|
||||||
|
try:
|
||||||
|
logger.info(f"尝试发送邮件到 {msg.to} (尝试 {retry_count + 1}/{self.max_retries + 1})")
|
||||||
|
server = smtplib.SMTP(self.smtp_host, self.smtp_port)
|
||||||
|
server.starttls()
|
||||||
|
server.login(self.smtp_user, self.smtp_password)
|
||||||
|
|
||||||
|
mime_msg = self._create_mime_message(msg)
|
||||||
|
server.send_message(mime_msg)
|
||||||
|
server.quit()
|
||||||
|
|
||||||
|
message_id = f"smtp_{uuid.uuid4().hex[:8]}"
|
||||||
|
logger.info(f"邮件发送成功: {msg.to}, ID: {message_id}")
|
||||||
|
return EmailSendResult(
|
||||||
|
success=True,
|
||||||
|
message_id=message_id,
|
||||||
|
error=None,
|
||||||
|
retry_count=retry_count,
|
||||||
|
)
|
||||||
|
except smtplib.SMTPException as e:
|
||||||
|
last_error = f"SMTP错误: {str(e)}"
|
||||||
|
logger.error(f"SMTP错误 (尝试 {retry_count + 1}/{self.max_retries + 1}): {e}")
|
||||||
|
retry_count += 1
|
||||||
|
if retry_count <= self.max_retries:
|
||||||
|
time.sleep(1 * retry_count)
|
||||||
|
except Exception as e:
|
||||||
|
last_error = str(e)
|
||||||
|
logger.error(f"邮件发送失败 (尝试 {retry_count + 1}/{self.max_retries + 1}): {e}")
|
||||||
|
retry_count += 1
|
||||||
|
if retry_count <= self.max_retries:
|
||||||
|
time.sleep(1 * retry_count)
|
||||||
|
|
||||||
|
logger.error(f"邮件发送最终失败,已重试 {retry_count} 次: {msg.to}, 错误: {last_error}")
|
||||||
|
return EmailSendResult(
|
||||||
|
success=False,
|
||||||
|
message_id=None,
|
||||||
|
error=last_error,
|
||||||
|
retry_count=retry_count,
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_to_queue(self, msg: EmailMessage) -> None:
|
||||||
|
"""添加邮件到队列
|
||||||
|
|
||||||
|
Args:
|
||||||
|
msg: EmailMessage邮件消息对象
|
||||||
|
"""
|
||||||
|
self._queue.append(msg)
|
||||||
|
logger.debug(f"邮件已添加到队列: {msg.to}, 队列长度: {len(self._queue)}")
|
||||||
|
|
||||||
|
def get_queue(self) -> list[EmailMessage]:
|
||||||
|
"""获取队列中的邮件
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
邮件消息列表
|
||||||
|
"""
|
||||||
|
return self._queue.copy()
|
||||||
|
|
||||||
|
def clear_queue(self) -> None:
|
||||||
|
"""清空队列"""
|
||||||
|
count = len(self._queue)
|
||||||
|
self._queue.clear()
|
||||||
|
logger.info(f"队列已清空,移除了 {count} 封邮件")
|
||||||
|
|
||||||
|
def send_queue(self) -> list[EmailSendResult]:
|
||||||
|
"""发送队列中的所有邮件
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
发送结果列表
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
messages = self._queue.copy()
|
||||||
|
self._queue.clear()
|
||||||
|
|
||||||
|
logger.info(f"开始发送队列中的 {len(messages)} 封邮件")
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
result = self.send_email(msg)
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
success_count = sum(1 for r in results if r.success)
|
||||||
|
logger.info(f"队列发送完成: 成功 {success_count}/{len(results)}")
|
||||||
|
|
||||||
|
return results
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,186 @@
|
||||||
|
"""详细健康检查服务"""
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import redis.asyncio as aioredis
|
||||||
|
from sqlalchemy import text
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HealthCheckResult:
|
||||||
|
"""健康检查结果"""
|
||||||
|
name: str
|
||||||
|
healthy: bool
|
||||||
|
latency_ms: Optional[float] = None
|
||||||
|
message: Optional[str] = None
|
||||||
|
details: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
|
class HealthChecker:
|
||||||
|
"""健康检查服务"""
|
||||||
|
|
||||||
|
def __init__(self, db: AsyncSession, redis_url: str):
|
||||||
|
self.db = db
|
||||||
|
self.redis_url = redis_url
|
||||||
|
|
||||||
|
async def check_database(self) -> HealthCheckResult:
|
||||||
|
"""检查数据库连接"""
|
||||||
|
start = time.perf_counter()
|
||||||
|
try:
|
||||||
|
await self.db.execute(text("SELECT 1"))
|
||||||
|
latency = (time.perf_counter() - start) * 1000
|
||||||
|
|
||||||
|
return HealthCheckResult(
|
||||||
|
name="database",
|
||||||
|
healthy=True,
|
||||||
|
latency_ms=round(latency, 2),
|
||||||
|
message="Connection OK",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
latency = (time.perf_counter() - start) * 1000
|
||||||
|
return HealthCheckResult(
|
||||||
|
name="database",
|
||||||
|
healthy=False,
|
||||||
|
latency_ms=round(latency, 2),
|
||||||
|
message=f"Connection failed: {str(e)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def check_redis(self) -> HealthCheckResult:
|
||||||
|
"""检查Redis连接"""
|
||||||
|
start = time.perf_counter()
|
||||||
|
try:
|
||||||
|
redis = aioredis.from_url(
|
||||||
|
self.redis_url,
|
||||||
|
socket_connect_timeout=2,
|
||||||
|
)
|
||||||
|
await redis.ping()
|
||||||
|
await redis.aclose()
|
||||||
|
|
||||||
|
latency = (time.perf_counter() - start) * 1000
|
||||||
|
return HealthCheckResult(
|
||||||
|
name="redis",
|
||||||
|
healthy=True,
|
||||||
|
latency_ms=round(latency, 2),
|
||||||
|
message="Connection OK",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
latency = (time.perf_counter() - start) * 1000
|
||||||
|
return HealthCheckResult(
|
||||||
|
name="redis",
|
||||||
|
healthy=False,
|
||||||
|
latency_ms=round(latency, 2),
|
||||||
|
message=f"Connection failed: {str(e)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def check_llm_providers(self) -> HealthCheckResult:
|
||||||
|
"""检查LLM服务提供商"""
|
||||||
|
from app.config import settings
|
||||||
|
from app.services.llm.factory import LLMFactory
|
||||||
|
|
||||||
|
providers = {}
|
||||||
|
all_healthy = True
|
||||||
|
|
||||||
|
# 检查默认provider
|
||||||
|
try:
|
||||||
|
provider_name = getattr(settings, 'DEFAULT_LLM_PROVIDER', 'openai')
|
||||||
|
provider = LLMFactory.create(provider_name)
|
||||||
|
providers[provider_name] = {
|
||||||
|
"healthy": True,
|
||||||
|
"available": True,
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
providers[getattr(settings, 'DEFAULT_LLM_PROVIDER', 'openai')] = {
|
||||||
|
"healthy": False,
|
||||||
|
"error": str(e),
|
||||||
|
}
|
||||||
|
all_healthy = False
|
||||||
|
|
||||||
|
# 检查所有已注册的provider
|
||||||
|
for name in LLMFactory.list_providers():
|
||||||
|
if name not in providers:
|
||||||
|
try:
|
||||||
|
provider = LLMFactory.create(name)
|
||||||
|
providers[name] = {
|
||||||
|
"healthy": True,
|
||||||
|
"available": True,
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
providers[name] = {
|
||||||
|
"healthy": False,
|
||||||
|
"error": str(e),
|
||||||
|
}
|
||||||
|
all_healthy = False
|
||||||
|
|
||||||
|
return HealthCheckResult(
|
||||||
|
name="llm_providers",
|
||||||
|
healthy=all_healthy,
|
||||||
|
message="All providers healthy" if all_healthy else "Some providers unhealthy",
|
||||||
|
details={"providers": providers},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def check_storage(self) -> HealthCheckResult:
|
||||||
|
"""检查存储(本地文件系统)"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
storage_path = "/data/documents"
|
||||||
|
|
||||||
|
try:
|
||||||
|
if os.path.exists(storage_path):
|
||||||
|
# 检查读写权限
|
||||||
|
test_file = os.path.join(storage_path, ".health_check")
|
||||||
|
with open(test_file, "w") as f:
|
||||||
|
f.write("ok")
|
||||||
|
os.remove(test_file)
|
||||||
|
|
||||||
|
return HealthCheckResult(
|
||||||
|
name="storage",
|
||||||
|
healthy=True,
|
||||||
|
message=f"Storage path {storage_path} is writable",
|
||||||
|
details={"path": storage_path},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return HealthCheckResult(
|
||||||
|
name="storage",
|
||||||
|
healthy=True,
|
||||||
|
message=f"Storage path {storage_path} does not exist (will be created)",
|
||||||
|
details={"path": storage_path, "created": True},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
return HealthCheckResult(
|
||||||
|
name="storage",
|
||||||
|
healthy=False,
|
||||||
|
message=f"Storage check failed: {str(e)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def check_all(self) -> dict:
|
||||||
|
"""执行所有健康检查"""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
# 并行执行所有检查
|
||||||
|
checks = [
|
||||||
|
self.check_database(),
|
||||||
|
self.check_redis(),
|
||||||
|
self.check_llm_providers(),
|
||||||
|
self.check_storage(),
|
||||||
|
]
|
||||||
|
|
||||||
|
results = await asyncio.gather(*checks)
|
||||||
|
|
||||||
|
# 汇总结果
|
||||||
|
all_healthy = all(r.healthy for r in results)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "healthy" if all_healthy else "degraded",
|
||||||
|
"timestamp": time.time(),
|
||||||
|
"checks": {
|
||||||
|
r.name: {
|
||||||
|
"healthy": r.healthy,
|
||||||
|
"latency_ms": r.latency_ms,
|
||||||
|
"message": r.message,
|
||||||
|
"details": r.details,
|
||||||
|
}
|
||||||
|
for r in results
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,312 @@
|
||||||
|
"""阿里云百炼图片生成服务"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
# 平台尺寸适配
|
||||||
|
PLATFORM_IMAGE_SPECS = {
|
||||||
|
"zhihu": {
|
||||||
|
"cover": {"width": 690, "height": 280, "ratio": "2.5:1"},
|
||||||
|
"inline": {"width": 500, "height": 375, "ratio": "4:3"},
|
||||||
|
},
|
||||||
|
"wechat": {
|
||||||
|
"cover": {"width": 900, "height": 383, "ratio": "2.35:1"},
|
||||||
|
"inline": {"width": 800, "height": 600, "ratio": "4:3"},
|
||||||
|
},
|
||||||
|
"xiaohongshu": {
|
||||||
|
"cover": {"width": 1080, "height": 1080, "ratio": "1:1"}, # 方版
|
||||||
|
"inline": {"width": 1242, "height": 1660, "ratio": "3:4"}, # 竖版
|
||||||
|
},
|
||||||
|
"toutiao": {
|
||||||
|
"cover": {"width": 1024, "height": 678, "ratio": "1.5:1"},
|
||||||
|
},
|
||||||
|
"baijiahao": {
|
||||||
|
"cover": {"width": 600, "height": 400, "ratio": "3:2"},
|
||||||
|
},
|
||||||
|
"weibo": {
|
||||||
|
"cover": {"width": 980, "height": 560, "ratio": "1.75:1"},
|
||||||
|
},
|
||||||
|
"bilibili": {
|
||||||
|
"cover": {"width": 1920, "height": 1080, "ratio": "16:9"},
|
||||||
|
},
|
||||||
|
"jianshu": {
|
||||||
|
"cover": {"width": 800, "height": 600, "ratio": "4:3"},
|
||||||
|
},
|
||||||
|
"juejin": {
|
||||||
|
"cover": {"width": 1024, "height": 768, "ratio": "4:3"},
|
||||||
|
},
|
||||||
|
"douyin": {
|
||||||
|
"cover": {"width": 1080, "height": 1920, "ratio": "9:16"}, # 竖版短视频封面
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# 风格选项
|
||||||
|
IMAGE_STYLES = {
|
||||||
|
"modern": {"name": "现代简约", "prompt": "modern minimalist style, clean design, professional"},
|
||||||
|
"tech": {"name": "科技感", "prompt": "tech style, futuristic, digital, blue tones"},
|
||||||
|
"elegant": {"name": "优雅商务", "prompt": "elegant business style, sophisticated, premium"},
|
||||||
|
"creative": {"name": "创意活力", "prompt": "creative vibrant style, colorful, dynamic"},
|
||||||
|
"minimal": {"name": "极简主义", "prompt": "ultra minimal, white space, typography focus"},
|
||||||
|
}
|
||||||
|
|
||||||
|
# 排版选项
|
||||||
|
LAYOUT_OPTIONS = {
|
||||||
|
"centered": {"name": "居中排版", "prompt": "centered composition, text in middle"},
|
||||||
|
"left_text": {"name": "左文右图", "prompt": "left side text, right side visual"},
|
||||||
|
"top_text": {"name": "上文下图", "prompt": "text on top, visual below"},
|
||||||
|
"text_overlay": {"name": "文字叠加", "prompt": "text overlay on background image"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ImageResult:
|
||||||
|
"""图片生成结果"""
|
||||||
|
url: str
|
||||||
|
width: int
|
||||||
|
height: int
|
||||||
|
prompt: str
|
||||||
|
platform: str
|
||||||
|
task_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class ImageGenerationError(Exception):
|
||||||
|
"""图片生成异常"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ImageGenerator:
|
||||||
|
"""阿里云百炼图片生成服务(万相-文生图V1)"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.api_key = os.getenv("ALIYUN_DASHSCOPE_API_KEY")
|
||||||
|
self.base_url = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text2image/image-synthesis"
|
||||||
|
self.timeout = 120.0 # 异步任务等待超时时间(秒)
|
||||||
|
|
||||||
|
async def generate_cover(
|
||||||
|
self,
|
||||||
|
title: str,
|
||||||
|
platform: str,
|
||||||
|
image_type: str = "cover",
|
||||||
|
style: str = "modern",
|
||||||
|
layout: str = "centered",
|
||||||
|
custom_prompt: str = None,
|
||||||
|
) -> ImageResult:
|
||||||
|
"""生成封面图
|
||||||
|
|
||||||
|
Args:
|
||||||
|
title: 文章标题
|
||||||
|
platform: 目标平台
|
||||||
|
image_type: 图片类型 (cover/inline)
|
||||||
|
style: 风格选项
|
||||||
|
layout: 排版选项
|
||||||
|
custom_prompt: 自定义提示词(可选)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ImageResult: 包含生成结果的 dataclass
|
||||||
|
"""
|
||||||
|
# 1. 获取平台尺寸
|
||||||
|
specs = PLATFORM_IMAGE_SPECS.get(platform, PLATFORM_IMAGE_SPECS["zhihu"])
|
||||||
|
size_spec = specs.get(image_type, specs["cover"])
|
||||||
|
|
||||||
|
# 2. 构建提示词
|
||||||
|
if custom_prompt:
|
||||||
|
prompt = custom_prompt
|
||||||
|
else:
|
||||||
|
prompt = self._build_prompt(title, platform, style, layout)
|
||||||
|
|
||||||
|
# 3. 调用百炼API(异步)
|
||||||
|
task_id = await self._create_task(prompt, size_spec)
|
||||||
|
|
||||||
|
# 4. 轮询结果
|
||||||
|
result = await self._wait_for_result(task_id)
|
||||||
|
|
||||||
|
return ImageResult(
|
||||||
|
url=result["image_url"],
|
||||||
|
width=size_spec["width"],
|
||||||
|
height=size_spec["height"],
|
||||||
|
prompt=prompt,
|
||||||
|
platform=platform,
|
||||||
|
task_id=task_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _build_prompt(self, title: str, platform: str, style: str, layout: str) -> str:
|
||||||
|
"""构建AI提示词
|
||||||
|
|
||||||
|
Args:
|
||||||
|
title: 文章标题
|
||||||
|
platform: 目标平台
|
||||||
|
style: 风格选项
|
||||||
|
layout: 排版选项
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 构造的英文提示词
|
||||||
|
"""
|
||||||
|
style_prompt = IMAGE_STYLES.get(style, IMAGE_STYLES["modern"])["prompt"]
|
||||||
|
layout_prompt = LAYOUT_OPTIONS.get(layout, LAYOUT_OPTIONS["centered"])["prompt"]
|
||||||
|
|
||||||
|
# 平台特定要求
|
||||||
|
platform_notes = {
|
||||||
|
"xiaohongshu": "warm tones, lifestyle, lifestyle photography",
|
||||||
|
"wechat": "professional, clean, suitable for WeChat article cover",
|
||||||
|
"zhihu": "intellectual, professional, suitable for long-form content",
|
||||||
|
"toutiao": "eye-catching, clear hierarchy, news style",
|
||||||
|
"baijiahao": "clear, professional, suitable for news media",
|
||||||
|
"weibo": "social media friendly, engaging",
|
||||||
|
"bilibili": "anime style friendly, vibrant, suitable for video platform",
|
||||||
|
"jianshu": "literary, elegant, clean layout",
|
||||||
|
"juejin": "tech blog style, developer friendly",
|
||||||
|
"douyin": "vertical video style, eye-catching, short form content",
|
||||||
|
}
|
||||||
|
platform_note = platform_notes.get(platform, "")
|
||||||
|
|
||||||
|
return f"{title}, {style_prompt}, {layout_prompt}, {platform_note}, high quality, 4K"
|
||||||
|
|
||||||
|
async def _create_task(self, prompt: str, size_spec: dict) -> str:
|
||||||
|
"""创建异步任务
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: 英文提示词
|
||||||
|
size_spec: 尺寸规格
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 任务ID
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ImageGenerationError: API调用失败时
|
||||||
|
"""
|
||||||
|
if not self.api_key:
|
||||||
|
raise ImageGenerationError("ALIYUN_DASHSCOPE_API_KEY 未设置")
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": "wanx-v1", # 万相-文生图V1
|
||||||
|
"input": {
|
||||||
|
"prompt": prompt,
|
||||||
|
},
|
||||||
|
"parameters": {
|
||||||
|
"size": f"{size_spec['width']}x{size_spec['height']}",
|
||||||
|
"n": 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
|
try:
|
||||||
|
response = await client.post(
|
||||||
|
self.base_url,
|
||||||
|
headers=headers,
|
||||||
|
json=payload,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
if data.get("code"):
|
||||||
|
raise ImageGenerationError(f"API错误: {data.get('message', data.get('code'))}")
|
||||||
|
|
||||||
|
# 返回任务ID
|
||||||
|
task_id = data.get("output", {}).get("task_id")
|
||||||
|
if not task_id:
|
||||||
|
raise ImageGenerationError("未获取到任务ID")
|
||||||
|
|
||||||
|
return task_id
|
||||||
|
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
raise ImageGenerationError(f"HTTP错误: {e.response.status_code}")
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
raise ImageGenerationError(f"请求错误: {e}")
|
||||||
|
|
||||||
|
async def _wait_for_result(self, task_id: str) -> dict:
|
||||||
|
"""轮询等待结果
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: 任务ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 包含 image_url 的结果
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ImageGenerationError: 任务失败或超时
|
||||||
|
"""
|
||||||
|
status_url = f"{self.base_url}/task/{task_id}"
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
}
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||||
|
import asyncio
|
||||||
|
start_time = asyncio.get_event_loop().time()
|
||||||
|
poll_interval = 2.0 # 轮询间隔(秒)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
response = await client.get(status_url, headers=headers)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
status = data.get("output", {}).get("task_status")
|
||||||
|
|
||||||
|
if status == "succeeded":
|
||||||
|
# 任务成功
|
||||||
|
images = data.get("output", {}).get("results", [])
|
||||||
|
if images:
|
||||||
|
return {"image_url": images[0].get("url")}
|
||||||
|
raise ImageGenerationError("未获取到生成图片URL")
|
||||||
|
|
||||||
|
elif status == "failed":
|
||||||
|
error_msg = data.get("output", {}).get("message", "任务失败")
|
||||||
|
raise ImageGenerationError(f"图片生成失败: {error_msg}")
|
||||||
|
|
||||||
|
elif status == "pending":
|
||||||
|
# 等待中,继续轮询
|
||||||
|
pass
|
||||||
|
|
||||||
|
else:
|
||||||
|
# 未知状态,继续等待
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 检查超时
|
||||||
|
elapsed = asyncio.get_event_loop().time() - start_time
|
||||||
|
if elapsed > self.timeout:
|
||||||
|
raise ImageGenerationError("图片生成超时")
|
||||||
|
|
||||||
|
# 等待后继续轮询
|
||||||
|
await asyncio.sleep(poll_interval)
|
||||||
|
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
raise ImageGenerationError(f"HTTP错误: {e.response.status_code}")
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
raise ImageGenerationError(f"请求错误: {e}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_platform_specs(platform: str) -> Optional[dict]:
|
||||||
|
"""获取平台的图片规格
|
||||||
|
|
||||||
|
Args:
|
||||||
|
platform: 平台标识
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[dict]: 平台图片规格,如果没有则返回None
|
||||||
|
"""
|
||||||
|
return PLATFORM_IMAGE_SPECS.get(platform)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_supported_platforms() -> list[str]:
|
||||||
|
"""获取支持的平台列表"""
|
||||||
|
return list(PLATFORM_IMAGE_SPECS.keys())
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_styles() -> dict:
|
||||||
|
"""获取所有风格选项"""
|
||||||
|
return IMAGE_STYLES
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_layouts() -> dict:
|
||||||
|
"""获取所有排版选项"""
|
||||||
|
return LAYOUT_OPTIONS
|
||||||
|
|
@ -2,5 +2,21 @@ from .rag_service import RAGService
|
||||||
from .chunker import RecursiveChunker
|
from .chunker import RecursiveChunker
|
||||||
from .embedder import EmbeddingService, OpenAIEmbedder, MockEmbedder
|
from .embedder import EmbeddingService, OpenAIEmbedder, MockEmbedder
|
||||||
from .retriever import HybridRetriever
|
from .retriever import HybridRetriever
|
||||||
|
from .entity_extractor import EntityExtractor, ExtractionResult, ExtractedEntity, ExtractedRelation
|
||||||
|
from .graph_builder import GraphBuilder
|
||||||
|
from .graph_query import GraphQuery
|
||||||
|
|
||||||
__all__ = ["RAGService", "RecursiveChunker", "EmbeddingService", "OpenAIEmbedder", "MockEmbedder", "HybridRetriever"]
|
__all__ = [
|
||||||
|
"RAGService",
|
||||||
|
"RecursiveChunker",
|
||||||
|
"EmbeddingService",
|
||||||
|
"OpenAIEmbedder",
|
||||||
|
"MockEmbedder",
|
||||||
|
"HybridRetriever",
|
||||||
|
"EntityExtractor",
|
||||||
|
"ExtractionResult",
|
||||||
|
"ExtractedEntity",
|
||||||
|
"ExtractedRelation",
|
||||||
|
"GraphBuilder",
|
||||||
|
"GraphQuery",
|
||||||
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,178 +1,212 @@
|
||||||
"""
|
"""
|
||||||
RecursiveChunker: 递归语义分块器
|
分块策略 - 支持多种分块方式
|
||||||
按优先级分隔符(段落→句子→词)将文档切割为适合embedding的块。
|
|
||||||
"""
|
"""
|
||||||
import re
|
import re
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChunkStrategy:
|
||||||
|
"""分块策略配置"""
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
chunk_size: int # 字符数
|
||||||
|
chunk_overlap: int # 重叠字符数
|
||||||
|
min_chunk_size: int
|
||||||
|
|
||||||
class RecursiveChunker:
|
class BaseChunker(ABC):
|
||||||
"""递归语义分块器"""
|
"""分块器基类"""
|
||||||
|
|
||||||
def __init__(
|
STRATEGY: ChunkStrategy = None
|
||||||
self,
|
|
||||||
chunk_size: int = 512,
|
@abstractmethod
|
||||||
chunk_overlap: int = 50,
|
|
||||||
min_chunk_size: int = 100,
|
|
||||||
):
|
|
||||||
self.chunk_size = chunk_size
|
|
||||||
self.chunk_overlap = chunk_overlap
|
|
||||||
self.min_chunk_size = min_chunk_size
|
|
||||||
# 分隔符优先级:段落 > 句子 > 词
|
|
||||||
self.separators = ["\n\n", "\n", "。", ".", "!", "!", "?", "?", ";", ";", " "]
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# Public API
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
|
|
||||||
def chunk(self, text: str, metadata: Optional[dict] = None) -> list[dict]:
|
def chunk(self, text: str, metadata: Optional[dict] = None) -> list[dict]:
|
||||||
"""
|
"""执行分块"""
|
||||||
将文本递归分块。
|
pass
|
||||||
|
|
||||||
|
def preview(self, text: str, max_chunks: int = 5) -> list[str]:
|
||||||
|
"""预览分块结果"""
|
||||||
|
chunks = self.chunk(text)
|
||||||
|
return [c["content"][:200] + "..." if len(c["content"]) > 200 else c["content"]
|
||||||
|
for c in chunks[:max_chunks]]
|
||||||
|
|
||||||
Returns:
|
class RecursiveChunker(BaseChunker):
|
||||||
list of dicts:
|
"""递归分块器(现有实现)"""
|
||||||
{
|
|
||||||
"content": str,
|
STRATEGY = ChunkStrategy(
|
||||||
"chunk_index": int,
|
name="recursive",
|
||||||
"token_count": int,
|
description="优先按段落分割,过长时按句子分割",
|
||||||
"metadata": dict,
|
chunk_size=500,
|
||||||
}
|
chunk_overlap=50,
|
||||||
"""
|
min_chunk_size=50,
|
||||||
if not text or not text.strip():
|
)
|
||||||
return []
|
|
||||||
|
# 分割模式(按优先级)
|
||||||
raw_chunks = self._split_recursive(text.strip(), self.separators)
|
SEPARATORS = [
|
||||||
|
r"\n\n+", # 双换行(段落)
|
||||||
# 合并过短的块 & 添加重叠
|
r"\n", # 单换行
|
||||||
merged = self._merge_small_chunks(raw_chunks)
|
r"[。!?!?]\s*", # 句子结束
|
||||||
result = []
|
r"[,,;;]\s*", # 分句
|
||||||
for idx, content in enumerate(merged):
|
r"\s+", # 空格
|
||||||
result.append(
|
]
|
||||||
{
|
|
||||||
"content": content,
|
def chunk(self, text: str, metadata: Optional[dict] = None) -> list[dict]:
|
||||||
"chunk_index": idx,
|
|
||||||
"token_count": self._estimate_tokens(content),
|
|
||||||
"metadata": metadata or {},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# Internal helpers
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
|
|
||||||
def _split_recursive(self, text: str, separators: list[str]) -> list[str]:
|
|
||||||
"""
|
|
||||||
递归分割:尝试当前分隔符,若块太大则用下一级分隔符继续。
|
|
||||||
"""
|
|
||||||
if not separators:
|
|
||||||
# 最后手段:按字符强制截断
|
|
||||||
return self._hard_split(text)
|
|
||||||
|
|
||||||
sep = separators[0]
|
|
||||||
remaining_seps = separators[1:]
|
|
||||||
|
|
||||||
# 先用当前分隔符分割
|
|
||||||
splits = text.split(sep)
|
|
||||||
# 去掉空串,但保留分隔符语义(拼回去)
|
|
||||||
splits = [s for s in splits if s.strip()]
|
|
||||||
|
|
||||||
if len(splits) <= 1:
|
|
||||||
# 该分隔符无法分割,尝试下一级
|
|
||||||
return self._split_recursive(text, remaining_seps)
|
|
||||||
|
|
||||||
chunks: list[str] = []
|
|
||||||
current_buffer = ""
|
|
||||||
|
|
||||||
for piece in splits:
|
|
||||||
candidate = (current_buffer + sep + piece).strip() if current_buffer else piece.strip()
|
|
||||||
if self._estimate_tokens(candidate) <= self.chunk_size:
|
|
||||||
current_buffer = candidate
|
|
||||||
else:
|
|
||||||
# 当前 buffer 达到 chunk_size
|
|
||||||
if current_buffer:
|
|
||||||
# buffer 本身是否太大?若是,递归细分
|
|
||||||
if self._estimate_tokens(current_buffer) > self.chunk_size:
|
|
||||||
chunks.extend(self._split_recursive(current_buffer, remaining_seps))
|
|
||||||
else:
|
|
||||||
chunks.append(current_buffer)
|
|
||||||
# piece 单独处理
|
|
||||||
if self._estimate_tokens(piece) > self.chunk_size:
|
|
||||||
chunks.extend(self._split_recursive(piece, remaining_seps))
|
|
||||||
current_buffer = ""
|
|
||||||
else:
|
|
||||||
current_buffer = piece.strip()
|
|
||||||
|
|
||||||
if current_buffer:
|
|
||||||
if self._estimate_tokens(current_buffer) > self.chunk_size:
|
|
||||||
chunks.extend(self._split_recursive(current_buffer, remaining_seps))
|
|
||||||
else:
|
|
||||||
chunks.append(current_buffer)
|
|
||||||
|
|
||||||
return [c for c in chunks if c.strip()]
|
|
||||||
|
|
||||||
def _merge_small_chunks(self, chunks: list[str]) -> list[str]:
|
|
||||||
"""
|
|
||||||
合并过短的块(< min_chunk_size token),并在相邻块间加入重叠文本。
|
|
||||||
"""
|
|
||||||
if not chunks:
|
|
||||||
return []
|
|
||||||
|
|
||||||
merged: list[str] = []
|
|
||||||
buffer = chunks[0]
|
|
||||||
|
|
||||||
for chunk in chunks[1:]:
|
|
||||||
if self._estimate_tokens(buffer) < self.min_chunk_size:
|
|
||||||
buffer = buffer + "\n" + chunk
|
|
||||||
else:
|
|
||||||
merged.append(buffer)
|
|
||||||
# 添加重叠:取上一块末尾若干 token 作为前缀
|
|
||||||
overlap_text = self._get_overlap_prefix(buffer)
|
|
||||||
buffer = (overlap_text + "\n" + chunk).strip() if overlap_text else chunk
|
|
||||||
|
|
||||||
merged.append(buffer)
|
|
||||||
return [c for c in merged if c.strip()]
|
|
||||||
|
|
||||||
def _get_overlap_prefix(self, text: str) -> str:
|
|
||||||
"""截取文本末尾作为下一块的重叠前缀(按 token 估算)。"""
|
|
||||||
if self.chunk_overlap <= 0:
|
|
||||||
return ""
|
|
||||||
# 简单实现:按字符比例截取
|
|
||||||
words = text.split()
|
|
||||||
if not words:
|
|
||||||
return ""
|
|
||||||
# 估算每个词约 1.5 token(中英混合)
|
|
||||||
token_per_word = 1.5
|
|
||||||
overlap_words = int(self.chunk_overlap / token_per_word)
|
|
||||||
overlap_words = max(1, min(overlap_words, len(words)))
|
|
||||||
return " ".join(words[-overlap_words:])
|
|
||||||
|
|
||||||
def _hard_split(self, text: str) -> list[str]:
|
|
||||||
"""按字符强制截断(最后手段)。"""
|
|
||||||
# 粗略:1 token ≈ 2 字符(中文)
|
|
||||||
char_limit = self.chunk_size * 2
|
|
||||||
chunks = []
|
chunks = []
|
||||||
start = 0
|
metadata = metadata or {}
|
||||||
while start < len(text):
|
|
||||||
end = start + char_limit
|
# 按段落分割
|
||||||
chunks.append(text[start:end])
|
segments = re.split(r"\n\n+", text)
|
||||||
start = end - self.chunk_overlap * 2 # 加入字符级重叠
|
|
||||||
if start <= 0:
|
current_chunk = ""
|
||||||
start = end
|
for segment in segments:
|
||||||
|
if len(current_chunk) + len(segment) <= self.STRATEGY.chunk_size:
|
||||||
|
current_chunk += segment + "\n\n"
|
||||||
|
else:
|
||||||
|
# 当前块足够大,保存
|
||||||
|
if len(current_chunk.strip()) >= self.STRATEGY.min_chunk_size:
|
||||||
|
chunks.append({
|
||||||
|
"content": current_chunk.strip(),
|
||||||
|
"chunk_index": len(chunks),
|
||||||
|
"metadata": metadata,
|
||||||
|
})
|
||||||
|
|
||||||
|
# 处理过长段落
|
||||||
|
if len(segment) > self.STRATEGY.chunk_size:
|
||||||
|
current_chunk = segment
|
||||||
|
else:
|
||||||
|
# 保留重叠
|
||||||
|
overlap = current_chunk[-self.STRATEGY.chunk_overlap:]
|
||||||
|
current_chunk = overlap + segment + "\n\n"
|
||||||
|
|
||||||
|
# 处理最后一个块
|
||||||
|
if len(current_chunk.strip()) >= self.STRATEGY.min_chunk_size:
|
||||||
|
chunks.append({
|
||||||
|
"content": current_chunk.strip(),
|
||||||
|
"chunk_index": len(chunks),
|
||||||
|
"metadata": metadata,
|
||||||
|
})
|
||||||
|
|
||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
def _estimate_tokens(self, text: str) -> int:
|
class SemanticChunker(BaseChunker):
|
||||||
"""
|
"""语义分块器 - 按语义边界分割"""
|
||||||
估算 token 数。
|
|
||||||
规则:中文字符每字计 1 token,英文单词计 1.3 token(BPE 碎片系数)。
|
STRATEGY = ChunkStrategy(
|
||||||
"""
|
name="semantic",
|
||||||
if not text:
|
description="根据语义边界(标题、段落)自动分块",
|
||||||
return 0
|
chunk_size=800,
|
||||||
|
chunk_overlap=100,
|
||||||
|
min_chunk_size=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 语义边界模式
|
||||||
|
SEMANTIC_PATTERNS = [
|
||||||
|
(r"^#{1,6}\s+(.+)$", "heading"), # Markdown标题
|
||||||
|
(r"^【(.+?)】\s*$", "heading"), # 中文标题
|
||||||
|
(r"^第[一二三四五六七八九十百]+[章节条]", "heading"), # 章节标题
|
||||||
|
(r"^(\d+\.)+\s+", "heading"), # 数字编号
|
||||||
|
]
|
||||||
|
|
||||||
|
def chunk(self, text: str, metadata: Optional[dict] = None) -> list[dict]:
|
||||||
|
chunks = []
|
||||||
|
metadata = metadata or {}
|
||||||
|
lines = text.split("\n")
|
||||||
|
|
||||||
|
current_chunk = ""
|
||||||
|
current_section = None
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
# 检查是否是语义边界
|
||||||
|
is_boundary = False
|
||||||
|
for pattern, boundary_type in self.SEMANTIC_PATTERNS:
|
||||||
|
if re.match(pattern, line.strip()):
|
||||||
|
is_boundary = True
|
||||||
|
current_section = line.strip()
|
||||||
|
break
|
||||||
|
|
||||||
|
# 如果是边界且当前块不为空,保存
|
||||||
|
if is_boundary and current_chunk.strip():
|
||||||
|
chunks.append({
|
||||||
|
"content": current_chunk.strip(),
|
||||||
|
"chunk_index": len(chunks),
|
||||||
|
"section": current_section,
|
||||||
|
"metadata": metadata,
|
||||||
|
})
|
||||||
|
# 保留重叠
|
||||||
|
overlap = current_chunk[-self.STRATEGY.chunk_overlap:]
|
||||||
|
current_chunk = overlap + line + "\n"
|
||||||
|
else:
|
||||||
|
current_chunk += line + "\n"
|
||||||
|
|
||||||
|
# 检查块大小
|
||||||
|
if len(current_chunk) >= self.STRATEGY.chunk_size:
|
||||||
|
chunks.append({
|
||||||
|
"content": current_chunk.strip(),
|
||||||
|
"chunk_index": len(chunks),
|
||||||
|
"section": current_section,
|
||||||
|
"metadata": metadata,
|
||||||
|
})
|
||||||
|
overlap = current_chunk[-self.STRATEGY.chunk_overlap:]
|
||||||
|
current_chunk = overlap
|
||||||
|
|
||||||
|
# 处理最后一个块
|
||||||
|
if current_chunk.strip():
|
||||||
|
chunks.append({
|
||||||
|
"content": current_chunk.strip(),
|
||||||
|
"chunk_index": len(chunks),
|
||||||
|
"section": current_section,
|
||||||
|
"metadata": metadata,
|
||||||
|
})
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
# 中文字符计数
|
class FixedLengthChunker(BaseChunker):
|
||||||
chinese_chars = len(re.findall(r"[\u4e00-\u9fff\u3400-\u4dbf]", text))
|
"""固定长度分块器"""
|
||||||
# 去掉中文后,计算英文单词数
|
|
||||||
non_chinese = re.sub(r"[\u4e00-\u9fff\u3400-\u4dbf]", " ", text)
|
STRATEGY = ChunkStrategy(
|
||||||
english_words = len(non_chinese.split())
|
name="fixed",
|
||||||
|
description="按固定长度强制分块",
|
||||||
|
chunk_size=300,
|
||||||
|
chunk_overlap=30,
|
||||||
|
min_chunk_size=50,
|
||||||
|
)
|
||||||
|
|
||||||
|
def chunk(self, text: str, metadata: Optional[dict] = None) -> list[dict]:
|
||||||
|
chunks = []
|
||||||
|
metadata = metadata or {}
|
||||||
|
|
||||||
|
# 移除多余空白
|
||||||
|
text = re.sub(r"\s+", " ", text)
|
||||||
|
|
||||||
|
for i in range(0, len(text), self.STRATEGY.chunk_size - self.STRATEGY.chunk_overlap):
|
||||||
|
chunk_text = text[i:i + self.STRATEGY.chunk_size]
|
||||||
|
|
||||||
|
if len(chunk_text.strip()) >= self.STRATEGY.min_chunk_size:
|
||||||
|
chunks.append({
|
||||||
|
"content": chunk_text.strip(),
|
||||||
|
"chunk_index": len(chunks),
|
||||||
|
"metadata": metadata,
|
||||||
|
})
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
return int(chinese_chars + english_words * 1.3)
|
class ChunkerFactory:
|
||||||
|
"""分块策略工厂"""
|
||||||
|
|
||||||
|
STRATEGIES = {
|
||||||
|
"recursive": RecursiveChunker,
|
||||||
|
"semantic": SemanticChunker,
|
||||||
|
"fixed": FixedLengthChunker,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(cls, strategy: str = "recursive") -> BaseChunker:
|
||||||
|
"""创建分块器"""
|
||||||
|
chunker_cls = cls.STRATEGIES.get(strategy, RecursiveChunker)
|
||||||
|
return chunker_cls()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def list_strategies(cls) -> list[ChunkStrategy]:
|
||||||
|
"""列出所有策略"""
|
||||||
|
return [chunker_cls.STRATEGY for chunker_cls in cls.STRATEGIES.values()]
|
||||||
|
|
@ -0,0 +1,149 @@
|
||||||
|
"""
|
||||||
|
增强版RAG服务 - 包含重排序和上下文压缩
|
||||||
|
"""
|
||||||
|
import re
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from app.services.llm.factory import LLMFactory
|
||||||
|
|
||||||
|
|
||||||
|
class EnhancedRAG:
|
||||||
|
"""增强版RAG检索服务"""
|
||||||
|
|
||||||
|
def __init__(self, rag_service, embedder):
|
||||||
|
self.rag = rag_service
|
||||||
|
self.embedder = embedder
|
||||||
|
self.llm = LLMFactory.create()
|
||||||
|
|
||||||
|
async def retrieve_with_rerank(
|
||||||
|
self,
|
||||||
|
session,
|
||||||
|
query: str,
|
||||||
|
kb_ids: list[str],
|
||||||
|
top_k: int = 5,
|
||||||
|
use_rerank: bool = True,
|
||||||
|
use_compression: bool = False,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""
|
||||||
|
增强检索流程:
|
||||||
|
1. 初始检索(扩大候选集)
|
||||||
|
2. 可选:重排序
|
||||||
|
3. 可选:上下文压缩
|
||||||
|
"""
|
||||||
|
# Step 1: 初始检索
|
||||||
|
initial_k = top_k * 4 if use_rerank else top_k
|
||||||
|
candidates = await self.rag.search(
|
||||||
|
session, query, kb_ids, top_k=initial_k
|
||||||
|
)
|
||||||
|
|
||||||
|
if not candidates:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Step 2: 可选重排序
|
||||||
|
if use_rerank and len(candidates) > top_k:
|
||||||
|
candidates = await self._rerank(query, candidates, top_k)
|
||||||
|
|
||||||
|
# Step 3: 可选上下文压缩
|
||||||
|
if use_compression:
|
||||||
|
candidates = await self._compress(candidates, query)
|
||||||
|
|
||||||
|
return candidates[:top_k]
|
||||||
|
|
||||||
|
async def _rerank(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
candidates: list[dict],
|
||||||
|
top_k: int,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""
|
||||||
|
使用LLM进行相关性重排序
|
||||||
|
|
||||||
|
对每个候选计算与查询的相关性分数,然后排序
|
||||||
|
"""
|
||||||
|
reranked = []
|
||||||
|
|
||||||
|
for item in candidates:
|
||||||
|
# 提取候选内容片段
|
||||||
|
content = item.get("content", "")[:500] # 限制长度
|
||||||
|
|
||||||
|
# 构建评估Prompt
|
||||||
|
prompt = f"""评估以下查询与文档片段的相关性。
|
||||||
|
|
||||||
|
查询:{query}
|
||||||
|
|
||||||
|
文档片段:
|
||||||
|
{content}
|
||||||
|
|
||||||
|
请只返回一个0到1之间的小数,表示相关性分数。0表示完全不相关,1表示完全相关。只返回数字:"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await self.llm.generate(prompt)
|
||||||
|
# 提取数字
|
||||||
|
match = re.search(r'0?\.\d+', response)
|
||||||
|
relevance = float(match.group()) if match else 0.5
|
||||||
|
except Exception:
|
||||||
|
relevance = item.get("score", 0.5)
|
||||||
|
|
||||||
|
item["relevance_score"] = relevance
|
||||||
|
reranked.append(item)
|
||||||
|
|
||||||
|
# 按相关性分数降序排序
|
||||||
|
reranked.sort(key=lambda x: x["relevance_score"], reverse=True)
|
||||||
|
|
||||||
|
return reranked
|
||||||
|
|
||||||
|
async def _compress(
|
||||||
|
self,
|
||||||
|
candidates: list[dict],
|
||||||
|
query: str,
|
||||||
|
max_context_tokens: int = 2000,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""
|
||||||
|
上下文压缩
|
||||||
|
|
||||||
|
从每个chunk中提取与query相关的内容,减少token消耗
|
||||||
|
"""
|
||||||
|
compressed = []
|
||||||
|
current_tokens = 0
|
||||||
|
|
||||||
|
for chunk in candidates:
|
||||||
|
content = chunk.get("content", "")
|
||||||
|
|
||||||
|
# 估算token(中文约1.5字符/token)
|
||||||
|
est_tokens = len(content) // 2
|
||||||
|
|
||||||
|
if current_tokens + est_tokens <= max_context_tokens:
|
||||||
|
chunk["compressed"] = False
|
||||||
|
compressed.append(chunk)
|
||||||
|
current_tokens += est_tokens
|
||||||
|
else:
|
||||||
|
# 尝试压缩这个chunk
|
||||||
|
compressed_chunk = await self._compress_chunk(content, query)
|
||||||
|
chunk["compressed"] = True
|
||||||
|
chunk["compressed_content"] = compressed_chunk
|
||||||
|
compressed.append(chunk)
|
||||||
|
break
|
||||||
|
|
||||||
|
return compressed
|
||||||
|
|
||||||
|
async def _compress_chunk(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
query: str,
|
||||||
|
) -> str:
|
||||||
|
"""压缩单个chunk,保留与query相关的内容"""
|
||||||
|
prompt = f"""从以下文本中提取与问题最相关的部分,保持原文的表达方式,不要总结或改写。
|
||||||
|
|
||||||
|
问题:{query}
|
||||||
|
|
||||||
|
原文:
|
||||||
|
{content}
|
||||||
|
|
||||||
|
直接返回提取的内容(不要解释):"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await self.llm.generate(prompt)
|
||||||
|
return result.strip()
|
||||||
|
except Exception:
|
||||||
|
# 压缩失败时返回原文
|
||||||
|
return content[:500] + "..."
|
||||||
|
|
@ -0,0 +1,168 @@
|
||||||
|
"""实体和关系抽取服务"""
|
||||||
|
import re
|
||||||
|
import json
|
||||||
|
from typing import Optional
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from app.services.llm.factory import LLMFactory
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ExtractedEntity:
|
||||||
|
"""抽取的实体"""
|
||||||
|
name: str
|
||||||
|
entity_type: str
|
||||||
|
description: Optional[str] = None
|
||||||
|
properties: dict = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ExtractedRelation:
|
||||||
|
"""抽取的关系"""
|
||||||
|
source_entity: str
|
||||||
|
target_entity: str
|
||||||
|
relation_type: str
|
||||||
|
properties: dict = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ExtractionResult:
|
||||||
|
"""抽取结果"""
|
||||||
|
entities: list[ExtractedEntity] = field(default_factory=list)
|
||||||
|
relations: list[ExtractedRelation] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class EntityExtractor:
|
||||||
|
"""实体和关系抽取服务"""
|
||||||
|
|
||||||
|
# 实体类型映射
|
||||||
|
ENTITY_TYPES = [
|
||||||
|
"ORGANIZATION", # 公司/组织
|
||||||
|
"PRODUCT", # 产品
|
||||||
|
"PERSON", # 人物
|
||||||
|
"LOCATION", # 地点
|
||||||
|
"TECHNOLOGY", # 技术
|
||||||
|
"BRAND", # 品牌
|
||||||
|
"EVENT", # 事件
|
||||||
|
"CONCEPT", # 概念
|
||||||
|
]
|
||||||
|
|
||||||
|
# 关系类型映射
|
||||||
|
RELATION_TYPES = [
|
||||||
|
"COMPETES_WITH", # 竞争对手
|
||||||
|
"PARTNERS_WITH", # 合作伙伴
|
||||||
|
"PRODUCES", # 生产
|
||||||
|
"USES_TECHNOLOGY", # 使用技术
|
||||||
|
"LOCATED_IN", # 位于
|
||||||
|
"FOUNDED_IN", # 成立于
|
||||||
|
"CEO_OF", # CEO
|
||||||
|
"FOUNDER_OF", # 创始人
|
||||||
|
"RELATED_TO", # 相关
|
||||||
|
"PART_OF", # 属于
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.llm = LLMFactory.create()
|
||||||
|
|
||||||
|
async def extract(self, text: str, context: Optional[str] = None) -> ExtractionResult:
|
||||||
|
"""
|
||||||
|
从文本中抽取实体和关系
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 待处理的文本
|
||||||
|
context: 可选的上下文信息(如品牌名、行业等)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ExtractionResult: 包含实体和关系的抽取结果
|
||||||
|
"""
|
||||||
|
# 构建抽取Prompt
|
||||||
|
prompt = self._build_extraction_prompt(text, context)
|
||||||
|
|
||||||
|
# 调用LLM
|
||||||
|
response = await self.llm.generate(prompt)
|
||||||
|
|
||||||
|
# 解析结果
|
||||||
|
return self._parse_response(response)
|
||||||
|
|
||||||
|
def _build_extraction_prompt(self, text: str, context: Optional[str] = None) -> str:
|
||||||
|
"""构建抽取Prompt"""
|
||||||
|
entity_types = "\n".join([f"- {t}" for t in self.ENTITY_TYPES])
|
||||||
|
relation_types = "\n".join([f"- {t}" for t in self.RELATION_TYPES])
|
||||||
|
|
||||||
|
context_hint = f"\n\n附加上下文:{context}" if context else ""
|
||||||
|
|
||||||
|
return f"""从以下文本中抽取知识图谱的实体和关系。
|
||||||
|
|
||||||
|
要求:
|
||||||
|
1. 实体必须从文本中明确提及,不能臆造
|
||||||
|
2. 关系必须有文本依据,不能臆造
|
||||||
|
3. 每个实体和关系都要有置信度说明(high/medium/low)
|
||||||
|
|
||||||
|
实体类型:
|
||||||
|
{entity_types}
|
||||||
|
|
||||||
|
关系类型:
|
||||||
|
{relation_types}
|
||||||
|
|
||||||
|
{context_hint}
|
||||||
|
|
||||||
|
文本内容:
|
||||||
|
{text}
|
||||||
|
|
||||||
|
请以JSON格式返回结果:
|
||||||
|
{{
|
||||||
|
"entities": [
|
||||||
|
{{
|
||||||
|
"name": "实体名称",
|
||||||
|
"entity_type": "实体类型",
|
||||||
|
"description": "实体描述(可选)",
|
||||||
|
"confidence": "high/medium/low"
|
||||||
|
}}
|
||||||
|
],
|
||||||
|
"relations": [
|
||||||
|
{{
|
||||||
|
"source_entity": "源实体名称",
|
||||||
|
"target_entity": "目标实体名称",
|
||||||
|
"relation_type": "关系类型",
|
||||||
|
"confidence": "high/medium/low"
|
||||||
|
}}
|
||||||
|
]
|
||||||
|
}}
|
||||||
|
|
||||||
|
只返回JSON,不要有其他内容:"""
|
||||||
|
|
||||||
|
def _parse_response(self, response: str) -> ExtractionResult:
|
||||||
|
"""解析LLM返回的结果"""
|
||||||
|
# 提取JSON
|
||||||
|
json_match = re.search(r'\{[\s\S]*\}', response)
|
||||||
|
if not json_match:
|
||||||
|
return ExtractionResult(entities=[], relations=[])
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = json.loads(json_match.group())
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return ExtractionResult(entities=[], relations=[])
|
||||||
|
|
||||||
|
# 解析实体
|
||||||
|
entities = [
|
||||||
|
ExtractedEntity(
|
||||||
|
name=e["name"],
|
||||||
|
entity_type=e["entity_type"],
|
||||||
|
description=e.get("description"),
|
||||||
|
properties={"confidence": e.get("confidence", "medium")},
|
||||||
|
)
|
||||||
|
for e in data.get("entities", [])
|
||||||
|
]
|
||||||
|
|
||||||
|
# 解析关系
|
||||||
|
relations = [
|
||||||
|
ExtractedRelation(
|
||||||
|
source_entity=r["source_entity"],
|
||||||
|
target_entity=r["target_entity"],
|
||||||
|
relation_type=r["relation_type"],
|
||||||
|
properties={"confidence": r.get("confidence", "medium")},
|
||||||
|
)
|
||||||
|
for r in data.get("relations", [])
|
||||||
|
]
|
||||||
|
|
||||||
|
return ExtractionResult(entities=entities, relations=relations)
|
||||||
|
|
@ -0,0 +1,187 @@
|
||||||
|
"""知识图谱构建服务"""
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.models.knowledge_graph import (
|
||||||
|
KnowledgeEntity,
|
||||||
|
KnowledgeRelation,
|
||||||
|
EntityType,
|
||||||
|
RelationType,
|
||||||
|
)
|
||||||
|
from app.models.knowledge import KnowledgeChunk, KnowledgeDocument
|
||||||
|
from app.services.knowledge.entity_extractor import EntityExtractor, ExtractionResult
|
||||||
|
|
||||||
|
|
||||||
|
class GraphBuilder:
|
||||||
|
"""知识图谱构建服务"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.extractor = EntityExtractor()
|
||||||
|
|
||||||
|
async def build_from_chunk(
|
||||||
|
self,
|
||||||
|
session: AsyncSession,
|
||||||
|
chunk_id: str,
|
||||||
|
context: Optional[str] = None,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
从Chunk构建知识图谱
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: 数据库会话
|
||||||
|
chunk_id: Chunk ID
|
||||||
|
context: 可选的上下文(如品牌名)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
构建统计信息
|
||||||
|
"""
|
||||||
|
# 1. 获取Chunk内容
|
||||||
|
chunk = await session.get(KnowledgeChunk, chunk_id)
|
||||||
|
if not chunk:
|
||||||
|
raise ValueError(f"Chunk not found: {chunk_id}")
|
||||||
|
|
||||||
|
# 2. 抽取实体和关系
|
||||||
|
result = await self.extractor.extract(chunk.content, context)
|
||||||
|
|
||||||
|
# 3. 存储到图谱
|
||||||
|
stats = await self._store_extraction(session, chunk_id, result)
|
||||||
|
|
||||||
|
return stats
|
||||||
|
|
||||||
|
async def _store_extraction(
|
||||||
|
self,
|
||||||
|
session: AsyncSession,
|
||||||
|
chunk_id: str,
|
||||||
|
result: ExtractionResult,
|
||||||
|
) -> dict:
|
||||||
|
"""存储抽取结果"""
|
||||||
|
stats = {
|
||||||
|
"entities_created": 0,
|
||||||
|
"entities_existing": 0,
|
||||||
|
"relations_created": 0,
|
||||||
|
"relations_existing": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 实体名称到ID的映射
|
||||||
|
entity_map = {}
|
||||||
|
|
||||||
|
# 4. 存储实体
|
||||||
|
for extracted_entity in result.entities:
|
||||||
|
# 检查是否已存在
|
||||||
|
existing, created = await self._get_or_create_entity(
|
||||||
|
session,
|
||||||
|
chunk_id,
|
||||||
|
extracted_entity,
|
||||||
|
)
|
||||||
|
|
||||||
|
entity_map[extracted_entity.name] = existing.id
|
||||||
|
if created:
|
||||||
|
stats["entities_created"] += 1
|
||||||
|
else:
|
||||||
|
stats["entities_existing"] += 1
|
||||||
|
|
||||||
|
# 5. 存储关系
|
||||||
|
for extracted_relation in result.relations:
|
||||||
|
# 查找实体ID
|
||||||
|
source_id = entity_map.get(extracted_relation.source_entity)
|
||||||
|
target_id = entity_map.get(extracted_relation.target_entity)
|
||||||
|
|
||||||
|
if not source_id or not target_id:
|
||||||
|
continue # 跳过找不到实体的情况
|
||||||
|
|
||||||
|
# 创建关系
|
||||||
|
created = await self._create_relation(
|
||||||
|
session,
|
||||||
|
chunk_id,
|
||||||
|
source_id,
|
||||||
|
target_id,
|
||||||
|
extracted_relation,
|
||||||
|
)
|
||||||
|
|
||||||
|
if created:
|
||||||
|
stats["relations_created"] += 1
|
||||||
|
else:
|
||||||
|
stats["relations_existing"] += 1
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
return stats
|
||||||
|
|
||||||
|
async def _get_or_create_entity(
|
||||||
|
self,
|
||||||
|
session: AsyncSession,
|
||||||
|
chunk_id: str,
|
||||||
|
extracted_entity,
|
||||||
|
) -> tuple:
|
||||||
|
"""获取或创建实体"""
|
||||||
|
kb_id = await self._get_chunk_kb_id(session, chunk_id)
|
||||||
|
|
||||||
|
# 查找现有实体
|
||||||
|
stmt = select(KnowledgeEntity).where(
|
||||||
|
KnowledgeEntity.knowledge_base_id == kb_id,
|
||||||
|
KnowledgeEntity.name == extracted_entity.name,
|
||||||
|
)
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
existing = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if existing:
|
||||||
|
return (existing, False)
|
||||||
|
|
||||||
|
# 创建新实体
|
||||||
|
entity = KnowledgeEntity(
|
||||||
|
knowledge_base_id=kb_id,
|
||||||
|
name=extracted_entity.name,
|
||||||
|
entity_type=EntityType(extracted_entity.entity_type),
|
||||||
|
description=extracted_entity.description,
|
||||||
|
properties=extracted_entity.properties or {},
|
||||||
|
source_chunk_id=chunk_id,
|
||||||
|
confidence=extracted_entity.properties.get("confidence") if extracted_entity.properties else None,
|
||||||
|
)
|
||||||
|
session.add(entity)
|
||||||
|
await session.flush()
|
||||||
|
|
||||||
|
return (entity, True)
|
||||||
|
|
||||||
|
async def _create_relation(
|
||||||
|
self,
|
||||||
|
session: AsyncSession,
|
||||||
|
chunk_id: str,
|
||||||
|
source_id: str,
|
||||||
|
target_id: str,
|
||||||
|
extracted_relation,
|
||||||
|
) -> bool:
|
||||||
|
"""创建关系(如果不存在)"""
|
||||||
|
# 检查是否已存在
|
||||||
|
stmt = select(KnowledgeRelation).where(
|
||||||
|
KnowledgeRelation.source_entity_id == source_id,
|
||||||
|
KnowledgeRelation.target_entity_id == target_id,
|
||||||
|
KnowledgeRelation.relation_type == RelationType(extracted_relation.relation_type),
|
||||||
|
)
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
existing = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if existing:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 创建关系
|
||||||
|
relation = KnowledgeRelation(
|
||||||
|
source_entity_id=source_id,
|
||||||
|
target_entity_id=target_id,
|
||||||
|
relation_type=RelationType(extracted_relation.relation_type),
|
||||||
|
properties=extracted_relation.properties or {},
|
||||||
|
source_chunk_id=chunk_id,
|
||||||
|
confidence=extracted_relation.properties.get("confidence") if extracted_relation.properties else None,
|
||||||
|
)
|
||||||
|
session.add(relation)
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _get_chunk_kb_id(self, session: AsyncSession, chunk_id: str) -> str:
|
||||||
|
"""获取Chunk所属的知识库ID"""
|
||||||
|
chunk = await session.get(KnowledgeChunk, chunk_id)
|
||||||
|
if not chunk:
|
||||||
|
raise ValueError(f"Chunk not found: {chunk_id}")
|
||||||
|
|
||||||
|
# 通过document获取kb_id
|
||||||
|
doc = await session.get(KnowledgeDocument, chunk.document_id)
|
||||||
|
return doc.knowledge_base_id
|
||||||
|
|
@ -0,0 +1,249 @@
|
||||||
|
"""知识图谱查询服务"""
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import select, func
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
|
from app.models.knowledge_graph import (
|
||||||
|
KnowledgeEntity,
|
||||||
|
KnowledgeRelation,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GraphQuery:
|
||||||
|
"""知识图谱查询服务"""
|
||||||
|
|
||||||
|
async def get_entity(
|
||||||
|
self,
|
||||||
|
session: AsyncSession,
|
||||||
|
entity_id: str,
|
||||||
|
) -> Optional[dict]:
|
||||||
|
"""根据ID获取实体详情"""
|
||||||
|
entity = await session.get(KnowledgeEntity, entity_id)
|
||||||
|
if not entity:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return self._entity_to_dict(entity)
|
||||||
|
|
||||||
|
async def search_entities(
|
||||||
|
self,
|
||||||
|
session: AsyncSession,
|
||||||
|
kb_id: str,
|
||||||
|
query: str,
|
||||||
|
entity_type: Optional[str] = None,
|
||||||
|
limit: int = 20,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""搜索实体"""
|
||||||
|
stmt = select(KnowledgeEntity).where(
|
||||||
|
KnowledgeEntity.knowledge_base_id == kb_id,
|
||||||
|
KnowledgeEntity.name.ilike(f"%{query}%"),
|
||||||
|
)
|
||||||
|
|
||||||
|
if entity_type:
|
||||||
|
stmt = stmt.where(KnowledgeEntity.entity_type == entity_type)
|
||||||
|
|
||||||
|
stmt = stmt.limit(limit)
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
|
||||||
|
return [self._entity_to_dict(e) for e in result.scalars()]
|
||||||
|
|
||||||
|
async def get_entity_neighbors(
|
||||||
|
self,
|
||||||
|
session: AsyncSession,
|
||||||
|
entity_id: str,
|
||||||
|
max_depth: int = 1,
|
||||||
|
) -> dict:
|
||||||
|
"""获取实体的邻居(直接关联的实体)"""
|
||||||
|
entity = await session.get(KnowledgeEntity, entity_id)
|
||||||
|
if not entity:
|
||||||
|
return None
|
||||||
|
|
||||||
|
neighbors = {
|
||||||
|
"entity": self._entity_to_dict(entity),
|
||||||
|
"incoming": [], # 入边(别人指向我)
|
||||||
|
"outgoing": [], # 出边(我指向别人)
|
||||||
|
}
|
||||||
|
|
||||||
|
# 获取入边
|
||||||
|
incoming_stmt = (
|
||||||
|
select(KnowledgeRelation, KnowledgeEntity)
|
||||||
|
.join(KnowledgeEntity, KnowledgeRelation.source_entity_id == KnowledgeEntity.id)
|
||||||
|
.where(KnowledgeRelation.target_entity_id == entity_id)
|
||||||
|
)
|
||||||
|
incoming_result = await session.execute(incoming_stmt)
|
||||||
|
for rel, source_entity in incoming_result:
|
||||||
|
neighbors["incoming"].append({
|
||||||
|
"relation": self._relation_to_dict(rel),
|
||||||
|
"entity": self._entity_to_dict(source_entity),
|
||||||
|
})
|
||||||
|
|
||||||
|
# 获取出边
|
||||||
|
outgoing_stmt = (
|
||||||
|
select(KnowledgeRelation, KnowledgeEntity)
|
||||||
|
.join(KnowledgeEntity, KnowledgeRelation.target_entity_id == KnowledgeEntity.id)
|
||||||
|
.where(KnowledgeRelation.source_entity_id == entity_id)
|
||||||
|
)
|
||||||
|
outgoing_result = await session.execute(outgoing_stmt)
|
||||||
|
for rel, target_entity in outgoing_result:
|
||||||
|
neighbors["outgoing"].append({
|
||||||
|
"relation": self._relation_to_dict(rel),
|
||||||
|
"entity": self._entity_to_dict(target_entity),
|
||||||
|
})
|
||||||
|
|
||||||
|
return neighbors
|
||||||
|
|
||||||
|
async def get_entity_path(
|
||||||
|
self,
|
||||||
|
session: AsyncSession,
|
||||||
|
source_name: str,
|
||||||
|
target_name: str,
|
||||||
|
max_hops: int = 3,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""
|
||||||
|
查找两个实体之间的路径
|
||||||
|
|
||||||
|
使用简单BFS查找路径
|
||||||
|
"""
|
||||||
|
# 获取实体ID
|
||||||
|
source_stmt = select(KnowledgeEntity).where(
|
||||||
|
KnowledgeEntity.name == source_name
|
||||||
|
)
|
||||||
|
source_result = await session.execute(source_stmt)
|
||||||
|
source_entity = source_result.scalar_one_or_none()
|
||||||
|
|
||||||
|
target_stmt = select(KnowledgeEntity).where(
|
||||||
|
KnowledgeEntity.name == target_name
|
||||||
|
)
|
||||||
|
target_result = await session.execute(target_stmt)
|
||||||
|
target_entity = target_result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not source_entity or not target_entity:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# BFS查找路径
|
||||||
|
visited = {str(source_entity.id)}
|
||||||
|
queue = [(str(source_entity.id), [])]
|
||||||
|
|
||||||
|
while queue:
|
||||||
|
current_id, path = queue.pop(0)
|
||||||
|
|
||||||
|
if current_id == str(target_entity.id):
|
||||||
|
# 找到路径,返回
|
||||||
|
return await self._format_path(path, session)
|
||||||
|
|
||||||
|
if len(path) >= max_hops:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 探索邻居
|
||||||
|
neighbors_stmt = (
|
||||||
|
select(KnowledgeRelation, KnowledgeEntity)
|
||||||
|
.join(KnowledgeEntity, KnowledgeRelation.target_entity_id == KnowledgeEntity.id)
|
||||||
|
.where(KnowledgeRelation.source_entity_id == current_id)
|
||||||
|
)
|
||||||
|
neighbors_result = await session.execute(neighbors_stmt)
|
||||||
|
|
||||||
|
for rel, neighbor in neighbors_result:
|
||||||
|
neighbor_id = str(neighbor.id)
|
||||||
|
if neighbor_id not in visited:
|
||||||
|
visited.add(neighbor_id)
|
||||||
|
new_path = path + [{
|
||||||
|
"from": current_id,
|
||||||
|
"relation": rel.relation_type.value,
|
||||||
|
"to": neighbor_id,
|
||||||
|
}]
|
||||||
|
queue.append((neighbor_id, new_path))
|
||||||
|
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def get_statistics(
|
||||||
|
self,
|
||||||
|
session: AsyncSession,
|
||||||
|
kb_id: str,
|
||||||
|
) -> dict:
|
||||||
|
"""获取图谱统计信息"""
|
||||||
|
# 实体数量
|
||||||
|
entity_count_stmt = select(func.count()).where(
|
||||||
|
KnowledgeEntity.knowledge_base_id == kb_id
|
||||||
|
)
|
||||||
|
entity_count_result = await session.execute(entity_count_stmt)
|
||||||
|
entity_count = entity_count_result.scalar() or 0
|
||||||
|
|
||||||
|
# 关系数量
|
||||||
|
kb_entities = select(KnowledgeEntity.id).where(
|
||||||
|
KnowledgeEntity.knowledge_base_id == kb_id
|
||||||
|
)
|
||||||
|
relation_count_stmt = select(func.count()).where(
|
||||||
|
KnowledgeRelation.source_entity_id.in_(kb_entities)
|
||||||
|
)
|
||||||
|
relation_count_result = await session.execute(relation_count_stmt)
|
||||||
|
relation_count = relation_count_result.scalar() or 0
|
||||||
|
|
||||||
|
# 实体类型分布
|
||||||
|
type_dist_stmt = (
|
||||||
|
select(
|
||||||
|
KnowledgeEntity.entity_type,
|
||||||
|
func.count()
|
||||||
|
)
|
||||||
|
.where(KnowledgeEntity.knowledge_base_id == kb_id)
|
||||||
|
.group_by(KnowledgeEntity.entity_type)
|
||||||
|
)
|
||||||
|
type_dist_result = await session.execute(type_dist_stmt)
|
||||||
|
entity_type_dist = {str(k): v for k, v in type_dist_result}
|
||||||
|
|
||||||
|
# 关系类型分布
|
||||||
|
rel_type_dist_stmt = (
|
||||||
|
select(
|
||||||
|
KnowledgeRelation.relation_type,
|
||||||
|
func.count()
|
||||||
|
)
|
||||||
|
.where(KnowledgeRelation.source_entity_id.in_(kb_entities))
|
||||||
|
.group_by(KnowledgeRelation.relation_type)
|
||||||
|
)
|
||||||
|
rel_type_dist_result = await session.execute(rel_type_dist_stmt)
|
||||||
|
relation_type_dist = {str(k): v for k, v in rel_type_dist_result}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"entity_count": entity_count,
|
||||||
|
"relation_count": relation_count,
|
||||||
|
"entity_type_distribution": entity_type_dist,
|
||||||
|
"relation_type_distribution": relation_type_dist,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _entity_to_dict(self, entity: KnowledgeEntity) -> dict:
|
||||||
|
"""实体转字典"""
|
||||||
|
return {
|
||||||
|
"id": str(entity.id),
|
||||||
|
"name": entity.name,
|
||||||
|
"entity_type": entity.entity_type.value,
|
||||||
|
"description": entity.description,
|
||||||
|
"properties": entity.properties,
|
||||||
|
"confidence": entity.confidence,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _relation_to_dict(self, relation: KnowledgeRelation) -> dict:
|
||||||
|
"""关系转字典"""
|
||||||
|
return {
|
||||||
|
"id": str(relation.id),
|
||||||
|
"source_id": str(relation.source_entity_id),
|
||||||
|
"target_id": str(relation.target_entity_id),
|
||||||
|
"relation_type": relation.relation_type.value,
|
||||||
|
"properties": relation.properties,
|
||||||
|
"confidence": relation.confidence,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _format_path(self, path: list, session: AsyncSession) -> list[dict]:
|
||||||
|
"""格式化路径,返回实体名称"""
|
||||||
|
formatted = []
|
||||||
|
for step in path:
|
||||||
|
# 获取实体名称
|
||||||
|
from_entity = await session.get(KnowledgeEntity, step["from"])
|
||||||
|
to_entity = await session.get(KnowledgeEntity, step["to"])
|
||||||
|
|
||||||
|
formatted.append({
|
||||||
|
"from": from_entity.name if from_entity else step["from"],
|
||||||
|
"relation": step["relation"],
|
||||||
|
"to": to_entity.name if to_entity else step["to"],
|
||||||
|
})
|
||||||
|
|
||||||
|
return formatted
|
||||||
|
|
@ -0,0 +1,242 @@
|
||||||
|
"""
|
||||||
|
增量索引服务
|
||||||
|
"""
|
||||||
|
import hashlib
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import delete, func, select, update
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.models.knowledge import KnowledgeChunk, KnowledgeDocument
|
||||||
|
|
||||||
|
|
||||||
|
class IncrementalIndexService:
|
||||||
|
"""增量索引服务 - 支持文档的增删改"""
|
||||||
|
|
||||||
|
def __init__(self, rag_service):
|
||||||
|
self.rag = rag_service
|
||||||
|
|
||||||
|
async def add_document(
|
||||||
|
self,
|
||||||
|
session: AsyncSession,
|
||||||
|
kb_id: str,
|
||||||
|
document_id: str,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
增量添加文档
|
||||||
|
|
||||||
|
不重建全量索引,只处理单个文档
|
||||||
|
"""
|
||||||
|
# 检查是否已存在
|
||||||
|
existing = await self._get_document_status(session, document_id)
|
||||||
|
|
||||||
|
if existing and existing.get("status") == "ready":
|
||||||
|
return {"action": "skip", "reason": "already_indexed"}
|
||||||
|
|
||||||
|
# 执行增量摄入
|
||||||
|
chunk_count = await self.rag.ingest_document(session, document_id)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"action": "indexed",
|
||||||
|
"document_id": document_id,
|
||||||
|
"chunk_count": chunk_count,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def update_document(
|
||||||
|
self,
|
||||||
|
session: AsyncSession,
|
||||||
|
document_id: str,
|
||||||
|
new_content: str,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
增量更新文档
|
||||||
|
|
||||||
|
策略:
|
||||||
|
1. 计算新内容hash
|
||||||
|
2. 若hash未变,跳过
|
||||||
|
3. 若hash改变,删除旧chunks,生成新的
|
||||||
|
"""
|
||||||
|
# 计算新hash
|
||||||
|
new_hash = hashlib.sha256(new_content.encode()).hexdigest()
|
||||||
|
|
||||||
|
# 获取旧hash
|
||||||
|
old_hash = await self._get_content_hash(session, document_id)
|
||||||
|
|
||||||
|
if new_hash == old_hash:
|
||||||
|
return {"action": "skip", "reason": "content_unchanged"}
|
||||||
|
|
||||||
|
# 删除旧chunks
|
||||||
|
deleted = await self._delete_document_chunks(session, document_id)
|
||||||
|
|
||||||
|
# 更新文档内容
|
||||||
|
await self._update_document_content(session, document_id, new_content, new_hash)
|
||||||
|
|
||||||
|
# 增量摄入新内容
|
||||||
|
chunk_count = await self.rag.ingest_document(session, document_id)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"action": "updated",
|
||||||
|
"document_id": document_id,
|
||||||
|
"deleted_chunks": deleted,
|
||||||
|
"new_chunks": chunk_count,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def delete_document(
|
||||||
|
self,
|
||||||
|
session: AsyncSession,
|
||||||
|
document_id: str,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
删除文档
|
||||||
|
|
||||||
|
删除文档及其所有chunks
|
||||||
|
"""
|
||||||
|
# 删除chunks
|
||||||
|
deleted = await self._delete_document_chunks(session, document_id)
|
||||||
|
|
||||||
|
# 删除文档记录
|
||||||
|
await self._delete_document(session, document_id)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"action": "deleted",
|
||||||
|
"document_id": document_id,
|
||||||
|
"deleted_chunks": deleted,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def rebuild_knowledge_base(
|
||||||
|
self,
|
||||||
|
session: AsyncSession,
|
||||||
|
kb_id: str,
|
||||||
|
force: bool = False,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
重建知识库索引
|
||||||
|
|
||||||
|
Args:
|
||||||
|
force: 是否强制重建(即使状态是ready)
|
||||||
|
"""
|
||||||
|
# 获取所有文档
|
||||||
|
stmt = select(KnowledgeDocument).where(
|
||||||
|
KnowledgeDocument.knowledge_base_id == kb_id
|
||||||
|
)
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
documents = result.scalars().all()
|
||||||
|
|
||||||
|
stats = {
|
||||||
|
"total": len(documents),
|
||||||
|
"processed": 0,
|
||||||
|
"skipped": 0,
|
||||||
|
"failed": 0,
|
||||||
|
"errors": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
for doc in documents:
|
||||||
|
try:
|
||||||
|
if doc.status == "ready" and not force:
|
||||||
|
stats["skipped"] += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 删除旧chunks
|
||||||
|
await self._delete_document_chunks(session, str(doc.id))
|
||||||
|
|
||||||
|
# 重新摄入
|
||||||
|
await self.rag.ingest_document(session, str(doc.id))
|
||||||
|
|
||||||
|
stats["processed"] += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
stats["failed"] += 1
|
||||||
|
stats["errors"].append({
|
||||||
|
"document_id": str(doc.id),
|
||||||
|
"error": str(e),
|
||||||
|
})
|
||||||
|
|
||||||
|
return stats
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# 辅助方法
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _get_document_status(
|
||||||
|
self,
|
||||||
|
session: AsyncSession,
|
||||||
|
document_id: str,
|
||||||
|
) -> Optional[dict]:
|
||||||
|
"""获取文档状态"""
|
||||||
|
stmt = select(KnowledgeDocument).where(
|
||||||
|
KnowledgeDocument.id == document_id
|
||||||
|
)
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
doc = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not doc:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": doc.status,
|
||||||
|
"content_hash": getattr(doc, "content_hash", None),
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _get_content_hash(
|
||||||
|
self,
|
||||||
|
session: AsyncSession,
|
||||||
|
document_id: str,
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""获取文档内容hash"""
|
||||||
|
status = await self._get_document_status(session, document_id)
|
||||||
|
return status.get("content_hash") if status else None
|
||||||
|
|
||||||
|
async def _delete_document_chunks(
|
||||||
|
self,
|
||||||
|
session: AsyncSession,
|
||||||
|
document_id: str,
|
||||||
|
) -> int:
|
||||||
|
"""删除文档的所有chunks"""
|
||||||
|
# 统计要删除的数量
|
||||||
|
count_stmt = select(func.count()).where(
|
||||||
|
KnowledgeChunk.document_id == document_id
|
||||||
|
)
|
||||||
|
count_result = await session.execute(count_stmt)
|
||||||
|
count = count_result.scalar() or 0
|
||||||
|
|
||||||
|
# 删除
|
||||||
|
delete_stmt = delete(KnowledgeChunk).where(
|
||||||
|
KnowledgeChunk.document_id == document_id
|
||||||
|
)
|
||||||
|
await session.execute(delete_stmt)
|
||||||
|
|
||||||
|
return count
|
||||||
|
|
||||||
|
async def _update_document_content(
|
||||||
|
self,
|
||||||
|
session: AsyncSession,
|
||||||
|
document_id: str,
|
||||||
|
content: str,
|
||||||
|
content_hash: str,
|
||||||
|
):
|
||||||
|
"""更新文档内容和hash"""
|
||||||
|
stmt = (
|
||||||
|
update(KnowledgeDocument)
|
||||||
|
.where(KnowledgeDocument.id == document_id)
|
||||||
|
.values(
|
||||||
|
content=content,
|
||||||
|
content_hash=content_hash,
|
||||||
|
status="pending", # 标记为待处理
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await session.execute(stmt)
|
||||||
|
|
||||||
|
async def _delete_document(
|
||||||
|
self,
|
||||||
|
session: AsyncSession,
|
||||||
|
document_id: str,
|
||||||
|
):
|
||||||
|
"""删除文档记录"""
|
||||||
|
stmt = select(KnowledgeDocument).where(
|
||||||
|
KnowledgeDocument.id == document_id
|
||||||
|
)
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
doc = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if doc:
|
||||||
|
await session.delete(doc)
|
||||||
|
|
@ -0,0 +1,165 @@
|
||||||
|
"""文档解析器 - 支持多种格式"""
|
||||||
|
import io
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ParsedDocument:
|
||||||
|
"""解析后的文档"""
|
||||||
|
title: str
|
||||||
|
content: str
|
||||||
|
metadata: dict
|
||||||
|
|
||||||
|
class BaseParser(ABC):
|
||||||
|
"""解析器基类"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def parse(self, content: bytes) -> ParsedDocument:
|
||||||
|
"""解析文档内容"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
class PDFParser(BaseParser):
|
||||||
|
"""PDF解析器"""
|
||||||
|
|
||||||
|
async def parse(self, content: bytes) -> ParsedDocument:
|
||||||
|
"""使用PyMuPDF解析PDF"""
|
||||||
|
import fitz
|
||||||
|
|
||||||
|
doc = fitz.open(stream=content)
|
||||||
|
text_parts = []
|
||||||
|
metadata = {}
|
||||||
|
|
||||||
|
# 提取元数据
|
||||||
|
if doc.metadata:
|
||||||
|
metadata = {
|
||||||
|
"author": doc.metadata.get("author", ""),
|
||||||
|
"title": doc.metadata.get("title", ""),
|
||||||
|
"subject": doc.metadata.get("subject", ""),
|
||||||
|
}
|
||||||
|
|
||||||
|
# 提取每页文本
|
||||||
|
for page_num, page in enumerate(doc):
|
||||||
|
text = page.get_text()
|
||||||
|
if text.strip():
|
||||||
|
text_parts.append(f"[第{page_num + 1}页]\n{text}")
|
||||||
|
|
||||||
|
# 提取目录(如果存在)
|
||||||
|
toc = doc.get_toc()
|
||||||
|
if toc:
|
||||||
|
metadata["has_toc"] = True
|
||||||
|
metadata["toc_items"] = len(toc)
|
||||||
|
|
||||||
|
doc.close()
|
||||||
|
|
||||||
|
return ParsedDocument(
|
||||||
|
title=metadata.get("title", "未命名文档") or "未命名文档",
|
||||||
|
content="\n\n".join(text_parts),
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
class DocxParser(BaseParser):
|
||||||
|
"""Word文档解析器"""
|
||||||
|
|
||||||
|
async def parse(self, content: bytes) -> ParsedDocument:
|
||||||
|
"""使用python-docx解析Word"""
|
||||||
|
from docx import Document
|
||||||
|
|
||||||
|
doc = Document(io.BytesIO(content))
|
||||||
|
paragraphs = []
|
||||||
|
metadata = {}
|
||||||
|
|
||||||
|
# 提取核心属性
|
||||||
|
core_props = doc.core_properties
|
||||||
|
metadata = {
|
||||||
|
"author": getattr(core_props, "author", "") or "",
|
||||||
|
"title": getattr(core_props, "title", "") or "",
|
||||||
|
"subject": getattr(core_props, "subject", "") or "",
|
||||||
|
"created": str(getattr(core_props, "created", "")) or "",
|
||||||
|
"modified": str(getattr(core_props, "modified", "")) or "",
|
||||||
|
}
|
||||||
|
|
||||||
|
# 提取段落
|
||||||
|
for para in doc.paragraphs:
|
||||||
|
text = para.text.strip()
|
||||||
|
if text:
|
||||||
|
paragraphs.append(text)
|
||||||
|
|
||||||
|
# 提取表格
|
||||||
|
for i, table in enumerate(doc.tables):
|
||||||
|
table_text = []
|
||||||
|
for row in table.rows:
|
||||||
|
cells = [cell.text.strip() for cell in row.cells]
|
||||||
|
if any(cells):
|
||||||
|
table_text.append(" | ".join(cells))
|
||||||
|
if table_text:
|
||||||
|
paragraphs.append(f"[表格{i+1}]\n" + "\n".join(table_text))
|
||||||
|
|
||||||
|
return ParsedDocument(
|
||||||
|
title=metadata.get("title", "未命名文档") or "未命名文档",
|
||||||
|
content="\n\n".join(paragraphs),
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
class MarkdownParser(BaseParser):
|
||||||
|
"""Markdown解析器"""
|
||||||
|
|
||||||
|
async def parse(self, content: bytes) -> ParsedDocument:
|
||||||
|
"""解析Markdown"""
|
||||||
|
text = content.decode("utf-8")
|
||||||
|
|
||||||
|
# 提取标题(第一个#开头的内容)
|
||||||
|
lines = text.split("\n")
|
||||||
|
title = "未命名文档"
|
||||||
|
for line in lines:
|
||||||
|
line = line.strip()
|
||||||
|
if line.startswith("# "):
|
||||||
|
title = line[2:].strip()
|
||||||
|
break
|
||||||
|
|
||||||
|
return ParsedDocument(
|
||||||
|
title=title,
|
||||||
|
content=text,
|
||||||
|
metadata={"format": "markdown"},
|
||||||
|
)
|
||||||
|
|
||||||
|
class TextParser(BaseParser):
|
||||||
|
"""纯文本解析器"""
|
||||||
|
|
||||||
|
async def parse(self, content: bytes) -> ParsedDocument:
|
||||||
|
"""解析纯文本"""
|
||||||
|
text = content.decode("utf-8")
|
||||||
|
|
||||||
|
# 使用第一行作为标题
|
||||||
|
lines = text.split("\n")
|
||||||
|
title = lines[0][:50] if lines else "未命名文档"
|
||||||
|
|
||||||
|
return ParsedDocument(
|
||||||
|
title=title,
|
||||||
|
content=text,
|
||||||
|
metadata={"format": "text"},
|
||||||
|
)
|
||||||
|
|
||||||
|
class ParserFactory:
|
||||||
|
"""解析器工厂"""
|
||||||
|
|
||||||
|
PARSERS = {
|
||||||
|
".pdf": PDFParser,
|
||||||
|
".docx": DocxParser,
|
||||||
|
".md": MarkdownParser,
|
||||||
|
".txt": TextParser,
|
||||||
|
".html": MarkdownParser, # HTML当Markdown处理
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(cls, file_extension: str) -> BaseParser:
|
||||||
|
"""创建解析器"""
|
||||||
|
parser_cls = cls.PARSERS.get(file_extension.lower())
|
||||||
|
if not parser_cls:
|
||||||
|
raise ValueError(f"Unsupported format: {file_extension}")
|
||||||
|
return parser_cls()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def supported_formats(cls) -> list[str]:
|
||||||
|
"""支持的格式"""
|
||||||
|
return list(cls.PARSERS.keys())
|
||||||
|
|
@ -0,0 +1,96 @@
|
||||||
|
"""文档上传服务"""
|
||||||
|
import hashlib
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import shortuuid
|
||||||
|
|
||||||
|
from app.services.knowledge.parsers import ParserFactory, ParsedDocument
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UploadResult:
|
||||||
|
"""上传结果"""
|
||||||
|
document_id: str
|
||||||
|
title: str
|
||||||
|
content: str
|
||||||
|
content_hash: str
|
||||||
|
file_size: int
|
||||||
|
file_type: str
|
||||||
|
metadata: dict
|
||||||
|
|
||||||
|
class DocumentUploader:
|
||||||
|
"""文档上传服务"""
|
||||||
|
|
||||||
|
SUPPORTED_FORMATS = {".pdf", ".docx", ".md", ".txt", ".html"}
|
||||||
|
MAX_SIZE_MB = 50
|
||||||
|
|
||||||
|
def __init__(self, storage_path: str = "/data/documents"):
|
||||||
|
self.storage_path = storage_path
|
||||||
|
|
||||||
|
async def upload(
|
||||||
|
self,
|
||||||
|
file_content: bytes,
|
||||||
|
filename: str,
|
||||||
|
kb_id: str,
|
||||||
|
) -> UploadResult:
|
||||||
|
"""上传并解析文档"""
|
||||||
|
# 1. 验证格式
|
||||||
|
ext = self._get_extension(filename)
|
||||||
|
if ext not in self.SUPPORTED_FORMATS:
|
||||||
|
raise ValueError(f"Unsupported format: {ext}")
|
||||||
|
|
||||||
|
# 2. 验证大小
|
||||||
|
size_mb = len(file_content) / (1024 * 1024)
|
||||||
|
if size_mb > self.MAX_SIZE_MB:
|
||||||
|
raise ValueError(f"File too large: {size_mb:.1f}MB > {self.MAX_SIZE_MB}MB")
|
||||||
|
|
||||||
|
# 3. 解析文档
|
||||||
|
parser = ParserFactory.create(ext)
|
||||||
|
parsed = await parser.parse(file_content)
|
||||||
|
|
||||||
|
# 4. 生成ID和哈希
|
||||||
|
doc_id = shortuuid.uuid()
|
||||||
|
content_hash = hashlib.sha256(parsed.content.encode()).hexdigest()
|
||||||
|
|
||||||
|
# 5. 保存文件
|
||||||
|
file_path = self._save_file(file_content, kb_id, doc_id, ext)
|
||||||
|
|
||||||
|
return UploadResult(
|
||||||
|
document_id=doc_id,
|
||||||
|
title=parsed.title,
|
||||||
|
content=parsed.content,
|
||||||
|
content_hash=content_hash,
|
||||||
|
file_size=len(file_content),
|
||||||
|
file_type=ext,
|
||||||
|
metadata={
|
||||||
|
**parsed.metadata,
|
||||||
|
"original_filename": filename,
|
||||||
|
"stored_path": file_path,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_extension(self, filename: str) -> str:
|
||||||
|
"""获取文件扩展名"""
|
||||||
|
if "." not in filename:
|
||||||
|
raise ValueError("File has no extension")
|
||||||
|
return "." + filename.rsplit(".", 1)[1].lower()
|
||||||
|
|
||||||
|
def _save_file(
|
||||||
|
self,
|
||||||
|
content: bytes,
|
||||||
|
kb_id: str,
|
||||||
|
doc_id: str,
|
||||||
|
ext: str
|
||||||
|
) -> str:
|
||||||
|
"""保存文件到存储路径"""
|
||||||
|
# 创建目录
|
||||||
|
dir_path = os.path.join(self.storage_path, kb_id)
|
||||||
|
os.makedirs(dir_path, exist_ok=True)
|
||||||
|
|
||||||
|
# 保存文件
|
||||||
|
file_path = os.path.join(dir_path, f"{doc_id}{ext}")
|
||||||
|
with open(file_path, "wb") as f:
|
||||||
|
f.write(content)
|
||||||
|
|
||||||
|
return file_path
|
||||||
|
|
@ -1,12 +1,14 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
from typing import AsyncIterator
|
from typing import AsyncIterator
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from .base import LLMError, LLMProvider, LLMResponse
|
from .base import LLMError, LLMProvider, LLMResponse
|
||||||
from .rate_limiter import get_rate_limiter
|
from .rate_limiter import get_rate_limiter
|
||||||
|
from app.monitoring.llm_metrics import get_llm_metrics
|
||||||
|
|
||||||
_DEFAULT_MODEL = "deepseek-chat"
|
_DEFAULT_MODEL = "deepseek-chat"
|
||||||
_DEFAULT_MAX_CONTEXT = 64_000
|
_DEFAULT_MAX_CONTEXT = 64_000
|
||||||
|
|
@ -75,21 +77,40 @@ class DeepSeekProvider(LLMProvider):
|
||||||
if stop:
|
if stop:
|
||||||
payload["stop"] = stop
|
payload["stop"] = stop
|
||||||
|
|
||||||
data = await self._request_with_retry(payload, stream=False)
|
start_time = time.perf_counter()
|
||||||
|
metrics = get_llm_metrics(self.provider_name, self._model)
|
||||||
|
|
||||||
choice = data["choices"][0]
|
try:
|
||||||
content = choice["message"]["content"]
|
data = await self._request_with_retry(payload, stream=False)
|
||||||
usage = data.get("usage", {})
|
|
||||||
|
|
||||||
return LLMResponse(
|
choice = data["choices"][0]
|
||||||
content=content,
|
content = choice["message"]["content"]
|
||||||
model=data.get("model", self._model),
|
usage = data.get("usage", {})
|
||||||
usage={
|
|
||||||
"prompt_tokens": usage.get("prompt_tokens", 0),
|
duration = time.perf_counter() - start_time
|
||||||
"completion_tokens": usage.get("completion_tokens", 0),
|
metrics.record_request(
|
||||||
"total_tokens": usage.get("total_tokens", 0),
|
status="success",
|
||||||
},
|
duration=duration,
|
||||||
)
|
prompt_tokens=usage.get("prompt_tokens"),
|
||||||
|
completion_tokens=usage.get("completion_tokens"),
|
||||||
|
)
|
||||||
|
|
||||||
|
return LLMResponse(
|
||||||
|
content=content,
|
||||||
|
model=data.get("model", self._model),
|
||||||
|
usage={
|
||||||
|
"prompt_tokens": usage.get("prompt_tokens", 0),
|
||||||
|
"completion_tokens": usage.get("completion_tokens", 0),
|
||||||
|
"total_tokens": usage.get("total_tokens", 0),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
duration = time.perf_counter() - start_time
|
||||||
|
metrics.record_request(
|
||||||
|
status="error",
|
||||||
|
duration=duration,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
async def chat_stream(
|
async def chat_stream(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,14 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
from typing import AsyncIterator
|
from typing import AsyncIterator
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from .base import LLMError, LLMProvider, LLMResponse
|
from .base import LLMError, LLMProvider, LLMResponse
|
||||||
from .rate_limiter import get_rate_limiter
|
from .rate_limiter import get_rate_limiter
|
||||||
|
from app.monitoring.llm_metrics import get_llm_metrics
|
||||||
|
|
||||||
# 支持的模型及其上下文长度(百炼 Coding Plan + OpenAI)
|
# 支持的模型及其上下文长度(百炼 Coding Plan + OpenAI)
|
||||||
_OPENAI_MODELS: dict[str, int] = {
|
_OPENAI_MODELS: dict[str, int] = {
|
||||||
|
|
@ -90,21 +92,40 @@ class OpenAIProvider(LLMProvider):
|
||||||
if stop:
|
if stop:
|
||||||
payload["stop"] = stop
|
payload["stop"] = stop
|
||||||
|
|
||||||
data = await self._request_with_retry(payload, stream=False)
|
start_time = time.perf_counter()
|
||||||
|
metrics = get_llm_metrics(self.provider_name, self._model)
|
||||||
|
|
||||||
choice = data["choices"][0]
|
try:
|
||||||
content = choice["message"]["content"]
|
data = await self._request_with_retry(payload, stream=False)
|
||||||
usage = data.get("usage", {})
|
|
||||||
|
|
||||||
return LLMResponse(
|
choice = data["choices"][0]
|
||||||
content=content,
|
content = choice["message"]["content"]
|
||||||
model=data.get("model", self._model),
|
usage = data.get("usage", {})
|
||||||
usage={
|
|
||||||
"prompt_tokens": usage.get("prompt_tokens", 0),
|
duration = time.perf_counter() - start_time
|
||||||
"completion_tokens": usage.get("completion_tokens", 0),
|
metrics.record_request(
|
||||||
"total_tokens": usage.get("total_tokens", 0),
|
status="success",
|
||||||
},
|
duration=duration,
|
||||||
)
|
prompt_tokens=usage.get("prompt_tokens"),
|
||||||
|
completion_tokens=usage.get("completion_tokens"),
|
||||||
|
)
|
||||||
|
|
||||||
|
return LLMResponse(
|
||||||
|
content=content,
|
||||||
|
model=data.get("model", self._model),
|
||||||
|
usage={
|
||||||
|
"prompt_tokens": usage.get("prompt_tokens", 0),
|
||||||
|
"completion_tokens": usage.get("completion_tokens", 0),
|
||||||
|
"total_tokens": usage.get("total_tokens", 0),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
duration = time.perf_counter() - start_time
|
||||||
|
metrics.record_request(
|
||||||
|
status="error",
|
||||||
|
duration=duration,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
async def chat_stream(
|
async def chat_stream(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,324 @@
|
||||||
|
"""
|
||||||
|
套餐额度预警服务
|
||||||
|
|
||||||
|
监控用户的API调用额度、查询次数等使用情况,在接近限制时触发预警。
|
||||||
|
|
||||||
|
功能:
|
||||||
|
- 额度使用查询: 获取用户当前额度使用情况
|
||||||
|
- 使用率计算: 计算额度使用百分比
|
||||||
|
- 预警阈值检查: 80%警告、90%严重、100%限制
|
||||||
|
- 预警消息生成: 生成带建议操作的预警消息
|
||||||
|
- 额度重置: 支持额度重置功能
|
||||||
|
|
||||||
|
额度类型:
|
||||||
|
- api_calls: API调用次数
|
||||||
|
- queries: 查询次数
|
||||||
|
- content_generation: 内容生成次数
|
||||||
|
- storage: 存储空间(MB)
|
||||||
|
|
||||||
|
预警阈值:
|
||||||
|
- warning: 80% - 警告
|
||||||
|
- critical: 90% - 严重
|
||||||
|
- exhausted: 100% - 耗尽
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class QuotaType(str, Enum):
|
||||||
|
"""额度类型枚举"""
|
||||||
|
API_CALLS = "api_calls"
|
||||||
|
QUERIES = "queries"
|
||||||
|
CONTENT_GENERATION = "content_generation"
|
||||||
|
STORAGE = "storage"
|
||||||
|
|
||||||
|
|
||||||
|
QUOTA_LIMITS = {
|
||||||
|
"free": {
|
||||||
|
"api_calls": 1000,
|
||||||
|
"queries": 50,
|
||||||
|
"content_generation": 10,
|
||||||
|
"storage": 100,
|
||||||
|
},
|
||||||
|
"basic": {
|
||||||
|
"api_calls": 10000,
|
||||||
|
"queries": 500,
|
||||||
|
"content_generation": 100,
|
||||||
|
"storage": 1000,
|
||||||
|
},
|
||||||
|
"pro": {
|
||||||
|
"api_calls": 100000,
|
||||||
|
"queries": 5000,
|
||||||
|
"content_generation": 1000,
|
||||||
|
"storage": 10000,
|
||||||
|
},
|
||||||
|
"unlimited": {
|
||||||
|
"api_calls": -1,
|
||||||
|
"queries": -1,
|
||||||
|
"content_generation": -1,
|
||||||
|
"storage": -1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
WARNING_THRESHOLDS = {
|
||||||
|
"warning": 0.80,
|
||||||
|
"critical": 0.90,
|
||||||
|
"exhausted": 1.00,
|
||||||
|
}
|
||||||
|
|
||||||
|
WARNING_MESSAGES = {
|
||||||
|
"warning": "{quota_type}额度已使用{percentage:.0f}%",
|
||||||
|
"critical": "{quota_type}额度已使用{percentage:.0f}%,即将耗尽",
|
||||||
|
"exhausted": "{quota_type}额度已使用{percentage:.0f}%,已耗尽",
|
||||||
|
}
|
||||||
|
|
||||||
|
RECOMMENDED_ACTIONS = {
|
||||||
|
"warning": "请关注使用情况,考虑升级套餐",
|
||||||
|
"critical": "额度即将耗尽,建议立即升级套餐",
|
||||||
|
"exhausted": "额度已耗尽,请升级套餐或等待重置",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class QuotaUsage:
|
||||||
|
"""额度使用数据结构
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
quota_type: 额度类型
|
||||||
|
used: 已使用量
|
||||||
|
limit: 额度限制(-1表示无限)
|
||||||
|
usage_percentage: 使用百分比(0-100,无限时为0)
|
||||||
|
status: 状态(ok/warning/critical/exhausted/unlimited)
|
||||||
|
remaining: 剩余额度(-1表示无限)
|
||||||
|
"""
|
||||||
|
quota_type: str
|
||||||
|
used: int
|
||||||
|
limit: int
|
||||||
|
usage_percentage: float
|
||||||
|
status: str
|
||||||
|
remaining: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class QuotaWarning:
|
||||||
|
"""预警数据结构
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
quota_type: 额度类型
|
||||||
|
status: 预警状态
|
||||||
|
usage_percentage: 使用百分比
|
||||||
|
message: 预警消息
|
||||||
|
recommended_action: 建议操作
|
||||||
|
"""
|
||||||
|
quota_type: str
|
||||||
|
status: str
|
||||||
|
usage_percentage: float
|
||||||
|
message: str
|
||||||
|
recommended_action: str
|
||||||
|
|
||||||
|
|
||||||
|
class QuotaService:
|
||||||
|
"""套餐额度预警服务
|
||||||
|
|
||||||
|
提供额度查询、使用率计算、预警检查等功能。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def calculate_usage_percentage(self, used: int, limit: int) -> float:
|
||||||
|
"""计算额度使用百分比
|
||||||
|
|
||||||
|
Args:
|
||||||
|
used: 已使用量
|
||||||
|
limit: 额度限制(-1表示无限)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
使用百分比(0-100),无限额度时返回0.0
|
||||||
|
"""
|
||||||
|
if limit == -1 or limit == 0:
|
||||||
|
return 0.0
|
||||||
|
return (used / limit) * 100.0
|
||||||
|
|
||||||
|
def get_quota_status(self, usage_percentage: float, limit: int = 0) -> str:
|
||||||
|
"""根据使用率获取额度状态
|
||||||
|
|
||||||
|
Args:
|
||||||
|
usage_percentage: 使用百分比
|
||||||
|
limit: 额度限制(-1表示无限)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
状态字符串: ok/warning/critical/exhausted/unlimited
|
||||||
|
"""
|
||||||
|
if limit == -1:
|
||||||
|
return "unlimited"
|
||||||
|
if usage_percentage >= 100.0:
|
||||||
|
return "exhausted"
|
||||||
|
if usage_percentage >= 90.0:
|
||||||
|
return "critical"
|
||||||
|
if usage_percentage >= 80.0:
|
||||||
|
return "warning"
|
||||||
|
return "ok"
|
||||||
|
|
||||||
|
def get_quota_limit(self, plan: str, quota_type: QuotaType) -> int:
|
||||||
|
"""获取指定套餐的额度限制
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plan: 套餐名称(free/basic/pro/unlimited)
|
||||||
|
quota_type: 额度类型
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
额度限制值(-1表示无限)
|
||||||
|
"""
|
||||||
|
return QUOTA_LIMITS.get(plan, {}).get(quota_type.value, 0)
|
||||||
|
|
||||||
|
def get_remaining(self, used: int, limit: int) -> int:
|
||||||
|
"""计算剩余额度
|
||||||
|
|
||||||
|
Args:
|
||||||
|
used: 已使用量
|
||||||
|
limit: 额度限制(-1表示无限)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
剩余额度(-1表示无限)
|
||||||
|
"""
|
||||||
|
if limit == -1:
|
||||||
|
return -1
|
||||||
|
return limit - used
|
||||||
|
|
||||||
|
def _format_warning_message(self, quota_type: str, status: str, percentage: float) -> str:
|
||||||
|
"""格式化预警消息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
quota_type: 额度类型
|
||||||
|
status: 预警状态
|
||||||
|
percentage: 使用百分比
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
格式化后的预警消息
|
||||||
|
"""
|
||||||
|
template = WARNING_MESSAGES.get(status, "{quota_type}额度使用情况")
|
||||||
|
return template.format(quota_type=quota_type, percentage=percentage)
|
||||||
|
|
||||||
|
def _get_recommended_action(self, status: str) -> str:
|
||||||
|
"""获取建议操作
|
||||||
|
|
||||||
|
Args:
|
||||||
|
status: 预警状态
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
建议操作描述
|
||||||
|
"""
|
||||||
|
return RECOMMENDED_ACTIONS.get(status, "请关注使用情况")
|
||||||
|
|
||||||
|
def generate_warning(
|
||||||
|
self,
|
||||||
|
quota_type: QuotaType,
|
||||||
|
status: str,
|
||||||
|
usage_percentage: float,
|
||||||
|
) -> QuotaWarning:
|
||||||
|
"""生成预警信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
quota_type: 额度类型
|
||||||
|
status: 预警状态
|
||||||
|
usage_percentage: 使用百分比
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
QuotaWarning预警数据对象
|
||||||
|
"""
|
||||||
|
return QuotaWarning(
|
||||||
|
quota_type=quota_type.value,
|
||||||
|
status=status,
|
||||||
|
usage_percentage=usage_percentage,
|
||||||
|
message=self._format_warning_message(quota_type.value, status, usage_percentage),
|
||||||
|
recommended_action=self._get_recommended_action(status),
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_quota(
|
||||||
|
self,
|
||||||
|
plan: str,
|
||||||
|
quota_type: QuotaType,
|
||||||
|
used: int,
|
||||||
|
) -> QuotaUsage:
|
||||||
|
"""检查指定套餐的额度使用情况
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plan: 套餐名称
|
||||||
|
quota_type: 额度类型
|
||||||
|
used: 已使用量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
QuotaUsage额度使用数据对象
|
||||||
|
"""
|
||||||
|
limit = self.get_quota_limit(plan, quota_type)
|
||||||
|
percentage = self.calculate_usage_percentage(used, limit)
|
||||||
|
status = self.get_quota_status(percentage, limit)
|
||||||
|
remaining = self.get_remaining(used, limit)
|
||||||
|
|
||||||
|
return QuotaUsage(
|
||||||
|
quota_type=quota_type.value,
|
||||||
|
used=used,
|
||||||
|
limit=limit,
|
||||||
|
usage_percentage=percentage,
|
||||||
|
status=status,
|
||||||
|
remaining=remaining,
|
||||||
|
)
|
||||||
|
|
||||||
|
def reset_quota(self, used: int, reset_to: int = 0) -> int:
|
||||||
|
"""重置额度使用量
|
||||||
|
|
||||||
|
Args:
|
||||||
|
used: 当前使用量(未使用,仅为接口兼容)
|
||||||
|
reset_to: 重置到的值,默认为0
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
重置后的使用量
|
||||||
|
"""
|
||||||
|
return reset_to
|
||||||
|
|
||||||
|
def get_all_quota_usage(
|
||||||
|
self,
|
||||||
|
plan: str,
|
||||||
|
api_calls_used: int = 0,
|
||||||
|
queries_used: int = 0,
|
||||||
|
content_generation_used: int = 0,
|
||||||
|
storage_used: int = 0,
|
||||||
|
) -> list[QuotaUsage]:
|
||||||
|
"""获取所有额度类型的使用情况
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plan: 套餐名称
|
||||||
|
api_calls_used: API调用已使用量
|
||||||
|
queries_used: 查询已使用量
|
||||||
|
content_generation_used: 内容生成已使用量
|
||||||
|
storage_used: 存储已使用量(MB)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
所有额度类型的使用情况列表
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
self.check_quota(plan, QuotaType.API_CALLS, api_calls_used),
|
||||||
|
self.check_quota(plan, QuotaType.QUERIES, queries_used),
|
||||||
|
self.check_quota(plan, QuotaType.CONTENT_GENERATION, content_generation_used),
|
||||||
|
self.check_quota(plan, QuotaType.STORAGE, storage_used),
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_warnings(self, usage_list: list[QuotaUsage]) -> list[QuotaWarning]:
|
||||||
|
"""从使用情况列表中生成预警
|
||||||
|
|
||||||
|
Args:
|
||||||
|
usage_list: 额度使用情况列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
预警信息列表(仅包含需要预警的项)
|
||||||
|
"""
|
||||||
|
warnings = []
|
||||||
|
for usage in usage_list:
|
||||||
|
if usage.status in ["warning", "critical", "exhausted"]:
|
||||||
|
warning = self.generate_warning(
|
||||||
|
quota_type=QuotaType(usage.quota_type),
|
||||||
|
status=usage.status,
|
||||||
|
usage_percentage=usage.usage_percentage,
|
||||||
|
)
|
||||||
|
warnings.append(warning)
|
||||||
|
return warnings
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -40,3 +40,11 @@ aiosqlite
|
||||||
|
|
||||||
# PDF生成
|
# PDF生成
|
||||||
fpdf2>=2.7
|
fpdf2>=2.7
|
||||||
|
|
||||||
|
# 监控
|
||||||
|
prometheus-client>=0.19.0
|
||||||
|
|
||||||
|
# 文档解析
|
||||||
|
PyMuPDF>=1.23.0
|
||||||
|
python-docx>=1.1.0
|
||||||
|
shortuuid>=1.0.0
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,227 @@
|
||||||
|
"""Agent基类测试"""
|
||||||
|
import pytest
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from app.agent_framework.base import BaseAgent
|
||||||
|
from app.agent_framework.protocol import (
|
||||||
|
AgentCapability,
|
||||||
|
AgentStatus,
|
||||||
|
TaskMessage,
|
||||||
|
TaskResult,
|
||||||
|
TaskStatus,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ConcreteTestAgent(BaseAgent):
|
||||||
|
"""用于测试的BaseAgent实现"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
name="concrete_test_agent",
|
||||||
|
agent_type="test_type",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
|
self._execute_called = False
|
||||||
|
self._execute_task_data = None
|
||||||
|
|
||||||
|
def get_capabilities(self) -> AgentCapability:
|
||||||
|
return AgentCapability(
|
||||||
|
agent_name=self.name,
|
||||||
|
agent_type=self.agent_type,
|
||||||
|
version=self.version,
|
||||||
|
supported_tasks=["test_task"],
|
||||||
|
max_concurrency=3,
|
||||||
|
description="测试用Agent",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def execute(self, task: TaskMessage) -> TaskResult:
|
||||||
|
self._execute_called = True
|
||||||
|
self._execute_task_data = task
|
||||||
|
return TaskResult(
|
||||||
|
task_id=task.task_id,
|
||||||
|
agent_name=self.name,
|
||||||
|
status=TaskStatus.COMPLETED,
|
||||||
|
output_data={"result": "success"},
|
||||||
|
error_message=None,
|
||||||
|
started_at=datetime.now(timezone.utc),
|
||||||
|
completed_at=datetime.now(timezone.utc),
|
||||||
|
metrics={"test": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestBaseAgent:
|
||||||
|
"""Agent基类测试"""
|
||||||
|
|
||||||
|
def test_agent_initialization(self):
|
||||||
|
"""测试Agent初始化"""
|
||||||
|
agent = ConcreteTestAgent()
|
||||||
|
|
||||||
|
assert agent.name == "concrete_test_agent"
|
||||||
|
assert agent.agent_type == "test_type"
|
||||||
|
assert agent.version == "1.0.0"
|
||||||
|
assert agent.status == AgentStatus.OFFLINE
|
||||||
|
|
||||||
|
def test_get_capabilities(self):
|
||||||
|
"""测试获取Agent能力"""
|
||||||
|
agent = ConcreteTestAgent()
|
||||||
|
capability = agent.get_capabilities()
|
||||||
|
|
||||||
|
assert isinstance(capability, AgentCapability)
|
||||||
|
assert capability.agent_name == "concrete_test_agent"
|
||||||
|
assert "test_task" in capability.supported_tasks
|
||||||
|
|
||||||
|
def test_agent_status_transitions(self):
|
||||||
|
"""测试Agent状态转换"""
|
||||||
|
agent = ConcreteTestAgent()
|
||||||
|
|
||||||
|
# 初始状态
|
||||||
|
assert agent.status == AgentStatus.OFFLINE
|
||||||
|
|
||||||
|
# 模拟设置状态
|
||||||
|
agent._status = AgentStatus.ONLINE
|
||||||
|
assert agent.status == AgentStatus.ONLINE
|
||||||
|
|
||||||
|
agent._status = AgentStatus.BUSY
|
||||||
|
assert agent.status == AgentStatus.BUSY
|
||||||
|
|
||||||
|
def test_agent_running_tasks_tracking(self):
|
||||||
|
"""测试运行任务跟踪"""
|
||||||
|
agent = ConcreteTestAgent()
|
||||||
|
|
||||||
|
assert len(agent._running_tasks) == 0
|
||||||
|
|
||||||
|
# 模拟添加任务
|
||||||
|
agent._running_tasks.add("task-1")
|
||||||
|
agent._running_tasks.add("task-2")
|
||||||
|
assert len(agent._running_tasks) == 2
|
||||||
|
|
||||||
|
# 模拟移除任务
|
||||||
|
agent._running_tasks.discard("task-1")
|
||||||
|
assert len(agent._running_tasks) == 1
|
||||||
|
assert "task-1" not in agent._running_tasks
|
||||||
|
assert "task-2" in agent._running_tasks
|
||||||
|
|
||||||
|
def test_agent_semaphore_initialization(self):
|
||||||
|
"""测试信号量初始化"""
|
||||||
|
agent = ConcreteTestAgent()
|
||||||
|
capability = agent.get_capabilities()
|
||||||
|
max_concurrency = capability.max_concurrency
|
||||||
|
|
||||||
|
# 初始化信号量
|
||||||
|
import asyncio
|
||||||
|
agent._semaphore = asyncio.Semaphore(max_concurrency)
|
||||||
|
|
||||||
|
assert agent._semaphore is not None
|
||||||
|
assert agent._semaphore._value == max_concurrency
|
||||||
|
|
||||||
|
def test_is_idle_property(self):
|
||||||
|
"""测试空闲状态判断"""
|
||||||
|
agent = ConcreteTestAgent()
|
||||||
|
|
||||||
|
# OFFLINE 或 ONLINE 状态应该被视为空闲
|
||||||
|
agent._status = AgentStatus.OFFLINE
|
||||||
|
# 这里假设有is_idle属性或方法
|
||||||
|
# 根据实际实现检查
|
||||||
|
|
||||||
|
# BUSY状态不是空闲
|
||||||
|
agent._status = AgentStatus.BUSY
|
||||||
|
|
||||||
|
|
||||||
|
class TestTaskMessage:
|
||||||
|
"""TaskMessage测试"""
|
||||||
|
|
||||||
|
def test_task_message_creation(self):
|
||||||
|
"""测试TaskMessage创建"""
|
||||||
|
task = TaskMessage(
|
||||||
|
task_id=str(uuid.uuid4()),
|
||||||
|
agent_name="test_agent",
|
||||||
|
task_type="test_task",
|
||||||
|
priority=5,
|
||||||
|
input_data={"key": "value"},
|
||||||
|
callback_url=None,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert task.task_id is not None
|
||||||
|
assert task.agent_name == "test_agent"
|
||||||
|
assert task.task_type == "test_task"
|
||||||
|
assert task.priority == 5
|
||||||
|
assert task.input_data == {"key": "value"}
|
||||||
|
|
||||||
|
def test_task_message_to_dict(self):
|
||||||
|
"""测试TaskMessage序列化"""
|
||||||
|
task = TaskMessage(
|
||||||
|
task_id="test-uuid",
|
||||||
|
agent_name="test_agent",
|
||||||
|
task_type="test_task",
|
||||||
|
priority=5,
|
||||||
|
input_data={"key": "value"},
|
||||||
|
callback_url=None,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
data = task.to_dict()
|
||||||
|
|
||||||
|
assert data["task_id"] == "test-uuid"
|
||||||
|
assert data["agent_name"] == "test_agent"
|
||||||
|
assert data["task_type"] == "test_task"
|
||||||
|
|
||||||
|
def test_task_message_from_dict(self):
|
||||||
|
"""测试TaskMessage反序列化"""
|
||||||
|
data = {
|
||||||
|
"task_id": "test-uuid",
|
||||||
|
"agent_name": "test_agent",
|
||||||
|
"task_type": "test_task",
|
||||||
|
"priority": 5,
|
||||||
|
"input_data": {"key": "value"},
|
||||||
|
"callback_url": None,
|
||||||
|
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||||
|
}
|
||||||
|
|
||||||
|
task = TaskMessage.from_dict(data)
|
||||||
|
|
||||||
|
assert task.task_id == "test-uuid"
|
||||||
|
assert task.agent_name == "test_agent"
|
||||||
|
|
||||||
|
|
||||||
|
class TestTaskResult:
|
||||||
|
"""TaskResult测试"""
|
||||||
|
|
||||||
|
def test_task_result_creation(self):
|
||||||
|
"""测试TaskResult创建"""
|
||||||
|
result = TaskResult(
|
||||||
|
task_id="test-uuid",
|
||||||
|
agent_name="test_agent",
|
||||||
|
status=TaskStatus.COMPLETED,
|
||||||
|
output_data={"result": "success"},
|
||||||
|
error_message=None,
|
||||||
|
started_at=datetime.now(timezone.utc),
|
||||||
|
completed_at=datetime.now(timezone.utc),
|
||||||
|
metrics={"elapsed": 1.5},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.task_id == "test-uuid"
|
||||||
|
assert result.agent_name == "test_agent"
|
||||||
|
assert result.status == TaskStatus.COMPLETED
|
||||||
|
|
||||||
|
def test_task_result_to_dict(self):
|
||||||
|
"""测试TaskResult序列化"""
|
||||||
|
result = TaskResult(
|
||||||
|
task_id="test-uuid",
|
||||||
|
agent_name="test_agent",
|
||||||
|
status=TaskStatus.COMPLETED,
|
||||||
|
output_data={"result": "success"},
|
||||||
|
error_message=None,
|
||||||
|
started_at=datetime.now(timezone.utc),
|
||||||
|
completed_at=datetime.now(timezone.utc),
|
||||||
|
metrics={"elapsed": 1.5},
|
||||||
|
)
|
||||||
|
|
||||||
|
data = result.to_dict()
|
||||||
|
|
||||||
|
assert data["task_id"] == "test-uuid"
|
||||||
|
assert data["status"] == TaskStatus.COMPLETED
|
||||||
|
assert data["output_data"] == {"result": "success"}
|
||||||
|
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
@ -0,0 +1,214 @@
|
||||||
|
"""任务分发器测试"""
|
||||||
|
import pytest
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from app.agent_framework.dispatcher import TaskDispatcher
|
||||||
|
from app.agent_framework.registry import AgentRegistry
|
||||||
|
from app.agent_framework.protocol import (
|
||||||
|
AgentCapability,
|
||||||
|
TaskMessage,
|
||||||
|
)
|
||||||
|
from app.config import settings
|
||||||
|
|
||||||
|
|
||||||
|
def is_database_available():
|
||||||
|
"""检查数据库是否可用(同步方式)"""
|
||||||
|
try:
|
||||||
|
from sqlalchemy import create_engine, text
|
||||||
|
from app.config import settings
|
||||||
|
|
||||||
|
# 从URL创建同步引擎进行测试
|
||||||
|
sync_url = settings.DATABASE_URL.replace('+aiosqlite', '').replace('+asyncpg', '')
|
||||||
|
if 'sqlite' in sync_url:
|
||||||
|
engine = create_engine(sync_url)
|
||||||
|
else:
|
||||||
|
engine = create_engine(sync_url, connect_args={"connect_timeout": 1})
|
||||||
|
|
||||||
|
with engine.connect() as conn:
|
||||||
|
conn.execute(text("SELECT 1"))
|
||||||
|
engine.dispose()
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# 检查服务是否可用
|
||||||
|
_db_available = None
|
||||||
|
|
||||||
|
def check_db():
|
||||||
|
global _db_available
|
||||||
|
if _db_available is None:
|
||||||
|
try:
|
||||||
|
_db_available = is_database_available()
|
||||||
|
except Exception:
|
||||||
|
_db_available = False
|
||||||
|
return _db_available
|
||||||
|
|
||||||
|
|
||||||
|
def is_redis_available():
|
||||||
|
"""检查Redis是否可用"""
|
||||||
|
import redis
|
||||||
|
try:
|
||||||
|
r = redis.Redis.from_url(settings.REDIS_URL)
|
||||||
|
r.ping()
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
_redis_available = None
|
||||||
|
|
||||||
|
def check_redis():
|
||||||
|
global _redis_available
|
||||||
|
if _redis_available is None:
|
||||||
|
try:
|
||||||
|
_redis_available = is_redis_available()
|
||||||
|
except Exception:
|
||||||
|
_redis_available = False
|
||||||
|
return _redis_available
|
||||||
|
|
||||||
|
|
||||||
|
class TestTaskDispatcher:
|
||||||
|
"""任务分发器测试"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def dispatcher(self):
|
||||||
|
"""创建分发器实例"""
|
||||||
|
return TaskDispatcher(settings.REDIS_URL)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dispatcher_initialization(self, dispatcher):
|
||||||
|
"""测试分发器初始化"""
|
||||||
|
assert dispatcher is not None
|
||||||
|
assert dispatcher._redis_url == settings.REDIS_URL
|
||||||
|
assert dispatcher._redis is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_task_status_not_found(self, dispatcher):
|
||||||
|
"""测试获取不存在的任务状态"""
|
||||||
|
if not check_db():
|
||||||
|
pytest.skip("数据库不可用,跳过此测试")
|
||||||
|
|
||||||
|
non_existent_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
from app.agent_framework.exceptions import TaskNotFoundError
|
||||||
|
with pytest.raises(TaskNotFoundError):
|
||||||
|
await dispatcher.get_task_status(non_existent_id)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dispatch_without_agent(self, dispatcher):
|
||||||
|
"""测试分发任务到不存在的Agent"""
|
||||||
|
if not check_db():
|
||||||
|
pytest.skip("数据库不可用,跳过此测试")
|
||||||
|
|
||||||
|
task = TaskMessage(
|
||||||
|
task_id=str(uuid.uuid4()),
|
||||||
|
agent_name="non_existent_agent",
|
||||||
|
task_type="test_task",
|
||||||
|
priority=5,
|
||||||
|
input_data={},
|
||||||
|
callback_url=None,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
from app.agent_framework.exceptions import TaskDispatchError
|
||||||
|
with pytest.raises(TaskDispatchError):
|
||||||
|
await dispatcher.dispatch(task)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cancel_task_not_found(self, dispatcher):
|
||||||
|
"""测试取消不存在的任务"""
|
||||||
|
if not check_db():
|
||||||
|
pytest.skip("数据库不可用,跳过此测试")
|
||||||
|
|
||||||
|
non_existent_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
from app.agent_framework.exceptions import TaskNotFoundError
|
||||||
|
with pytest.raises(TaskNotFoundError):
|
||||||
|
await dispatcher.cancel_task(non_existent_id)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_progress(self, dispatcher):
|
||||||
|
"""测试处理进度上报"""
|
||||||
|
if not check_db():
|
||||||
|
pytest.skip("数据库不可用,跳过此测试")
|
||||||
|
|
||||||
|
from app.agent_framework.protocol import TaskProgress
|
||||||
|
|
||||||
|
# 创建一个假的progress对象
|
||||||
|
progress = TaskProgress(
|
||||||
|
task_id=str(uuid.uuid4()),
|
||||||
|
agent_name="non_existent",
|
||||||
|
progress=0.5,
|
||||||
|
message="测试进度",
|
||||||
|
updated_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 不应抛出异常
|
||||||
|
await dispatcher.handle_progress(progress)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_close_dispatcher(self, dispatcher):
|
||||||
|
"""测试关闭分发器"""
|
||||||
|
if not check_redis():
|
||||||
|
pytest.skip("Redis不可用,跳过此测试")
|
||||||
|
|
||||||
|
# 先获取redis连接
|
||||||
|
await dispatcher._get_redis()
|
||||||
|
assert dispatcher._redis is not None
|
||||||
|
|
||||||
|
# 关闭
|
||||||
|
await dispatcher.close()
|
||||||
|
assert dispatcher._redis is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dispatch_and_query_flow(self, dispatcher):
|
||||||
|
"""测试完整分发和查询流程"""
|
||||||
|
if not check_db() or not check_redis():
|
||||||
|
pytest.skip("数据库或Redis不可用,跳过此测试")
|
||||||
|
|
||||||
|
# 1. 注册一个测试Agent
|
||||||
|
registry = AgentRegistry()
|
||||||
|
agent_name = f"test_dispatch_agent_{uuid.uuid4().hex[:8]}"
|
||||||
|
capability = AgentCapability(
|
||||||
|
agent_name=agent_name,
|
||||||
|
agent_type="test_type",
|
||||||
|
version="1.0.0",
|
||||||
|
supported_tasks=["test_task"],
|
||||||
|
max_concurrency=3,
|
||||||
|
description="测试Agent",
|
||||||
|
)
|
||||||
|
await registry.register(capability, endpoint=f"agent:{agent_name}")
|
||||||
|
|
||||||
|
# 2. 尝试分发任务
|
||||||
|
task = TaskMessage(
|
||||||
|
task_id=str(uuid.uuid4()),
|
||||||
|
agent_name=agent_name,
|
||||||
|
task_type="test_task",
|
||||||
|
priority=5,
|
||||||
|
input_data={"test": "data"},
|
||||||
|
callback_url=None,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Agent虽然注册了但可能不在线,这里只验证方法能正常执行
|
||||||
|
try:
|
||||||
|
task_id = await dispatcher.dispatch(task)
|
||||||
|
assert task_id is not None
|
||||||
|
except Exception:
|
||||||
|
# Agent可能不在线,这是预期行为
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 清理
|
||||||
|
await registry.unregister(agent_name)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_retry_failed_tasks_empty(self, dispatcher):
|
||||||
|
"""测试重试失败任务(无失败任务)"""
|
||||||
|
if not check_db():
|
||||||
|
pytest.skip("数据库不可用,跳过此测试")
|
||||||
|
|
||||||
|
result = await dispatcher.retry_failed_tasks(max_retries=3)
|
||||||
|
# 无失败任务时不应抛出异常
|
||||||
|
assert result is None or result == [] or isinstance(result, int)
|
||||||
|
|
@ -0,0 +1,228 @@
|
||||||
|
"""Agent注册表测试"""
|
||||||
|
import pytest
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from app.agent_framework.registry import AgentRegistry
|
||||||
|
from app.agent_framework.protocol import AgentCapability, AgentStatus
|
||||||
|
|
||||||
|
|
||||||
|
def is_database_available():
|
||||||
|
"""检查数据库是否可用(同步方式)"""
|
||||||
|
try:
|
||||||
|
# 直接尝试同步方式检查
|
||||||
|
from sqlalchemy import create_engine, text
|
||||||
|
from app.config import settings
|
||||||
|
|
||||||
|
# 从URL创建同步引擎进行测试
|
||||||
|
sync_url = settings.DATABASE_URL.replace('+aiosqlite', '').replace('+asyncpg', '')
|
||||||
|
if 'sqlite' in sync_url:
|
||||||
|
engine = create_engine(sync_url)
|
||||||
|
else:
|
||||||
|
engine = create_engine(sync_url, connect_args={"connect_timeout": 1})
|
||||||
|
|
||||||
|
with engine.connect() as conn:
|
||||||
|
conn.execute(text("SELECT 1"))
|
||||||
|
engine.dispose()
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# 检查服务是否可用
|
||||||
|
_db_available = None
|
||||||
|
|
||||||
|
def check_db():
|
||||||
|
global _db_available
|
||||||
|
if _db_available is None:
|
||||||
|
try:
|
||||||
|
_db_available = is_database_available()
|
||||||
|
except Exception:
|
||||||
|
_db_available = False
|
||||||
|
return _db_available
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentRegistry:
|
||||||
|
"""Agent注册表测试"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_agent(self):
|
||||||
|
"""测试Agent注册"""
|
||||||
|
if not check_db():
|
||||||
|
pytest.skip("数据库不可用,跳过此测试")
|
||||||
|
|
||||||
|
registry = AgentRegistry()
|
||||||
|
agent_name = f"test_agent_{uuid.uuid4().hex[:8]}"
|
||||||
|
capability = AgentCapability(
|
||||||
|
agent_name=agent_name,
|
||||||
|
agent_type="citation_detector",
|
||||||
|
version="1.0.0",
|
||||||
|
supported_tasks=["citation_detect", "citation_detect_single"],
|
||||||
|
max_concurrency=3,
|
||||||
|
description="测试Agent",
|
||||||
|
)
|
||||||
|
|
||||||
|
agent_id = await registry.register(capability, endpoint=f"agent:{agent_name}")
|
||||||
|
|
||||||
|
# 验证注册成功
|
||||||
|
assert agent_id is not None
|
||||||
|
assert len(agent_id) > 0
|
||||||
|
|
||||||
|
# 清理
|
||||||
|
await registry.unregister(agent_name)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_registered_agent(self):
|
||||||
|
"""测试获取已注册的Agent"""
|
||||||
|
if not check_db():
|
||||||
|
pytest.skip("数据库不可用,跳过此测试")
|
||||||
|
|
||||||
|
registry = AgentRegistry()
|
||||||
|
agent_name = f"test_agent_{uuid.uuid4().hex[:8]}"
|
||||||
|
capability = AgentCapability(
|
||||||
|
agent_name=agent_name,
|
||||||
|
agent_type="citation_detector",
|
||||||
|
version="1.0.0",
|
||||||
|
supported_tasks=["citation_detect"],
|
||||||
|
max_concurrency=3,
|
||||||
|
description="测试Agent",
|
||||||
|
)
|
||||||
|
|
||||||
|
await registry.register(capability, endpoint=f"agent:{agent_name}")
|
||||||
|
|
||||||
|
retrieved = await registry.get_agent(agent_name)
|
||||||
|
|
||||||
|
assert retrieved is not None
|
||||||
|
assert retrieved["name"] == agent_name
|
||||||
|
assert retrieved["agent_type"] == "citation_detector"
|
||||||
|
|
||||||
|
# 清理
|
||||||
|
await registry.unregister(agent_name)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_all_agents(self):
|
||||||
|
"""测试列出所有Agent"""
|
||||||
|
if not check_db():
|
||||||
|
pytest.skip("数据库不可用,跳过此测试")
|
||||||
|
|
||||||
|
registry = AgentRegistry()
|
||||||
|
agents = await registry.list_agents()
|
||||||
|
|
||||||
|
assert isinstance(agents, list)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unregister_agent(self):
|
||||||
|
"""测试取消注册"""
|
||||||
|
if not check_db():
|
||||||
|
pytest.skip("数据库不可用,跳过此测试")
|
||||||
|
|
||||||
|
registry = AgentRegistry()
|
||||||
|
agent_name = f"test_agent_{uuid.uuid4().hex[:8]}"
|
||||||
|
capability = AgentCapability(
|
||||||
|
agent_name=agent_name,
|
||||||
|
agent_type="citation_detector",
|
||||||
|
version="1.0.0",
|
||||||
|
supported_tasks=["citation_detect"],
|
||||||
|
max_concurrency=3,
|
||||||
|
description="测试Agent",
|
||||||
|
)
|
||||||
|
|
||||||
|
await registry.register(capability, endpoint=f"agent:{agent_name}")
|
||||||
|
await registry.unregister(agent_name)
|
||||||
|
|
||||||
|
retrieved = await registry.get_agent(agent_name)
|
||||||
|
# 注销后状态应为OFFLINE或None
|
||||||
|
assert retrieved is None or retrieved["status"] == AgentStatus.OFFLINE
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_heartbeat(self):
|
||||||
|
"""测试心跳更新"""
|
||||||
|
if not check_db():
|
||||||
|
pytest.skip("数据库不可用,跳过此测试")
|
||||||
|
|
||||||
|
registry = AgentRegistry()
|
||||||
|
agent_name = f"test_agent_{uuid.uuid4().hex[:8]}"
|
||||||
|
capability = AgentCapability(
|
||||||
|
agent_name=agent_name,
|
||||||
|
agent_type="citation_detector",
|
||||||
|
version="1.0.0",
|
||||||
|
supported_tasks=["citation_detect"],
|
||||||
|
max_concurrency=3,
|
||||||
|
description="测试Agent",
|
||||||
|
)
|
||||||
|
|
||||||
|
await registry.register(capability, endpoint=f"agent:{agent_name}")
|
||||||
|
await registry.update_heartbeat(agent_name)
|
||||||
|
|
||||||
|
agent = await registry.get_agent(agent_name)
|
||||||
|
assert agent is not None
|
||||||
|
assert agent["last_heartbeat"] is not None
|
||||||
|
|
||||||
|
# 清理
|
||||||
|
await registry.unregister(agent_name)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_available_agent(self):
|
||||||
|
"""测试根据任务类型获取可用Agent"""
|
||||||
|
if not check_db():
|
||||||
|
pytest.skip("数据库不可用,跳过此测试")
|
||||||
|
|
||||||
|
registry = AgentRegistry()
|
||||||
|
agent_name = f"test_agent_{uuid.uuid4().hex[:8]}"
|
||||||
|
capability = AgentCapability(
|
||||||
|
agent_name=agent_name,
|
||||||
|
agent_type="citation_detector",
|
||||||
|
version="1.0.0",
|
||||||
|
supported_tasks=["citation_detect", "citation_detect_single"],
|
||||||
|
max_concurrency=3,
|
||||||
|
description="测试Agent",
|
||||||
|
)
|
||||||
|
|
||||||
|
await registry.register(capability, endpoint=f"agent:{agent_name}")
|
||||||
|
|
||||||
|
available = await registry.get_available_agent("citation_detect")
|
||||||
|
# 可能返回None因为Agent状态不是ONLINE
|
||||||
|
# 这里只验证方法能正常执行
|
||||||
|
assert available is None or isinstance(available, str)
|
||||||
|
|
||||||
|
# 清理
|
||||||
|
await registry.unregister(agent_name)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_agent_reregistration(self):
|
||||||
|
"""测试Agent重复注册(应该更新而非报错)"""
|
||||||
|
if not check_db():
|
||||||
|
pytest.skip("数据库不可用,跳过此测试")
|
||||||
|
|
||||||
|
registry = AgentRegistry()
|
||||||
|
agent_name = f"test_agent_{uuid.uuid4().hex[:8]}"
|
||||||
|
capability1 = AgentCapability(
|
||||||
|
agent_name=agent_name,
|
||||||
|
agent_type="citation_detector",
|
||||||
|
version="1.0.0",
|
||||||
|
supported_tasks=["citation_detect"],
|
||||||
|
max_concurrency=3,
|
||||||
|
description="测试AgentV1",
|
||||||
|
)
|
||||||
|
capability2 = AgentCapability(
|
||||||
|
agent_name=agent_name,
|
||||||
|
agent_type="citation_detector",
|
||||||
|
version="2.0.0",
|
||||||
|
supported_tasks=["citation_detect", "new_task"],
|
||||||
|
max_concurrency=5,
|
||||||
|
description="测试AgentV2",
|
||||||
|
)
|
||||||
|
|
||||||
|
await registry.register(capability1, endpoint=f"agent:{agent_name}")
|
||||||
|
|
||||||
|
# 验证第一次注册成功
|
||||||
|
agent_data = await registry.get_agent(agent_name)
|
||||||
|
assert agent_data is not None
|
||||||
|
|
||||||
|
# 重新注册同名Agent
|
||||||
|
agent_id2 = await registry.register(capability2, endpoint=f"agent:{agent_name}")
|
||||||
|
|
||||||
|
# 应该成功且不报错
|
||||||
|
assert agent_id2 is not None
|
||||||
|
|
||||||
|
# 清理
|
||||||
|
await registry.unregister(agent_name)
|
||||||
|
|
@ -0,0 +1,224 @@
|
||||||
|
"""Agent集成测试"""
|
||||||
|
import pytest
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from app.agent_framework.agents.citation_detector import CitationDetectorAgent
|
||||||
|
from app.agent_framework.agents.content_generator_agent import ContentGeneratorAgent
|
||||||
|
from app.agent_framework.protocol import (
|
||||||
|
TaskMessage,
|
||||||
|
TaskStatus,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCitationDetectorAgent:
|
||||||
|
"""引用检测Agent测试"""
|
||||||
|
|
||||||
|
def test_agent_initialization(self):
|
||||||
|
"""测试Agent初始化"""
|
||||||
|
agent = CitationDetectorAgent()
|
||||||
|
|
||||||
|
assert agent.name == "citation_detector"
|
||||||
|
assert agent.agent_type.value == "citation_detector"
|
||||||
|
assert agent.version == "1.0.0"
|
||||||
|
|
||||||
|
def test_get_capabilities(self):
|
||||||
|
"""测试获取Agent能力"""
|
||||||
|
agent = CitationDetectorAgent()
|
||||||
|
capability = agent.get_capabilities()
|
||||||
|
|
||||||
|
assert capability.agent_name == "citation_detector"
|
||||||
|
assert "citation_detect" in capability.supported_tasks
|
||||||
|
assert "citation_detect_single" in capability.supported_tasks
|
||||||
|
assert capability.max_concurrency == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_with_invalid_task_type(self):
|
||||||
|
"""测试执行不支持的任务类型"""
|
||||||
|
agent = CitationDetectorAgent()
|
||||||
|
task = TaskMessage(
|
||||||
|
task_id=str(uuid.uuid4()),
|
||||||
|
agent_name="citation_detector",
|
||||||
|
task_type="invalid_task_type",
|
||||||
|
priority=5,
|
||||||
|
input_data={},
|
||||||
|
callback_url=None,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await agent.execute(task)
|
||||||
|
|
||||||
|
assert result.status == TaskStatus.FAILED
|
||||||
|
assert result.error_message is not None
|
||||||
|
assert "Unsupported task type" in result.error_message
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_single_detect_missing_params(self):
|
||||||
|
"""测试单平台检测缺少必需参数"""
|
||||||
|
agent = CitationDetectorAgent()
|
||||||
|
task = TaskMessage(
|
||||||
|
task_id=str(uuid.uuid4()),
|
||||||
|
agent_name="citation_detector",
|
||||||
|
task_type="citation_detect_single",
|
||||||
|
priority=5,
|
||||||
|
input_data={}, # 缺少keyword, platform, target_brand
|
||||||
|
callback_url=None,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await agent.execute(task)
|
||||||
|
|
||||||
|
assert result.status == TaskStatus.FAILED
|
||||||
|
assert result.error_message is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_full_detect_missing_query_id(self):
|
||||||
|
"""测试完整检测缺少query_id"""
|
||||||
|
agent = CitationDetectorAgent()
|
||||||
|
task = TaskMessage(
|
||||||
|
task_id=str(uuid.uuid4()),
|
||||||
|
agent_name="citation_detector",
|
||||||
|
task_type="citation_detect",
|
||||||
|
priority=5,
|
||||||
|
input_data={}, # 缺少query_id
|
||||||
|
callback_url=None,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await agent.execute(task)
|
||||||
|
|
||||||
|
assert result.status == TaskStatus.FAILED
|
||||||
|
assert "query_id" in result.error_message or "must contain" in result.error_message
|
||||||
|
|
||||||
|
def test_compatibility_methods_exist(self):
|
||||||
|
"""测试向后兼容方法存在"""
|
||||||
|
agent = CitationDetectorAgent()
|
||||||
|
|
||||||
|
assert hasattr(agent, 'execute_query_compat')
|
||||||
|
assert hasattr(agent, 'execute_single_platform_compat')
|
||||||
|
assert callable(agent.execute_query_compat)
|
||||||
|
assert callable(agent.execute_single_platform_compat)
|
||||||
|
|
||||||
|
|
||||||
|
class TestContentGeneratorAgent:
|
||||||
|
"""内容生成Agent测试"""
|
||||||
|
|
||||||
|
def test_agent_initialization(self):
|
||||||
|
"""测试Agent初始化"""
|
||||||
|
agent = ContentGeneratorAgent()
|
||||||
|
|
||||||
|
assert agent.name == "content_generator"
|
||||||
|
assert agent.agent_type.value == "content_generator"
|
||||||
|
assert agent.version == "1.0.0"
|
||||||
|
|
||||||
|
def test_get_capabilities(self):
|
||||||
|
"""测试获取Agent能力"""
|
||||||
|
agent = ContentGeneratorAgent()
|
||||||
|
capability = agent.get_capabilities()
|
||||||
|
|
||||||
|
assert capability.agent_name == "content_generator"
|
||||||
|
assert "generate_topics" in capability.supported_tasks
|
||||||
|
assert "generate_article" in capability.supported_tasks
|
||||||
|
assert capability.max_concurrency == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_with_invalid_task_type(self):
|
||||||
|
"""测试执行不支持的任务类型"""
|
||||||
|
agent = ContentGeneratorAgent()
|
||||||
|
task = TaskMessage(
|
||||||
|
task_id=str(uuid.uuid4()),
|
||||||
|
agent_name="content_generator",
|
||||||
|
task_type="invalid_task_type",
|
||||||
|
priority=5,
|
||||||
|
input_data={},
|
||||||
|
callback_url=None,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await agent.execute(task)
|
||||||
|
|
||||||
|
assert result.status == TaskStatus.FAILED
|
||||||
|
assert "Unsupported task type" in result.error_message
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_topics_missing_keyword(self):
|
||||||
|
"""测试生成选题缺少关键词"""
|
||||||
|
agent = ContentGeneratorAgent()
|
||||||
|
task = TaskMessage(
|
||||||
|
task_id=str(uuid.uuid4()),
|
||||||
|
agent_name="content_generator",
|
||||||
|
task_type="generate_topics",
|
||||||
|
priority=5,
|
||||||
|
input_data={}, # 缺少target_keyword
|
||||||
|
callback_url=None,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 由于没有真实LLM调用和知识库,这个测试会调用LLM
|
||||||
|
# 我们主要验证方法能正常执行
|
||||||
|
result = await agent.execute(task)
|
||||||
|
|
||||||
|
# 结果可能是FAILED(因为缺少必要参数或LLM调用失败)
|
||||||
|
assert result is not None
|
||||||
|
assert result.status in [TaskStatus.COMPLETED, TaskStatus.FAILED]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_article_missing_keyword(self):
|
||||||
|
"""测试生成文章缺少关键词"""
|
||||||
|
agent = ContentGeneratorAgent()
|
||||||
|
task = TaskMessage(
|
||||||
|
task_id=str(uuid.uuid4()),
|
||||||
|
agent_name="content_generator",
|
||||||
|
task_type="generate_article",
|
||||||
|
priority=5,
|
||||||
|
input_data={}, # 缺少target_keyword
|
||||||
|
callback_url=None,
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await agent.execute(task)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
# 缺少必要参数可能导致失败
|
||||||
|
assert result.status in [TaskStatus.COMPLETED, TaskStatus.FAILED]
|
||||||
|
|
||||||
|
def test_extract_json_method(self):
|
||||||
|
"""测试JSON提取方法"""
|
||||||
|
agent = ContentGeneratorAgent()
|
||||||
|
|
||||||
|
# 测试普通JSON
|
||||||
|
json_text = '{"title": "测试标题", "reason": "测试原因"}'
|
||||||
|
extracted = agent._extract_json(json_text)
|
||||||
|
assert "title" in extracted
|
||||||
|
|
||||||
|
# 测试被markdown包裹的JSON
|
||||||
|
md_text = '```json\n{"title": "测试"}\n```'
|
||||||
|
extracted = agent._extract_json(md_text)
|
||||||
|
assert "title" in extracted
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentProtocol:
|
||||||
|
"""Agent协议测试"""
|
||||||
|
|
||||||
|
def test_agent_type_enum_values(self):
|
||||||
|
"""测试AgentType枚举值"""
|
||||||
|
from app.agent_framework.protocol import AgentType
|
||||||
|
|
||||||
|
assert AgentType.CITATION_DETECTOR.value == "citation_detector"
|
||||||
|
assert AgentType.CONTENT_GENERATOR.value == "content_generator"
|
||||||
|
|
||||||
|
def test_task_status_enum_values(self):
|
||||||
|
"""测试TaskStatus枚举值"""
|
||||||
|
assert TaskStatus.PENDING.value == "pending"
|
||||||
|
assert TaskStatus.RUNNING.value == "running"
|
||||||
|
assert TaskStatus.COMPLETED.value == "completed"
|
||||||
|
assert TaskStatus.FAILED.value == "failed"
|
||||||
|
assert TaskStatus.CANCELLED.value == "cancelled"
|
||||||
|
|
||||||
|
def test_agent_status_enum_values(self):
|
||||||
|
"""测试AgentStatus枚举值"""
|
||||||
|
from app.agent_framework.protocol import AgentStatus
|
||||||
|
|
||||||
|
assert AgentStatus.ONLINE.value == "online"
|
||||||
|
assert AgentStatus.OFFLINE.value == "offline"
|
||||||
|
assert AgentStatus.BUSY.value == "busy"
|
||||||
|
|
@ -0,0 +1,329 @@
|
||||||
|
"""告警设置API测试"""
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from httpx import AsyncClient, ASGITransport
|
||||||
|
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession, create_async_engine
|
||||||
|
from sqlalchemy.pool import StaticPool
|
||||||
|
|
||||||
|
from app.database import Base
|
||||||
|
from app.main import app
|
||||||
|
from app.models.user import User
|
||||||
|
from app.models.brand import Brand
|
||||||
|
from app.models.alert_setting import AlertSetting
|
||||||
|
from app.api.deps import get_current_user, get_db
|
||||||
|
from app.services.auth import hash_password, create_access_token
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== Fixtures ====================
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def async_engine():
|
||||||
|
"""创建测试用SQLite异步引擎"""
|
||||||
|
engine = create_async_engine(
|
||||||
|
"sqlite+aiosqlite:///:memory:",
|
||||||
|
connect_args={"check_same_thread": False},
|
||||||
|
poolclass=StaticPool,
|
||||||
|
)
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
yield engine
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def async_session(async_engine):
|
||||||
|
"""创建测试用异步数据库会话"""
|
||||||
|
async_session_maker = async_sessionmaker(
|
||||||
|
async_engine,
|
||||||
|
class_=AsyncSession,
|
||||||
|
expire_on_commit=False,
|
||||||
|
autoflush=False,
|
||||||
|
autocommit=False,
|
||||||
|
)
|
||||||
|
async with async_session_maker() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def test_user(async_session):
|
||||||
|
"""创建测试用户"""
|
||||||
|
user = User(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
email="test@example.com",
|
||||||
|
password_hash=hash_password("Test@123456"),
|
||||||
|
name="Test User",
|
||||||
|
plan="free",
|
||||||
|
max_queries=5,
|
||||||
|
is_active=True,
|
||||||
|
email_verified=True,
|
||||||
|
)
|
||||||
|
async_session.add(user)
|
||||||
|
await async_session.commit()
|
||||||
|
await async_session.refresh(user)
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def test_brand(async_session, test_user):
|
||||||
|
"""创建测试品牌"""
|
||||||
|
brand = Brand(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
user_id=test_user.id,
|
||||||
|
name="Test Brand",
|
||||||
|
aliases=["TestBrand", "TB"],
|
||||||
|
website="https://testbrand.com",
|
||||||
|
industry="technology",
|
||||||
|
platforms=["wenxin", "kimi"],
|
||||||
|
frequency="weekly",
|
||||||
|
status="active",
|
||||||
|
)
|
||||||
|
async_session.add(brand)
|
||||||
|
await async_session.commit()
|
||||||
|
await async_session.refresh(brand)
|
||||||
|
return brand
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def test_alert_setting(async_session, test_user, test_brand):
|
||||||
|
"""创建测试告警设置"""
|
||||||
|
setting = AlertSetting(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
brand_id=test_brand.id,
|
||||||
|
user_id=test_user.id,
|
||||||
|
alert_type="score_drop",
|
||||||
|
enabled=True,
|
||||||
|
threshold=5.0,
|
||||||
|
)
|
||||||
|
async_session.add(setting)
|
||||||
|
await async_session.commit()
|
||||||
|
await async_session.refresh(setting)
|
||||||
|
return setting
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def async_client(async_session, test_user):
|
||||||
|
"""创建异步HTTP客户端用于API测试"""
|
||||||
|
|
||||||
|
async def override_get_db():
|
||||||
|
yield async_session
|
||||||
|
|
||||||
|
async def override_get_current_user():
|
||||||
|
return test_user
|
||||||
|
|
||||||
|
app.dependency_overrides[get_db] = override_get_db
|
||||||
|
app.dependency_overrides[get_current_user] = override_get_current_user
|
||||||
|
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||||
|
yield client
|
||||||
|
|
||||||
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def auth_headers(test_user):
|
||||||
|
"""创建认证请求头"""
|
||||||
|
token = create_access_token(data={"sub": str(test_user.id)})
|
||||||
|
return {"Authorization": f"Bearer {token}"}
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 测试类 ====================
|
||||||
|
|
||||||
|
class TestAlertSettingsAPI:
|
||||||
|
"""告警设置API测试"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_alert_settings_success(self, async_client, test_alert_setting):
|
||||||
|
"""测试获取告警设置 - 成功返回设置列表"""
|
||||||
|
response = await async_client.get("/api/v1/alerts/settings")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "items" in data
|
||||||
|
assert "total" in data
|
||||||
|
assert data["total"] >= 1
|
||||||
|
assert len(data["items"]) >= 1
|
||||||
|
|
||||||
|
first_item = data["items"][0]
|
||||||
|
assert "id" in first_item
|
||||||
|
assert "brand_id" in first_item
|
||||||
|
assert "alert_type" in first_item
|
||||||
|
assert "enabled" in first_item
|
||||||
|
assert "threshold" in first_item
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_alert_settings_by_brand(self, async_client, test_brand, test_alert_setting):
|
||||||
|
"""测试按品牌筛选告警设置"""
|
||||||
|
response = await async_client.get(
|
||||||
|
f"/api/v1/alerts/settings?brand_id={test_brand.id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["total"] >= 1
|
||||||
|
for item in data["items"]:
|
||||||
|
assert item["brand_id"] == str(test_brand.id)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_alert_settings_success(self, async_client, test_brand):
|
||||||
|
"""测试更新告警设置 - 成功更新并返回新设置"""
|
||||||
|
update_data = {
|
||||||
|
"settings": [
|
||||||
|
{
|
||||||
|
"brand_id": str(test_brand.id),
|
||||||
|
"alert_type": "score_drop",
|
||||||
|
"enabled": True,
|
||||||
|
"threshold": 20.0,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
response = await async_client.put(
|
||||||
|
"/api/v1/alerts/settings", json=update_data
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "items" in data
|
||||||
|
assert len(data["items"]) >= 1
|
||||||
|
|
||||||
|
updated_setting = data["items"][0]
|
||||||
|
assert updated_setting["alert_type"] == "score_drop"
|
||||||
|
assert updated_setting["threshold"] == 20.0
|
||||||
|
assert updated_setting["enabled"] is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_alert_setting_success(self, async_client, test_brand):
|
||||||
|
"""测试创建告警设置 - 为新品牌创建默认设置"""
|
||||||
|
create_data = {
|
||||||
|
"brand_id": str(test_brand.id),
|
||||||
|
"alert_type": "negative_sentiment",
|
||||||
|
"enabled": True,
|
||||||
|
"threshold": 1.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
response = await async_client.post(
|
||||||
|
"/api/v1/alerts/settings", json=create_data
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 201
|
||||||
|
data = response.json()
|
||||||
|
assert data["alert_type"] == "negative_sentiment"
|
||||||
|
assert data["threshold"] == 1.0
|
||||||
|
assert data["enabled"] is True
|
||||||
|
assert "id" in data
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_alert_setting_success(self, async_client, test_alert_setting):
|
||||||
|
"""测试删除告警设置 - 成功删除"""
|
||||||
|
response = await async_client.delete(
|
||||||
|
f"/api/v1/alerts/settings/{test_alert_setting.id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 204
|
||||||
|
|
||||||
|
get_response = await async_client.get("/api/v1/alerts/settings")
|
||||||
|
data = get_response.json()
|
||||||
|
deleted_ids = [item["id"] for item in data["items"]]
|
||||||
|
assert str(test_alert_setting.id) not in deleted_ids
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_alert_setting_not_found(self, async_client):
|
||||||
|
"""测试删除不存在的告警设置"""
|
||||||
|
non_existent_id = uuid.uuid4()
|
||||||
|
response = await async_client.delete(
|
||||||
|
f"/api/v1/alerts/settings/{non_existent_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
class TestAlertSettingsValidation:
|
||||||
|
"""告警设置验证测试"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_brand_not_found_returns_404(self, async_client):
|
||||||
|
"""测试品牌不存在时返回404"""
|
||||||
|
non_existent_brand_id = uuid.uuid4()
|
||||||
|
create_data = {
|
||||||
|
"brand_id": str(non_existent_brand_id),
|
||||||
|
"alert_type": "score_drop",
|
||||||
|
"enabled": True,
|
||||||
|
"threshold": 5.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
response = await async_client.post(
|
||||||
|
"/api/v1/alerts/settings", json=create_data
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unauthorized_returns_401(self, async_session):
|
||||||
|
"""测试未认证时返回401"""
|
||||||
|
async def override_get_db():
|
||||||
|
yield async_session
|
||||||
|
|
||||||
|
app.dependency_overrides[get_db] = override_get_db
|
||||||
|
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||||
|
headers = {"Authorization": "Bearer invalid_token"}
|
||||||
|
|
||||||
|
response = await client.get(
|
||||||
|
"/api/v1/alerts/settings",
|
||||||
|
headers=headers
|
||||||
|
)
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invalid_alert_type_returns_422(self, async_client, test_brand):
|
||||||
|
"""测试无效的告警类型返回422"""
|
||||||
|
create_data = {
|
||||||
|
"brand_id": str(test_brand.id),
|
||||||
|
"alert_type": "invalid_type",
|
||||||
|
"enabled": True,
|
||||||
|
"threshold": 5.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
response = await async_client.post(
|
||||||
|
"/api/v1/alerts/settings", json=create_data
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_negative_threshold_returns_422(self, async_client, test_brand):
|
||||||
|
"""测试负数阈值返回422"""
|
||||||
|
create_data = {
|
||||||
|
"brand_id": str(test_brand.id),
|
||||||
|
"alert_type": "score_drop",
|
||||||
|
"enabled": True,
|
||||||
|
"threshold": -10.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
response = await async_client.post(
|
||||||
|
"/api/v1/alerts/settings", json=create_data
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_threshold_over_100_returns_422(self, async_client, test_brand):
|
||||||
|
"""测试阈值超过100返回422"""
|
||||||
|
create_data = {
|
||||||
|
"brand_id": str(test_brand.id),
|
||||||
|
"alert_type": "score_drop",
|
||||||
|
"enabled": True,
|
||||||
|
"threshold": 150.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
response = await async_client.post(
|
||||||
|
"/api/v1/alerts/settings", json=create_data
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
@ -0,0 +1,451 @@
|
||||||
|
"""认证API测试"""
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from httpx import AsyncClient, ASGITransport
|
||||||
|
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession, create_async_engine
|
||||||
|
from sqlalchemy.pool import StaticPool
|
||||||
|
|
||||||
|
from app.database import Base
|
||||||
|
from app.main import app
|
||||||
|
from app.models.user import User
|
||||||
|
from app.api.deps import get_current_user, get_db
|
||||||
|
from app.services.auth import hash_password, create_access_token
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== Fixtures ====================
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def async_engine():
|
||||||
|
"""创建测试用SQLite异步引擎"""
|
||||||
|
engine = create_async_engine(
|
||||||
|
"sqlite+aiosqlite:///:memory:",
|
||||||
|
connect_args={"check_same_thread": False},
|
||||||
|
poolclass=StaticPool,
|
||||||
|
)
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
yield engine
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def async_session(async_engine):
|
||||||
|
"""创建测试用异步数据库会话"""
|
||||||
|
async_session_maker = async_sessionmaker(
|
||||||
|
async_engine,
|
||||||
|
class_=AsyncSession,
|
||||||
|
expire_on_commit=False,
|
||||||
|
autoflush=False,
|
||||||
|
autocommit=False,
|
||||||
|
)
|
||||||
|
async with async_session_maker() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def test_user(async_session):
|
||||||
|
"""创建测试用户"""
|
||||||
|
user = User(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
email="test@example.com",
|
||||||
|
password_hash=hash_password("Test@123456"),
|
||||||
|
name="Test User",
|
||||||
|
plan="free",
|
||||||
|
max_queries=5,
|
||||||
|
is_active=True,
|
||||||
|
email_verified=True,
|
||||||
|
)
|
||||||
|
async_session.add(user)
|
||||||
|
await async_session.commit()
|
||||||
|
await async_session.refresh(user)
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def async_client(async_session, test_user):
|
||||||
|
"""创建异步HTTP客户端用于API测试"""
|
||||||
|
|
||||||
|
async def override_get_db():
|
||||||
|
yield async_session
|
||||||
|
|
||||||
|
async def override_get_current_user():
|
||||||
|
return test_user
|
||||||
|
|
||||||
|
app.dependency_overrides[get_db] = override_get_db
|
||||||
|
app.dependency_overrides[get_current_user] = override_get_current_user
|
||||||
|
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||||
|
yield client
|
||||||
|
|
||||||
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def auth_headers(test_user):
|
||||||
|
"""创建认证请求头"""
|
||||||
|
token = create_access_token(data={"sub": str(test_user.id)})
|
||||||
|
return {"Authorization": f"Bearer {token}"}
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 测试类 ====================
|
||||||
|
|
||||||
|
class TestAuthAPI:
|
||||||
|
"""认证API测试"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_success(self, async_session):
|
||||||
|
"""测试用户注册成功"""
|
||||||
|
async def override_get_db():
|
||||||
|
yield async_session
|
||||||
|
|
||||||
|
app.dependency_overrides[get_db] = override_get_db
|
||||||
|
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
json={
|
||||||
|
"email": f"test_{uuid.uuid4()}@example.com",
|
||||||
|
"name": "Test User",
|
||||||
|
"password": "Test@123456"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 201
|
||||||
|
data = response.json()
|
||||||
|
assert "id" in data
|
||||||
|
assert data["email"] is not None
|
||||||
|
assert data["name"] == "Test User"
|
||||||
|
|
||||||
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_duplicate_email(self, async_session):
|
||||||
|
"""测试重复邮箱注册失败"""
|
||||||
|
email = f"test_{uuid.uuid4()}@example.com"
|
||||||
|
|
||||||
|
async def override_get_db():
|
||||||
|
yield async_session
|
||||||
|
|
||||||
|
app.dependency_overrides[get_db] = override_get_db
|
||||||
|
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||||
|
# 第一次注册
|
||||||
|
response1 = await client.post(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
json={
|
||||||
|
"email": email,
|
||||||
|
"name": "Test User 1",
|
||||||
|
"password": "Test@123456"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert response1.status_code == 201
|
||||||
|
|
||||||
|
# 第二次使用相同邮箱注册
|
||||||
|
response2 = await client.post(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
json={
|
||||||
|
"email": email,
|
||||||
|
"name": "Test User 2",
|
||||||
|
"password": "Test@123456"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert response2.status_code == 400
|
||||||
|
|
||||||
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_login_success(self, async_session):
|
||||||
|
"""测试用户登录成功"""
|
||||||
|
email = f"test_{uuid.uuid4()}@example.com"
|
||||||
|
password = "Test@123456"
|
||||||
|
|
||||||
|
# 先创建用户
|
||||||
|
user = User(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
email=email,
|
||||||
|
password_hash=hash_password(password),
|
||||||
|
name="Test User",
|
||||||
|
plan="free",
|
||||||
|
max_queries=5,
|
||||||
|
is_active=True,
|
||||||
|
email_verified=True,
|
||||||
|
)
|
||||||
|
async_session.add(user)
|
||||||
|
await async_session.commit()
|
||||||
|
|
||||||
|
async def override_get_db():
|
||||||
|
yield async_session
|
||||||
|
|
||||||
|
app.dependency_overrides[get_db] = override_get_db
|
||||||
|
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
json={
|
||||||
|
"email": email,
|
||||||
|
"password": password
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "access_token" in data
|
||||||
|
assert "refresh_token" in data
|
||||||
|
assert data["token_type"] == "bearer"
|
||||||
|
assert "user" in data
|
||||||
|
|
||||||
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_login_wrong_password(self, async_session):
|
||||||
|
"""测试错误密码登录"""
|
||||||
|
email = f"test_{uuid.uuid4()}@example.com"
|
||||||
|
|
||||||
|
# 创建用户
|
||||||
|
user = User(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
email=email,
|
||||||
|
password_hash=hash_password("Correct@123456"),
|
||||||
|
name="Test User",
|
||||||
|
plan="free",
|
||||||
|
max_queries=5,
|
||||||
|
is_active=True,
|
||||||
|
email_verified=True,
|
||||||
|
)
|
||||||
|
async_session.add(user)
|
||||||
|
await async_session.commit()
|
||||||
|
|
||||||
|
async def override_get_db():
|
||||||
|
yield async_session
|
||||||
|
|
||||||
|
app.dependency_overrides[get_db] = override_get_db
|
||||||
|
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
json={
|
||||||
|
"email": email,
|
||||||
|
"password": "WrongPassword"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_login_nonexistent_user(self, async_session):
|
||||||
|
"""测试不存在的用户登录"""
|
||||||
|
async def override_get_db():
|
||||||
|
yield async_session
|
||||||
|
|
||||||
|
app.dependency_overrides[get_db] = override_get_db
|
||||||
|
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
json={
|
||||||
|
"email": "nonexistent@example.com",
|
||||||
|
"password": "Test@123456"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 统一错误消息,不区分用户不存在和密码错误
|
||||||
|
# 可能返回401(认证失败)或429(速率限制),都是有效响应
|
||||||
|
assert response.status_code in [401, 429]
|
||||||
|
|
||||||
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_current_user(self, async_client, test_user):
|
||||||
|
"""测试获取当前用户信息"""
|
||||||
|
response = await async_client.get("/api/v1/auth/me")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["id"] == str(test_user.id)
|
||||||
|
assert data["email"] == test_user.email
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_change_password_success(self, async_client, async_session, test_user):
|
||||||
|
"""测试修改密码成功"""
|
||||||
|
async def override_get_db():
|
||||||
|
yield async_session
|
||||||
|
|
||||||
|
app.dependency_overrides[get_db] = override_get_db
|
||||||
|
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||||
|
# 使用token获取用户
|
||||||
|
token = create_access_token(data={"sub": str(test_user.id)})
|
||||||
|
headers = {"Authorization": f"Bearer {token}"}
|
||||||
|
|
||||||
|
response = await client.put(
|
||||||
|
"/api/v1/auth/change-password",
|
||||||
|
headers=headers,
|
||||||
|
json={
|
||||||
|
"old_password": "Test@123456",
|
||||||
|
"new_password": "NewPass@123456"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["message"] == "密码修改成功"
|
||||||
|
|
||||||
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_change_password_wrong_old(self, async_client, test_user):
|
||||||
|
"""测试旧密码错误"""
|
||||||
|
token = create_access_token(data={"sub": str(test_user.id)})
|
||||||
|
headers = {"Authorization": f"Bearer {token}"}
|
||||||
|
|
||||||
|
response = await async_client.put(
|
||||||
|
"/api/v1/auth/change-password",
|
||||||
|
headers=headers,
|
||||||
|
json={
|
||||||
|
"old_password": "WrongOldPass",
|
||||||
|
"new_password": "NewPass@123456"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert "旧密码错误" in response.json()["detail"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_profile(self, async_client, test_user):
|
||||||
|
"""测试更新用户资料"""
|
||||||
|
token = create_access_token(data={"sub": str(test_user.id)})
|
||||||
|
headers = {"Authorization": f"Bearer {token}"}
|
||||||
|
|
||||||
|
response = await async_client.put(
|
||||||
|
"/api/v1/auth/profile",
|
||||||
|
headers=headers,
|
||||||
|
json={
|
||||||
|
"name": "Updated Name",
|
||||||
|
"avatar_url": "https://example.com/avatar.png"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["name"] == "Updated Name"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_refresh_token(self, async_session, test_user):
|
||||||
|
"""测试刷新令牌"""
|
||||||
|
from app.services.auth import create_refresh_token
|
||||||
|
|
||||||
|
async def override_get_db():
|
||||||
|
yield async_session
|
||||||
|
|
||||||
|
app.dependency_overrides[get_db] = override_get_db
|
||||||
|
|
||||||
|
# 创建refresh token
|
||||||
|
refresh_token = create_refresh_token(data={"sub": str(test_user.id)})
|
||||||
|
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/auth/refresh",
|
||||||
|
json={
|
||||||
|
"refresh_token": refresh_token
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "access_token" in data
|
||||||
|
assert "refresh_token" in data
|
||||||
|
|
||||||
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_forgot_password(self, async_session):
|
||||||
|
"""测试忘记密码请求"""
|
||||||
|
async def override_get_db():
|
||||||
|
yield async_session
|
||||||
|
|
||||||
|
app.dependency_overrides[get_db] = override_get_db
|
||||||
|
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/auth/forgot-password",
|
||||||
|
json={
|
||||||
|
"email": "test@example.com"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 无论邮箱是否存在都返回成功,防止用户枚举
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert "message" in response.json()
|
||||||
|
|
||||||
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resend_verification(self, async_session):
|
||||||
|
"""测试重新发送验证码"""
|
||||||
|
async def override_get_db():
|
||||||
|
yield async_session
|
||||||
|
|
||||||
|
app.dependency_overrides[get_db] = override_get_db
|
||||||
|
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||||
|
response = await client.post(
|
||||||
|
"/api/v1/auth/resend-verification",
|
||||||
|
json={
|
||||||
|
"email": "test@example.com"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert "message" in response.json()
|
||||||
|
|
||||||
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unauthorized_access(self, async_session):
|
||||||
|
"""测试未授权访问受保护端点"""
|
||||||
|
async def override_get_db():
|
||||||
|
yield async_session
|
||||||
|
|
||||||
|
app.dependency_overrides[get_db] = override_get_db
|
||||||
|
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||||
|
# 不带token访问受保护端点
|
||||||
|
response = await client.get("/api/v1/auth/me")
|
||||||
|
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invalid_token(self, async_session):
|
||||||
|
"""测试无效令牌访问"""
|
||||||
|
async def override_get_db():
|
||||||
|
yield async_session
|
||||||
|
|
||||||
|
app.dependency_overrides[get_db] = override_get_db
|
||||||
|
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||||
|
response = await client.get(
|
||||||
|
"/api/v1/auth/me",
|
||||||
|
headers={"Authorization": "Bearer invalid_token"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
app.dependency_overrides.clear()
|
||||||
|
|
@ -0,0 +1,321 @@
|
||||||
|
"""品牌API测试"""
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from httpx import AsyncClient, ASGITransport
|
||||||
|
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession, create_async_engine
|
||||||
|
from sqlalchemy.pool import StaticPool
|
||||||
|
|
||||||
|
from app.database import Base
|
||||||
|
from app.main import app
|
||||||
|
from app.models.user import User
|
||||||
|
from app.models.brand import Brand
|
||||||
|
from app.api.deps import get_current_user, get_db
|
||||||
|
from app.services.auth import hash_password, create_access_token
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== Fixtures ====================
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def async_engine():
|
||||||
|
"""创建测试用SQLite异步引擎"""
|
||||||
|
engine = create_async_engine(
|
||||||
|
"sqlite+aiosqlite:///:memory:",
|
||||||
|
connect_args={"check_same_thread": False},
|
||||||
|
poolclass=StaticPool,
|
||||||
|
)
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
yield engine
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def async_session(async_engine):
|
||||||
|
"""创建测试用异步数据库会话"""
|
||||||
|
async_session_maker = async_sessionmaker(
|
||||||
|
async_engine,
|
||||||
|
class_=AsyncSession,
|
||||||
|
expire_on_commit=False,
|
||||||
|
autoflush=False,
|
||||||
|
autocommit=False,
|
||||||
|
)
|
||||||
|
async with async_session_maker() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def test_user(async_session):
|
||||||
|
"""创建测试用户"""
|
||||||
|
user = User(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
email="test@example.com",
|
||||||
|
password_hash=hash_password("Test@123456"),
|
||||||
|
name="Test User",
|
||||||
|
plan="free",
|
||||||
|
max_queries=5,
|
||||||
|
is_active=True,
|
||||||
|
email_verified=True,
|
||||||
|
)
|
||||||
|
async_session.add(user)
|
||||||
|
await async_session.commit()
|
||||||
|
await async_session.refresh(user)
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def test_brand(async_session, test_user):
|
||||||
|
"""创建测试品牌"""
|
||||||
|
brand = Brand(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
user_id=test_user.id,
|
||||||
|
name="Test Brand",
|
||||||
|
aliases=["TestBrand", "TB"],
|
||||||
|
website="https://testbrand.com",
|
||||||
|
industry="technology",
|
||||||
|
platforms=["wenxin", "kimi"],
|
||||||
|
frequency="weekly",
|
||||||
|
status="active",
|
||||||
|
)
|
||||||
|
async_session.add(brand)
|
||||||
|
await async_session.commit()
|
||||||
|
await async_session.refresh(brand)
|
||||||
|
return brand
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def async_client(async_session, test_user):
|
||||||
|
"""创建异步HTTP客户端用于API测试"""
|
||||||
|
|
||||||
|
async def override_get_db():
|
||||||
|
yield async_session
|
||||||
|
|
||||||
|
async def override_get_current_user():
|
||||||
|
return test_user
|
||||||
|
|
||||||
|
app.dependency_overrides[get_db] = override_get_db
|
||||||
|
app.dependency_overrides[get_current_user] = override_get_current_user
|
||||||
|
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||||
|
yield client
|
||||||
|
|
||||||
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def auth_headers(test_user):
|
||||||
|
"""创建认证请求头"""
|
||||||
|
token = create_access_token(data={"sub": str(test_user.id)})
|
||||||
|
return {"Authorization": f"Bearer {token}"}
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 测试类 ====================
|
||||||
|
|
||||||
|
class TestBrandsAPI:
|
||||||
|
"""品牌API测试"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_brands_empty(self, async_client):
|
||||||
|
"""测试获取空品牌列表"""
|
||||||
|
response = await async_client.get("/api/v1/brands/")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["items"] == []
|
||||||
|
assert data["total"] == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_brand(self, async_client, async_session, test_user):
|
||||||
|
"""测试创建品牌"""
|
||||||
|
brand_data = {
|
||||||
|
"name": "华为",
|
||||||
|
"aliases": ["Huawei", "HW"],
|
||||||
|
"website": "https://www.huawei.com",
|
||||||
|
"industry": "technology",
|
||||||
|
"platforms": ["wenxin", "kimi", "doubao"],
|
||||||
|
"frequency": "daily",
|
||||||
|
}
|
||||||
|
response = await async_client.post("/api/v1/brands/", json=brand_data)
|
||||||
|
|
||||||
|
assert response.status_code == 201
|
||||||
|
data = response.json()
|
||||||
|
assert data["name"] == "华为"
|
||||||
|
assert data["aliases"] == ["Huawei", "HW"]
|
||||||
|
assert data["website"] == "https://www.huawei.com"
|
||||||
|
assert data["industry"] == "technology"
|
||||||
|
assert data["platforms"] == ["wenxin", "kimi", "doubao"]
|
||||||
|
assert data["frequency"] == "daily"
|
||||||
|
assert data["status"] == "active"
|
||||||
|
assert data["user_id"] == str(test_user.id)
|
||||||
|
assert "id" in data
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_brand_minimal(self, async_client):
|
||||||
|
"""测试创建品牌(最小数据)"""
|
||||||
|
brand_data = {
|
||||||
|
"name": "minimal_brand",
|
||||||
|
}
|
||||||
|
response = await async_client.post("/api/v1/brands/", json=brand_data)
|
||||||
|
|
||||||
|
assert response.status_code == 201
|
||||||
|
data = response.json()
|
||||||
|
assert data["name"] == "minimal_brand"
|
||||||
|
assert data["aliases"] == []
|
||||||
|
assert data["platforms"] == ["wenxin", "kimi"] # 默认值
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_brands(self, async_client, async_session, test_user):
|
||||||
|
"""测试列出多个品牌"""
|
||||||
|
# 创建多个品牌
|
||||||
|
for i in range(3):
|
||||||
|
brand = Brand(
|
||||||
|
user_id=test_user.id,
|
||||||
|
name=f"Brand {i}",
|
||||||
|
platforms=["wenxin"],
|
||||||
|
)
|
||||||
|
async_session.add(brand)
|
||||||
|
await async_session.commit()
|
||||||
|
|
||||||
|
response = await async_client.get("/api/v1/brands/")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert len(data["items"]) == 3
|
||||||
|
assert data["total"] == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_brand_by_id(self, async_client, test_brand):
|
||||||
|
"""测试通过ID获取品牌"""
|
||||||
|
response = await async_client.get(f"/api/v1/brands/{test_brand.id}/")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["id"] == str(test_brand.id)
|
||||||
|
assert data["name"] == "Test Brand"
|
||||||
|
assert data["aliases"] == ["TestBrand", "TB"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_brand_not_found(self, async_client):
|
||||||
|
"""测试获取不存在的品牌"""
|
||||||
|
non_existent_id = uuid.uuid4()
|
||||||
|
response = await async_client.get(f"/api/v1/brands/{non_existent_id}/")
|
||||||
|
|
||||||
|
assert response.status_code == 404
|
||||||
|
assert "品牌不存在" in response.json()["detail"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_brand(self, async_client, test_brand):
|
||||||
|
"""测试更新品牌"""
|
||||||
|
# 注意:BrandUpdate schema 不允许更新 name 字段
|
||||||
|
update_data = {
|
||||||
|
"aliases": ["Updated", "Alias"],
|
||||||
|
"frequency": "daily",
|
||||||
|
}
|
||||||
|
response = await async_client.put(
|
||||||
|
f"/api/v1/brands/{test_brand.id}/", json=update_data
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["aliases"] == ["Updated", "Alias"]
|
||||||
|
assert data["frequency"] == "daily"
|
||||||
|
assert data["name"] == "Test Brand" # name 保持不变
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_brand_partial(self, async_client, test_brand):
|
||||||
|
"""测试部分更新品牌"""
|
||||||
|
update_data = {
|
||||||
|
"frequency": "monthly",
|
||||||
|
}
|
||||||
|
response = await async_client.put(
|
||||||
|
f"/api/v1/brands/{test_brand.id}/", json=update_data
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
# 只更新frequency,其他字段保持不变
|
||||||
|
assert data["frequency"] == "monthly"
|
||||||
|
assert data["name"] == "Test Brand"
|
||||||
|
assert data["aliases"] == ["TestBrand", "TB"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_brand_not_found(self, async_client):
|
||||||
|
"""测试更新不存在的品牌"""
|
||||||
|
non_existent_id = uuid.uuid4()
|
||||||
|
response = await async_client.put(
|
||||||
|
f"/api/v1/brands/{non_existent_id}/", json={"name": "New Name"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_brand(self, async_client, test_brand):
|
||||||
|
"""测试删除品牌"""
|
||||||
|
response = await async_client.delete(f"/api/v1/brands/{test_brand.id}/")
|
||||||
|
|
||||||
|
assert response.status_code == 204
|
||||||
|
|
||||||
|
# 验证品牌已删除
|
||||||
|
get_response = await async_client.get(f"/api/v1/brands/{test_brand.id}/")
|
||||||
|
assert get_response.status_code == 404
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_brand_not_found(self, async_client):
|
||||||
|
"""测试删除不存在的品牌"""
|
||||||
|
non_existent_id = uuid.uuid4()
|
||||||
|
response = await async_client.delete(f"/api/v1/brands/{non_existent_id}/")
|
||||||
|
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_brand_unauthorized_access(self, async_session):
|
||||||
|
"""测试未授权访问品牌API(无效token)"""
|
||||||
|
# 使用无效的token访问品牌API
|
||||||
|
async def override_get_db():
|
||||||
|
yield async_session
|
||||||
|
|
||||||
|
app.dependency_overrides[get_db] = override_get_db
|
||||||
|
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||||
|
# 使用无效的token
|
||||||
|
headers = {"Authorization": "Bearer invalid_token"}
|
||||||
|
|
||||||
|
# 尝试访问品牌列表
|
||||||
|
response = await client.get(
|
||||||
|
"/api/v1/brands/",
|
||||||
|
headers=headers
|
||||||
|
)
|
||||||
|
# 无效token应该返回401
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
|
||||||
|
class TestBrandsValidation:
|
||||||
|
"""品牌API验证测试"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_brand_name_too_short(self, async_client):
|
||||||
|
"""测试品牌名称过短"""
|
||||||
|
brand_data = {
|
||||||
|
"name": "A", # 最小长度为2
|
||||||
|
}
|
||||||
|
response = await async_client.post("/api/v1/brands/", json=brand_data)
|
||||||
|
|
||||||
|
assert response.status_code == 422 # 验证错误
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_brand_invalid_frequency(self, async_client):
|
||||||
|
"""测试无效的更新频率"""
|
||||||
|
brand_data = {
|
||||||
|
"name": "Valid Brand",
|
||||||
|
"frequency": "invalid_frequency",
|
||||||
|
}
|
||||||
|
response = await async_client.post("/api/v1/brands/", json=brand_data)
|
||||||
|
|
||||||
|
# frequency字段在BrandCreate中没有严格验证,但应该是有效值
|
||||||
|
# 这里主要测试API不会崩溃
|
||||||
|
assert response.status_code in [201, 422]
|
||||||
|
|
@ -0,0 +1,459 @@
|
||||||
|
"""内容API测试"""
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from httpx import AsyncClient, ASGITransport
|
||||||
|
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession, create_async_engine
|
||||||
|
from sqlalchemy.pool import StaticPool
|
||||||
|
|
||||||
|
from app.database import Base
|
||||||
|
from app.main import app
|
||||||
|
from app.models.user import User
|
||||||
|
from app.models.brand import Brand
|
||||||
|
from app.api.deps import get_current_user, get_db
|
||||||
|
from app.services.auth import hash_password, create_access_token
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== Fixtures ====================
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def async_engine():
|
||||||
|
"""创建测试用SQLite异步引擎"""
|
||||||
|
engine = create_async_engine(
|
||||||
|
"sqlite+aiosqlite:///:memory:",
|
||||||
|
connect_args={"check_same_thread": False},
|
||||||
|
poolclass=StaticPool,
|
||||||
|
)
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
yield engine
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def async_session(async_engine):
|
||||||
|
"""创建测试用异步数据库会话"""
|
||||||
|
async_session_maker = async_sessionmaker(
|
||||||
|
async_engine,
|
||||||
|
class_=AsyncSession,
|
||||||
|
expire_on_commit=False,
|
||||||
|
autoflush=False,
|
||||||
|
autocommit=False,
|
||||||
|
)
|
||||||
|
async with async_session_maker() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def test_user(async_session):
|
||||||
|
"""创建测试用户"""
|
||||||
|
user = User(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
email="test@example.com",
|
||||||
|
password_hash=hash_password("Test@123456"),
|
||||||
|
name="Test User",
|
||||||
|
plan="free",
|
||||||
|
max_queries=5,
|
||||||
|
is_active=True,
|
||||||
|
email_verified=True,
|
||||||
|
organization_id=uuid.uuid4(), # 需要organization_id用于内容API
|
||||||
|
)
|
||||||
|
async_session.add(user)
|
||||||
|
await async_session.commit()
|
||||||
|
await async_session.refresh(user)
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def test_brand(async_session, test_user):
|
||||||
|
"""创建测试品牌"""
|
||||||
|
brand = Brand(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
user_id=test_user.id,
|
||||||
|
name="Test Brand",
|
||||||
|
aliases=["TestBrand", "TB"],
|
||||||
|
website="https://testbrand.com",
|
||||||
|
industry="technology",
|
||||||
|
platforms=["wenxin", "kimi"],
|
||||||
|
frequency="weekly",
|
||||||
|
status="active",
|
||||||
|
)
|
||||||
|
async_session.add(brand)
|
||||||
|
await async_session.commit()
|
||||||
|
await async_session.refresh(brand)
|
||||||
|
return brand
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def async_client(async_session, test_user):
|
||||||
|
"""创建异步HTTP客户端用于API测试"""
|
||||||
|
|
||||||
|
async def override_get_db():
|
||||||
|
yield async_session
|
||||||
|
|
||||||
|
async def override_get_current_user():
|
||||||
|
return test_user
|
||||||
|
|
||||||
|
app.dependency_overrides[get_db] = override_get_db
|
||||||
|
app.dependency_overrides[get_current_user] = override_get_current_user
|
||||||
|
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||||
|
yield client
|
||||||
|
|
||||||
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def auth_headers(test_user):
|
||||||
|
"""创建认证请求头"""
|
||||||
|
token = create_access_token(data={"sub": str(test_user.id)})
|
||||||
|
return {"Authorization": f"Bearer {token}"}
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 测试类 ====================
|
||||||
|
|
||||||
|
class TestContentAPI:
|
||||||
|
"""内容管理API测试"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_contents_empty(self, async_client):
|
||||||
|
"""测试获取空内容列表"""
|
||||||
|
response = await async_client.get("/api/v1/contents/")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert isinstance(data, list)
|
||||||
|
assert len(data) == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_content(self, async_client, test_user):
|
||||||
|
"""测试创建内容"""
|
||||||
|
content_data = {
|
||||||
|
"title": "测试文章标题",
|
||||||
|
"body": "这是测试文章的内容",
|
||||||
|
"content_type": "article",
|
||||||
|
"tags": ["测试", "文章"],
|
||||||
|
}
|
||||||
|
response = await async_client.post("/api/v1/contents/", json=content_data)
|
||||||
|
|
||||||
|
assert response.status_code == 201
|
||||||
|
data = response.json()
|
||||||
|
assert data["title"] == "测试文章标题"
|
||||||
|
assert data["body"] == "这是测试文章的内容"
|
||||||
|
assert data["content_type"] == "article"
|
||||||
|
assert data["status"] == "draft"
|
||||||
|
assert "id" in data
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_content_by_id(self, async_client, test_user):
|
||||||
|
"""测试通过ID获取内容"""
|
||||||
|
# 先创建内容
|
||||||
|
create_response = await async_client.post(
|
||||||
|
"/api/v1/contents/",
|
||||||
|
json={
|
||||||
|
"title": "测试内容",
|
||||||
|
"body": "测试内容正文",
|
||||||
|
"content_type": "article",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
content_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# 获取内容
|
||||||
|
response = await async_client.get(f"/api/v1/contents/{content_id}")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["id"] == content_id
|
||||||
|
assert data["title"] == "测试内容"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_content_not_found(self, async_client):
|
||||||
|
"""测试获取不存在的内容"""
|
||||||
|
non_existent_id = uuid.uuid4()
|
||||||
|
response = await async_client.get(f"/api/v1/contents/{non_existent_id}")
|
||||||
|
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_content(self, async_client, test_user):
|
||||||
|
"""测试更新内容"""
|
||||||
|
# 先创建内容
|
||||||
|
create_response = await async_client.post(
|
||||||
|
"/api/v1/contents/",
|
||||||
|
json={
|
||||||
|
"title": "原始标题",
|
||||||
|
"body": "原始内容",
|
||||||
|
"content_type": "article",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
content_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# 更新内容
|
||||||
|
update_data = {
|
||||||
|
"title": "更新后的标题",
|
||||||
|
"body": "更新后的内容",
|
||||||
|
"status": "published",
|
||||||
|
}
|
||||||
|
response = await async_client.put(
|
||||||
|
f"/api/v1/contents/{content_id}", json=update_data
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["title"] == "更新后的标题"
|
||||||
|
assert data["body"] == "更新后的内容"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_content_partial(self, async_client, test_user):
|
||||||
|
"""测试部分更新内容"""
|
||||||
|
# 先创建内容
|
||||||
|
create_response = await async_client.post(
|
||||||
|
"/api/v1/contents/",
|
||||||
|
json={
|
||||||
|
"title": "原始标题",
|
||||||
|
"body": "原始内容",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
content_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# 只更新标题
|
||||||
|
response = await async_client.put(
|
||||||
|
f"/api/v1/contents/{content_id}",
|
||||||
|
json={"title": "新标题"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["title"] == "新标题"
|
||||||
|
assert data["body"] == "原始内容" # 保持不变
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_content(self, async_client, test_user):
|
||||||
|
"""测试删除内容"""
|
||||||
|
# 先创建内容
|
||||||
|
create_response = await async_client.post(
|
||||||
|
"/api/v1/contents/",
|
||||||
|
json={
|
||||||
|
"title": "待删除内容",
|
||||||
|
"body": "内容正文",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
content_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# 删除内容
|
||||||
|
response = await async_client.delete(f"/api/v1/contents/{content_id}")
|
||||||
|
assert response.status_code == 204
|
||||||
|
|
||||||
|
# 验证已删除
|
||||||
|
get_response = await async_client.get(f"/api/v1/contents/{content_id}")
|
||||||
|
assert get_response.status_code == 404
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_publish_content(self, async_client, test_user):
|
||||||
|
"""测试发布内容"""
|
||||||
|
# 先创建内容
|
||||||
|
create_response = await async_client.post(
|
||||||
|
"/api/v1/contents/",
|
||||||
|
json={
|
||||||
|
"title": "待发布内容",
|
||||||
|
"body": "内容正文",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
content_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# 发布内容
|
||||||
|
response = await async_client.post(f"/api/v1/contents/{content_id}/publish")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["status"] == "published"
|
||||||
|
assert data["published_at"] is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_contents_with_filter(self, async_client, test_user):
|
||||||
|
"""测试内容列表过滤"""
|
||||||
|
# 创建多个不同类型的内容
|
||||||
|
for i, content_type in enumerate(["article", "post", "article"]):
|
||||||
|
await async_client.post(
|
||||||
|
"/api/v1/contents/",
|
||||||
|
json={
|
||||||
|
"title": f"内容 {i}",
|
||||||
|
"body": "正文",
|
||||||
|
"content_type": content_type,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 按类型过滤
|
||||||
|
response = await async_client.get("/api/v1/contents/?content_type=article")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert isinstance(data, list)
|
||||||
|
for item in data:
|
||||||
|
assert item["content_type"] == "article"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_contents_with_status_filter(self, async_client, test_user):
|
||||||
|
"""测试内容列表按状态过滤"""
|
||||||
|
# 创建一个草稿和一个已发布
|
||||||
|
draft_response = await async_client.post(
|
||||||
|
"/api/v1/contents/",
|
||||||
|
json={"title": "草稿", "body": "正文", "content_type": "article"}
|
||||||
|
)
|
||||||
|
|
||||||
|
published_response = await async_client.post(
|
||||||
|
"/api/v1/contents/",
|
||||||
|
json={"title": "已发布", "body": "正文", "content_type": "article"}
|
||||||
|
)
|
||||||
|
await async_client.post(f"/api/v1/contents/{published_response.json()['id']}/publish")
|
||||||
|
|
||||||
|
# 只获取草稿
|
||||||
|
response = await async_client.get("/api/v1/contents/?status=draft")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert isinstance(data, list)
|
||||||
|
for item in data:
|
||||||
|
assert item["status"] == "draft"
|
||||||
|
|
||||||
|
|
||||||
|
class TestContentGenerationAPI:
|
||||||
|
"""内容生成API测试"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_topics(self, async_client):
|
||||||
|
"""测试获取母题库列表"""
|
||||||
|
response = await async_client.get("/api/v1/content/topics")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert isinstance(data, list)
|
||||||
|
# 应该有多个母题
|
||||||
|
assert len(data) > 0
|
||||||
|
|
||||||
|
# 验证母题结构
|
||||||
|
for topic in data:
|
||||||
|
assert "id" in topic
|
||||||
|
assert "name" in topic
|
||||||
|
assert "description" in topic
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_topic_detail(self, async_client):
|
||||||
|
"""测试获取母题详情"""
|
||||||
|
# 获取第一个可用的topic id
|
||||||
|
topics_response = await async_client.get("/api/v1/content/topics")
|
||||||
|
topics = topics_response.json()
|
||||||
|
|
||||||
|
if len(topics) > 0:
|
||||||
|
topic_id = topics[0]["id"]
|
||||||
|
response = await async_client.get(f"/api/v1/content/topics/{topic_id}")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["id"] == topic_id
|
||||||
|
assert "prompt_template" in data
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_topic_not_found(self, async_client):
|
||||||
|
"""测试获取不存在的母题"""
|
||||||
|
response = await async_client.get("/api/v1/content/topics/nonexistent_topic")
|
||||||
|
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
class TestBrandKnowledgeAPI:
|
||||||
|
"""品牌知识库API测试"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_brand_knowledge_empty(self, async_client):
|
||||||
|
"""测试获取空品牌知识库"""
|
||||||
|
response = await async_client.get("/api/v1/contents/knowledge/")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert isinstance(data, list)
|
||||||
|
assert len(data) == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_brand_knowledge(self, async_client, test_user):
|
||||||
|
"""测试创建品牌知识条目"""
|
||||||
|
knowledge_data = {
|
||||||
|
"category": "product",
|
||||||
|
"title": "产品介绍",
|
||||||
|
"body": "这是产品的详细介绍",
|
||||||
|
"source": "官网",
|
||||||
|
}
|
||||||
|
response = await async_client.post(
|
||||||
|
"/api/v1/contents/knowledge/", json=knowledge_data
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 201
|
||||||
|
data = response.json()
|
||||||
|
assert data["title"] == "产品介绍"
|
||||||
|
assert data["category"] == "product"
|
||||||
|
assert data["body"] == "这是产品的详细介绍"
|
||||||
|
assert data["source"] == "官网"
|
||||||
|
assert "id" in data
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_brand_knowledge_with_category(self, async_client, test_user):
|
||||||
|
"""测试按分类获取品牌知识"""
|
||||||
|
# 创建不同分类的知识
|
||||||
|
categories = ["product", "technology", "product"]
|
||||||
|
for i, cat in enumerate(categories):
|
||||||
|
await async_client.post(
|
||||||
|
"/api/v1/contents/knowledge/",
|
||||||
|
json={"category": cat, "title": f"知识 {i}", "body": "正文"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 按product分类过滤
|
||||||
|
response = await async_client.get("/api/v1/contents/knowledge/?category=product")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert isinstance(data, list)
|
||||||
|
for item in data:
|
||||||
|
assert item["category"] == "product"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_brand_knowledge(self, async_client, test_user):
|
||||||
|
"""测试更新品牌知识"""
|
||||||
|
# 先创建知识
|
||||||
|
create_response = await async_client.post(
|
||||||
|
"/api/v1/contents/knowledge/",
|
||||||
|
json={"category": "test", "title": "原始标题", "body": "原始内容"}
|
||||||
|
)
|
||||||
|
knowledge_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# 更新
|
||||||
|
response = await async_client.put(
|
||||||
|
f"/api/v1/contents/knowledge/{knowledge_id}",
|
||||||
|
json={"title": "新标题", "body": "新内容"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["title"] == "新标题"
|
||||||
|
assert data["body"] == "新内容"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_brand_knowledge(self, async_client, test_user):
|
||||||
|
"""测试删除品牌知识"""
|
||||||
|
# 先创建知识
|
||||||
|
create_response = await async_client.post(
|
||||||
|
"/api/v1/contents/knowledge/",
|
||||||
|
json={"category": "test", "title": "待删除", "body": "正文"}
|
||||||
|
)
|
||||||
|
knowledge_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# 删除
|
||||||
|
response = await async_client.delete(f"/api/v1/contents/knowledge/{knowledge_id}")
|
||||||
|
|
||||||
|
assert response.status_code == 204
|
||||||
|
|
||||||
|
# 验证删除 - 通过列出知识来确认
|
||||||
|
list_response = await async_client.get("/api/v1/contents/knowledge/")
|
||||||
|
knowledge_ids = [item["id"] for item in list_response.json()]
|
||||||
|
assert knowledge_id not in knowledge_ids
|
||||||
|
|
@ -0,0 +1,221 @@
|
||||||
|
"""诊断API测试"""
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from httpx import AsyncClient, ASGITransport
|
||||||
|
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession, create_async_engine
|
||||||
|
from sqlalchemy.pool import StaticPool
|
||||||
|
|
||||||
|
from app.database import Base
|
||||||
|
from app.main import app
|
||||||
|
from app.models.user import User
|
||||||
|
from app.models.brand import Brand
|
||||||
|
from app.api.deps import get_current_user, get_db
|
||||||
|
from app.services.auth import hash_password
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def async_engine():
|
||||||
|
"""创建测试用SQLite异步引擎"""
|
||||||
|
engine = create_async_engine(
|
||||||
|
"sqlite+aiosqlite:///:memory:",
|
||||||
|
connect_args={"check_same_thread": False},
|
||||||
|
poolclass=StaticPool,
|
||||||
|
)
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
yield engine
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def async_session(async_engine):
|
||||||
|
"""创建测试用异步数据库会话"""
|
||||||
|
async_session_maker = async_sessionmaker(
|
||||||
|
async_engine,
|
||||||
|
class_=AsyncSession,
|
||||||
|
expire_on_commit=False,
|
||||||
|
autoflush=False,
|
||||||
|
autocommit=False,
|
||||||
|
)
|
||||||
|
async with async_session_maker() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def test_user(async_session):
|
||||||
|
"""创建测试用户"""
|
||||||
|
user = User(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
email="test@example.com",
|
||||||
|
password_hash=hash_password("Test@123456"),
|
||||||
|
name="Test User",
|
||||||
|
plan="free",
|
||||||
|
max_queries=5,
|
||||||
|
is_active=True,
|
||||||
|
email_verified=True,
|
||||||
|
)
|
||||||
|
async_session.add(user)
|
||||||
|
await async_session.commit()
|
||||||
|
await async_session.refresh(user)
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def test_brand(async_session, test_user):
|
||||||
|
"""创建测试品牌"""
|
||||||
|
brand = Brand(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
user_id=test_user.id,
|
||||||
|
name="Test Brand",
|
||||||
|
aliases=["TestBrand", "TB"],
|
||||||
|
website="https://testbrand.com",
|
||||||
|
industry="technology",
|
||||||
|
platforms=["wenxin", "kimi"],
|
||||||
|
frequency="weekly",
|
||||||
|
status="active",
|
||||||
|
)
|
||||||
|
async_session.add(brand)
|
||||||
|
await async_session.commit()
|
||||||
|
await async_session.refresh(brand)
|
||||||
|
return brand
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def async_client(async_session, test_user):
|
||||||
|
"""创建异步HTTP客户端用于API测试"""
|
||||||
|
|
||||||
|
async def override_get_db():
|
||||||
|
yield async_session
|
||||||
|
|
||||||
|
async def override_get_current_user():
|
||||||
|
return test_user
|
||||||
|
|
||||||
|
app.dependency_overrides[get_db] = override_get_db
|
||||||
|
app.dependency_overrides[get_current_user] = override_get_current_user
|
||||||
|
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||||
|
yield client
|
||||||
|
|
||||||
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
|
||||||
|
class TestDiagnosisAPI:
|
||||||
|
"""诊断API测试"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_seo_diagnosis_success(self, async_client, test_brand):
|
||||||
|
"""测试SEO诊断端点成功返回"""
|
||||||
|
response = await async_client.get(f"/api/v1/diagnosis/seo/{test_brand.id}")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "overall_score" in data
|
||||||
|
assert "health_level" in data
|
||||||
|
assert "dimensions" in data
|
||||||
|
assert "recommendations" in data
|
||||||
|
assert isinstance(data["overall_score"], (int, float))
|
||||||
|
assert isinstance(data["dimensions"], list)
|
||||||
|
assert isinstance(data["recommendations"], list)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_geo_diagnosis_success(self, async_client, test_brand):
|
||||||
|
"""测试GEO诊断端点成功返回"""
|
||||||
|
response = await async_client.get(f"/api/v1/diagnosis/geo/{test_brand.id}")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "overall_score" in data
|
||||||
|
assert "health_level" in data
|
||||||
|
assert "dimensions" in data
|
||||||
|
assert "recommendations" in data
|
||||||
|
assert isinstance(data["overall_score"], (int, float))
|
||||||
|
assert isinstance(data["dimensions"], list)
|
||||||
|
assert isinstance(data["recommendations"], list)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_combined_diagnosis_success(self, async_client, test_brand):
|
||||||
|
"""测试综合诊断端点成功返回"""
|
||||||
|
response = await async_client.get(f"/api/v1/diagnosis/combined/{test_brand.id}")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert "seo_score" in data
|
||||||
|
assert "geo_score" in data
|
||||||
|
assert "combined_score" in data
|
||||||
|
assert "seo_diagnosis" in data
|
||||||
|
assert "geo_diagnosis" in data
|
||||||
|
assert isinstance(data["seo_score"], (int, float))
|
||||||
|
assert isinstance(data["geo_score"], (int, float))
|
||||||
|
assert isinstance(data["combined_score"], (int, float))
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_diagnosis_brand_not_found(self, async_client):
|
||||||
|
"""测试品牌不存在时返回404"""
|
||||||
|
non_existent_id = uuid.uuid4()
|
||||||
|
|
||||||
|
seo_response = await async_client.get(f"/api/v1/diagnosis/seo/{non_existent_id}")
|
||||||
|
assert seo_response.status_code == 404
|
||||||
|
|
||||||
|
geo_response = await async_client.get(f"/api/v1/diagnosis/geo/{non_existent_id}")
|
||||||
|
assert geo_response.status_code == 404
|
||||||
|
|
||||||
|
combined_response = await async_client.get(f"/api/v1/diagnosis/combined/{non_existent_id}")
|
||||||
|
assert combined_response.status_code == 404
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_diagnosis_unauthorized_access(self, async_session):
|
||||||
|
"""测试未认证时返回401"""
|
||||||
|
async def override_get_db():
|
||||||
|
yield async_session
|
||||||
|
|
||||||
|
app.dependency_overrides[get_db] = override_get_db
|
||||||
|
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||||
|
headers = {"Authorization": "Bearer invalid_token"}
|
||||||
|
|
||||||
|
seo_response = await client.get(f"/api/v1/diagnosis/seo/{uuid.uuid4()}", headers=headers)
|
||||||
|
assert seo_response.status_code == 401
|
||||||
|
|
||||||
|
geo_response = await client.get(f"/api/v1/diagnosis/geo/{uuid.uuid4()}", headers=headers)
|
||||||
|
assert geo_response.status_code == 401
|
||||||
|
|
||||||
|
combined_response = await client.get(f"/api/v1/diagnosis/combined/{uuid.uuid4()}", headers=headers)
|
||||||
|
assert combined_response.status_code == 401
|
||||||
|
|
||||||
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_diagnosis_result_format(self, async_client, test_brand):
|
||||||
|
"""测试诊断结果格式正确"""
|
||||||
|
response = await async_client.get(f"/api/v1/diagnosis/seo/{test_brand.id}")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
assert 0 <= data["overall_score"] <= 100
|
||||||
|
assert data["health_level"] in ["excellent", "good", "pass", "danger"]
|
||||||
|
|
||||||
|
for dimension in data["dimensions"]:
|
||||||
|
assert "name" in dimension
|
||||||
|
assert "score" in dimension
|
||||||
|
assert "max_score" in dimension
|
||||||
|
assert "percentage" in dimension
|
||||||
|
assert "status" in dimension
|
||||||
|
assert "items" in dimension
|
||||||
|
assert isinstance(dimension["items"], list)
|
||||||
|
|
||||||
|
for item in dimension["items"]:
|
||||||
|
assert "name" in item
|
||||||
|
assert "status" in item
|
||||||
|
assert "description" in item
|
||||||
|
assert "suggestion" in item
|
||||||
|
assert "score" in item
|
||||||
|
|
||||||
|
for rec in data["recommendations"]:
|
||||||
|
assert "priority" in rec
|
||||||
|
assert "dimension" in rec
|
||||||
|
assert "description" in rec
|
||||||
|
|
@ -0,0 +1,204 @@
|
||||||
|
"""健康检查API测试"""
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from httpx import AsyncClient, ASGITransport
|
||||||
|
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession, create_async_engine
|
||||||
|
from sqlalchemy.pool import StaticPool
|
||||||
|
|
||||||
|
from app.database import Base
|
||||||
|
from app.main import app
|
||||||
|
from app.api.deps import get_db
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== Fixtures ====================
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def async_engine():
|
||||||
|
"""创建测试用SQLite异步引擎"""
|
||||||
|
engine = create_async_engine(
|
||||||
|
"sqlite+aiosqlite:///:memory:",
|
||||||
|
connect_args={"check_same_thread": False},
|
||||||
|
poolclass=StaticPool,
|
||||||
|
)
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
yield engine
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def async_session(async_engine):
|
||||||
|
"""创建测试用异步数据库会话"""
|
||||||
|
async_session_maker = async_sessionmaker(
|
||||||
|
async_engine,
|
||||||
|
class_=AsyncSession,
|
||||||
|
expire_on_commit=False,
|
||||||
|
autoflush=False,
|
||||||
|
autocommit=False,
|
||||||
|
)
|
||||||
|
async with async_session_maker() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def async_client(async_session):
|
||||||
|
"""创建异步HTTP客户端用于API测试"""
|
||||||
|
|
||||||
|
async def override_get_db():
|
||||||
|
yield async_session
|
||||||
|
|
||||||
|
app.dependency_overrides[get_db] = override_get_db
|
||||||
|
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||||
|
yield client
|
||||||
|
|
||||||
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 测试类 ====================
|
||||||
|
|
||||||
|
class TestHealthAPI:
|
||||||
|
"""健康检查API测试"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_basic_health(self):
|
||||||
|
"""测试基本健康检查"""
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||||
|
response = await client.get("/health")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["status"] == "healthy"
|
||||||
|
assert "timestamp" in data
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_liveness(self):
|
||||||
|
"""测试存活探针"""
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||||
|
response = await client.get("/health/liveness")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["status"] == "alive"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ready_endpoint(self):
|
||||||
|
"""测试就绪端点"""
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||||
|
response = await client.get("/ready")
|
||||||
|
|
||||||
|
# 返回200或503取决于依赖服务状态
|
||||||
|
assert response.status_code in [200, 503]
|
||||||
|
data = response.json()
|
||||||
|
assert "status" in data
|
||||||
|
assert "checks" in data
|
||||||
|
assert "database" in data["checks"]
|
||||||
|
assert "redis" in data["checks"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_metrics(self):
|
||||||
|
"""测试Prometheus指标端点"""
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||||
|
response = await client.get("/metrics")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
# Prometheus指标返回文本格式
|
||||||
|
assert "text/plain" in response.headers["content-type"] or \
|
||||||
|
"text/plain" in str(response.headers.get("content-type", ""))
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_detailed_health(self, async_client):
|
||||||
|
"""测试详细健康检查"""
|
||||||
|
response = await async_client.get("/health/detailed")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
# 检查返回结构
|
||||||
|
assert "checks" in data
|
||||||
|
assert "app" in data
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_readiness_probe(self, async_client):
|
||||||
|
"""测试就绪探针"""
|
||||||
|
response = await async_client.get("/health/readiness")
|
||||||
|
|
||||||
|
# 健康状态应该是200,unhealthy是503
|
||||||
|
assert response.status_code in [200, 503]
|
||||||
|
data = response.json()
|
||||||
|
assert "status" in data
|
||||||
|
assert "checks" in data
|
||||||
|
|
||||||
|
|
||||||
|
class TestHealthEndpointsStructure:
|
||||||
|
"""健康检查端点结构测试"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_health_endpoint_returns_json(self):
|
||||||
|
"""测试健康端点返回JSON格式"""
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||||
|
response = await client.get("/health")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert "application/json" in response.headers.get("content-type", "")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_detailed_health_has_required_fields(self, async_client):
|
||||||
|
"""测试详细健康检查包含必需字段"""
|
||||||
|
response = await async_client.get("/health/detailed")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
# 检查app信息
|
||||||
|
if "app" in data:
|
||||||
|
app_info = data["app"]
|
||||||
|
assert isinstance(app_info, dict)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ready_checks_database_and_redis(self, async_client):
|
||||||
|
"""测试就绪检查包含数据库和Redis检查"""
|
||||||
|
response = await async_client.get("/ready")
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
checks = data.get("checks", {})
|
||||||
|
|
||||||
|
assert "database" in checks
|
||||||
|
assert "redis" in checks
|
||||||
|
|
||||||
|
|
||||||
|
class TestHealthCheckIndependence:
|
||||||
|
"""健康检查独立性测试"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_health_no_auth_required(self):
|
||||||
|
"""测试健康检查不需要认证"""
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||||
|
# 不带token访问
|
||||||
|
response = await client.get("/health")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_liveness_no_auth_required(self):
|
||||||
|
"""测试存活探针不需要认证"""
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||||
|
response = await client.get("/health/liveness")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_metrics_no_auth_required(self):
|
||||||
|
"""测试指标端点不需要认证"""
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||||
|
response = await client.get("/metrics")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
"""内容生成Pipeline测试包"""
|
||||||
|
|
@ -0,0 +1,207 @@
|
||||||
|
"""内容生成Pipeline测试"""
|
||||||
|
import pytest
|
||||||
|
from app.services.content.content_pipeline import ContentPipeline, PipelineResponse
|
||||||
|
|
||||||
|
|
||||||
|
class TestContentPipeline:
|
||||||
|
"""内容Pipeline测试"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pipeline_complete_run(self):
|
||||||
|
"""测试完整Pipeline执行"""
|
||||||
|
pipeline = ContentPipeline()
|
||||||
|
request = {
|
||||||
|
"content": "这是一篇关于华为手机评测的深度文章,"
|
||||||
|
"详细介绍了华为Mate系列的特点和性能。",
|
||||||
|
"title": "华为手机评测",
|
||||||
|
"platform": "zhihu",
|
||||||
|
"optimize_for": ["validation", "sensitive", "seo"],
|
||||||
|
"output_formats": ["html", "markdown", "plain"],
|
||||||
|
"keyword": "华为手机"
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await pipeline.run(request)
|
||||||
|
|
||||||
|
# 验证返回类型
|
||||||
|
assert isinstance(result, PipelineResponse)
|
||||||
|
assert result.stages is not None
|
||||||
|
assert len(result.stages) > 0
|
||||||
|
|
||||||
|
# 验证输出
|
||||||
|
assert result.outputs is not None
|
||||||
|
assert result.outputs.html is not None or result.outputs.markdown is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pipeline_validation_only(self):
|
||||||
|
"""测试仅校验模式"""
|
||||||
|
pipeline = ContentPipeline()
|
||||||
|
request = {
|
||||||
|
"content": "这是一篇测试文章内容",
|
||||||
|
"title": "测试标题",
|
||||||
|
"platform": "zhihu",
|
||||||
|
"optimize_for": ["validation"]
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await pipeline.run(request)
|
||||||
|
|
||||||
|
assert result.stages is not None
|
||||||
|
# 仅校验模式不生成输出
|
||||||
|
validation_stage = next((s for s in result.stages if s.name == "validation"), None)
|
||||||
|
assert validation_stage is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pipeline_sensitive_filter_only(self):
|
||||||
|
"""测试仅敏感词过滤模式"""
|
||||||
|
pipeline = ContentPipeline()
|
||||||
|
request = {
|
||||||
|
"content": "这是一篇正常内容的文章",
|
||||||
|
"title": "测试标题",
|
||||||
|
"platform": "zhihu",
|
||||||
|
"optimize_for": ["sensitive"]
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await pipeline.run(request)
|
||||||
|
|
||||||
|
assert result.stages is not None
|
||||||
|
assert len(result.stages) >= 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pipeline_seo_only(self):
|
||||||
|
"""测试仅SEO优化模式"""
|
||||||
|
pipeline = ContentPipeline()
|
||||||
|
request = {
|
||||||
|
"content": "华为手机是知名的手机品牌",
|
||||||
|
"title": "华为手机评测",
|
||||||
|
"platform": "zhihu",
|
||||||
|
"optimize_for": ["seo"],
|
||||||
|
"keyword": "华为手机"
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await pipeline.run(request)
|
||||||
|
|
||||||
|
assert result.stages is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pipeline_with_validation_fail(self):
|
||||||
|
"""测试校验失败场景"""
|
||||||
|
pipeline = ContentPipeline()
|
||||||
|
request = {
|
||||||
|
"content": "内容",
|
||||||
|
"title": "这个标题太长了超过了三十个字符的限制了哈哈哈啊",
|
||||||
|
"platform": "wechat",
|
||||||
|
"optimize_for": ["validation"]
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await pipeline.run(request)
|
||||||
|
|
||||||
|
# 校验失败时不应继续执行后续阶段
|
||||||
|
validation_stage = next((s for s in result.stages if s.name == "validation"), None)
|
||||||
|
assert validation_stage is not None
|
||||||
|
assert validation_stage.passed is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pipeline_stage_timing(self):
|
||||||
|
"""测试各阶段耗时记录"""
|
||||||
|
pipeline = ContentPipeline()
|
||||||
|
request = {
|
||||||
|
"content": "这是一篇测试文章内容",
|
||||||
|
"title": "测试标题",
|
||||||
|
"platform": "zhihu",
|
||||||
|
"optimize_for": ["validation", "sensitive", "seo"]
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await pipeline.run(request)
|
||||||
|
|
||||||
|
for stage in result.stages:
|
||||||
|
assert stage.name is not None
|
||||||
|
assert hasattr(stage, 'duration')
|
||||||
|
assert stage.duration >= 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pipeline_multi_platform(self):
|
||||||
|
"""测试多平台适配"""
|
||||||
|
pipeline = ContentPipeline()
|
||||||
|
|
||||||
|
platforms = ["zhihu", "wechat", "xiaohongshu"]
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for platform in platforms:
|
||||||
|
result = await pipeline.run({
|
||||||
|
"content": "<p>测试内容</p>",
|
||||||
|
"title": "测试标题",
|
||||||
|
"platform": platform,
|
||||||
|
"optimize_for": ["validation", "sensitive"]
|
||||||
|
})
|
||||||
|
results.append(result)
|
||||||
|
assert result.stages is not None
|
||||||
|
|
||||||
|
# 各平台结果都应成功
|
||||||
|
assert len(results) == len(platforms)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pipeline_output_formats(self):
|
||||||
|
"""测试不同输出格式"""
|
||||||
|
pipeline = ContentPipeline()
|
||||||
|
request = {
|
||||||
|
"content": "测试内容",
|
||||||
|
"title": "测试标题",
|
||||||
|
"platform": "zhihu",
|
||||||
|
"optimize_for": [],
|
||||||
|
"output_formats": ["html", "markdown", "plain"]
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await pipeline.run(request)
|
||||||
|
|
||||||
|
assert result.outputs is not None
|
||||||
|
# HTML生成是默认的
|
||||||
|
assert result.outputs.html is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pipeline_empty_optimize_for(self):
|
||||||
|
"""测试空的optimize_for参数"""
|
||||||
|
pipeline = ContentPipeline()
|
||||||
|
request = {
|
||||||
|
"content": "测试内容",
|
||||||
|
"title": "测试标题",
|
||||||
|
"platform": "zhihu",
|
||||||
|
"optimize_for": [],
|
||||||
|
"output_formats": ["html"]
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await pipeline.run(request)
|
||||||
|
|
||||||
|
# 无优化阶段但应有HTML生成阶段
|
||||||
|
assert result.outputs is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pipeline_error_handling(self):
|
||||||
|
"""测试无效平台错误处理"""
|
||||||
|
pipeline = ContentPipeline()
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await pipeline.run({
|
||||||
|
"content": "内容",
|
||||||
|
"title": "标题",
|
||||||
|
"platform": "invalid_platform"
|
||||||
|
})
|
||||||
|
# 应该返回错误
|
||||||
|
assert result.error is not None or len(result.stages) > 0
|
||||||
|
except ValueError as e:
|
||||||
|
assert "不支持的平台" in str(e)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pipeline_validate_only_method(self):
|
||||||
|
"""测试validate_only方法"""
|
||||||
|
pipeline = ContentPipeline()
|
||||||
|
|
||||||
|
result = await pipeline.validate_only(
|
||||||
|
content="测试内容",
|
||||||
|
title="测试标题",
|
||||||
|
platform="zhihu"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert hasattr(result, 'is_valid')
|
||||||
|
assert hasattr(result, 'score')
|
||||||
|
assert hasattr(result, 'issues')
|
||||||
|
assert hasattr(result, 'passed')
|
||||||
|
|
@ -0,0 +1,265 @@
|
||||||
|
"""HTML生成器测试"""
|
||||||
|
import pytest
|
||||||
|
from app.services.content.html_generator import HTMLGenerator
|
||||||
|
|
||||||
|
|
||||||
|
class TestHTMLGenerator:
|
||||||
|
"""HTML生成器测试"""
|
||||||
|
|
||||||
|
def test_generate_basic_html(self):
|
||||||
|
"""基础HTML生成"""
|
||||||
|
generator = HTMLGenerator()
|
||||||
|
html = generator.generate(
|
||||||
|
content="<p>这是测试内容</p>",
|
||||||
|
platform="zhihu"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert html is not None
|
||||||
|
assert isinstance(html, str)
|
||||||
|
|
||||||
|
def test_generate_for_zhihu(self):
|
||||||
|
"""知乎平台HTML生成"""
|
||||||
|
generator = HTMLGenerator()
|
||||||
|
html = generator.generate(
|
||||||
|
content="<h1>华为手机评测</h1><p>这是一篇关于华为手机的详细评测文章。</p>",
|
||||||
|
platform="zhihu"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert html is not None
|
||||||
|
assert "华为手机评测" in html
|
||||||
|
|
||||||
|
def test_generate_for_wechat(self):
|
||||||
|
"""微信公众号HTML生成"""
|
||||||
|
generator = HTMLGenerator()
|
||||||
|
html = generator.generate(
|
||||||
|
content="<p>华为手机非常好用</p>",
|
||||||
|
platform="wechat"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert html is not None
|
||||||
|
|
||||||
|
def test_generate_for_xiaohongshu(self):
|
||||||
|
"""小红书HTML生成"""
|
||||||
|
generator = HTMLGenerator()
|
||||||
|
html = generator.generate(
|
||||||
|
content="<p>种草笔记内容</p>",
|
||||||
|
platform="xiaohongshu"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert html is not None
|
||||||
|
|
||||||
|
def test_filter_banned_tags(self):
|
||||||
|
"""禁用标签过滤"""
|
||||||
|
generator = HTMLGenerator()
|
||||||
|
html = generator.generate(
|
||||||
|
content="<script>alert('xss')</script><p>正常内容</p>",
|
||||||
|
platform="zhihu"
|
||||||
|
)
|
||||||
|
|
||||||
|
# script标签应被移除
|
||||||
|
assert "<script>" not in html
|
||||||
|
assert "正常内容" in html
|
||||||
|
|
||||||
|
def test_filter_banned_tags_wechat_external_links(self):
|
||||||
|
"""微信公众号外部链接过滤"""
|
||||||
|
generator = HTMLGenerator()
|
||||||
|
html = generator.generate(
|
||||||
|
content="<a href='http://baidu.com'>外部链接</a><p>内容</p>",
|
||||||
|
platform="wechat"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 微信公众号应过滤外部链接
|
||||||
|
assert "http://baidu.com" not in html
|
||||||
|
|
||||||
|
def test_filter_banned_tags_wechat_preserves_internal(self):
|
||||||
|
"""微信公众号保留内部链接"""
|
||||||
|
generator = HTMLGenerator()
|
||||||
|
html = generator.generate(
|
||||||
|
content="<a href='https://mp.weixin.qq.com/s/test'>内部链接</a><p>内容</p>",
|
||||||
|
platform="wechat"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 微信公众号应保留内部链接
|
||||||
|
assert html is not None
|
||||||
|
|
||||||
|
def test_to_markdown_h1_conversion(self):
|
||||||
|
"""H1标签转Markdown"""
|
||||||
|
generator = HTMLGenerator()
|
||||||
|
md = generator.to_markdown("<h1>标题</h1>")
|
||||||
|
|
||||||
|
assert "# 标题" in md
|
||||||
|
|
||||||
|
def test_to_markdown_h2_conversion(self):
|
||||||
|
"""H2标签转Markdown"""
|
||||||
|
generator = HTMLGenerator()
|
||||||
|
md = generator.to_markdown("<h2>二级标题</h2>")
|
||||||
|
|
||||||
|
assert "## 二级标题" in md
|
||||||
|
|
||||||
|
def test_to_markdown_h3_conversion(self):
|
||||||
|
"""H3标签转Markdown"""
|
||||||
|
generator = HTMLGenerator()
|
||||||
|
md = generator.to_markdown("<h3>三级标题</h3>")
|
||||||
|
|
||||||
|
assert "### 三级标题" in md
|
||||||
|
|
||||||
|
def test_to_markdown_paragraph_conversion(self):
|
||||||
|
"""段落标签转Markdown"""
|
||||||
|
generator = HTMLGenerator()
|
||||||
|
md = generator.to_markdown("<p>段落内容</p>")
|
||||||
|
|
||||||
|
assert "段落内容" in md
|
||||||
|
|
||||||
|
def test_to_markdown_br_conversion(self):
|
||||||
|
"""换行标签转Markdown"""
|
||||||
|
generator = HTMLGenerator()
|
||||||
|
md = generator.to_markdown("第一行<br>第二行")
|
||||||
|
|
||||||
|
assert "第一行" in md
|
||||||
|
assert "第二行" in md
|
||||||
|
|
||||||
|
def test_to_markdown_list_conversion(self):
|
||||||
|
"""列表标签转Markdown"""
|
||||||
|
generator = HTMLGenerator()
|
||||||
|
md = generator.to_markdown("<li>列表项</li>")
|
||||||
|
|
||||||
|
assert "- 列表项" in md
|
||||||
|
|
||||||
|
def test_to_markdown_code_inline(self):
|
||||||
|
"""行内代码转Markdown"""
|
||||||
|
generator = HTMLGenerator()
|
||||||
|
md = generator.to_markdown("<code>代码</code>")
|
||||||
|
|
||||||
|
assert "`代码`" in md
|
||||||
|
|
||||||
|
def test_to_markdown_blockquote(self):
|
||||||
|
"""引用标签转Markdown"""
|
||||||
|
generator = HTMLGenerator()
|
||||||
|
md = generator.to_markdown("<blockquote>引用内容</blockquote>")
|
||||||
|
|
||||||
|
assert "> 引用内容" in md
|
||||||
|
|
||||||
|
def test_to_markdown_pre_block(self):
|
||||||
|
"""代码块转Markdown"""
|
||||||
|
generator = HTMLGenerator()
|
||||||
|
md = generator.to_markdown("<pre>代码块</pre>")
|
||||||
|
|
||||||
|
assert "```" in md
|
||||||
|
|
||||||
|
def test_to_markdown_strips残留_tags(self):
|
||||||
|
"""Markdown转换清理残留标签"""
|
||||||
|
generator = HTMLGenerator()
|
||||||
|
md = generator.to_markdown("<div>内容</div>")
|
||||||
|
|
||||||
|
# div标签应被移除
|
||||||
|
assert "<div>" not in md
|
||||||
|
|
||||||
|
def test_to_plain_text_basic(self):
|
||||||
|
"""纯文本基本转换"""
|
||||||
|
generator = HTMLGenerator()
|
||||||
|
plain = generator.to_plain("<h1>标题</h1><p>段落</p>")
|
||||||
|
|
||||||
|
assert "标题" in plain
|
||||||
|
assert "段落" in plain
|
||||||
|
|
||||||
|
def test_to_plain_text_removes_tags(self):
|
||||||
|
"""纯文本移除所有标签"""
|
||||||
|
generator = HTMLGenerator()
|
||||||
|
plain = generator.to_plain("<script>alert(1)</script><p>内容</p>")
|
||||||
|
|
||||||
|
assert "<script>" not in plain
|
||||||
|
assert "<p>" not in plain
|
||||||
|
|
||||||
|
def test_to_plain_text_decodes_html_entities(self):
|
||||||
|
"""纯文本解码HTML实体"""
|
||||||
|
generator = HTMLGenerator()
|
||||||
|
plain = generator.to_plain("<>&"")
|
||||||
|
|
||||||
|
assert "<" in plain
|
||||||
|
assert ">" in plain
|
||||||
|
assert "&" in plain
|
||||||
|
assert '"' in plain
|
||||||
|
|
||||||
|
def test_to_plain_text_removes_extra_spaces(self):
|
||||||
|
"""纯文本清理多余空格"""
|
||||||
|
generator = HTMLGenerator()
|
||||||
|
plain = generator.to_plain("内容 多个 空格")
|
||||||
|
|
||||||
|
assert " " not in plain
|
||||||
|
|
||||||
|
def test_to_plain_text_removes_extra_newlines(self):
|
||||||
|
"""纯文本清理多余换行"""
|
||||||
|
generator = HTMLGenerator()
|
||||||
|
plain = generator.to_plain("内容\n\n\n换行")
|
||||||
|
|
||||||
|
# 不应有超过2个连续换行
|
||||||
|
assert "\n\n\n" not in plain
|
||||||
|
|
||||||
|
def test_generate_format_html(self):
|
||||||
|
"""HTML格式输出"""
|
||||||
|
generator = HTMLGenerator()
|
||||||
|
html = generator.generate(
|
||||||
|
content="<p>内容</p>",
|
||||||
|
platform="zhihu",
|
||||||
|
format="html"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert html is not None
|
||||||
|
|
||||||
|
def test_generate_format_markdown(self):
|
||||||
|
"""Markdown格式输出"""
|
||||||
|
generator = HTMLGenerator()
|
||||||
|
result = generator.generate(
|
||||||
|
content="<h1>标题</h1>",
|
||||||
|
platform="zhihu",
|
||||||
|
format="markdown"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "# 标题" in result
|
||||||
|
|
||||||
|
def test_generate_format_plain(self):
|
||||||
|
"""纯文本格式输出"""
|
||||||
|
generator = HTMLGenerator()
|
||||||
|
result = generator.generate(
|
||||||
|
content="<p>内容</p>",
|
||||||
|
platform="zhihu",
|
||||||
|
format="plain"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "内容" in result
|
||||||
|
assert "<p>" not in result
|
||||||
|
|
||||||
|
def test_generate_invalid_platform(self):
|
||||||
|
"""无效平台处理"""
|
||||||
|
generator = HTMLGenerator()
|
||||||
|
html = generator.generate(
|
||||||
|
content="<p>内容</p>",
|
||||||
|
platform="invalid_platform"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 无效平台应返回原内容
|
||||||
|
assert html is not None
|
||||||
|
|
||||||
|
def test_generate_with_empty_content(self):
|
||||||
|
"""空内容生成"""
|
||||||
|
generator = HTMLGenerator()
|
||||||
|
html = generator.generate(
|
||||||
|
content="",
|
||||||
|
platform="zhihu"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert html == ""
|
||||||
|
|
||||||
|
def test_to_markdown_empty_content(self):
|
||||||
|
"""空内容转Markdown"""
|
||||||
|
generator = HTMLGenerator()
|
||||||
|
md = generator.to_markdown("")
|
||||||
|
|
||||||
|
assert md == ""
|
||||||
|
|
||||||
|
def test_to_plain_empty_content(self):
|
||||||
|
"""空内容转纯文本"""
|
||||||
|
generator = HTMLGenerator()
|
||||||
|
plain = generator.to_plain("")
|
||||||
|
|
||||||
|
assert plain == ""
|
||||||
|
|
@ -0,0 +1,218 @@
|
||||||
|
"""规则校验器测试"""
|
||||||
|
import pytest
|
||||||
|
from app.services.content.rule_validator import (
|
||||||
|
RuleValidator,
|
||||||
|
ValidationIssue,
|
||||||
|
ValidationResult,
|
||||||
|
AI_Pattern
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRuleValidator:
|
||||||
|
"""规则校验器测试"""
|
||||||
|
|
||||||
|
def test_validate_title_length_pass(self):
|
||||||
|
"""标题长度符合规则时返回passed"""
|
||||||
|
validator = RuleValidator()
|
||||||
|
result = validator.validate(
|
||||||
|
content="这是一篇关于AI医疗的深度分析文章...",
|
||||||
|
title="AI医疗的发展趋势与未来展望",
|
||||||
|
platform="zhihu"
|
||||||
|
)
|
||||||
|
assert result.is_valid is True
|
||||||
|
assert any("标题长度合规" in p or "合规" in p for p in result.passed)
|
||||||
|
|
||||||
|
def test_validate_title_length_fail(self):
|
||||||
|
"""标题长度超出限制时返回issue"""
|
||||||
|
validator = RuleValidator()
|
||||||
|
result = validator.validate(
|
||||||
|
content="内容",
|
||||||
|
title="这个标题太长了超过了三十个字符的限制了哈哈哈哈哈哈",
|
||||||
|
platform="wechat"
|
||||||
|
)
|
||||||
|
assert result.is_valid is False
|
||||||
|
assert any("超过" in i.message for i in result.issues if i.severity == "high")
|
||||||
|
|
||||||
|
def test_validate_content_length_pass(self):
|
||||||
|
"""内容长度符合规则时返回passed"""
|
||||||
|
validator = RuleValidator()
|
||||||
|
result = validator.validate(
|
||||||
|
content="A" * 1500,
|
||||||
|
title="测试标题",
|
||||||
|
platform="zhihu"
|
||||||
|
)
|
||||||
|
assert result.score >= 80
|
||||||
|
|
||||||
|
def test_validate_content_length_fail(self):
|
||||||
|
"""内容超长返回issue"""
|
||||||
|
validator = RuleValidator()
|
||||||
|
result = validator.validate(
|
||||||
|
content="A" * 30000,
|
||||||
|
title="测试标题",
|
||||||
|
platform="wechat"
|
||||||
|
)
|
||||||
|
assert any("超过" in i.message for i in result.issues if i.severity == "high")
|
||||||
|
|
||||||
|
def test_validate_zhihu_specific_rules(self):
|
||||||
|
"""知乎特定规则"""
|
||||||
|
validator = RuleValidator()
|
||||||
|
result = validator.validate(
|
||||||
|
content="这是一个专业回答",
|
||||||
|
title="专业回答",
|
||||||
|
platform="zhihu"
|
||||||
|
)
|
||||||
|
assert result.score > 0
|
||||||
|
|
||||||
|
def test_validate_wechat_inducing_content(self):
|
||||||
|
"""微信公众号诱导分享检测"""
|
||||||
|
validator = RuleValidator()
|
||||||
|
result = validator.validate(
|
||||||
|
content="转发本文领取红包",
|
||||||
|
title="限时优惠",
|
||||||
|
platform="wechat"
|
||||||
|
)
|
||||||
|
# 诱导分享应该被检测
|
||||||
|
assert any("诱导" in i.message for i in result.issues)
|
||||||
|
|
||||||
|
def test_validate_wechat_marketing_words(self):
|
||||||
|
"""微信公众号营销用语检测"""
|
||||||
|
validator = RuleValidator()
|
||||||
|
result = validator.validate(
|
||||||
|
content="点击购买,限时优惠",
|
||||||
|
title="优惠信息",
|
||||||
|
platform="wechat"
|
||||||
|
)
|
||||||
|
# 营销用语应该被检测
|
||||||
|
assert any("营销" in i.message for i in result.issues)
|
||||||
|
|
||||||
|
def test_validate_xiaohongshu_cross_platform(self):
|
||||||
|
"""小红书跨平台引流检测"""
|
||||||
|
validator = RuleValidator()
|
||||||
|
result = validator.validate(
|
||||||
|
content="微信公众号搜索xxx获取更多内容",
|
||||||
|
title="种草笔记",
|
||||||
|
platform="xiaohongshu"
|
||||||
|
)
|
||||||
|
# 小红书应检测跨平台引流
|
||||||
|
assert any("引流" in i.message for i in result.issues)
|
||||||
|
|
||||||
|
def test_validate_xiaohongshu_content_length(self):
|
||||||
|
"""小红书内容长度检测"""
|
||||||
|
validator = RuleValidator()
|
||||||
|
result = validator.validate(
|
||||||
|
content="A" * 100,
|
||||||
|
title="笔记",
|
||||||
|
platform="xiaohongshu"
|
||||||
|
)
|
||||||
|
# 内容过短应该被检测
|
||||||
|
assert any("300" in i.message for i in result.issues)
|
||||||
|
|
||||||
|
def test_detect_ai_patterns_banned_words(self):
|
||||||
|
"""检测禁用词"""
|
||||||
|
validator = RuleValidator()
|
||||||
|
result = validator.detect_ai_patterns(
|
||||||
|
content="首先,其次,最后,总而言之,总之,总之",
|
||||||
|
platform="zhihu"
|
||||||
|
)
|
||||||
|
assert len(result) > 0
|
||||||
|
assert any("首先" in r.pattern or "总之" in r.pattern for r in result)
|
||||||
|
|
||||||
|
def test_detect_ai_patterns_banned_structures(self):
|
||||||
|
"""检测禁用结构"""
|
||||||
|
validator = RuleValidator()
|
||||||
|
result = validator.detect_ai_patterns(
|
||||||
|
content="第一,观点一。第二,观点二。第三,观点三。",
|
||||||
|
platform="zhihu"
|
||||||
|
)
|
||||||
|
assert len(result) > 0
|
||||||
|
|
||||||
|
def test_detect_ai_patterns_no_banned(self):
|
||||||
|
"""无禁用词时返回空列表"""
|
||||||
|
validator = RuleValidator()
|
||||||
|
result = validator.detect_ai_patterns(
|
||||||
|
content="这是一篇正常的人类写作内容",
|
||||||
|
platform="zhihu"
|
||||||
|
)
|
||||||
|
assert isinstance(result, list)
|
||||||
|
|
||||||
|
def test_detect_ai_patterns_invalid_platform(self):
|
||||||
|
"""无效平台返回空列表"""
|
||||||
|
validator = RuleValidator()
|
||||||
|
result = validator.detect_ai_patterns(
|
||||||
|
content="内容",
|
||||||
|
platform="invalid_platform"
|
||||||
|
)
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
def test_get_optimization_tips(self):
|
||||||
|
"""获取优化建议"""
|
||||||
|
validator = RuleValidator()
|
||||||
|
tips = validator.get_optimization_tips("zhihu")
|
||||||
|
assert isinstance(tips, list)
|
||||||
|
|
||||||
|
def test_get_optimization_tips_invalid_platform(self):
|
||||||
|
"""无效平台返回空列表"""
|
||||||
|
validator = RuleValidator()
|
||||||
|
tips = validator.get_optimization_tips("invalid_platform")
|
||||||
|
assert tips == []
|
||||||
|
|
||||||
|
def test_validation_result_structure(self):
|
||||||
|
"""验证结果结构完整性"""
|
||||||
|
validator = RuleValidator()
|
||||||
|
result = validator.validate(
|
||||||
|
content="测试内容",
|
||||||
|
title="测试标题",
|
||||||
|
platform="zhihu"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, ValidationResult)
|
||||||
|
assert isinstance(result.is_valid, bool)
|
||||||
|
assert isinstance(result.score, int)
|
||||||
|
assert isinstance(result.issues, list)
|
||||||
|
assert isinstance(result.passed, list)
|
||||||
|
|
||||||
|
for issue in result.issues:
|
||||||
|
assert isinstance(issue, ValidationIssue)
|
||||||
|
assert issue.severity in ["high", "medium", "low"]
|
||||||
|
assert isinstance(issue.message, str)
|
||||||
|
assert isinstance(issue.category, str)
|
||||||
|
|
||||||
|
def test_validation_ai_pattern_structure(self):
|
||||||
|
"""AI特征结构完整性"""
|
||||||
|
validator = RuleValidator()
|
||||||
|
results = validator.detect_ai_patterns(
|
||||||
|
content="首先,其次,最后",
|
||||||
|
platform="zhihu"
|
||||||
|
)
|
||||||
|
|
||||||
|
for pattern in results:
|
||||||
|
assert isinstance(pattern, AI_Pattern)
|
||||||
|
assert isinstance(pattern.pattern, str)
|
||||||
|
assert pattern.type in ["banned_word", "banned_structure"]
|
||||||
|
assert pattern.severity in ["medium", "high"]
|
||||||
|
|
||||||
|
def test_validate_platform_rules(self):
|
||||||
|
"""测试平台规则校验返回passed列表"""
|
||||||
|
validator = RuleValidator()
|
||||||
|
result = validator.validate(
|
||||||
|
content="华为手机非常好用",
|
||||||
|
title="华为手机评测",
|
||||||
|
platform="zhihu"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert hasattr(result, 'passed')
|
||||||
|
assert isinstance(result.passed, list)
|
||||||
|
|
||||||
|
def test_validation_score_calculation(self):
|
||||||
|
"""验证分数计算逻辑"""
|
||||||
|
validator = RuleValidator()
|
||||||
|
result = validator.validate(
|
||||||
|
content="A" * 100,
|
||||||
|
title="正常标题",
|
||||||
|
platform="zhihu"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 无严重问题时应该得到较高的分数
|
||||||
|
high_severity_issues = [i for i in result.issues if i.severity == "high"]
|
||||||
|
if len(high_severity_issues) == 0:
|
||||||
|
assert result.score >= 80
|
||||||
|
|
@ -0,0 +1,192 @@
|
||||||
|
"""敏感词过滤器测试"""
|
||||||
|
import pytest
|
||||||
|
from app.services.content.sensitive_filter import (
|
||||||
|
SensitiveFilter,
|
||||||
|
FoundWord,
|
||||||
|
FilterResult
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSensitiveFilter:
|
||||||
|
"""敏感词过滤器测试"""
|
||||||
|
|
||||||
|
def test_filter_politics_words(self):
|
||||||
|
"""政治敏感词过滤"""
|
||||||
|
filter_obj = SensitiveFilter()
|
||||||
|
result = filter_obj.filter(
|
||||||
|
content="这是一个关于台湾问题的分析",
|
||||||
|
platform="zhihu"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, FilterResult)
|
||||||
|
# 政治敏感词应被替换
|
||||||
|
assert "台湾" not in result.filtered_content or "**" in result.filtered_content
|
||||||
|
assert len(result.found_words) > 0
|
||||||
|
assert result.found_words[0].category == "politics"
|
||||||
|
|
||||||
|
def test_filter_medical_words(self):
|
||||||
|
"""医疗敏感词过滤"""
|
||||||
|
filter_obj = SensitiveFilter()
|
||||||
|
result = filter_obj.filter(
|
||||||
|
content="这个药品效果很好",
|
||||||
|
platform="wechat"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 医疗类敏感词应被检测
|
||||||
|
assert result.found_words is not None
|
||||||
|
assert isinstance(result.found_words, list)
|
||||||
|
|
||||||
|
def test_filter_finance_words(self):
|
||||||
|
"""金融敏感词过滤"""
|
||||||
|
filter_obj = SensitiveFilter()
|
||||||
|
result = filter_obj.filter(
|
||||||
|
content="年化收益率10%",
|
||||||
|
platform="zhihu"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 金融敏感词检测
|
||||||
|
assert result.found_words is not None
|
||||||
|
assert isinstance(result.found_words, list)
|
||||||
|
|
||||||
|
def test_filter_no_sensitive_words(self):
|
||||||
|
"""无敏感词时返回原内容"""
|
||||||
|
filter_obj = SensitiveFilter()
|
||||||
|
result = filter_obj.filter(
|
||||||
|
content="这是一个正常的产品介绍",
|
||||||
|
platform="zhihu"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.filtered_content == "这是一个正常的产品介绍"
|
||||||
|
assert len(result.found_words) == 0
|
||||||
|
assert len(result.replacements) == 0
|
||||||
|
|
||||||
|
def test_filter_multiple_categories(self):
|
||||||
|
"""多分类同时过滤"""
|
||||||
|
filter_obj = SensitiveFilter()
|
||||||
|
result = filter_obj.filter(
|
||||||
|
content="这是内容包含政治和医疗敏感词的内容",
|
||||||
|
platform="wechat"
|
||||||
|
)
|
||||||
|
|
||||||
|
categories = [w.category for w in result.found_words]
|
||||||
|
assert len(set(categories)) >= 1
|
||||||
|
|
||||||
|
def test_add_custom_words(self):
|
||||||
|
"""添加自定义敏感词"""
|
||||||
|
filter_obj = SensitiveFilter()
|
||||||
|
filter_obj.add_custom_words("custom", ["自定义敏感词1", "自定义敏感词2"])
|
||||||
|
|
||||||
|
result = filter_obj.filter(
|
||||||
|
content="这是一段包含自定义敏感词1的内容",
|
||||||
|
platform="zhihu"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 自定义敏感词应被替换
|
||||||
|
assert "自定义敏感词1" not in result.filtered_content
|
||||||
|
|
||||||
|
def test_filter_result_structure(self):
|
||||||
|
"""过滤结果结构完整性"""
|
||||||
|
filter_obj = SensitiveFilter()
|
||||||
|
result = filter_obj.filter(
|
||||||
|
content="测试内容",
|
||||||
|
platform="zhihu"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, FilterResult)
|
||||||
|
assert isinstance(result.filtered_content, str)
|
||||||
|
assert isinstance(result.found_words, list)
|
||||||
|
assert isinstance(result.replacements, dict)
|
||||||
|
|
||||||
|
def test_found_word_structure(self):
|
||||||
|
"""发现词汇结构完整性"""
|
||||||
|
filter_obj = SensitiveFilter()
|
||||||
|
result = filter_obj.filter(
|
||||||
|
content="台湾是中华人民共和国不可分割的一部分",
|
||||||
|
platform="zhihu"
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(result.found_words) > 0:
|
||||||
|
word = result.found_words[0]
|
||||||
|
assert isinstance(word, FoundWord)
|
||||||
|
assert isinstance(word.word, str)
|
||||||
|
assert isinstance(word.category, str)
|
||||||
|
assert isinstance(word.position, int)
|
||||||
|
assert isinstance(word.replacement, str)
|
||||||
|
|
||||||
|
def test_filter_replacement_mapping(self):
|
||||||
|
"""敏感词替换映射"""
|
||||||
|
filter_obj = SensitiveFilter()
|
||||||
|
result = filter_obj.filter(
|
||||||
|
content="这是一个关于台湾和西藏的内容",
|
||||||
|
platform="zhihu"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 替换映射应该包含被替换的词
|
||||||
|
assert len(result.replacements) > 0
|
||||||
|
for original, replacement in result.replacements.items():
|
||||||
|
assert len(original) == len(replacement)
|
||||||
|
assert replacement == "*" * len(original)
|
||||||
|
|
||||||
|
def test_filter_various_platforms(self):
|
||||||
|
"""不同平台敏感词过滤"""
|
||||||
|
filter_obj = SensitiveFilter()
|
||||||
|
platforms = ["zhihu", "wechat", "xiaohongshu", "douyin"]
|
||||||
|
|
||||||
|
for platform in platforms:
|
||||||
|
result = filter_obj.filter(
|
||||||
|
content="测试内容",
|
||||||
|
platform=platform
|
||||||
|
)
|
||||||
|
assert isinstance(result, FilterResult)
|
||||||
|
assert result.filtered_content is not None
|
||||||
|
|
||||||
|
def test_filter_empty_content(self):
|
||||||
|
"""空内容过滤"""
|
||||||
|
filter_obj = SensitiveFilter()
|
||||||
|
result = filter_obj.filter(
|
||||||
|
content="",
|
||||||
|
platform="zhihu"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.filtered_content == ""
|
||||||
|
assert len(result.found_words) == 0
|
||||||
|
|
||||||
|
def test_filter_replacement_char(self):
|
||||||
|
"""自定义替换字符"""
|
||||||
|
filter_obj = SensitiveFilter()
|
||||||
|
filter_obj.replacement_char = "#"
|
||||||
|
|
||||||
|
result = filter_obj.filter(
|
||||||
|
content="台湾是中华人民共和国的一部分",
|
||||||
|
platform="zhihu"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 替换字符应该是#
|
||||||
|
if len(result.replacements) > 0:
|
||||||
|
for original, replacement in result.replacements.items():
|
||||||
|
assert replacement == "#" * len(original)
|
||||||
|
|
||||||
|
def test_filter_with_zhihu_platform(self):
|
||||||
|
"""知乎平台敏感词过滤"""
|
||||||
|
filter_obj = SensitiveFilter()
|
||||||
|
result = filter_obj.filter(
|
||||||
|
content="这是一篇关于华为手机的专业评测",
|
||||||
|
platform="zhihu"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 正常内容不应有敏感词
|
||||||
|
assert result.filtered_content == "这是一篇关于华为手机的专业评测"
|
||||||
|
|
||||||
|
def test_filter_found_words_positions(self):
|
||||||
|
"""敏感词位置记录"""
|
||||||
|
filter_obj = SensitiveFilter()
|
||||||
|
result = filter_obj.filter(
|
||||||
|
content="台湾位于中国大陆东南沿海",
|
||||||
|
platform="zhihu"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查位置是否正确记录
|
||||||
|
if len(result.found_words) > 0:
|
||||||
|
for word in result.found_words:
|
||||||
|
assert word.position >= 0
|
||||||
|
assert result.filtered_content.find(word.replacement) >= 0
|
||||||
|
|
@ -0,0 +1,211 @@
|
||||||
|
"""SEO优化器测试"""
|
||||||
|
import pytest
|
||||||
|
from app.services.content.seo_optimizer import SEOOptimizer, OptimizationResult
|
||||||
|
|
||||||
|
|
||||||
|
class TestSEOOptimizer:
|
||||||
|
"""SEO优化器测试"""
|
||||||
|
|
||||||
|
def test_get_keyword_density_basic(self):
|
||||||
|
"""关键词密度基本计算"""
|
||||||
|
optimizer = SEOOptimizer()
|
||||||
|
content = "华为手机华为手机华为手机"
|
||||||
|
density = optimizer.get_keyword_density(content, "华为手机")
|
||||||
|
|
||||||
|
assert density > 0
|
||||||
|
assert isinstance(density, float)
|
||||||
|
|
||||||
|
def test_get_keyword_density_zero_content(self):
|
||||||
|
"""空内容密度计算"""
|
||||||
|
optimizer = SEOOptimizer()
|
||||||
|
density = optimizer.get_keyword_density("", "华为")
|
||||||
|
|
||||||
|
assert density == 0.0
|
||||||
|
|
||||||
|
def test_get_keyword_density_zero_keyword(self):
|
||||||
|
"""空关键词密度计算"""
|
||||||
|
optimizer = SEOOptimizer()
|
||||||
|
density = optimizer.get_keyword_density("内容", "")
|
||||||
|
|
||||||
|
assert density == 0.0
|
||||||
|
|
||||||
|
def test_get_keyword_density_no_match(self):
|
||||||
|
"""关键词不匹配时密度"""
|
||||||
|
optimizer = SEOOptimizer()
|
||||||
|
content = "这是一篇关于苹果的文章"
|
||||||
|
density = optimizer.get_keyword_density(content, "华为")
|
||||||
|
|
||||||
|
assert density == 0.0
|
||||||
|
|
||||||
|
def test_optimize_with_keyword(self):
|
||||||
|
"""带关键词的SEO优化"""
|
||||||
|
optimizer = SEOOptimizer()
|
||||||
|
result = optimizer.optimize(
|
||||||
|
content="华为手机是知名的手机品牌,华为手机性能出色。",
|
||||||
|
title="华为手机评测",
|
||||||
|
platform="zhihu",
|
||||||
|
keyword="华为手机"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, OptimizationResult)
|
||||||
|
assert isinstance(result.density, float)
|
||||||
|
assert isinstance(result.suggestions, list)
|
||||||
|
assert isinstance(result.tips, list)
|
||||||
|
|
||||||
|
def test_optimize_without_keyword(self):
|
||||||
|
"""不带关键词的SEO优化"""
|
||||||
|
optimizer = SEOOptimizer()
|
||||||
|
result = optimizer.optimize(
|
||||||
|
content="这是一篇关于手机评测的文章",
|
||||||
|
title="手机评测",
|
||||||
|
platform="zhihu"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.optimized_content is not None
|
||||||
|
assert result.density == 0.0
|
||||||
|
|
||||||
|
def test_optimize_density_too_low(self):
|
||||||
|
"""关键词密度过低建议"""
|
||||||
|
optimizer = SEOOptimizer()
|
||||||
|
result = optimizer.optimize(
|
||||||
|
content="这是一篇简短的内容",
|
||||||
|
title="标题",
|
||||||
|
platform="zhihu",
|
||||||
|
keyword="华为手机"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 密度过低应有建议
|
||||||
|
if result.density < 1.0:
|
||||||
|
assert len(result.suggestions) > 0
|
||||||
|
|
||||||
|
def test_optimize_density_too_high(self):
|
||||||
|
"""关键词密度过高建议"""
|
||||||
|
optimizer = SEOOptimizer()
|
||||||
|
result = optimizer.optimize(
|
||||||
|
content="华为华为华为华为华为华为华为华为华为华为",
|
||||||
|
title="华为",
|
||||||
|
platform="zhihu",
|
||||||
|
keyword="华为"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 密度过高应有建议
|
||||||
|
if result.density > 3.0:
|
||||||
|
assert any("超过" in s for s in result.suggestions)
|
||||||
|
|
||||||
|
def test_optimize_keyword_in_title(self):
|
||||||
|
"""关键词在标题中的检查"""
|
||||||
|
optimizer = SEOOptimizer()
|
||||||
|
result = optimizer.optimize(
|
||||||
|
content="手机内容",
|
||||||
|
title="华为手机评测",
|
||||||
|
platform="zhihu",
|
||||||
|
keyword="华为手机"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 关键词在标题中应有通过的建议
|
||||||
|
assert isinstance(result.suggestions, list)
|
||||||
|
|
||||||
|
def test_optimize_keyword_in_first_paragraph(self):
|
||||||
|
"""关键词在前100字的检查"""
|
||||||
|
optimizer = SEOOptimizer()
|
||||||
|
result = optimizer.optimize(
|
||||||
|
content="华为手机是本文的主题...",
|
||||||
|
title="评测",
|
||||||
|
platform="zhihu",
|
||||||
|
keyword="华为手机"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result.suggestions, list)
|
||||||
|
|
||||||
|
def test_optimize_different_platforms(self):
|
||||||
|
"""不同平台SEO优化"""
|
||||||
|
optimizer = SEOOptimizer()
|
||||||
|
platforms = ["zhihu", "wechat", "xiaohongshu", "baijiahao"]
|
||||||
|
|
||||||
|
for platform in platforms:
|
||||||
|
result = optimizer.optimize(
|
||||||
|
content="测试内容",
|
||||||
|
title="测试标题",
|
||||||
|
platform=platform,
|
||||||
|
keyword="测试"
|
||||||
|
)
|
||||||
|
assert isinstance(result, OptimizationResult)
|
||||||
|
|
||||||
|
def test_optimization_result_structure(self):
|
||||||
|
"""优化结果结构完整性"""
|
||||||
|
optimizer = SEOOptimizer()
|
||||||
|
result = optimizer.optimize(
|
||||||
|
content="华为手机是知名品牌",
|
||||||
|
title="华为评测",
|
||||||
|
platform="zhihu",
|
||||||
|
keyword="华为"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, OptimizationResult)
|
||||||
|
assert isinstance(result.optimized_content, str)
|
||||||
|
assert isinstance(result.density, float)
|
||||||
|
assert isinstance(result.suggestions, list)
|
||||||
|
assert isinstance(result.tips, list)
|
||||||
|
|
||||||
|
def test_optimize_returns_suggestions(self):
|
||||||
|
"""优化建议列表返回"""
|
||||||
|
optimizer = SEOOptimizer()
|
||||||
|
result = optimizer.optimize(
|
||||||
|
content="内容过短",
|
||||||
|
title="标题",
|
||||||
|
platform="zhihu",
|
||||||
|
keyword="关键词"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result.suggestions, list)
|
||||||
|
|
||||||
|
def test_optimize_returns_tips(self):
|
||||||
|
"""优化提示列表返回"""
|
||||||
|
optimizer = SEOOptimizer()
|
||||||
|
result = optimizer.optimize(
|
||||||
|
content="测试内容",
|
||||||
|
title="标题",
|
||||||
|
platform="zhihu"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result.tips, list)
|
||||||
|
|
||||||
|
def test_density_calculation_accuracy(self):
|
||||||
|
"""密度计算准确性"""
|
||||||
|
optimizer = SEOOptimizer()
|
||||||
|
# "华为" 4个字符,出现2次 = 8 / 总字符数 * 100
|
||||||
|
content = "华为手机华为"
|
||||||
|
total_chars = len(content) # 6
|
||||||
|
keyword = "华为"
|
||||||
|
keyword_count = content.count(keyword) # 2
|
||||||
|
expected_density = (len(keyword) * keyword_count) / total_chars * 100
|
||||||
|
|
||||||
|
actual_density = optimizer.get_keyword_density(content, keyword)
|
||||||
|
|
||||||
|
assert abs(actual_density - round(expected_density, 2)) < 0.1
|
||||||
|
|
||||||
|
def test_optimize_empty_content(self):
|
||||||
|
"""空内容优化"""
|
||||||
|
optimizer = SEOOptimizer()
|
||||||
|
result = optimizer.optimize(
|
||||||
|
content="",
|
||||||
|
title="标题",
|
||||||
|
platform="zhihu",
|
||||||
|
keyword="关键词"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.density == 0.0
|
||||||
|
|
||||||
|
def test_optimize_invalid_platform(self):
|
||||||
|
"""无效平台优化"""
|
||||||
|
optimizer = SEOOptimizer()
|
||||||
|
result = optimizer.optimize(
|
||||||
|
content="测试内容",
|
||||||
|
title="标题",
|
||||||
|
platform="invalid_platform",
|
||||||
|
keyword="关键词"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 无效平台应返回空建议和提示
|
||||||
|
assert isinstance(result.suggestions, list)
|
||||||
|
assert isinstance(result.tips, list)
|
||||||
|
|
@ -0,0 +1,496 @@
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import Mock, patch, MagicMock
|
||||||
|
from app.services.email_service import (
|
||||||
|
EmailService,
|
||||||
|
EmailMessage,
|
||||||
|
EmailSendResult,
|
||||||
|
EMAIL_TEMPLATES,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestEmailMessage:
|
||||||
|
"""邮件消息数据结构测试"""
|
||||||
|
|
||||||
|
def test_email_message_creation(self):
|
||||||
|
"""测试EmailMessage数据创建"""
|
||||||
|
msg = EmailMessage(
|
||||||
|
to="test@example.com",
|
||||||
|
subject="测试邮件",
|
||||||
|
body_html="<h1>测试</h1>",
|
||||||
|
body_text="测试邮件",
|
||||||
|
)
|
||||||
|
assert msg.to == "test@example.com"
|
||||||
|
assert msg.subject == "测试邮件"
|
||||||
|
assert msg.body_html == "<h1>测试</h1>"
|
||||||
|
assert msg.body_text == "测试邮件"
|
||||||
|
assert msg.attachments == []
|
||||||
|
assert msg.metadata == {}
|
||||||
|
|
||||||
|
def test_email_message_with_attachments(self):
|
||||||
|
"""测试带附件的邮件消息"""
|
||||||
|
msg = EmailMessage(
|
||||||
|
to="test@example.com",
|
||||||
|
subject="带附件",
|
||||||
|
body_html="<p>内容</p>",
|
||||||
|
body_text="内容",
|
||||||
|
attachments=[{"filename": "report.pdf", "content": b"data"}],
|
||||||
|
)
|
||||||
|
assert len(msg.attachments) == 1
|
||||||
|
assert msg.attachments[0]["filename"] == "report.pdf"
|
||||||
|
|
||||||
|
def test_email_message_with_metadata(self):
|
||||||
|
"""测试带元数据的邮件消息"""
|
||||||
|
msg = EmailMessage(
|
||||||
|
to="test@example.com",
|
||||||
|
subject="测试",
|
||||||
|
body_html="<p>内容</p>",
|
||||||
|
body_text="内容",
|
||||||
|
metadata={"brand_name": "test_brand", "alert_type": "error"},
|
||||||
|
)
|
||||||
|
assert msg.metadata["brand_name"] == "test_brand"
|
||||||
|
|
||||||
|
|
||||||
|
class TestEmailSendResult:
|
||||||
|
"""邮件发送结果数据结构测试"""
|
||||||
|
|
||||||
|
def test_email_send_result_success(self):
|
||||||
|
"""测试成功发送结果"""
|
||||||
|
result = EmailSendResult(
|
||||||
|
success=True,
|
||||||
|
message_id="msg_123",
|
||||||
|
error=None,
|
||||||
|
retry_count=0,
|
||||||
|
)
|
||||||
|
assert result.success is True
|
||||||
|
assert result.message_id == "msg_123"
|
||||||
|
assert result.error is None
|
||||||
|
assert result.retry_count == 0
|
||||||
|
|
||||||
|
def test_email_send_result_failure(self):
|
||||||
|
"""测试失败发送结果"""
|
||||||
|
result = EmailSendResult(
|
||||||
|
success=False,
|
||||||
|
message_id=None,
|
||||||
|
error="SMTP连接失败",
|
||||||
|
retry_count=3,
|
||||||
|
)
|
||||||
|
assert result.success is False
|
||||||
|
assert result.message_id is None
|
||||||
|
assert result.error == "SMTP连接失败"
|
||||||
|
assert result.retry_count == 3
|
||||||
|
|
||||||
|
|
||||||
|
class TestEmailTemplateRendering:
|
||||||
|
"""邮件模板渲染测试"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def email_service(self):
|
||||||
|
"""创建邮件服务实例"""
|
||||||
|
return EmailService()
|
||||||
|
|
||||||
|
def test_render_alert_notification_template(self, email_service):
|
||||||
|
"""测试告警通知模板渲染"""
|
||||||
|
variables = {
|
||||||
|
"alert_type": "系统错误",
|
||||||
|
"brand_name": "测试品牌",
|
||||||
|
"severity": "高",
|
||||||
|
"description": "数据库连接失败",
|
||||||
|
"timestamp": "2024-01-01 12:00:00",
|
||||||
|
}
|
||||||
|
msg = email_service.render_template("alert_notification", "admin@example.com", variables)
|
||||||
|
|
||||||
|
assert msg.to == "admin@example.com"
|
||||||
|
assert "[GEO平台] 告警通知:系统错误" in msg.subject
|
||||||
|
assert "测试品牌" in msg.body_html
|
||||||
|
assert "系统错误" in msg.body_html
|
||||||
|
assert "高" in msg.body_html
|
||||||
|
|
||||||
|
def test_render_quota_warning_template(self, email_service):
|
||||||
|
"""测试额度预警模板渲染"""
|
||||||
|
variables = {
|
||||||
|
"quota_type": "API调用",
|
||||||
|
"usage_percentage": 85,
|
||||||
|
"used": 850,
|
||||||
|
"limit": 1000,
|
||||||
|
"recommended_action": "请升级套餐",
|
||||||
|
}
|
||||||
|
msg = email_service.render_template("quota_warning", "user@example.com", variables)
|
||||||
|
|
||||||
|
assert msg.to == "user@example.com"
|
||||||
|
assert "[GEO平台] 额度预警:API调用" in msg.subject
|
||||||
|
assert "85%" in msg.body_html
|
||||||
|
assert "850" in msg.body_html
|
||||||
|
assert "1000" in msg.body_html
|
||||||
|
|
||||||
|
def test_render_template_missing_variables(self, email_service):
|
||||||
|
"""测试模板渲染缺少变量"""
|
||||||
|
variables = {"alert_type": "系统错误"}
|
||||||
|
msg = email_service.render_template("alert_notification", "admin@example.com", variables)
|
||||||
|
|
||||||
|
assert msg is not None
|
||||||
|
assert msg.to == "admin@example.com"
|
||||||
|
|
||||||
|
def test_render_template_invalid_template(self, email_service):
|
||||||
|
"""测试无效模板名称"""
|
||||||
|
with pytest.raises(ValueError, match="模板不存在"):
|
||||||
|
email_service.render_template(
|
||||||
|
"invalid_template",
|
||||||
|
"admin@example.com",
|
||||||
|
{"key": "value"},
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_render_template_variable_substitution(self, email_service):
|
||||||
|
"""测试模板变量替换"""
|
||||||
|
variables = {
|
||||||
|
"brand_name": "品牌A",
|
||||||
|
"alert_type": "告警B",
|
||||||
|
"severity": "严重",
|
||||||
|
"description": "描述C",
|
||||||
|
"timestamp": "时间D",
|
||||||
|
}
|
||||||
|
msg = email_service.render_template("alert_notification", "test@example.com", variables)
|
||||||
|
|
||||||
|
assert "品牌A" in msg.body_html
|
||||||
|
assert "告警B" in msg.body_html
|
||||||
|
assert "严重" in msg.body_html
|
||||||
|
assert "描述C" in msg.body_html
|
||||||
|
assert "时间D" in msg.body_html
|
||||||
|
|
||||||
|
|
||||||
|
class TestEmailGeneration:
|
||||||
|
"""邮件内容生成测试"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def email_service(self):
|
||||||
|
"""创建邮件服务实例"""
|
||||||
|
return EmailService()
|
||||||
|
|
||||||
|
def test_generate_alert_notification_email(self, email_service):
|
||||||
|
"""测试生成告警通知邮件"""
|
||||||
|
msg = email_service.generate_alert_email(
|
||||||
|
to="admin@example.com",
|
||||||
|
alert_type="数据库告警",
|
||||||
|
brand_name="测试品牌",
|
||||||
|
severity="高",
|
||||||
|
description="数据库CPU使用率超过90%",
|
||||||
|
timestamp="2024-01-01 12:00:00",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert msg.to == "admin@example.com"
|
||||||
|
assert "数据库告警" in msg.subject
|
||||||
|
assert "测试品牌" in msg.body_html
|
||||||
|
assert "高" in msg.body_html
|
||||||
|
assert msg.body_text != ""
|
||||||
|
|
||||||
|
def test_generate_quota_warning_email(self, email_service):
|
||||||
|
"""测试生成额度预警邮件"""
|
||||||
|
msg = email_service.generate_quota_warning_email(
|
||||||
|
to="user@example.com",
|
||||||
|
quota_type="API调用",
|
||||||
|
usage_percentage=85,
|
||||||
|
used=850,
|
||||||
|
limit=1000,
|
||||||
|
recommended_action="建议升级套餐",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert msg.to == "user@example.com"
|
||||||
|
assert "API调用" in msg.subject
|
||||||
|
assert "85%" in msg.body_html
|
||||||
|
assert "850" in msg.body_html
|
||||||
|
assert "1000" in msg.body_html
|
||||||
|
|
||||||
|
def test_generate_email_has_both_formats(self, email_service):
|
||||||
|
"""测试生成的邮件包含HTML和纯文本格式"""
|
||||||
|
msg = email_service.generate_alert_email(
|
||||||
|
to="test@example.com",
|
||||||
|
alert_type="测试",
|
||||||
|
brand_name="品牌",
|
||||||
|
severity="中",
|
||||||
|
description="描述",
|
||||||
|
timestamp="时间",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert msg.body_html != ""
|
||||||
|
assert msg.body_text != ""
|
||||||
|
assert "<" in msg.body_html
|
||||||
|
assert "<" not in msg.body_text
|
||||||
|
|
||||||
|
|
||||||
|
class TestEmailSending:
|
||||||
|
"""邮件发送测试"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def email_service(self):
|
||||||
|
"""创建邮件服务实例(模拟模式)"""
|
||||||
|
return EmailService(simulate_mode=True)
|
||||||
|
|
||||||
|
def test_send_email_simulate_mode(self, email_service):
|
||||||
|
"""测试模拟模式发送邮件"""
|
||||||
|
msg = EmailMessage(
|
||||||
|
to="test@example.com",
|
||||||
|
subject="测试",
|
||||||
|
body_html="<p>测试</p>",
|
||||||
|
body_text="测试",
|
||||||
|
)
|
||||||
|
result = email_service.send_email(msg)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.message_id is not None
|
||||||
|
assert result.error is None
|
||||||
|
|
||||||
|
@patch("smtplib.SMTP")
|
||||||
|
def test_send_email_real_smtp(self, mock_smtp):
|
||||||
|
"""测试真实SMTP发送(模拟SMTP)"""
|
||||||
|
mock_server = MagicMock()
|
||||||
|
mock_smtp.return_value = mock_server
|
||||||
|
|
||||||
|
service = EmailService(
|
||||||
|
simulate_mode=False,
|
||||||
|
smtp_host="smtp.example.com",
|
||||||
|
smtp_port=587,
|
||||||
|
smtp_user="user@example.com",
|
||||||
|
smtp_password="password",
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = EmailMessage(
|
||||||
|
to="recipient@example.com",
|
||||||
|
subject="测试",
|
||||||
|
body_html="<p>内容</p>",
|
||||||
|
body_text="内容",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = service.send_email(msg)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
mock_smtp.assert_called_once_with("smtp.example.com", 587)
|
||||||
|
mock_server.starttls.assert_called_once()
|
||||||
|
mock_server.login.assert_called_once_with("user@example.com", "password")
|
||||||
|
|
||||||
|
@patch("smtplib.SMTP")
|
||||||
|
def test_send_email_smtp_failure(self, mock_smtp):
|
||||||
|
"""测试SMTP发送失败"""
|
||||||
|
mock_smtp.side_effect = Exception("连接失败")
|
||||||
|
|
||||||
|
service = EmailService(
|
||||||
|
simulate_mode=False,
|
||||||
|
smtp_host="smtp.example.com",
|
||||||
|
smtp_port=587,
|
||||||
|
smtp_user="user@example.com",
|
||||||
|
smtp_password="password",
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = EmailMessage(
|
||||||
|
to="test@example.com",
|
||||||
|
subject="测试",
|
||||||
|
body_html="<p>内容</p>",
|
||||||
|
body_text="内容",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = service.send_email(msg)
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert result.error is not None
|
||||||
|
assert "连接失败" in result.error
|
||||||
|
|
||||||
|
@patch("smtplib.SMTP")
|
||||||
|
def test_send_email_with_retry(self, mock_smtp):
|
||||||
|
"""测试邮件发送重试"""
|
||||||
|
mock_server = MagicMock()
|
||||||
|
mock_smtp.return_value = mock_server
|
||||||
|
|
||||||
|
service = EmailService(
|
||||||
|
simulate_mode=False,
|
||||||
|
smtp_host="smtp.example.com",
|
||||||
|
smtp_port=587,
|
||||||
|
smtp_user="user@example.com",
|
||||||
|
smtp_password="password",
|
||||||
|
max_retries=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = EmailMessage(
|
||||||
|
to="test@example.com",
|
||||||
|
subject="测试",
|
||||||
|
body_html="<p>内容</p>",
|
||||||
|
body_text="内容",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = service.send_email(msg)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.retry_count == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestEmailQueue:
|
||||||
|
"""邮件队列测试"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def email_service(self):
|
||||||
|
"""创建邮件服务实例"""
|
||||||
|
return EmailService(simulate_mode=True)
|
||||||
|
|
||||||
|
def test_add_to_queue(self, email_service):
|
||||||
|
"""测试添加邮件到队列"""
|
||||||
|
msg = EmailMessage(
|
||||||
|
to="test@example.com",
|
||||||
|
subject="测试",
|
||||||
|
body_html="<p>内容</p>",
|
||||||
|
body_text="内容",
|
||||||
|
)
|
||||||
|
|
||||||
|
email_service.add_to_queue(msg)
|
||||||
|
|
||||||
|
assert len(email_service.get_queue()) == 1
|
||||||
|
assert email_service.get_queue()[0].to == "test@example.com"
|
||||||
|
|
||||||
|
def test_add_multiple_to_queue(self, email_service):
|
||||||
|
"""测试批量添加邮件到队列"""
|
||||||
|
messages = [
|
||||||
|
EmailMessage(to=f"user{i}@example.com", subject=f"测试{i}", body_html="<p>内容</p>", body_text="内容")
|
||||||
|
for i in range(5)
|
||||||
|
]
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
email_service.add_to_queue(msg)
|
||||||
|
|
||||||
|
assert len(email_service.get_queue()) == 5
|
||||||
|
|
||||||
|
def test_send_queue(self, email_service):
|
||||||
|
"""测试发送队列中的邮件"""
|
||||||
|
messages = [
|
||||||
|
EmailMessage(to=f"user{i}@example.com", subject=f"测试{i}", body_html="<p>内容</p>", body_text="内容")
|
||||||
|
for i in range(3)
|
||||||
|
]
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
email_service.add_to_queue(msg)
|
||||||
|
|
||||||
|
results = email_service.send_queue()
|
||||||
|
|
||||||
|
assert len(results) == 3
|
||||||
|
assert all(r.success for r in results)
|
||||||
|
assert len(email_service.get_queue()) == 0
|
||||||
|
|
||||||
|
def test_send_queue_empty(self, email_service):
|
||||||
|
"""测试发送空队列"""
|
||||||
|
results = email_service.send_queue()
|
||||||
|
|
||||||
|
assert results == []
|
||||||
|
|
||||||
|
def test_clear_queue(self, email_service):
|
||||||
|
"""测试清空队列"""
|
||||||
|
msg = EmailMessage(
|
||||||
|
to="test@example.com",
|
||||||
|
subject="测试",
|
||||||
|
body_html="<p>内容</p>",
|
||||||
|
body_text="内容",
|
||||||
|
)
|
||||||
|
email_service.add_to_queue(msg)
|
||||||
|
assert len(email_service.get_queue()) == 1
|
||||||
|
|
||||||
|
email_service.clear_queue()
|
||||||
|
|
||||||
|
assert len(email_service.get_queue()) == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestEmailAttachments:
|
||||||
|
"""邮件附件测试"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def email_service(self):
|
||||||
|
"""创建邮件服务实例"""
|
||||||
|
return EmailService(simulate_mode=True)
|
||||||
|
|
||||||
|
def test_add_attachment_to_message(self, email_service):
|
||||||
|
"""测试添加附件到邮件"""
|
||||||
|
msg = EmailMessage(
|
||||||
|
to="test@example.com",
|
||||||
|
subject="带附件",
|
||||||
|
body_html="<p>内容</p>",
|
||||||
|
body_text="内容",
|
||||||
|
)
|
||||||
|
|
||||||
|
email_service.add_attachment(msg, "report.pdf", b"PDF content")
|
||||||
|
|
||||||
|
assert len(msg.attachments) == 1
|
||||||
|
assert msg.attachments[0]["filename"] == "report.pdf"
|
||||||
|
assert msg.attachments[0]["content"] == b"PDF content"
|
||||||
|
|
||||||
|
def test_add_multiple_attachments(self, email_service):
|
||||||
|
"""测试添加多个附件"""
|
||||||
|
msg = EmailMessage(
|
||||||
|
to="test@example.com",
|
||||||
|
subject="多附件",
|
||||||
|
body_html="<p>内容</p>",
|
||||||
|
body_text="内容",
|
||||||
|
)
|
||||||
|
|
||||||
|
email_service.add_attachment(msg, "file1.pdf", b"content1")
|
||||||
|
email_service.add_attachment(msg, "file2.xlsx", b"content2")
|
||||||
|
|
||||||
|
assert len(msg.attachments) == 2
|
||||||
|
assert msg.attachments[0]["filename"] == "file1.pdf"
|
||||||
|
assert msg.attachments[1]["filename"] == "file2.xlsx"
|
||||||
|
|
||||||
|
|
||||||
|
class TestEmailValidation:
|
||||||
|
"""邮箱地址验证测试"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def email_service(self):
|
||||||
|
"""创建邮件服务实例"""
|
||||||
|
return EmailService()
|
||||||
|
|
||||||
|
def test_validate_valid_email(self, email_service):
|
||||||
|
"""测试有效邮箱地址"""
|
||||||
|
assert email_service.validate_email("test@example.com") is True
|
||||||
|
assert email_service.validate_email("user.name@domain.org") is True
|
||||||
|
assert email_service.validate_email("user+tag@example.com") is True
|
||||||
|
|
||||||
|
def test_validate_invalid_email(self, email_service):
|
||||||
|
"""测试无效邮箱地址"""
|
||||||
|
assert email_service.validate_email("invalid") is False
|
||||||
|
assert email_service.validate_email("invalid@") is False
|
||||||
|
assert email_service.validate_email("@example.com") is False
|
||||||
|
assert email_service.validate_email("test@") is False
|
||||||
|
assert email_service.validate_email("") is False
|
||||||
|
|
||||||
|
def test_send_to_invalid_email(self, email_service):
|
||||||
|
"""测试发送到无效邮箱"""
|
||||||
|
msg = EmailMessage(
|
||||||
|
to="invalid_email",
|
||||||
|
subject="测试",
|
||||||
|
body_html="<p>内容</p>",
|
||||||
|
body_text="内容",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = email_service.send_email(msg)
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert result.error is not None
|
||||||
|
|
||||||
|
|
||||||
|
class TestEmailTemplatesConstants:
|
||||||
|
"""邮件模板常量测试"""
|
||||||
|
|
||||||
|
def test_email_templates_exist(self):
|
||||||
|
"""测试邮件模板存在"""
|
||||||
|
assert "alert_notification" in EMAIL_TEMPLATES
|
||||||
|
assert "quota_warning" in EMAIL_TEMPLATES
|
||||||
|
|
||||||
|
def test_alert_notification_template_structure(self):
|
||||||
|
"""测试告警通知模板结构"""
|
||||||
|
template = EMAIL_TEMPLATES["alert_notification"]
|
||||||
|
assert "subject" in template
|
||||||
|
assert "body_html" in template
|
||||||
|
assert "body_text" in template
|
||||||
|
assert "{alert_type}" in template["subject"]
|
||||||
|
assert "{brand_name}" in template["body_html"]
|
||||||
|
|
||||||
|
def test_quota_warning_template_structure(self):
|
||||||
|
"""测试额度预警模板结构"""
|
||||||
|
template = EMAIL_TEMPLATES["quota_warning"]
|
||||||
|
assert "subject" in template
|
||||||
|
assert "body_html" in template
|
||||||
|
assert "body_text" in template
|
||||||
|
assert "{quota_type}" in template["subject"]
|
||||||
|
assert "{usage_percentage}" in template["body_html"]
|
||||||
|
|
@ -0,0 +1,605 @@
|
||||||
|
"""
|
||||||
|
GEO诊断服务单元测试
|
||||||
|
|
||||||
|
测试6大维度诊断逻辑、评分算法、推荐生成和服务类
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
from app.services.geo_diagnosis import (
|
||||||
|
GEODiagnosisService,
|
||||||
|
GEODiagnosisInput,
|
||||||
|
diagnose_content_extractability,
|
||||||
|
diagnose_entity_clarity,
|
||||||
|
diagnose_eeat_signals,
|
||||||
|
diagnose_schema_markup,
|
||||||
|
diagnose_topic_authority,
|
||||||
|
diagnose_citation_readiness,
|
||||||
|
generate_recommendations,
|
||||||
|
get_health_level,
|
||||||
|
get_health_level_label,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestContentExtractability:
|
||||||
|
"""内容可提取性诊断测试"""
|
||||||
|
|
||||||
|
def test_all_pass(self):
|
||||||
|
"""所有项都通过"""
|
||||||
|
result = diagnose_content_extractability(
|
||||||
|
has_direct_answer=True,
|
||||||
|
has_qa_headings=True,
|
||||||
|
has_structured_data=True,
|
||||||
|
has_internal_links=True,
|
||||||
|
has_freshness_info=True,
|
||||||
|
update_days_ago=10,
|
||||||
|
)
|
||||||
|
assert result.score == 20.0
|
||||||
|
assert result.max_score == 20.0
|
||||||
|
assert result.status == "pass"
|
||||||
|
assert result.percentage == 100.0
|
||||||
|
|
||||||
|
def test_all_fail(self):
|
||||||
|
"""所有项都失败"""
|
||||||
|
result = diagnose_content_extractability(
|
||||||
|
has_direct_answer=False,
|
||||||
|
has_qa_headings=False,
|
||||||
|
has_structured_data=False,
|
||||||
|
has_internal_links=False,
|
||||||
|
has_freshness_info=False,
|
||||||
|
)
|
||||||
|
assert result.score == 0.0
|
||||||
|
assert result.status == "warning"
|
||||||
|
|
||||||
|
def test_partial_pass(self):
|
||||||
|
"""部分通过"""
|
||||||
|
result = diagnose_content_extractability(
|
||||||
|
has_direct_answer=True,
|
||||||
|
has_qa_headings=True,
|
||||||
|
has_structured_data=False,
|
||||||
|
has_internal_links=False,
|
||||||
|
has_freshness_info=False,
|
||||||
|
)
|
||||||
|
assert result.score == 11.0 # 6 + 5
|
||||||
|
|
||||||
|
def test_freshness_recent(self):
|
||||||
|
"""内容新鲜度 - 近期更新"""
|
||||||
|
result = diagnose_content_extractability(
|
||||||
|
has_freshness_info=True,
|
||||||
|
update_days_ago=10,
|
||||||
|
)
|
||||||
|
freshness_item = [i for i in result.items if i.name == "内容新鲜度"][0]
|
||||||
|
assert freshness_item.score == 2.0
|
||||||
|
assert freshness_item.status == "pass"
|
||||||
|
|
||||||
|
def test_freshness_old(self):
|
||||||
|
"""内容新鲜度 - 过期更新"""
|
||||||
|
result = diagnose_content_extractability(
|
||||||
|
has_freshness_info=True,
|
||||||
|
update_days_ago=100,
|
||||||
|
)
|
||||||
|
freshness_item = [i for i in result.items if i.name == "内容新鲜度"][0]
|
||||||
|
assert freshness_item.score == 0.5
|
||||||
|
assert freshness_item.status == "warning"
|
||||||
|
|
||||||
|
|
||||||
|
class TestEntityClarity:
|
||||||
|
"""实体清晰度诊断测试"""
|
||||||
|
|
||||||
|
def test_all_pass(self):
|
||||||
|
"""所有项都通过"""
|
||||||
|
result = diagnose_entity_clarity(
|
||||||
|
has_brand_definition=True,
|
||||||
|
has_target_audience=True,
|
||||||
|
has_unique_value=True,
|
||||||
|
has_industry_classification=True,
|
||||||
|
)
|
||||||
|
assert result.score == 15.0
|
||||||
|
assert result.max_score == 15.0
|
||||||
|
assert result.status == "pass"
|
||||||
|
|
||||||
|
def test_all_fail(self):
|
||||||
|
"""所有项都失败"""
|
||||||
|
result = diagnose_entity_clarity(
|
||||||
|
has_brand_definition=False,
|
||||||
|
has_target_audience=False,
|
||||||
|
has_unique_value=False,
|
||||||
|
has_industry_classification=False,
|
||||||
|
)
|
||||||
|
assert result.score == 0.0
|
||||||
|
assert result.status == "warning"
|
||||||
|
|
||||||
|
def test_partial_pass(self):
|
||||||
|
"""部分通过"""
|
||||||
|
result = diagnose_entity_clarity(
|
||||||
|
has_brand_definition=True,
|
||||||
|
has_target_audience=True,
|
||||||
|
has_unique_value=False,
|
||||||
|
has_industry_classification=False,
|
||||||
|
)
|
||||||
|
assert result.score == 9.0 # 5 + 4
|
||||||
|
|
||||||
|
|
||||||
|
class TestEEATSignals:
|
||||||
|
"""E-E-A-T信号诊断测试"""
|
||||||
|
|
||||||
|
def test_all_pass(self):
|
||||||
|
"""所有项都通过"""
|
||||||
|
result = diagnose_eeat_signals(
|
||||||
|
has_author_bio=True,
|
||||||
|
author_credentials_complete=1.0,
|
||||||
|
has_certifications=True,
|
||||||
|
certification_count=5,
|
||||||
|
has_data_sources=True,
|
||||||
|
authoritative_source_ratio=1.0,
|
||||||
|
has_expert_endorsements=True,
|
||||||
|
endorsement_count=5,
|
||||||
|
)
|
||||||
|
assert result.score == 20.0
|
||||||
|
assert result.max_score == 20.0
|
||||||
|
assert result.status == "pass"
|
||||||
|
|
||||||
|
def test_all_fail(self):
|
||||||
|
"""所有项都失败"""
|
||||||
|
result = diagnose_eeat_signals(
|
||||||
|
has_author_bio=False,
|
||||||
|
has_certifications=False,
|
||||||
|
has_data_sources=False,
|
||||||
|
has_expert_endorsements=False,
|
||||||
|
)
|
||||||
|
assert result.score == 0.0
|
||||||
|
assert result.status == "warning"
|
||||||
|
|
||||||
|
def test_author_partial(self):
|
||||||
|
"""作者资质部分完成"""
|
||||||
|
result = diagnose_eeat_signals(
|
||||||
|
has_author_bio=True,
|
||||||
|
author_credentials_complete=0.5,
|
||||||
|
)
|
||||||
|
author_item = [i for i in result.items if i.name == "作者资质"][0]
|
||||||
|
assert author_item.score == 3.0 # 0.5 * 6.0
|
||||||
|
assert author_item.status == "warning"
|
||||||
|
|
||||||
|
def test_certification_tiers(self):
|
||||||
|
"""认证数量分级测试"""
|
||||||
|
# 5个以上
|
||||||
|
result = diagnose_eeat_signals(has_certifications=True, certification_count=5)
|
||||||
|
cert_item = [i for i in result.items if i.name == "专业认证"][0]
|
||||||
|
assert cert_item.score == 5.0
|
||||||
|
|
||||||
|
# 3-4个
|
||||||
|
result = diagnose_eeat_signals(has_certifications=True, certification_count=3)
|
||||||
|
cert_item = [i for i in result.items if i.name == "专业认证"][0]
|
||||||
|
assert cert_item.score == 4.0
|
||||||
|
|
||||||
|
# 1-2个
|
||||||
|
result = diagnose_eeat_signals(has_certifications=True, certification_count=1)
|
||||||
|
cert_item = [i for i in result.items if i.name == "专业认证"][0]
|
||||||
|
assert cert_item.score == 2.5
|
||||||
|
|
||||||
|
|
||||||
|
class TestSchemaMarkup:
|
||||||
|
"""Schema标记诊断测试"""
|
||||||
|
|
||||||
|
def test_all_pass(self):
|
||||||
|
"""所有项都通过"""
|
||||||
|
result = diagnose_schema_markup(
|
||||||
|
has_organization=True,
|
||||||
|
has_product=True,
|
||||||
|
has_article=True,
|
||||||
|
has_faq=True,
|
||||||
|
has_howto=True,
|
||||||
|
has_breadcrumb=True,
|
||||||
|
)
|
||||||
|
assert result.score == 15.0
|
||||||
|
assert result.max_score == 15.0
|
||||||
|
assert result.status == "pass"
|
||||||
|
|
||||||
|
def test_all_fail(self):
|
||||||
|
"""所有项都失败"""
|
||||||
|
result = diagnose_schema_markup(
|
||||||
|
has_organization=False,
|
||||||
|
has_product=False,
|
||||||
|
has_article=False,
|
||||||
|
has_faq=False,
|
||||||
|
has_howto=False,
|
||||||
|
has_breadcrumb=False,
|
||||||
|
)
|
||||||
|
assert result.score == 0.0
|
||||||
|
assert result.status == "warning"
|
||||||
|
|
||||||
|
def test_p0_only(self):
|
||||||
|
"""仅P0必须项"""
|
||||||
|
result = diagnose_schema_markup(
|
||||||
|
has_organization=True,
|
||||||
|
has_product=True,
|
||||||
|
has_article=True,
|
||||||
|
)
|
||||||
|
assert result.score == 10.0 # 4 + 3 + 3
|
||||||
|
assert result.status == "pass"
|
||||||
|
|
||||||
|
def test_schema_count(self):
|
||||||
|
"""Schema计数"""
|
||||||
|
result = diagnose_schema_markup(
|
||||||
|
has_organization=True,
|
||||||
|
has_product=True,
|
||||||
|
)
|
||||||
|
assert result.detail["schema_count"] == 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestTopicAuthority:
|
||||||
|
"""主题权威诊断测试"""
|
||||||
|
|
||||||
|
def test_all_pass(self):
|
||||||
|
"""所有项都通过"""
|
||||||
|
result = diagnose_topic_authority(
|
||||||
|
content_depth_score=1.0,
|
||||||
|
topic_coverage_ratio=1.0,
|
||||||
|
entity_consistency_score=1.0,
|
||||||
|
cluster_completeness=1.0,
|
||||||
|
)
|
||||||
|
assert result.score == 15.0
|
||||||
|
assert result.max_score == 15.0
|
||||||
|
assert result.status == "pass"
|
||||||
|
|
||||||
|
def test_all_fail(self):
|
||||||
|
"""所有项都失败"""
|
||||||
|
result = diagnose_topic_authority(
|
||||||
|
content_depth_score=0.0,
|
||||||
|
topic_coverage_ratio=0.0,
|
||||||
|
entity_consistency_score=0.0,
|
||||||
|
cluster_completeness=0.0,
|
||||||
|
)
|
||||||
|
assert result.score == 0.0
|
||||||
|
assert result.status == "warning"
|
||||||
|
|
||||||
|
def test_partial_scores(self):
|
||||||
|
"""部分分数"""
|
||||||
|
result = diagnose_topic_authority(
|
||||||
|
content_depth_score=0.8,
|
||||||
|
topic_coverage_ratio=0.5,
|
||||||
|
entity_consistency_score=0.7,
|
||||||
|
cluster_completeness=0.4,
|
||||||
|
)
|
||||||
|
# 0.8*5 + 0.5*4 + 0.7*3 + 0.4*3 = 4 + 2 + 2.1 + 1.2 = 9.3
|
||||||
|
assert result.score == pytest.approx(9.3, rel=0.01)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCitationReadiness:
|
||||||
|
"""引用就绪度诊断测试"""
|
||||||
|
|
||||||
|
def test_all_pass(self):
|
||||||
|
"""所有项都通过"""
|
||||||
|
result = diagnose_citation_readiness(
|
||||||
|
answer_ownership_rate=0.6,
|
||||||
|
citation_accuracy=1.0,
|
||||||
|
ai_sov=0.35,
|
||||||
|
competitor_gap=0.0,
|
||||||
|
)
|
||||||
|
assert result.score == 15.0
|
||||||
|
assert result.max_score == 15.0
|
||||||
|
assert result.status == "pass"
|
||||||
|
|
||||||
|
def test_all_fail(self):
|
||||||
|
"""所有项都失败"""
|
||||||
|
result = diagnose_citation_readiness(
|
||||||
|
answer_ownership_rate=0.0,
|
||||||
|
citation_accuracy=0.0,
|
||||||
|
ai_sov=0.0,
|
||||||
|
competitor_gap=0.6,
|
||||||
|
)
|
||||||
|
assert result.score == 0.0
|
||||||
|
assert result.status == "warning"
|
||||||
|
|
||||||
|
def test_aor_tiers(self):
|
||||||
|
"""AOR分级测试"""
|
||||||
|
# >= 50%
|
||||||
|
result = diagnose_citation_readiness(answer_ownership_rate=0.5)
|
||||||
|
aor_item = [i for i in result.items if i.name == "引用频率 (AOR)"][0]
|
||||||
|
assert aor_item.score == 5.0
|
||||||
|
|
||||||
|
# 30-49%
|
||||||
|
result = diagnose_citation_readiness(answer_ownership_rate=0.3)
|
||||||
|
aor_item = [i for i in result.items if i.name == "引用频率 (AOR)"][0]
|
||||||
|
assert aor_item.score == 3.5
|
||||||
|
|
||||||
|
# 10-29%
|
||||||
|
result = diagnose_citation_readiness(answer_ownership_rate=0.1)
|
||||||
|
aor_item = [i for i in result.items if i.name == "引用频率 (AOR)"][0]
|
||||||
|
assert aor_item.score == 2.0
|
||||||
|
|
||||||
|
def test_competitor_gap_tiers(self):
|
||||||
|
"""竞品差距分级测试"""
|
||||||
|
# <= 10pp
|
||||||
|
result = diagnose_citation_readiness(competitor_gap=0.05)
|
||||||
|
gap_item = [i for i in result.items if i.name == "竞品对比"][0]
|
||||||
|
assert gap_item.score == 3.0
|
||||||
|
|
||||||
|
# 10-20pp
|
||||||
|
result = diagnose_citation_readiness(competitor_gap=0.15)
|
||||||
|
gap_item = [i for i in result.items if i.name == "竞品对比"][0]
|
||||||
|
assert gap_item.score == 2.0
|
||||||
|
|
||||||
|
|
||||||
|
class TestRecommendations:
|
||||||
|
"""推荐生成测试"""
|
||||||
|
|
||||||
|
def test_generate_from_fail_items(self):
|
||||||
|
"""从fail项生成P0推荐"""
|
||||||
|
dimensions = [
|
||||||
|
diagnose_content_extractability(
|
||||||
|
has_direct_answer=False,
|
||||||
|
has_qa_headings=False,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
recommendations = generate_recommendations(dimensions)
|
||||||
|
assert len(recommendations) >= 2
|
||||||
|
# 检查有P0推荐(不一定是全部)
|
||||||
|
p0_recs = [r for r in recommendations if r.priority == "P0"]
|
||||||
|
assert len(p0_recs) >= 2
|
||||||
|
|
||||||
|
def test_generate_from_warning_items(self):
|
||||||
|
"""从warning项生成P1推荐"""
|
||||||
|
dimensions = [
|
||||||
|
diagnose_content_extractability(
|
||||||
|
has_direct_answer=True,
|
||||||
|
has_qa_headings=True,
|
||||||
|
has_structured_data=True,
|
||||||
|
has_internal_links=False,
|
||||||
|
has_freshness_info=True,
|
||||||
|
update_days_ago=50,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
recommendations = generate_recommendations(dimensions)
|
||||||
|
p1_recs = [r for r in recommendations if r.priority == "P1"]
|
||||||
|
assert len(p1_recs) >= 1
|
||||||
|
|
||||||
|
def test_priority_ordering(self):
|
||||||
|
"""推荐按优先级排序"""
|
||||||
|
dimensions = [
|
||||||
|
diagnose_content_extractability(
|
||||||
|
has_direct_answer=False,
|
||||||
|
has_qa_headings=False,
|
||||||
|
has_structured_data=True,
|
||||||
|
has_internal_links=False,
|
||||||
|
has_freshness_info=True,
|
||||||
|
update_days_ago=50,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
recommendations = generate_recommendations(dimensions)
|
||||||
|
priorities = [r.priority for r in recommendations]
|
||||||
|
assert priorities == sorted(priorities)
|
||||||
|
|
||||||
|
def test_empty_dimensions(self):
|
||||||
|
"""空维度列表"""
|
||||||
|
recommendations = generate_recommendations([])
|
||||||
|
assert len(recommendations) == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestHealthLevel:
|
||||||
|
"""健康等级测试"""
|
||||||
|
|
||||||
|
def test_excellent(self):
|
||||||
|
assert get_health_level(85) == "excellent"
|
||||||
|
assert get_health_level(80) == "excellent"
|
||||||
|
|
||||||
|
def test_good(self):
|
||||||
|
assert get_health_level(70) == "good"
|
||||||
|
assert get_health_level(60) == "good"
|
||||||
|
|
||||||
|
def test_pass(self):
|
||||||
|
assert get_health_level(50) == "pass"
|
||||||
|
assert get_health_level(40) == "pass"
|
||||||
|
|
||||||
|
def test_danger(self):
|
||||||
|
assert get_health_level(30) == "danger"
|
||||||
|
assert get_health_level(0) == "danger"
|
||||||
|
|
||||||
|
def test_labels(self):
|
||||||
|
assert get_health_level_label("excellent") == "优秀"
|
||||||
|
assert get_health_level_label("good") == "良好"
|
||||||
|
assert get_health_level_label("pass") == "及格"
|
||||||
|
assert get_health_level_label("danger") == "危险"
|
||||||
|
|
||||||
|
|
||||||
|
class TestGEODiagnosisService:
|
||||||
|
"""GEO诊断服务类测试"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def service(self):
|
||||||
|
return GEODiagnosisService()
|
||||||
|
|
||||||
|
def test_full_diagnosis_all_pass(self, service):
|
||||||
|
"""完整诊断 - 所有项通过"""
|
||||||
|
input_data = GEODiagnosisInput(
|
||||||
|
# 内容可提取性
|
||||||
|
has_direct_answer=True,
|
||||||
|
has_qa_headings=True,
|
||||||
|
has_structured_data=True,
|
||||||
|
has_internal_links=True,
|
||||||
|
has_freshness_info=True,
|
||||||
|
update_days_ago=10,
|
||||||
|
# 实体清晰度
|
||||||
|
has_brand_definition=True,
|
||||||
|
has_target_audience=True,
|
||||||
|
has_unique_value=True,
|
||||||
|
has_industry_classification=True,
|
||||||
|
# E-E-A-T
|
||||||
|
has_author_bio=True,
|
||||||
|
author_credentials_complete=1.0,
|
||||||
|
has_certifications=True,
|
||||||
|
certification_count=5,
|
||||||
|
has_data_sources=True,
|
||||||
|
authoritative_source_ratio=1.0,
|
||||||
|
has_expert_endorsements=True,
|
||||||
|
endorsement_count=5,
|
||||||
|
# Schema
|
||||||
|
has_organization=True,
|
||||||
|
has_product=True,
|
||||||
|
has_article=True,
|
||||||
|
has_faq=True,
|
||||||
|
has_howto=True,
|
||||||
|
has_breadcrumb=True,
|
||||||
|
# 主题权威
|
||||||
|
content_depth_score=1.0,
|
||||||
|
topic_coverage_ratio=1.0,
|
||||||
|
entity_consistency_score=1.0,
|
||||||
|
cluster_completeness=1.0,
|
||||||
|
# 引用就绪度
|
||||||
|
answer_ownership_rate=0.6,
|
||||||
|
citation_accuracy=1.0,
|
||||||
|
ai_sov=0.35,
|
||||||
|
competitor_gap=0.0,
|
||||||
|
)
|
||||||
|
result = service.diagnose(input_data)
|
||||||
|
|
||||||
|
assert result.overall_score == 100.0
|
||||||
|
assert result.health_level == "excellent"
|
||||||
|
assert len(result.dimensions) == 6
|
||||||
|
assert len(result.recommendations) == 0
|
||||||
|
|
||||||
|
def test_full_diagnosis_all_fail(self, service):
|
||||||
|
"""完整诊断 - 所有项失败"""
|
||||||
|
input_data = GEODiagnosisInput()
|
||||||
|
result = service.diagnose(input_data)
|
||||||
|
|
||||||
|
# 由于有些项在默认情况下会得少量分数,总分不一定是0
|
||||||
|
assert result.overall_score < 10.0
|
||||||
|
assert result.health_level == "danger"
|
||||||
|
assert len(result.dimensions) == 6
|
||||||
|
assert len(result.recommendations) > 0
|
||||||
|
|
||||||
|
def test_diagnose_from_dict(self, service):
|
||||||
|
"""从字典执行诊断"""
|
||||||
|
data = {
|
||||||
|
"has_direct_answer": True,
|
||||||
|
"has_brand_definition": True,
|
||||||
|
"has_author_bio": True,
|
||||||
|
"author_credentials_complete": 0.8,
|
||||||
|
"has_organization": True,
|
||||||
|
"content_depth_score": 0.8,
|
||||||
|
"answer_ownership_rate": 0.5,
|
||||||
|
}
|
||||||
|
result = service.diagnose_from_dict(data)
|
||||||
|
|
||||||
|
assert result.overall_score > 0
|
||||||
|
assert len(result.dimensions) == 6
|
||||||
|
|
||||||
|
def test_result_to_dict(self, service):
|
||||||
|
"""结果转字典"""
|
||||||
|
input_data = GEODiagnosisInput(
|
||||||
|
has_direct_answer=True,
|
||||||
|
has_brand_definition=True,
|
||||||
|
)
|
||||||
|
result = service.diagnose(input_data)
|
||||||
|
result_dict = result.to_dict()
|
||||||
|
|
||||||
|
assert "overall_score" in result_dict
|
||||||
|
assert "health_level" in result_dict
|
||||||
|
assert "health_level_label" in result_dict
|
||||||
|
assert "dimensions" in result_dict
|
||||||
|
assert "recommendations" in result_dict
|
||||||
|
assert len(result_dict["dimensions"]) == 6
|
||||||
|
|
||||||
|
def test_score_boundaries(self, service):
|
||||||
|
"""评分边界测试"""
|
||||||
|
# 最低分
|
||||||
|
result = service.diagnose(GEODiagnosisInput())
|
||||||
|
assert result.overall_score >= 0.0
|
||||||
|
|
||||||
|
# 最高分
|
||||||
|
input_data = GEODiagnosisInput(
|
||||||
|
has_direct_answer=True,
|
||||||
|
has_qa_headings=True,
|
||||||
|
has_structured_data=True,
|
||||||
|
has_internal_links=True,
|
||||||
|
has_freshness_info=True,
|
||||||
|
update_days_ago=10,
|
||||||
|
has_brand_definition=True,
|
||||||
|
has_target_audience=True,
|
||||||
|
has_unique_value=True,
|
||||||
|
has_industry_classification=True,
|
||||||
|
has_author_bio=True,
|
||||||
|
author_credentials_complete=1.0,
|
||||||
|
has_certifications=True,
|
||||||
|
certification_count=5,
|
||||||
|
has_data_sources=True,
|
||||||
|
authoritative_source_ratio=1.0,
|
||||||
|
has_expert_endorsements=True,
|
||||||
|
endorsement_count=5,
|
||||||
|
has_organization=True,
|
||||||
|
has_product=True,
|
||||||
|
has_article=True,
|
||||||
|
has_faq=True,
|
||||||
|
has_howto=True,
|
||||||
|
has_breadcrumb=True,
|
||||||
|
content_depth_score=1.0,
|
||||||
|
topic_coverage_ratio=1.0,
|
||||||
|
entity_consistency_score=1.0,
|
||||||
|
cluster_completeness=1.0,
|
||||||
|
answer_ownership_rate=0.6,
|
||||||
|
citation_accuracy=0.95,
|
||||||
|
ai_sov=0.35,
|
||||||
|
competitor_gap=0.05,
|
||||||
|
)
|
||||||
|
result = service.diagnose(input_data)
|
||||||
|
assert result.overall_score <= 100.0
|
||||||
|
|
||||||
|
def test_health_levels(self, service):
|
||||||
|
"""健康等级测试"""
|
||||||
|
# excellent (>= 80)
|
||||||
|
input_data = GEODiagnosisInput(
|
||||||
|
has_direct_answer=True,
|
||||||
|
has_qa_headings=True,
|
||||||
|
has_structured_data=True,
|
||||||
|
has_internal_links=True,
|
||||||
|
has_freshness_info=True,
|
||||||
|
update_days_ago=10,
|
||||||
|
has_brand_definition=True,
|
||||||
|
has_target_audience=True,
|
||||||
|
has_unique_value=True,
|
||||||
|
has_industry_classification=True,
|
||||||
|
has_author_bio=True,
|
||||||
|
author_credentials_complete=0.9,
|
||||||
|
has_certifications=True,
|
||||||
|
certification_count=4,
|
||||||
|
has_data_sources=True,
|
||||||
|
authoritative_source_ratio=0.9,
|
||||||
|
has_expert_endorsements=True,
|
||||||
|
endorsement_count=4,
|
||||||
|
has_organization=True,
|
||||||
|
has_product=True,
|
||||||
|
has_article=True,
|
||||||
|
has_faq=True,
|
||||||
|
has_howto=True,
|
||||||
|
has_breadcrumb=True,
|
||||||
|
content_depth_score=0.9,
|
||||||
|
topic_coverage_ratio=0.9,
|
||||||
|
entity_consistency_score=0.9,
|
||||||
|
cluster_completeness=0.8,
|
||||||
|
answer_ownership_rate=0.6,
|
||||||
|
citation_accuracy=0.95,
|
||||||
|
ai_sov=0.35,
|
||||||
|
competitor_gap=0.05,
|
||||||
|
)
|
||||||
|
result = service.diagnose(input_data)
|
||||||
|
assert result.health_level == "excellent"
|
||||||
|
|
||||||
|
def test_dimension_scores_sum(self, service):
|
||||||
|
"""维度分数求和验证"""
|
||||||
|
input_data = GEODiagnosisInput(
|
||||||
|
has_direct_answer=True,
|
||||||
|
has_brand_definition=True,
|
||||||
|
)
|
||||||
|
result = service.diagnose(input_data)
|
||||||
|
|
||||||
|
# 各维度分数求和应等于总分
|
||||||
|
total = sum(dim.score for dim in result.dimensions)
|
||||||
|
assert result.overall_score == pytest.approx(total, rel=0.01)
|
||||||
|
|
||||||
|
def test_recommendations_generated(self, service):
|
||||||
|
"""推荐生成验证"""
|
||||||
|
input_data = GEODiagnosisInput()
|
||||||
|
result = service.diagnose(input_data)
|
||||||
|
|
||||||
|
# 所有项都失败时应该有推荐
|
||||||
|
assert len(result.recommendations) > 0
|
||||||
|
assert all(r.priority in ["P0", "P1", "P2"] for r in result.recommendations)
|
||||||
|
|
@ -0,0 +1,332 @@
|
||||||
|
import pytest
|
||||||
|
from app.services.quota_service import (
|
||||||
|
QuotaService,
|
||||||
|
QuotaUsage,
|
||||||
|
QuotaWarning,
|
||||||
|
QuotaType,
|
||||||
|
QUOTA_LIMITS,
|
||||||
|
WARNING_THRESHOLDS,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestQuotaUsage:
|
||||||
|
"""额度使用数据结构测试"""
|
||||||
|
|
||||||
|
def test_quota_usage_creation(self):
|
||||||
|
"""测试QuotaUsage数据创建"""
|
||||||
|
usage = QuotaUsage(
|
||||||
|
quota_type="api_calls",
|
||||||
|
used=800,
|
||||||
|
limit=1000,
|
||||||
|
usage_percentage=80.0,
|
||||||
|
status="warning",
|
||||||
|
remaining=200,
|
||||||
|
)
|
||||||
|
assert usage.quota_type == "api_calls"
|
||||||
|
assert usage.used == 800
|
||||||
|
assert usage.limit == 1000
|
||||||
|
assert usage.usage_percentage == 80.0
|
||||||
|
assert usage.status == "warning"
|
||||||
|
assert usage.remaining == 200
|
||||||
|
|
||||||
|
def test_quota_usage_unlimited(self):
|
||||||
|
"""测试无限额度QuotaUsage"""
|
||||||
|
usage = QuotaUsage(
|
||||||
|
quota_type="api_calls",
|
||||||
|
used=5000,
|
||||||
|
limit=-1,
|
||||||
|
usage_percentage=0.0,
|
||||||
|
status="unlimited",
|
||||||
|
remaining=-1,
|
||||||
|
)
|
||||||
|
assert usage.limit == -1
|
||||||
|
assert usage.status == "unlimited"
|
||||||
|
assert usage.remaining == -1
|
||||||
|
|
||||||
|
|
||||||
|
class TestQuotaWarning:
|
||||||
|
"""预警数据结构测试"""
|
||||||
|
|
||||||
|
def test_quota_warning_creation(self):
|
||||||
|
"""测试QuotaWarning数据创建"""
|
||||||
|
warning = QuotaWarning(
|
||||||
|
quota_type="api_calls",
|
||||||
|
status="warning",
|
||||||
|
usage_percentage=80.0,
|
||||||
|
message="API调用额度已使用80%",
|
||||||
|
recommended_action="请关注使用情况,考虑升级套餐",
|
||||||
|
)
|
||||||
|
assert warning.quota_type == "api_calls"
|
||||||
|
assert warning.status == "warning"
|
||||||
|
assert warning.usage_percentage == 80.0
|
||||||
|
assert "80%" in warning.message
|
||||||
|
|
||||||
|
|
||||||
|
class TestQuotaService:
|
||||||
|
"""额度预警服务测试"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def quota_service(self):
|
||||||
|
"""创建额度服务实例"""
|
||||||
|
return QuotaService()
|
||||||
|
|
||||||
|
def test_calculate_usage_percentage_basic(self, quota_service):
|
||||||
|
"""测试基础使用率计算"""
|
||||||
|
percentage = quota_service.calculate_usage_percentage(used=800, limit=1000)
|
||||||
|
assert percentage == 80.0
|
||||||
|
|
||||||
|
def test_calculate_usage_percentage_zero(self, quota_service):
|
||||||
|
"""测试0%使用率"""
|
||||||
|
percentage = quota_service.calculate_usage_percentage(used=0, limit=1000)
|
||||||
|
assert percentage == 0.0
|
||||||
|
|
||||||
|
def test_calculate_usage_percentage_full(self, quota_service):
|
||||||
|
"""测试100%使用率"""
|
||||||
|
percentage = quota_service.calculate_usage_percentage(used=1000, limit=1000)
|
||||||
|
assert percentage == 100.0
|
||||||
|
|
||||||
|
def test_calculate_usage_percentage_half(self, quota_service):
|
||||||
|
"""测试50%使用率"""
|
||||||
|
percentage = quota_service.calculate_usage_percentage(used=500, limit=1000)
|
||||||
|
assert percentage == 50.0
|
||||||
|
|
||||||
|
def test_calculate_usage_percentage_unlimited(self, quota_service):
|
||||||
|
"""测试无限额度使用率"""
|
||||||
|
percentage = quota_service.calculate_usage_percentage(used=5000, limit=-1)
|
||||||
|
assert percentage == 0.0
|
||||||
|
|
||||||
|
def test_get_quota_status_ok(self, quota_service):
|
||||||
|
"""测试正常状态(低于80%)"""
|
||||||
|
status = quota_service.get_quota_status(usage_percentage=50.0)
|
||||||
|
assert status == "ok"
|
||||||
|
|
||||||
|
def test_get_quota_status_warning(self, quota_service):
|
||||||
|
"""测试警告状态(80%)"""
|
||||||
|
status = quota_service.get_quota_status(usage_percentage=80.0)
|
||||||
|
assert status == "warning"
|
||||||
|
|
||||||
|
def test_get_quota_status_critical(self, quota_service):
|
||||||
|
"""测试严重状态(90%)"""
|
||||||
|
status = quota_service.get_quota_status(usage_percentage=90.0)
|
||||||
|
assert status == "critical"
|
||||||
|
|
||||||
|
def test_get_quota_status_exhausted(self, quota_service):
|
||||||
|
"""测试耗尽状态(100%)"""
|
||||||
|
status = quota_service.get_quota_status(usage_percentage=100.0)
|
||||||
|
assert status == "exhausted"
|
||||||
|
|
||||||
|
def test_get_quota_status_unlimited(self, quota_service):
|
||||||
|
"""测试无限状态"""
|
||||||
|
status = quota_service.get_quota_status(usage_percentage=0.0, limit=-1)
|
||||||
|
assert status == "unlimited"
|
||||||
|
|
||||||
|
def test_get_quota_limit_free_plan(self, quota_service):
|
||||||
|
"""测试免费套餐额度"""
|
||||||
|
limit = quota_service.get_quota_limit("free", QuotaType.API_CALLS)
|
||||||
|
assert limit == 1000
|
||||||
|
|
||||||
|
def test_get_quota_limit_basic_plan(self, quota_service):
|
||||||
|
"""测试基础套餐额度"""
|
||||||
|
limit = quota_service.get_quota_limit("basic", QuotaType.QUERIES)
|
||||||
|
assert limit == 500
|
||||||
|
|
||||||
|
def test_get_quota_limit_pro_plan(self, quota_service):
|
||||||
|
"""测试专业套餐额度"""
|
||||||
|
limit = quota_service.get_quota_limit("pro", QuotaType.CONTENT_GENERATION)
|
||||||
|
assert limit == 1000
|
||||||
|
|
||||||
|
def test_get_quota_limit_unlimited_plan(self, quota_service):
|
||||||
|
"""测试无限套餐额度"""
|
||||||
|
limit = quota_service.get_quota_limit("unlimited", QuotaType.API_CALLS)
|
||||||
|
assert limit == -1
|
||||||
|
|
||||||
|
def test_get_remaining_quota(self, quota_service):
|
||||||
|
"""测试剩余额度计算"""
|
||||||
|
remaining = quota_service.get_remaining(used=800, limit=1000)
|
||||||
|
assert remaining == 200
|
||||||
|
|
||||||
|
def test_get_remaining_quota_exhausted(self, quota_service):
|
||||||
|
"""测试额度耗尽"""
|
||||||
|
remaining = quota_service.get_remaining(used=1000, limit=1000)
|
||||||
|
assert remaining == 0
|
||||||
|
|
||||||
|
def test_get_remaining_quota_unlimited(self, quota_service):
|
||||||
|
"""测试无限额度"""
|
||||||
|
remaining = quota_service.get_remaining(used=5000, limit=-1)
|
||||||
|
assert remaining == -1
|
||||||
|
|
||||||
|
def test_get_remaining_quota_overuse(self, quota_service):
|
||||||
|
"""测试超额使用"""
|
||||||
|
remaining = quota_service.get_remaining(used=1200, limit=1000)
|
||||||
|
assert remaining == -200
|
||||||
|
|
||||||
|
def test_generate_warning_message_warning(self, quota_service):
|
||||||
|
"""测试生成警告消息"""
|
||||||
|
warning = quota_service.generate_warning(
|
||||||
|
quota_type=QuotaType.API_CALLS,
|
||||||
|
status="warning",
|
||||||
|
usage_percentage=80.0,
|
||||||
|
)
|
||||||
|
assert warning.quota_type == QuotaType.API_CALLS
|
||||||
|
assert warning.status == "warning"
|
||||||
|
assert warning.usage_percentage == 80.0
|
||||||
|
assert "80%" in warning.message
|
||||||
|
assert warning.recommended_action != ""
|
||||||
|
|
||||||
|
def test_generate_warning_message_critical(self, quota_service):
|
||||||
|
"""测试生成严重警告消息"""
|
||||||
|
warning = quota_service.generate_warning(
|
||||||
|
quota_type=QuotaType.QUERIES,
|
||||||
|
status="critical",
|
||||||
|
usage_percentage=90.0,
|
||||||
|
)
|
||||||
|
assert warning.status == "critical"
|
||||||
|
assert "90%" in warning.message
|
||||||
|
|
||||||
|
def test_generate_warning_message_exhausted(self, quota_service):
|
||||||
|
"""测试生成耗尽警告消息"""
|
||||||
|
warning = quota_service.generate_warning(
|
||||||
|
quota_type=QuotaType.CONTENT_GENERATION,
|
||||||
|
status="exhausted",
|
||||||
|
usage_percentage=100.0,
|
||||||
|
)
|
||||||
|
assert warning.status == "exhausted"
|
||||||
|
assert "100%" in warning.message
|
||||||
|
|
||||||
|
def test_check_quota_free_plan(self, quota_service):
|
||||||
|
"""测试检查免费套餐额度"""
|
||||||
|
usage = quota_service.check_quota(
|
||||||
|
plan="free",
|
||||||
|
quota_type=QuotaType.API_CALLS,
|
||||||
|
used=800,
|
||||||
|
)
|
||||||
|
assert usage.quota_type == QuotaType.API_CALLS
|
||||||
|
assert usage.limit == 1000
|
||||||
|
assert usage.used == 800
|
||||||
|
assert usage.remaining == 200
|
||||||
|
assert usage.usage_percentage == 80.0
|
||||||
|
assert usage.status == "warning"
|
||||||
|
|
||||||
|
def test_check_quota_pro_plan_warning(self, quota_service):
|
||||||
|
"""测试专业套餐警告状态"""
|
||||||
|
usage = quota_service.check_quota(
|
||||||
|
plan="pro",
|
||||||
|
quota_type=QuotaType.QUERIES,
|
||||||
|
used=4500,
|
||||||
|
)
|
||||||
|
assert usage.limit == 5000
|
||||||
|
assert usage.usage_percentage == 90.0
|
||||||
|
assert usage.status == "critical"
|
||||||
|
|
||||||
|
def test_check_quota_unlimited_plan(self, quota_service):
|
||||||
|
"""测试无限套餐"""
|
||||||
|
usage = quota_service.check_quota(
|
||||||
|
plan="unlimited",
|
||||||
|
quota_type=QuotaType.API_CALLS,
|
||||||
|
used=100000,
|
||||||
|
)
|
||||||
|
assert usage.limit == -1
|
||||||
|
assert usage.status == "unlimited"
|
||||||
|
assert usage.usage_percentage == 0.0
|
||||||
|
assert usage.remaining == -1
|
||||||
|
|
||||||
|
def test_check_quota_exhausted(self, quota_service):
|
||||||
|
"""测试额度耗尽"""
|
||||||
|
usage = quota_service.check_quota(
|
||||||
|
plan="free",
|
||||||
|
quota_type=QuotaType.QUERIES,
|
||||||
|
used=50,
|
||||||
|
)
|
||||||
|
assert usage.limit == 50
|
||||||
|
assert usage.used == 50
|
||||||
|
assert usage.remaining == 0
|
||||||
|
assert usage.usage_percentage == 100.0
|
||||||
|
assert usage.status == "exhausted"
|
||||||
|
|
||||||
|
def test_check_quota_ok(self, quota_service):
|
||||||
|
"""测试额度正常"""
|
||||||
|
usage = quota_service.check_quota(
|
||||||
|
plan="basic",
|
||||||
|
quota_type=QuotaType.STORAGE,
|
||||||
|
used=500,
|
||||||
|
)
|
||||||
|
assert usage.limit == 1000
|
||||||
|
assert usage.usage_percentage == 50.0
|
||||||
|
assert usage.status == "ok"
|
||||||
|
|
||||||
|
def test_reset_quota_usage(self, quota_service):
|
||||||
|
"""测试额度重置"""
|
||||||
|
usage = quota_service.check_quota(
|
||||||
|
plan="free",
|
||||||
|
quota_type=QuotaType.API_CALLS,
|
||||||
|
used=800,
|
||||||
|
)
|
||||||
|
assert usage.used == 800
|
||||||
|
|
||||||
|
reset_usage = quota_service.reset_quota(used=800)
|
||||||
|
assert reset_usage == 0
|
||||||
|
|
||||||
|
def test_reset_quota_to_specific_value(self, quota_service):
|
||||||
|
"""测试重置到特定值"""
|
||||||
|
reset_usage = quota_service.reset_quota(used=500, reset_to=100)
|
||||||
|
assert reset_usage == 100
|
||||||
|
|
||||||
|
def test_get_all_quota_usage(self, quota_service):
|
||||||
|
"""测试获取所有额度使用情况"""
|
||||||
|
all_usage = quota_service.get_all_quota_usage(
|
||||||
|
plan="free",
|
||||||
|
api_calls_used=800,
|
||||||
|
queries_used=40,
|
||||||
|
content_generation_used=8,
|
||||||
|
storage_used=50,
|
||||||
|
)
|
||||||
|
assert len(all_usage) == 4
|
||||||
|
assert any(u.quota_type == QuotaType.API_CALLS and u.used == 800 for u in all_usage)
|
||||||
|
assert any(u.quota_type == QuotaType.QUERIES and u.used == 40 for u in all_usage)
|
||||||
|
|
||||||
|
def test_get_warnings_from_usage(self, quota_service):
|
||||||
|
"""测试从使用情况生成预警"""
|
||||||
|
all_usage = quota_service.get_all_quota_usage(
|
||||||
|
plan="free",
|
||||||
|
api_calls_used=800,
|
||||||
|
queries_used=45,
|
||||||
|
content_generation_used=5,
|
||||||
|
storage_used=20,
|
||||||
|
)
|
||||||
|
warnings = quota_service.get_warnings(all_usage)
|
||||||
|
assert len(warnings) >= 1
|
||||||
|
assert any(w.status in ["warning", "critical", "exhausted"] for w in warnings)
|
||||||
|
|
||||||
|
def test_boundary_value_0_percent(self, quota_service):
|
||||||
|
"""测试边界值0%"""
|
||||||
|
usage = quota_service.check_quota(
|
||||||
|
plan="free",
|
||||||
|
quota_type=QuotaType.API_CALLS,
|
||||||
|
used=0,
|
||||||
|
)
|
||||||
|
assert usage.usage_percentage == 0.0
|
||||||
|
assert usage.status == "ok"
|
||||||
|
|
||||||
|
def test_boundary_value_100_percent(self, quota_service):
|
||||||
|
"""测试边界值100%"""
|
||||||
|
usage = quota_service.check_quota(
|
||||||
|
plan="free",
|
||||||
|
quota_type=QuotaType.API_CALLS,
|
||||||
|
used=1000,
|
||||||
|
)
|
||||||
|
assert usage.usage_percentage == 100.0
|
||||||
|
assert usage.status == "exhausted"
|
||||||
|
|
||||||
|
def test_warning_thresholds_constants(self):
|
||||||
|
"""测试预警阈值常量"""
|
||||||
|
assert WARNING_THRESHOLDS["warning"] == 0.80
|
||||||
|
assert WARNING_THRESHOLDS["critical"] == 0.90
|
||||||
|
assert WARNING_THRESHOLDS["exhausted"] == 1.00
|
||||||
|
|
||||||
|
def test_quota_limits_constants(self):
|
||||||
|
"""测试套餐额度常量"""
|
||||||
|
assert QUOTA_LIMITS["free"]["api_calls"] == 1000
|
||||||
|
assert QUOTA_LIMITS["free"]["queries"] == 50
|
||||||
|
assert QUOTA_LIMITS["basic"]["api_calls"] == 10000
|
||||||
|
assert QUOTA_LIMITS["pro"]["api_calls"] == 100000
|
||||||
|
assert QUOTA_LIMITS["unlimited"]["api_calls"] == -1
|
||||||
|
|
@ -0,0 +1,844 @@
|
||||||
|
"""
|
||||||
|
SEO诊断服务单元测试
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
from app.services.seo_diagnosis import (
|
||||||
|
SEODiagnosisService,
|
||||||
|
SEODiagnosisResult,
|
||||||
|
SEODimensionScore,
|
||||||
|
DiagnosisItem,
|
||||||
|
SEORecommendation,
|
||||||
|
DiagnosisStatus,
|
||||||
|
DimensionName,
|
||||||
|
TechnicalSEOData,
|
||||||
|
OnPageSEOData,
|
||||||
|
ContentQualityData,
|
||||||
|
BacklinkData,
|
||||||
|
UserExperienceData,
|
||||||
|
diagnose_technical_seo,
|
||||||
|
diagnose_on_page_seo,
|
||||||
|
diagnose_content_quality,
|
||||||
|
diagnose_backlinks,
|
||||||
|
diagnose_user_experience,
|
||||||
|
generate_recommendations,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDiagnosisStatus:
|
||||||
|
"""诊断状态枚举测试"""
|
||||||
|
|
||||||
|
def test_status_values(self):
|
||||||
|
"""测试状态值"""
|
||||||
|
assert DiagnosisStatus.PASS == "pass"
|
||||||
|
assert DiagnosisStatus.WARNING == "warning"
|
||||||
|
assert DiagnosisStatus.FAIL == "fail"
|
||||||
|
|
||||||
|
|
||||||
|
class TestDimensionName:
|
||||||
|
"""维度名称枚举测试"""
|
||||||
|
|
||||||
|
def test_dimension_names(self):
|
||||||
|
"""测试维度名称"""
|
||||||
|
assert DimensionName.TECHNICAL_SEO == "技术SEO"
|
||||||
|
assert DimensionName.ON_PAGE_SEO == "页面SEO"
|
||||||
|
assert DimensionName.CONTENT_QUALITY == "内容质量"
|
||||||
|
assert DimensionName.BACKLINK_ANALYSIS == "外链分析"
|
||||||
|
assert DimensionName.USER_EXPERIENCE == "用户体验"
|
||||||
|
|
||||||
|
|
||||||
|
class TestDiagnosisItem:
|
||||||
|
"""诊断项数据结构测试"""
|
||||||
|
|
||||||
|
def test_create_item(self):
|
||||||
|
"""测试创建诊断项"""
|
||||||
|
item = DiagnosisItem(
|
||||||
|
name="测试项",
|
||||||
|
status=DiagnosisStatus.PASS,
|
||||||
|
description="测试描述",
|
||||||
|
suggestion="测试建议",
|
||||||
|
score=1.0,
|
||||||
|
)
|
||||||
|
assert item.name == "测试项"
|
||||||
|
assert item.status == DiagnosisStatus.PASS
|
||||||
|
assert item.score == 1.0
|
||||||
|
|
||||||
|
def test_item_with_details(self):
|
||||||
|
"""测试带详情的诊断项"""
|
||||||
|
item = DiagnosisItem(
|
||||||
|
name="测试项",
|
||||||
|
status=DiagnosisStatus.WARNING,
|
||||||
|
description="测试描述",
|
||||||
|
suggestion="测试建议",
|
||||||
|
details={"key": "value"},
|
||||||
|
)
|
||||||
|
assert item.details == {"key": "value"}
|
||||||
|
|
||||||
|
|
||||||
|
class TestSEODimensionScore:
|
||||||
|
"""维度评分数据结构测试"""
|
||||||
|
|
||||||
|
def test_create_dimension_score(self):
|
||||||
|
"""测试创建维度评分"""
|
||||||
|
dim = SEODimensionScore(
|
||||||
|
name="测试维度",
|
||||||
|
score=20.0,
|
||||||
|
max_score=25.0,
|
||||||
|
items=[],
|
||||||
|
status=DiagnosisStatus.PASS,
|
||||||
|
)
|
||||||
|
assert dim.score == 20.0
|
||||||
|
assert dim.max_score == 25.0
|
||||||
|
assert dim.percentage == 80.0
|
||||||
|
|
||||||
|
def test_percentage_calculation(self):
|
||||||
|
"""测试得分率计算"""
|
||||||
|
dim = SEODimensionScore(
|
||||||
|
name="测试维度",
|
||||||
|
score=15.0,
|
||||||
|
max_score=25.0,
|
||||||
|
items=[],
|
||||||
|
status=DiagnosisStatus.PASS,
|
||||||
|
)
|
||||||
|
assert dim.percentage == 60.0
|
||||||
|
|
||||||
|
def test_status_calculation_all_pass(self):
|
||||||
|
"""测试全部通过时的状态"""
|
||||||
|
items = [
|
||||||
|
DiagnosisItem(name="项1", status=DiagnosisStatus.PASS, description="", suggestion=""),
|
||||||
|
DiagnosisItem(name="项2", status=DiagnosisStatus.PASS, description="", suggestion=""),
|
||||||
|
]
|
||||||
|
dim = SEODimensionScore(
|
||||||
|
name="测试维度",
|
||||||
|
score=10.0,
|
||||||
|
max_score=10.0,
|
||||||
|
items=items,
|
||||||
|
status=DiagnosisStatus.PASS,
|
||||||
|
)
|
||||||
|
assert dim.status == DiagnosisStatus.PASS
|
||||||
|
|
||||||
|
def test_status_calculation_with_warnings(self):
|
||||||
|
"""测试有警告时的状态"""
|
||||||
|
items = [
|
||||||
|
DiagnosisItem(name="项1", status=DiagnosisStatus.PASS, description="", suggestion=""),
|
||||||
|
DiagnosisItem(name="项2", status=DiagnosisStatus.WARNING, description="", suggestion=""),
|
||||||
|
DiagnosisItem(name="项3", status=DiagnosisStatus.WARNING, description="", suggestion=""),
|
||||||
|
]
|
||||||
|
dim = SEODimensionScore(
|
||||||
|
name="测试维度",
|
||||||
|
score=7.0,
|
||||||
|
max_score=10.0,
|
||||||
|
items=items,
|
||||||
|
status=DiagnosisStatus.PASS,
|
||||||
|
)
|
||||||
|
assert dim.status == DiagnosisStatus.WARNING
|
||||||
|
|
||||||
|
def test_status_calculation_with_fails(self):
|
||||||
|
"""测试有失败时的状态"""
|
||||||
|
items = [
|
||||||
|
DiagnosisItem(name="项1", status=DiagnosisStatus.FAIL, description="", suggestion=""),
|
||||||
|
DiagnosisItem(name="项2", status=DiagnosisStatus.PASS, description="", suggestion=""),
|
||||||
|
DiagnosisItem(name="项3", status=DiagnosisStatus.PASS, description="", suggestion=""),
|
||||||
|
DiagnosisItem(name="项4", status=DiagnosisStatus.PASS, description="", suggestion=""),
|
||||||
|
]
|
||||||
|
dim = SEODimensionScore(
|
||||||
|
name="测试维度",
|
||||||
|
score=7.0,
|
||||||
|
max_score=10.0,
|
||||||
|
items=items,
|
||||||
|
status=DiagnosisStatus.PASS,
|
||||||
|
)
|
||||||
|
# 1个FAIL在4个项中占25%,未超过30%,但有FAIL所以是WARNING
|
||||||
|
assert dim.status == DiagnosisStatus.WARNING
|
||||||
|
|
||||||
|
def test_status_calculation_many_fails(self):
|
||||||
|
"""测试大量失败时的状态"""
|
||||||
|
items = [
|
||||||
|
DiagnosisItem(name="项1", status=DiagnosisStatus.FAIL, description="", suggestion=""),
|
||||||
|
DiagnosisItem(name="项2", status=DiagnosisStatus.FAIL, description="", suggestion=""),
|
||||||
|
DiagnosisItem(name="项3", status=DiagnosisStatus.FAIL, description="", suggestion=""),
|
||||||
|
DiagnosisItem(name="项4", status=DiagnosisStatus.PASS, description="", suggestion=""),
|
||||||
|
]
|
||||||
|
dim = SEODimensionScore(
|
||||||
|
name="测试维度",
|
||||||
|
score=5.0,
|
||||||
|
max_score=10.0,
|
||||||
|
items=items,
|
||||||
|
status=DiagnosisStatus.PASS,
|
||||||
|
)
|
||||||
|
assert dim.status == DiagnosisStatus.FAIL
|
||||||
|
|
||||||
|
|
||||||
|
class TestSEODiagnosisResult:
|
||||||
|
"""诊断结果数据结构测试"""
|
||||||
|
|
||||||
|
def test_create_result(self):
|
||||||
|
"""测试创建诊断结果"""
|
||||||
|
result = SEODiagnosisResult(
|
||||||
|
overall_score=75.0,
|
||||||
|
dimensions=[],
|
||||||
|
recommendations=[],
|
||||||
|
)
|
||||||
|
assert result.overall_score == 75.0
|
||||||
|
assert result.health_level == "good"
|
||||||
|
|
||||||
|
def test_health_level_excellent(self):
|
||||||
|
"""测试优秀等级"""
|
||||||
|
result = SEODiagnosisResult(
|
||||||
|
overall_score=85.0,
|
||||||
|
dimensions=[],
|
||||||
|
recommendations=[],
|
||||||
|
)
|
||||||
|
assert result.health_level == "excellent"
|
||||||
|
|
||||||
|
def test_health_level_good(self):
|
||||||
|
"""测试良好等级"""
|
||||||
|
result = SEODiagnosisResult(
|
||||||
|
overall_score=70.0,
|
||||||
|
dimensions=[],
|
||||||
|
recommendations=[],
|
||||||
|
)
|
||||||
|
assert result.health_level == "good"
|
||||||
|
|
||||||
|
def test_health_level_pass(self):
|
||||||
|
"""测试及格等级"""
|
||||||
|
result = SEODiagnosisResult(
|
||||||
|
overall_score=50.0,
|
||||||
|
dimensions=[],
|
||||||
|
recommendations=[],
|
||||||
|
)
|
||||||
|
assert result.health_level == "pass"
|
||||||
|
|
||||||
|
def test_health_level_danger(self):
|
||||||
|
"""测试危险等级"""
|
||||||
|
result = SEODiagnosisResult(
|
||||||
|
overall_score=30.0,
|
||||||
|
dimensions=[],
|
||||||
|
recommendations=[],
|
||||||
|
)
|
||||||
|
assert result.health_level == "danger"
|
||||||
|
|
||||||
|
def test_score_clamping(self):
|
||||||
|
"""测试分数限制"""
|
||||||
|
result = SEODiagnosisResult(
|
||||||
|
overall_score=150.0,
|
||||||
|
dimensions=[],
|
||||||
|
recommendations=[],
|
||||||
|
)
|
||||||
|
assert result.overall_score == 100.0
|
||||||
|
|
||||||
|
result = SEODiagnosisResult(
|
||||||
|
overall_score=-10.0,
|
||||||
|
dimensions=[],
|
||||||
|
recommendations=[],
|
||||||
|
)
|
||||||
|
assert result.overall_score == 0.0
|
||||||
|
|
||||||
|
def test_to_dict(self):
|
||||||
|
"""测试字典转换"""
|
||||||
|
result = SEODiagnosisResult(
|
||||||
|
overall_score=75.0,
|
||||||
|
dimensions=[],
|
||||||
|
recommendations=[],
|
||||||
|
)
|
||||||
|
d = result.to_dict()
|
||||||
|
assert d["overall_score"] == 75.0
|
||||||
|
assert d["health_level"] == "good"
|
||||||
|
assert d["health_level_label"] == "良好"
|
||||||
|
assert "dimensions" in d
|
||||||
|
assert "recommendations" in d
|
||||||
|
|
||||||
|
|
||||||
|
class TestTechnicalSEODiagnosis:
|
||||||
|
"""技术SEO诊断测试"""
|
||||||
|
|
||||||
|
def test_perfect_technical_seo(self):
|
||||||
|
"""测试完美技术SEO"""
|
||||||
|
data = TechnicalSEOData(
|
||||||
|
is_indexed=True,
|
||||||
|
crawl_errors=0,
|
||||||
|
lcp_seconds=2.0,
|
||||||
|
fid_ms=50.0,
|
||||||
|
cls_score=0.05,
|
||||||
|
has_robots_txt=True,
|
||||||
|
robots_txt_blocks_important=False,
|
||||||
|
has_sitemap=True,
|
||||||
|
sitemap_valid=True,
|
||||||
|
url_structure_normalized=True,
|
||||||
|
)
|
||||||
|
result = diagnose_technical_seo(data)
|
||||||
|
assert result.score == result.max_score
|
||||||
|
assert result.status == DiagnosisStatus.PASS
|
||||||
|
|
||||||
|
def test_indexed_fail(self):
|
||||||
|
"""测试未索引情况"""
|
||||||
|
data = TechnicalSEOData(is_indexed=False)
|
||||||
|
result = diagnose_technical_seo(data)
|
||||||
|
assert any(item.status == DiagnosisStatus.FAIL for item in result.items if item.name == "索引状态")
|
||||||
|
|
||||||
|
def test_crawl_errors_warning(self):
|
||||||
|
"""测试少量爬取错误"""
|
||||||
|
data = TechnicalSEOData(crawl_errors=3)
|
||||||
|
result = diagnose_technical_seo(data)
|
||||||
|
crawl_item = next(item for item in result.items if item.name == "爬取错误")
|
||||||
|
assert crawl_item.status == DiagnosisStatus.WARNING
|
||||||
|
|
||||||
|
def test_crawl_errors_fail(self):
|
||||||
|
"""测试大量爬取错误"""
|
||||||
|
data = TechnicalSEOData(crawl_errors=10)
|
||||||
|
result = diagnose_technical_seo(data)
|
||||||
|
crawl_item = next(item for item in result.items if item.name == "爬取错误")
|
||||||
|
assert crawl_item.status == DiagnosisStatus.FAIL
|
||||||
|
|
||||||
|
def test_core_web_vitals_pass(self):
|
||||||
|
"""测试Core Web Vitals通过"""
|
||||||
|
data = TechnicalSEOData(
|
||||||
|
lcp_seconds=2.0,
|
||||||
|
fid_ms=80.0,
|
||||||
|
cls_score=0.05,
|
||||||
|
)
|
||||||
|
result = diagnose_technical_seo(data)
|
||||||
|
cwv_items = [item for item in result.items if item.name in ["LCP", "FID", "CLS"]]
|
||||||
|
assert all(item.status == DiagnosisStatus.PASS for item in cwv_items)
|
||||||
|
|
||||||
|
def test_core_web_vitals_warning(self):
|
||||||
|
"""测试Core Web Vitals警告"""
|
||||||
|
data = TechnicalSEOData(
|
||||||
|
lcp_seconds=3.0,
|
||||||
|
fid_ms=200.0,
|
||||||
|
cls_score=0.15,
|
||||||
|
)
|
||||||
|
result = diagnose_technical_seo(data)
|
||||||
|
cwv_items = [item for item in result.items if item.name in ["LCP", "FID", "CLS"]]
|
||||||
|
assert any(item.status == DiagnosisStatus.WARNING for item in cwv_items)
|
||||||
|
|
||||||
|
def test_core_web_vitals_fail(self):
|
||||||
|
"""测试Core Web Vitals失败"""
|
||||||
|
data = TechnicalSEOData(
|
||||||
|
lcp_seconds=5.0,
|
||||||
|
fid_ms=400.0,
|
||||||
|
cls_score=0.3,
|
||||||
|
)
|
||||||
|
result = diagnose_technical_seo(data)
|
||||||
|
cwv_items = [item for item in result.items if item.name in ["LCP", "FID", "CLS"]]
|
||||||
|
assert all(item.status == DiagnosisStatus.FAIL for item in cwv_items)
|
||||||
|
|
||||||
|
def test_robots_txt_blocks_important(self):
|
||||||
|
"""测试robots.txt阻止重要页面"""
|
||||||
|
data = TechnicalSEOData(
|
||||||
|
has_robots_txt=True,
|
||||||
|
robots_txt_blocks_important=True,
|
||||||
|
)
|
||||||
|
result = diagnose_technical_seo(data)
|
||||||
|
robots_item = next(item for item in result.items if item.name == "robots.txt")
|
||||||
|
assert robots_item.status == DiagnosisStatus.FAIL
|
||||||
|
|
||||||
|
def test_missing_sitemap(self):
|
||||||
|
"""测试缺少sitemap"""
|
||||||
|
data = TechnicalSEOData(has_sitemap=False)
|
||||||
|
result = diagnose_technical_seo(data)
|
||||||
|
sitemap_item = next(item for item in result.items if item.name == "sitemap")
|
||||||
|
assert sitemap_item.status == DiagnosisStatus.FAIL
|
||||||
|
|
||||||
|
|
||||||
|
class TestOnPageSEODiagnosis:
|
||||||
|
"""页面SEO诊断测试"""
|
||||||
|
|
||||||
|
def test_perfect_on_page_seo(self):
|
||||||
|
"""测试完美页面SEO"""
|
||||||
|
data = OnPageSEOData(
|
||||||
|
has_title=True,
|
||||||
|
title_length=50,
|
||||||
|
title_keyword_stuffing=False,
|
||||||
|
has_meta_description=True,
|
||||||
|
meta_description_length=140,
|
||||||
|
h1_count=1,
|
||||||
|
h_structure_valid=True,
|
||||||
|
keyword_density=2.0,
|
||||||
|
internal_links=10,
|
||||||
|
broken_internal_links=0,
|
||||||
|
images_without_alt=0,
|
||||||
|
total_images=5,
|
||||||
|
)
|
||||||
|
result = diagnose_on_page_seo(data)
|
||||||
|
assert result.score == result.max_score
|
||||||
|
assert result.status == DiagnosisStatus.PASS
|
||||||
|
|
||||||
|
def test_missing_title(self):
|
||||||
|
"""测试缺少Title"""
|
||||||
|
data = OnPageSEOData(has_title=False)
|
||||||
|
result = diagnose_on_page_seo(data)
|
||||||
|
title_item = next(item for item in result.items if item.name == "Title标签")
|
||||||
|
assert title_item.status == DiagnosisStatus.FAIL
|
||||||
|
|
||||||
|
def test_title_too_long(self):
|
||||||
|
"""测试Title过长"""
|
||||||
|
data = OnPageSEOData(title_length=80)
|
||||||
|
result = diagnose_on_page_seo(data)
|
||||||
|
title_item = next(item for item in result.items if item.name == "Title标签")
|
||||||
|
assert title_item.status == DiagnosisStatus.WARNING
|
||||||
|
|
||||||
|
def test_keyword_stuffing(self):
|
||||||
|
"""测试关键词堆砌"""
|
||||||
|
data = OnPageSEOData(title_keyword_stuffing=True)
|
||||||
|
result = diagnose_on_page_seo(data)
|
||||||
|
title_item = next(item for item in result.items if item.name == "Title标签")
|
||||||
|
assert title_item.status == DiagnosisStatus.WARNING
|
||||||
|
|
||||||
|
def test_multiple_h1(self):
|
||||||
|
"""测试多个H1"""
|
||||||
|
data = OnPageSEOData(h1_count=3)
|
||||||
|
result = diagnose_on_page_seo(data)
|
||||||
|
h_item = next(item for item in result.items if item.name == "H标签结构")
|
||||||
|
assert h_item.status == DiagnosisStatus.WARNING
|
||||||
|
|
||||||
|
def test_broken_links_warning(self):
|
||||||
|
"""测试少量死链"""
|
||||||
|
data = OnPageSEOData(broken_internal_links=2)
|
||||||
|
result = diagnose_on_page_seo(data)
|
||||||
|
link_item = next(item for item in result.items if item.name == "内链结构")
|
||||||
|
assert link_item.status == DiagnosisStatus.WARNING
|
||||||
|
|
||||||
|
def test_broken_links_fail(self):
|
||||||
|
"""测试大量死链"""
|
||||||
|
data = OnPageSEOData(broken_internal_links=10)
|
||||||
|
result = diagnose_on_page_seo(data)
|
||||||
|
link_item = next(item for item in result.items if item.name == "内链结构")
|
||||||
|
assert link_item.status == DiagnosisStatus.FAIL
|
||||||
|
|
||||||
|
def test_images_without_alt(self):
|
||||||
|
"""测试图片缺少Alt"""
|
||||||
|
data = OnPageSEOData(
|
||||||
|
images_without_alt=3,
|
||||||
|
total_images=5,
|
||||||
|
)
|
||||||
|
result = diagnose_on_page_seo(data)
|
||||||
|
alt_item = next(item for item in result.items if item.name == "图片Alt文本")
|
||||||
|
assert alt_item.status == DiagnosisStatus.FAIL
|
||||||
|
|
||||||
|
|
||||||
|
class TestContentQualityDiagnosis:
|
||||||
|
"""内容质量诊断测试"""
|
||||||
|
|
||||||
|
def test_perfect_content_quality(self):
|
||||||
|
"""测试完美内容质量"""
|
||||||
|
data = ContentQualityData(
|
||||||
|
readability_score=80.0,
|
||||||
|
word_count=2000,
|
||||||
|
topic_coverage=0.9,
|
||||||
|
has_author_info=True,
|
||||||
|
has_publication_date=True,
|
||||||
|
last_updated_days=10,
|
||||||
|
has_citations=True,
|
||||||
|
citation_authority=0.9,
|
||||||
|
duplicate_content_ratio=0.02,
|
||||||
|
has_expert_review=True,
|
||||||
|
)
|
||||||
|
result = diagnose_content_quality(data)
|
||||||
|
assert result.score == result.max_score
|
||||||
|
assert result.status == DiagnosisStatus.PASS
|
||||||
|
|
||||||
|
def test_low_readability(self):
|
||||||
|
"""测试低可读性"""
|
||||||
|
data = ContentQualityData(readability_score=40.0)
|
||||||
|
result = diagnose_content_quality(data)
|
||||||
|
readability_item = next(item for item in result.items if item.name == "可读性")
|
||||||
|
assert readability_item.status == DiagnosisStatus.FAIL
|
||||||
|
|
||||||
|
def test_shallow_content(self):
|
||||||
|
"""测试内容深度不足"""
|
||||||
|
data = ContentQualityData(
|
||||||
|
word_count=500,
|
||||||
|
topic_coverage=0.4,
|
||||||
|
)
|
||||||
|
result = diagnose_content_quality(data)
|
||||||
|
depth_item = next(item for item in result.items if item.name == "信息深度")
|
||||||
|
assert depth_item.status == DiagnosisStatus.FAIL
|
||||||
|
|
||||||
|
def test_missing_author(self):
|
||||||
|
"""测试缺少作者信息"""
|
||||||
|
data = ContentQualityData(has_author_info=False)
|
||||||
|
result = diagnose_content_quality(data)
|
||||||
|
author_item = next(item for item in result.items if item.name == "作者资质")
|
||||||
|
assert author_item.status == DiagnosisStatus.WARNING
|
||||||
|
|
||||||
|
def test_stale_content(self):
|
||||||
|
"""测试过时内容"""
|
||||||
|
data = ContentQualityData(last_updated_days=200)
|
||||||
|
result = diagnose_content_quality(data)
|
||||||
|
freshness_item = next(item for item in result.items if item.name == "内容新鲜度")
|
||||||
|
assert freshness_item.status == DiagnosisStatus.FAIL
|
||||||
|
|
||||||
|
def test_high_duplicate_ratio(self):
|
||||||
|
"""测试高重复内容比例"""
|
||||||
|
data = ContentQualityData(duplicate_content_ratio=0.5)
|
||||||
|
result = diagnose_content_quality(data)
|
||||||
|
duplicate_item = next(item for item in result.items if item.name == "重复内容")
|
||||||
|
assert duplicate_item.status == DiagnosisStatus.FAIL
|
||||||
|
|
||||||
|
|
||||||
|
class TestBacklinkDiagnosis:
|
||||||
|
"""外链分析诊断测试"""
|
||||||
|
|
||||||
|
def test_perfect_backlinks(self):
|
||||||
|
"""测试完美外链"""
|
||||||
|
data = BacklinkData(
|
||||||
|
total_backlinks=200,
|
||||||
|
referring_domains=50,
|
||||||
|
high_authority_links=20,
|
||||||
|
toxic_links=0,
|
||||||
|
nofollow_ratio=0.3,
|
||||||
|
anchor_text_diversity=0.9,
|
||||||
|
exact_match_anchor_ratio=0.1,
|
||||||
|
)
|
||||||
|
result = diagnose_backlinks(data)
|
||||||
|
assert result.score == result.max_score
|
||||||
|
assert result.status == DiagnosisStatus.PASS
|
||||||
|
|
||||||
|
def test_few_referring_domains(self):
|
||||||
|
"""测试引用域名少"""
|
||||||
|
data = BacklinkData(referring_domains=5)
|
||||||
|
result = diagnose_backlinks(data)
|
||||||
|
domain_item = next(item for item in result.items if item.name == "引用域名")
|
||||||
|
assert domain_item.status == DiagnosisStatus.FAIL
|
||||||
|
|
||||||
|
def test_toxic_links_warning(self):
|
||||||
|
"""测试少量毒性链接"""
|
||||||
|
data = BacklinkData(
|
||||||
|
total_backlinks=100,
|
||||||
|
toxic_links=3,
|
||||||
|
)
|
||||||
|
result = diagnose_backlinks(data)
|
||||||
|
toxic_item = next(item for item in result.items if item.name == "毒性链接")
|
||||||
|
assert toxic_item.status == DiagnosisStatus.WARNING
|
||||||
|
|
||||||
|
def test_toxic_links_fail(self):
|
||||||
|
"""测试大量毒性链接"""
|
||||||
|
data = BacklinkData(
|
||||||
|
total_backlinks=50,
|
||||||
|
toxic_links=10,
|
||||||
|
)
|
||||||
|
result = diagnose_backlinks(data)
|
||||||
|
toxic_item = next(item for item in result.items if item.name == "毒性链接")
|
||||||
|
assert toxic_item.status == DiagnosisStatus.FAIL
|
||||||
|
|
||||||
|
def test_low_anchor_diversity(self):
|
||||||
|
"""测试锚文本多样性低"""
|
||||||
|
data = BacklinkData(
|
||||||
|
anchor_text_diversity=0.3,
|
||||||
|
exact_match_anchor_ratio=0.6,
|
||||||
|
)
|
||||||
|
result = diagnose_backlinks(data)
|
||||||
|
anchor_item = next(item for item in result.items if item.name == "锚文本分布")
|
||||||
|
assert anchor_item.status == DiagnosisStatus.FAIL
|
||||||
|
|
||||||
|
|
||||||
|
class TestUserExperienceDiagnosis:
|
||||||
|
"""用户体验诊断测试"""
|
||||||
|
|
||||||
|
def test_perfect_ux(self):
|
||||||
|
"""测试完美用户体验"""
|
||||||
|
data = UserExperienceData(
|
||||||
|
is_mobile_friendly=True,
|
||||||
|
mobile_viewport_set=True,
|
||||||
|
page_load_time=1.5,
|
||||||
|
has_https=True,
|
||||||
|
has_breadcrumbs=True,
|
||||||
|
conversion_path_clear=True,
|
||||||
|
has_cta=True,
|
||||||
|
form_usability=0.95,
|
||||||
|
has_search=True,
|
||||||
|
)
|
||||||
|
result = diagnose_user_experience(data)
|
||||||
|
assert result.score == result.max_score
|
||||||
|
assert result.status == DiagnosisStatus.PASS
|
||||||
|
|
||||||
|
def test_not_mobile_friendly(self):
|
||||||
|
"""测试不移动友好"""
|
||||||
|
data = UserExperienceData(is_mobile_friendly=False)
|
||||||
|
result = diagnose_user_experience(data)
|
||||||
|
mobile_item = next(item for item in result.items if item.name == "移动适配")
|
||||||
|
assert mobile_item.status == DiagnosisStatus.FAIL
|
||||||
|
|
||||||
|
def test_slow_page_load(self):
|
||||||
|
"""测试页面加载慢"""
|
||||||
|
data = UserExperienceData(page_load_time=5.0)
|
||||||
|
result = diagnose_user_experience(data)
|
||||||
|
speed_item = next(item for item in result.items if item.name == "页面速度")
|
||||||
|
assert speed_item.status == DiagnosisStatus.FAIL
|
||||||
|
|
||||||
|
def test_missing_https(self):
|
||||||
|
"""测试缺少HTTPS"""
|
||||||
|
data = UserExperienceData(has_https=False)
|
||||||
|
result = diagnose_user_experience(data)
|
||||||
|
https_item = next(item for item in result.items if item.name == "HTTPS")
|
||||||
|
assert https_item.status == DiagnosisStatus.FAIL
|
||||||
|
|
||||||
|
def test_missing_cta(self):
|
||||||
|
"""测试缺少CTA"""
|
||||||
|
data = UserExperienceData(has_cta=False)
|
||||||
|
result = diagnose_user_experience(data)
|
||||||
|
cta_item = next(item for item in result.items if item.name == "CTA")
|
||||||
|
assert cta_item.status == DiagnosisStatus.WARNING
|
||||||
|
|
||||||
|
|
||||||
|
class TestRecommendations:
|
||||||
|
"""优化建议生成测试"""
|
||||||
|
|
||||||
|
def test_generate_recommendations(self):
|
||||||
|
"""测试建议生成"""
|
||||||
|
result = SEODiagnosisResult(
|
||||||
|
overall_score=60.0,
|
||||||
|
dimensions=[
|
||||||
|
SEODimensionScore(
|
||||||
|
name="测试维度",
|
||||||
|
score=10.0,
|
||||||
|
max_score=20.0,
|
||||||
|
items=[
|
||||||
|
DiagnosisItem(
|
||||||
|
name="失败项",
|
||||||
|
status=DiagnosisStatus.FAIL,
|
||||||
|
description="描述",
|
||||||
|
suggestion="修复建议",
|
||||||
|
),
|
||||||
|
DiagnosisItem(
|
||||||
|
name="警告项",
|
||||||
|
status=DiagnosisStatus.WARNING,
|
||||||
|
description="描述",
|
||||||
|
suggestion="优化建议",
|
||||||
|
),
|
||||||
|
DiagnosisItem(
|
||||||
|
name="通过项",
|
||||||
|
status=DiagnosisStatus.PASS,
|
||||||
|
description="描述",
|
||||||
|
suggestion="保持",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
status=DiagnosisStatus.WARNING,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
recommendations=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
recommendations = generate_recommendations(result)
|
||||||
|
assert len(recommendations) == 2
|
||||||
|
assert recommendations[0].priority == "high"
|
||||||
|
assert recommendations[1].priority == "medium"
|
||||||
|
|
||||||
|
def test_recommendations_sorted_by_priority(self):
|
||||||
|
"""测试建议按优先级排序"""
|
||||||
|
result = SEODiagnosisResult(
|
||||||
|
overall_score=50.0,
|
||||||
|
dimensions=[
|
||||||
|
SEODimensionScore(
|
||||||
|
name="维度1",
|
||||||
|
score=5.0,
|
||||||
|
max_score=10.0,
|
||||||
|
items=[
|
||||||
|
DiagnosisItem(
|
||||||
|
name="警告项",
|
||||||
|
status=DiagnosisStatus.WARNING,
|
||||||
|
description="",
|
||||||
|
suggestion="",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
status=DiagnosisStatus.WARNING,
|
||||||
|
),
|
||||||
|
SEODimensionScore(
|
||||||
|
name="维度2",
|
||||||
|
score=5.0,
|
||||||
|
max_score=10.0,
|
||||||
|
items=[
|
||||||
|
DiagnosisItem(
|
||||||
|
name="失败项",
|
||||||
|
status=DiagnosisStatus.FAIL,
|
||||||
|
description="",
|
||||||
|
suggestion="",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
status=DiagnosisStatus.WARNING,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
recommendations=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
recommendations = generate_recommendations(result)
|
||||||
|
assert recommendations[0].priority == "high"
|
||||||
|
assert recommendations[1].priority == "medium"
|
||||||
|
|
||||||
|
|
||||||
|
class TestSEODiagnosisService:
|
||||||
|
"""SEO诊断服务测试"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def service(self):
|
||||||
|
"""创建诊断服务实例"""
|
||||||
|
return SEODiagnosisService()
|
||||||
|
|
||||||
|
def test_full_diagnosis_with_defaults(self, service):
|
||||||
|
"""测试使用默认数据的完整诊断"""
|
||||||
|
result = service.diagnose()
|
||||||
|
assert isinstance(result, SEODiagnosisResult)
|
||||||
|
assert 0 <= result.overall_score <= 100
|
||||||
|
assert len(result.dimensions) == 5
|
||||||
|
assert isinstance(result.recommendations, list)
|
||||||
|
|
||||||
|
def test_diagnosis_returns_all_dimensions(self, service):
|
||||||
|
"""测试诊断返回所有维度"""
|
||||||
|
result = service.diagnose()
|
||||||
|
dimension_names = [dim.name for dim in result.dimensions]
|
||||||
|
assert "技术SEO" in dimension_names
|
||||||
|
assert "页面SEO" in dimension_names
|
||||||
|
assert "内容质量" in dimension_names
|
||||||
|
assert "外链分析" in dimension_names
|
||||||
|
assert "用户体验" in dimension_names
|
||||||
|
|
||||||
|
def test_diagnosis_with_custom_data(self, service):
|
||||||
|
"""测试使用自定义数据的诊断"""
|
||||||
|
technical_data = TechnicalSEOData(
|
||||||
|
is_indexed=False,
|
||||||
|
crawl_errors=10,
|
||||||
|
)
|
||||||
|
result = service.diagnose(technical_data=technical_data)
|
||||||
|
assert result.overall_score < 100
|
||||||
|
|
||||||
|
def test_diagnose_technical_only(self, service):
|
||||||
|
"""测试仅技术SEO诊断"""
|
||||||
|
result = service.diagnose_technical_only()
|
||||||
|
assert isinstance(result, SEODimensionScore)
|
||||||
|
assert result.name == DimensionName.TECHNICAL_SEO
|
||||||
|
|
||||||
|
def test_diagnose_on_page_only(self, service):
|
||||||
|
"""测试仅页面SEO诊断"""
|
||||||
|
result = service.diagnose_on_page_only()
|
||||||
|
assert isinstance(result, SEODimensionScore)
|
||||||
|
assert result.name == DimensionName.ON_PAGE_SEO
|
||||||
|
|
||||||
|
def test_diagnose_content_only(self, service):
|
||||||
|
"""测试仅内容质量诊断"""
|
||||||
|
result = service.diagnose_content_only()
|
||||||
|
assert isinstance(result, SEODimensionScore)
|
||||||
|
assert result.name == DimensionName.CONTENT_QUALITY
|
||||||
|
|
||||||
|
def test_diagnose_backlinks_only(self, service):
|
||||||
|
"""测试仅外链分析"""
|
||||||
|
result = service.diagnose_backlinks_only()
|
||||||
|
assert isinstance(result, SEODimensionScore)
|
||||||
|
assert result.name == DimensionName.BACKLINK_ANALYSIS
|
||||||
|
|
||||||
|
def test_diagnose_ux_only(self, service):
|
||||||
|
"""测试仅用户体验诊断"""
|
||||||
|
result = service.diagnose_ux_only()
|
||||||
|
assert isinstance(result, SEODimensionScore)
|
||||||
|
assert result.name == DimensionName.USER_EXPERIENCE
|
||||||
|
|
||||||
|
def test_diagnosis_with_poor_data(self, service):
|
||||||
|
"""测试使用差数据的诊断"""
|
||||||
|
technical_data = TechnicalSEOData(
|
||||||
|
is_indexed=False,
|
||||||
|
crawl_errors=20,
|
||||||
|
lcp_seconds=6.0,
|
||||||
|
fid_ms=500.0,
|
||||||
|
cls_score=0.4,
|
||||||
|
has_robots_txt=False,
|
||||||
|
has_sitemap=False,
|
||||||
|
)
|
||||||
|
on_page_data = OnPageSEOData(
|
||||||
|
has_title=False,
|
||||||
|
has_meta_description=False,
|
||||||
|
h1_count=0,
|
||||||
|
broken_internal_links=10,
|
||||||
|
)
|
||||||
|
content_data = ContentQualityData(
|
||||||
|
readability_score=30.0,
|
||||||
|
word_count=200,
|
||||||
|
last_updated_days=365,
|
||||||
|
duplicate_content_ratio=0.6,
|
||||||
|
)
|
||||||
|
backlink_data = BacklinkData(
|
||||||
|
referring_domains=2,
|
||||||
|
toxic_links=20,
|
||||||
|
anchor_text_diversity=0.2,
|
||||||
|
)
|
||||||
|
ux_data = UserExperienceData(
|
||||||
|
is_mobile_friendly=False,
|
||||||
|
page_load_time=8.0,
|
||||||
|
has_https=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = service.diagnose(
|
||||||
|
technical_data=technical_data,
|
||||||
|
on_page_data=on_page_data,
|
||||||
|
content_data=content_data,
|
||||||
|
backlink_data=backlink_data,
|
||||||
|
ux_data=ux_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.overall_score < 30
|
||||||
|
assert result.health_level == "danger"
|
||||||
|
assert len(result.recommendations) > 0
|
||||||
|
assert any(rec.priority == "high" for rec in result.recommendations)
|
||||||
|
|
||||||
|
def test_diagnosis_with_excellent_data(self, service):
|
||||||
|
"""测试使用优秀数据的诊断"""
|
||||||
|
technical_data = TechnicalSEOData(
|
||||||
|
is_indexed=True,
|
||||||
|
crawl_errors=0,
|
||||||
|
lcp_seconds=1.5,
|
||||||
|
fid_ms=50.0,
|
||||||
|
cls_score=0.03,
|
||||||
|
)
|
||||||
|
on_page_data = OnPageSEOData(
|
||||||
|
title_length=45,
|
||||||
|
meta_description_length=140,
|
||||||
|
keyword_density=2.0,
|
||||||
|
)
|
||||||
|
content_data = ContentQualityData(
|
||||||
|
readability_score=85.0,
|
||||||
|
word_count=2500,
|
||||||
|
topic_coverage=0.95,
|
||||||
|
has_expert_review=True,
|
||||||
|
last_updated_days=5,
|
||||||
|
)
|
||||||
|
backlink_data = BacklinkData(
|
||||||
|
referring_domains=50,
|
||||||
|
high_authority_links=20,
|
||||||
|
toxic_links=0,
|
||||||
|
anchor_text_diversity=0.9,
|
||||||
|
)
|
||||||
|
ux_data = UserExperienceData(
|
||||||
|
page_load_time=1.2,
|
||||||
|
form_usability=0.95,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = service.diagnose(
|
||||||
|
technical_data=technical_data,
|
||||||
|
on_page_data=on_page_data,
|
||||||
|
content_data=content_data,
|
||||||
|
backlink_data=backlink_data,
|
||||||
|
ux_data=ux_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.overall_score >= 80
|
||||||
|
assert result.health_level == "excellent"
|
||||||
|
|
||||||
|
def test_result_to_dict_format(self, service):
|
||||||
|
"""测试结果字典格式"""
|
||||||
|
result = service.diagnose()
|
||||||
|
d = result.to_dict()
|
||||||
|
|
||||||
|
assert "overall_score" in d
|
||||||
|
assert "health_level" in d
|
||||||
|
assert "health_level_label" in d
|
||||||
|
assert "dimensions" in d
|
||||||
|
assert "recommendations" in d
|
||||||
|
assert isinstance(d["dimensions"], list)
|
||||||
|
assert isinstance(d["recommendations"], list)
|
||||||
|
|
||||||
|
if d["dimensions"]:
|
||||||
|
dim = d["dimensions"][0]
|
||||||
|
assert "name" in dim
|
||||||
|
assert "score" in dim
|
||||||
|
assert "max_score" in dim
|
||||||
|
assert "percentage" in dim
|
||||||
|
assert "status" in dim
|
||||||
|
assert "items" in dim
|
||||||
|
|
@ -0,0 +1,3 @@
|
||||||
|
version = 1
|
||||||
|
revision = 3
|
||||||
|
requires-python = ">=3.14"
|
||||||
|
|
@ -1,156 +0,0 @@
|
||||||
# GEO 平台 - 项目总览
|
|
||||||
|
|
||||||
## 项目定位
|
|
||||||
|
|
||||||
**GEO(Generative Engine Optimization)平台** 是一款面向企业级客户的 **AI 搜索引擎品牌曝光度优化 SaaS 平台**。
|
|
||||||
|
|
||||||
随着生成式 AI 搜索引擎(如 ChatGPT、Perplexity、Google SGE、Kimi 等)的崛起,传统 SEO 已无法满足品牌在 AI 生成答案中的可见性需求。GEO 平台通过系统化的诊断、策略制定、内容生产、分发执行和监测优化,帮助企业在 AI 搜索引擎中获得更高的品牌引用率和曝光度。
|
|
||||||
|
|
||||||
## 核心业务价值
|
|
||||||
|
|
||||||
- **AI 引用检测**:自动扫描主流 AI 平台对品牌的引用情况,识别引用、未引用和竞品引用场景
|
|
||||||
- **策略智能生成**:基于诊断结果,自动生成 GEO 优化策略和内容生产计划
|
|
||||||
- **内容自动化生产**:利用 AI Agent 框架自动生成符合 GEO 优化标准的内容资产
|
|
||||||
- **多渠道分发执行**:将内容分发至目标平台,并跟踪分发效果
|
|
||||||
- **持续监测优化**:建立监测闭环,持续追踪品牌在 AI 搜索引擎中的表现变化
|
|
||||||
|
|
||||||
## 业务生命周期(5 个阶段)
|
|
||||||
|
|
||||||
GEO 平台的业务运营遵循完整的 5 阶段生命周期:
|
|
||||||
|
|
||||||
| 阶段 | 名称 | 核心目标 | 关键动作 |
|
|
||||||
|------|------|----------|----------|
|
|
||||||
| Stage 1 | 诊断分析 | 了解品牌在 AI 搜索中的现状 | 查询执行、引用检测、竞品分析 |
|
|
||||||
| Stage 2 | 策略制定 | 制定 GEO 优化策略 | 策略生成、规则制定、计划排期 |
|
|
||||||
| Stage 3 | 内容生产 | 生成符合 GEO 标准的内容 | 内容生成、质量检查、素材准备 |
|
|
||||||
| Stage 4 | 分发执行 | 将内容分发至目标渠道 | 渠道管理、内容发布、效果追踪 |
|
|
||||||
| Stage 5 | 监测优化 | 持续监测并优化效果 | 性能追踪、报告生成、策略迭代 |
|
|
||||||
|
|
||||||
## 双运营模式
|
|
||||||
|
|
||||||
GEO 平台支持两种运营模式,满足不同客户的需求:
|
|
||||||
|
|
||||||
### 模式一:自主订阅(SaaS 自助)
|
|
||||||
|
|
||||||
- **目标客户**:中小企业、有自有运营团队的品牌方
|
|
||||||
- **核心特征**:
|
|
||||||
- 客户自主注册订阅
|
|
||||||
- 通过平台自助完成全生命周期操作
|
|
||||||
- 按订阅等级(Basic / Pro / Enterprise)享受不同功能权限
|
|
||||||
- 支持多用户团队协作
|
|
||||||
- **收费模式**:月度/年度订阅制
|
|
||||||
|
|
||||||
### 模式二:代理运营(全托管服务)
|
|
||||||
|
|
||||||
- **目标客户**:大型企业、需要专业团队代运营的品牌方
|
|
||||||
- **核心特征**:
|
|
||||||
- 由专业运营团队代替客户执行 GEO 优化
|
|
||||||
- 客户通过代理模式接入,无需直接操作系统
|
|
||||||
- 运营团队使用平台全部功能为客户服务
|
|
||||||
- 提供定制化策略和专属服务
|
|
||||||
- **收费模式**:项目制或月度服务费
|
|
||||||
|
|
||||||
### 双模式权限设计
|
|
||||||
|
|
||||||
| 功能模块 | 自主订阅 | 代理运营 |
|
|
||||||
|----------|----------|----------|
|
|
||||||
| 查询管理 | 有(按订阅等级限额) | 有(无限制) |
|
|
||||||
| 引用检测 | 有 | 有 |
|
|
||||||
| 策略生成 | 有 | 有 |
|
|
||||||
| 内容生产 | 有(AI 生成限额) | 有 |
|
|
||||||
| 分发执行 | 有 | 有 |
|
|
||||||
| 监测报告 | 有 | 有 |
|
|
||||||
| 多用户管理 | Enterprise 支持 | 支持 |
|
|
||||||
| 代理客户管理 | 无 | 有 |
|
|
||||||
| 白标报告 | 无 | 有 |
|
|
||||||
|
|
||||||
## 技术栈概要
|
|
||||||
|
|
||||||
### 前端
|
|
||||||
|
|
||||||
- **框架**:Next.js 15 + React 19 + TypeScript
|
|
||||||
- **样式**:Tailwind CSS 4 + shadcn/ui
|
|
||||||
- **状态管理**:React Server Components + SWR
|
|
||||||
- **认证**:NextAuth.js v5
|
|
||||||
- **图表**:Recharts
|
|
||||||
|
|
||||||
### 后端
|
|
||||||
|
|
||||||
- **框架**:FastAPI + Python 3.12
|
|
||||||
- **数据库**:PostgreSQL + SQLAlchemy 2.0
|
|
||||||
- **ORM**:SQLAlchemy 2.0 + Alembic 迁移
|
|
||||||
- **异步**:Celery + Redis 任务队列
|
|
||||||
- **认证**:JWT + OAuth2
|
|
||||||
|
|
||||||
### AI Agent 框架
|
|
||||||
|
|
||||||
- **架构**:模块化 Agent 注册机制
|
|
||||||
- **通信**:基于消息队列的异步通信协议
|
|
||||||
- **核心 Agent**:CitationDetector / ContentGenerator / RuleChecker / CompetitorAnalyzer / PerformanceTracker
|
|
||||||
|
|
||||||
### 基础设施
|
|
||||||
|
|
||||||
- **容器化**:Docker + Docker Compose
|
|
||||||
- **部署**:云服务器 + Nginx 反向代理
|
|
||||||
- **监控**:日志系统 + 健康检查
|
|
||||||
|
|
||||||
## 快速开始指引
|
|
||||||
|
|
||||||
### 环境准备
|
|
||||||
|
|
||||||
1. 确保已安装 Docker 和 Docker Compose
|
|
||||||
2. 克隆项目仓库到本地
|
|
||||||
3. 复制 `.env.example` 为 `.env` 并配置必要的环境变量
|
|
||||||
|
|
||||||
### 启动开发环境
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# 启动全部服务(前端 + 后端 + 数据库 + Redis)
|
|
||||||
docker-compose up -d
|
|
||||||
|
|
||||||
# 后端数据库迁移
|
|
||||||
cd backend
|
|
||||||
alembic upgrade head
|
|
||||||
|
|
||||||
# 启动前端开发服务器
|
|
||||||
cd frontend
|
|
||||||
npm run dev
|
|
||||||
```
|
|
||||||
|
|
||||||
### 默认访问地址
|
|
||||||
|
|
||||||
- 前端应用:`http://localhost:3000`
|
|
||||||
- 后端 API:`http://localhost:8000`
|
|
||||||
- API 文档:`http://localhost:8000/docs`
|
|
||||||
|
|
||||||
### 初始账号
|
|
||||||
|
|
||||||
系统初始化后,可通过注册功能创建首个管理员账号,或联系运维团队获取默认管理员凭据。
|
|
||||||
|
|
||||||
## 文档导航
|
|
||||||
|
|
||||||
| 目录 | 说明 |
|
|
||||||
|------|------|
|
|
||||||
| `00-project/` | 项目概述、架构设计、技术栈说明 |
|
|
||||||
| `01-requirements/` | 业务需求、功能清单、用户故事 |
|
|
||||||
| `02-design/` | UI 设计、数据库设计、Agent 框架设计 |
|
|
||||||
| `03-development/` | 开发规范、TDD 流程、模块指南 |
|
|
||||||
| `04-testing/` | 测试策略、测试计划、测试报告 |
|
|
||||||
| `05-deployment/` | 部署指南、Docker 配置、环境配置 |
|
|
||||||
| `06-progress/` | 迭代计划、任务看板、周报 |
|
|
||||||
|
|
||||||
## 项目里程碑
|
|
||||||
|
|
||||||
| 阶段 | 目标 | 状态 |
|
|
||||||
|------|------|------|
|
|
||||||
| Phase 0 | 基础设施搭建 + 文档体系建立 | 进行中 |
|
|
||||||
| Phase 1 | Stage 1 诊断分析(查询+检测)| 规划中 |
|
|
||||||
| Phase 2 | Stage 2 策略制定 + Stage 3 内容生产 | 规划中 |
|
|
||||||
| Phase 3 | Stage 4 分发执行 + Stage 5 监测优化 | 规划中 |
|
|
||||||
| Phase 4 | 代理运营模式 + 管理后台 | 规划中 |
|
|
||||||
| Phase 5 | 性能优化 + 稳定性提升 | 规划中 |
|
|
||||||
| Phase 6 | 生产环境部署 + 正式发布 | 规划中 |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
*GEO 平台 - 让品牌在 AI 时代被看见。*
|
|
||||||
|
|
@ -1,336 +0,0 @@
|
||||||
# GEO 平台 - 系统架构设计
|
|
||||||
|
|
||||||
## 整体架构
|
|
||||||
|
|
||||||
GEO 平台采用分层架构设计,由以下四层组成:
|
|
||||||
|
|
||||||
```
|
|
||||||
┌─────────────────────────────────────────────────────────────────┐
|
|
||||||
│ Presentation Layer │
|
|
||||||
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────────────────┐ │
|
|
||||||
│ │ Next.js │ │ React 19 │ │ Tailwind CSS + shadcn │ │
|
|
||||||
│ │ App Router │ │ Components │ │ UI Components │ │
|
|
||||||
│ └─────────────┘ └─────────────┘ └─────────────────────────┘ │
|
|
||||||
├─────────────────────────────────────────────────────────────────┤
|
|
||||||
│ API Gateway Layer │
|
|
||||||
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────────────────┐ │
|
|
||||||
│ │ FastAPI │ │ REST API │ │ JWT / OAuth2 │ │
|
|
||||||
│ │ Routers │ │ Endpoints │ │ Authentication │ │
|
|
||||||
│ └─────────────┘ └─────────────┘ └─────────────────────────┘ │
|
|
||||||
├─────────────────────────────────────────────────────────────────┤
|
|
||||||
│ Service & Agent Layer │
|
|
||||||
│ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────────────┐ │
|
|
||||||
│ │ Citation │ │ Content │ │ Rule │ │ Competitor │ │
|
|
||||||
│ │ Detector │ │Generator │ │ Checker │ │ Analyzer │ │
|
|
||||||
│ └──────────┘ └──────────┘ └──────────┘ └──────────────────┘ │
|
|
||||||
│ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────────────┐ │
|
|
||||||
│ │Performance│ │ Task │ │ Query │ │ Report │ │
|
|
||||||
│ │ Tracker │ │ Scheduler│ │ Engine │ │ Generator │ │
|
|
||||||
│ └──────────┘ └──────────┘ └──────────┘ └──────────────────┘ │
|
|
||||||
├─────────────────────────────────────────────────────────────────┤
|
|
||||||
│ Infrastructure Layer │
|
|
||||||
│ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌────────────────┐ │
|
|
||||||
│ │PostgreSQL│ │ Redis │ │ Celery │ │ Docker │ │
|
|
||||||
│ │ (Data) │ │ (Cache) │ │ (Queue) │ │ (Container) │ │
|
|
||||||
│ └──────────┘ └──────────┘ └──────────┘ └────────────────┘ │
|
|
||||||
└─────────────────────────────────────────────────────────────────┘
|
|
||||||
```
|
|
||||||
|
|
||||||
### 四层架构说明
|
|
||||||
|
|
||||||
#### 1. 表现层(Presentation Layer)
|
|
||||||
|
|
||||||
- **Next.js App Router**:基于文件系统的路由,支持服务端渲染(SSR)和静态生成(SSG)
|
|
||||||
- **React 19**:最新 React 特性,包括 Server Components、Actions、Form 状态管理
|
|
||||||
- **Tailwind CSS + shadcn/ui**:原子化 CSS + 可复用无头组件库,支持快速构建一致化 UI
|
|
||||||
|
|
||||||
**关键设计决策**:
|
|
||||||
- 采用 BFF(Backend for Frontend)模式,前端通过 API Client 与后端通信
|
|
||||||
- 认证状态使用 NextAuth.js 管理,支持 Credentials + OAuth 双模式
|
|
||||||
- 数据获取优先使用 Server Components 减少客户端 JavaScript 体积
|
|
||||||
|
|
||||||
#### 2. API 网关层(API Gateway Layer)
|
|
||||||
|
|
||||||
- **FastAPI**:高性能 Python Web 框架,原生支持异步和自动 API 文档
|
|
||||||
- **RESTful API**:标准的 REST 接口设计,返回 JSON 格式数据
|
|
||||||
- **认证授权**:JWT Token + OAuth2 密码流,支持角色权限控制
|
|
||||||
|
|
||||||
**模块划分**:
|
|
||||||
| 模块 | 职责 | 对应文件 |
|
|
||||||
|------|------|----------|
|
|
||||||
| Auth API | 用户注册、登录、Token 刷新、权限验证 | `api/auth.py` |
|
|
||||||
| Query API | 查询创建、执行、结果获取 | `api/queries.py` |
|
|
||||||
| Citation API | 引用记录管理、统计分析 | `api/citations.py` |
|
|
||||||
| Report API | 报告生成、导出、查看 | `api/reports.py` |
|
|
||||||
| Admin API | 用户管理、系统配置、订阅管理 | `api/admin.py` |
|
|
||||||
| Subscription API | 订阅计划、支付、限额管理 | `api/subscriptions.py` |
|
|
||||||
|
|
||||||
#### 3. 服务与 Agent 层(Service & Agent Layer)
|
|
||||||
|
|
||||||
本层是 GEO 平台的核心业务逻辑层,包含传统 Service 和 AI Agent 两大类组件。
|
|
||||||
|
|
||||||
**传统 Service**:
|
|
||||||
| Service | 职责 |
|
|
||||||
|---------|------|
|
|
||||||
| QueryService | 查询生命周期管理、AI 平台调用、结果聚合 |
|
|
||||||
| CitationService | 引用检测算法执行、引用记录存储、置信度评估 |
|
|
||||||
| ReportService | 报告模板渲染、数据聚合、多种格式导出 |
|
|
||||||
| UserService | 用户 CRUD、角色权限、订阅状态管理 |
|
|
||||||
| SubscriptionService | 订阅计划管理、配额控制、升级降级 |
|
|
||||||
|
|
||||||
**AI Agent 集群**:详见 `docs/02-design/agent-framework.md`
|
|
||||||
|
|
||||||
#### 4. 基础设施层(Infrastructure Layer)
|
|
||||||
|
|
||||||
| 组件 | 用途 | 技术选型 |
|
|
||||||
|------|------|----------|
|
|
||||||
| PostgreSQL | 主数据库,存储业务数据 | PostgreSQL 15+ |
|
|
||||||
| Redis | 缓存、Celery Broker、会话存储 | Redis 7+ |
|
|
||||||
| Celery | 异步任务队列、定时任务调度 | Celery 5+ |
|
|
||||||
| Docker | 应用容器化部署 | Docker + Compose |
|
|
||||||
|
|
||||||
## 模块解耦原则
|
|
||||||
|
|
||||||
### 1. 领域驱动设计(DDD)
|
|
||||||
|
|
||||||
- **领域模型独立**:核心业务模型(Query、Citation、User、Subscription)不依赖任何框架
|
|
||||||
- **边界上下文清晰**:每个模块维护自己的数据模型和业务规则
|
|
||||||
- **防腐层(ACL)**:外部 AI 平台适配器通过防腐层与核心领域隔离
|
|
||||||
|
|
||||||
### 2. 依赖倒置原则
|
|
||||||
|
|
||||||
```
|
|
||||||
API Layer → Service Layer → Repository Layer → Database
|
|
||||||
↑ ↑ ↑
|
|
||||||
依赖接口 依赖接口 依赖接口
|
|
||||||
```
|
|
||||||
|
|
||||||
- Service 层定义接口,Repository 层实现接口
|
|
||||||
- 便于单元测试时 Mock 依赖
|
|
||||||
- 支持未来数据库迁移或替换
|
|
||||||
|
|
||||||
### 3. 配置与代码分离
|
|
||||||
|
|
||||||
- 所有环境相关配置通过环境变量注入
|
|
||||||
- 应用配置集中管理在 `app/config.py`
|
|
||||||
- 不同环境(dev / staging / prod)使用不同的 `.env` 文件
|
|
||||||
|
|
||||||
### 4. 异步解耦
|
|
||||||
|
|
||||||
- 耗时操作(AI 查询执行、报告生成)通过 Celery 异步处理
|
|
||||||
- 前端通过轮询或 WebSocket 获取任务状态更新
|
|
||||||
- 任务执行状态持久化,支持断点续传和失败重试
|
|
||||||
|
|
||||||
## AI Agent 解耦通信协议设计
|
|
||||||
|
|
||||||
### 通信模型
|
|
||||||
|
|
||||||
GEO 平台的 AI Agent 之间采用 **基于消息队列的异步通信协议**,确保 Agent 之间松耦合、可独立部署和扩展。
|
|
||||||
|
|
||||||
```
|
|
||||||
┌─────────────┐ ┌─────────────┐ ┌─────────────────┐
|
|
||||||
│ Agent A │────▶│ Message │────▶│ Agent B │
|
|
||||||
│ (生产者) │ │ Queue │ │ (消费者) │
|
|
||||||
└─────────────┘ └─────────────┘ └─────────────────┘
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
┌─────────────┐
|
|
||||||
│ Registry │
|
|
||||||
│ (注册中心) │
|
|
||||||
└─────────────┘
|
|
||||||
```
|
|
||||||
|
|
||||||
### 消息格式
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"message_id": "uuid",
|
|
||||||
"timestamp": "2026-01-01T00:00:00Z",
|
|
||||||
"sender": "agent_name",
|
|
||||||
"recipient": "agent_name_or_broadcast",
|
|
||||||
"message_type": "task_request | task_result | status_update | heartbeat",
|
|
||||||
"payload": {
|
|
||||||
"task_id": "uuid",
|
|
||||||
"task_type": "detect_citations | generate_content | check_rules | analyze_competitors | track_performance",
|
|
||||||
"parameters": {},
|
|
||||||
"data": {},
|
|
||||||
"priority": 1
|
|
||||||
},
|
|
||||||
"correlation_id": "uuid",
|
|
||||||
"reply_to": "agent_name"
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### 通信模式
|
|
||||||
|
|
||||||
| 模式 | 说明 | 适用场景 |
|
|
||||||
|------|------|----------|
|
|
||||||
| 点对点 | Agent A 直接发送任务给 Agent B | 明确的一对一任务分配 |
|
|
||||||
| 广播 | Agent 向所有订阅者发送消息 | 状态更新、配置变更通知 |
|
|
||||||
| 发布/订阅 | Agent 向特定主题发布消息 | 事件驱动架构 |
|
|
||||||
| 请求/响应 | 同步等待返回结果 | 需要即时反馈的短任务 |
|
|
||||||
|
|
||||||
### Agent 注册与发现
|
|
||||||
|
|
||||||
- **注册中心**:Redis 作为轻量级注册中心
|
|
||||||
- **注册机制**:Agent 启动时向注册中心注册自己的能力和状态
|
|
||||||
- **健康检查**:定期发送心跳,失效 Agent 自动从注册中心移除
|
|
||||||
- **负载均衡**:任务分发时根据 Agent 负载状态选择执行节点
|
|
||||||
|
|
||||||
### 错误处理与重试
|
|
||||||
|
|
||||||
- 消息消费失败时自动进入死信队列(DLQ)
|
|
||||||
- 支持配置重试次数和重试间隔(指数退避)
|
|
||||||
- 失败任务可手动重新触发或自动补偿
|
|
||||||
|
|
||||||
## 双模式权限设计
|
|
||||||
|
|
||||||
### 用户角色模型
|
|
||||||
|
|
||||||
```
|
|
||||||
┌─────────────┐
|
|
||||||
│ User │
|
|
||||||
└──────┬──────┘
|
|
||||||
│
|
|
||||||
┌───────────────┼───────────────┐
|
|
||||||
▼ ▼ ▼
|
|
||||||
┌────────────┐ ┌────────────┐ ┌────────────┐
|
|
||||||
│ EndUser │ │ Admin │ │ Agent │
|
|
||||||
│ (终端用户) │ │ (管理员) │ │ (运营人员) │
|
|
||||||
└─────┬──────┘ └────────────┘ └─────┬──────┘
|
|
||||||
│ │
|
|
||||||
▼ ▼
|
|
||||||
┌──────────────┐ ┌──────────────┐
|
|
||||||
│ Self-Service │ │ Proxy-Service│
|
|
||||||
│ 自主订阅 │ │ 代理运营 │
|
|
||||||
└──────────────┘ └──────────────┘
|
|
||||||
```
|
|
||||||
|
|
||||||
### 权限矩阵
|
|
||||||
|
|
||||||
| 功能 | EndUser (Basic) | EndUser (Pro) | EndUser (Enterprise) | Admin | Agent |
|
|
||||||
|------|----------------|---------------|---------------------|-------|-------|
|
|
||||||
| 创建查询 | 10/月 | 50/月 | 无限制 | 无限制 | 无限制 |
|
|
||||||
| 引用检测 | 有 | 有 | 有 | 有 | 有 |
|
|
||||||
| 策略生成 | 基础 | 高级 | 定制化 | 全部 | 全部 |
|
|
||||||
| 内容生成 | 5/月 | 20/月 | 无限制 | 无限制 | 无限制 |
|
|
||||||
| 报告导出 | PDF | PDF + Excel | 全部格式 | 全部格式 | 全部格式 |
|
|
||||||
| 团队管理 | 无 | 无 | 有 | 有 | 有 |
|
|
||||||
| 代理客户管理 | 无 | 无 | 无 | 有 | 有 |
|
|
||||||
| 白标报告 | 无 | 无 | 无 | 有 | 有 |
|
|
||||||
| 系统配置 | 无 | 无 | 无 | 有 | 无 |
|
|
||||||
| 用户管理 | 无 | 无 | 无 | 有 | 无 |
|
|
||||||
|
|
||||||
### 权限实现
|
|
||||||
|
|
||||||
- **基于角色的访问控制(RBAC)**:用户通过角色获得一组权限
|
|
||||||
- **基于属性的访问控制(ABAC)**:结合用户订阅等级、团队归属等属性进行细粒度控制
|
|
||||||
- **中间件校验**:FastAPI 依赖注入实现统一的权限校验中间件
|
|
||||||
- **前端联动**:前端路由和组件根据用户权限动态渲染
|
|
||||||
|
|
||||||
### API 权限控制示例
|
|
||||||
|
|
||||||
```python
|
|
||||||
# FastAPI 依赖注入实现权限控制
|
|
||||||
async def require_admin(current_user: User = Depends(get_current_user)):
|
|
||||||
if not current_user.is_admin:
|
|
||||||
raise HTTPException(status_code=403, detail="需要管理员权限")
|
|
||||||
return current_user
|
|
||||||
|
|
||||||
async def require_subscription(min_tier: str):
|
|
||||||
async def checker(current_user: User = Depends(get_current_user)):
|
|
||||||
if not current_user.subscription.meets(min_tier):
|
|
||||||
raise HTTPException(status_code=403, detail="订阅等级不足")
|
|
||||||
return current_user
|
|
||||||
return checker
|
|
||||||
|
|
||||||
# 路由使用
|
|
||||||
@router.post("/queries")
|
|
||||||
async def create_query(
|
|
||||||
data: QueryCreate,
|
|
||||||
user: User = Depends(require_subscription("pro"))
|
|
||||||
):
|
|
||||||
...
|
|
||||||
```
|
|
||||||
|
|
||||||
## 数据流架构
|
|
||||||
|
|
||||||
### 查询执行数据流
|
|
||||||
|
|
||||||
```
|
|
||||||
用户创建查询 ──▶ API 接收 ──▶ 验证权限/配额 ──▶ 创建查询记录
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
提交 Celery 任务
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
┌─────────────────┐
|
|
||||||
│ Worker 执行查询 │
|
|
||||||
│ - 调用 AI 平台 │
|
|
||||||
│ - 聚合响应结果 │
|
|
||||||
└────────┬────────┘
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
┌─────────────────┐
|
|
||||||
│ CitationDetector │
|
|
||||||
│ - 解析引用内容 │
|
|
||||||
│ - 评估置信度 │
|
|
||||||
└────────┬────────┘
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
保存引用记录 ──▶ 更新查询状态
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
通知前端完成
|
|
||||||
```
|
|
||||||
|
|
||||||
### 报告生成数据流
|
|
||||||
|
|
||||||
```
|
|
||||||
用户请求报告 ──▶ API 接收 ──▶ 收集相关数据
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
ReportService 聚合数据
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
应用报告模板
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
生成 PDF/Excel/HTML
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
返回下载链接
|
|
||||||
```
|
|
||||||
|
|
||||||
## 扩展性设计
|
|
||||||
|
|
||||||
### 水平扩展
|
|
||||||
|
|
||||||
- **无状态服务**:API 服务无状态设计,可通过负载均衡水平扩展
|
|
||||||
- **数据库读写分离**:未来支持主从复制,读操作分散到从库
|
|
||||||
- **缓存层**:Redis 缓存热点数据,减少数据库压力
|
|
||||||
|
|
||||||
### AI 平台适配器扩展
|
|
||||||
|
|
||||||
GEO 平台需要对接多种 AI 搜索引擎平台,适配器设计遵循开闭原则:
|
|
||||||
|
|
||||||
```
|
|
||||||
┌─────────────────────────────────────────┐
|
|
||||||
│ PlatformAdapter (抽象基类) │
|
|
||||||
│ - execute_query() │
|
|
||||||
│ - parse_response() │
|
|
||||||
│ - validate_config() │
|
|
||||||
└─────────────────────────────────────────┘
|
|
||||||
│ │ │
|
|
||||||
▼ ▼ ▼
|
|
||||||
┌──────────┐ ┌──────────┐ ┌──────────┐
|
|
||||||
│ ChatGPT │ │ Kimi │ │ Perplex │
|
|
||||||
│ Adapter │ │ Adapter │ │ Adapter │
|
|
||||||
└──────────┘ └──────────┘ └──────────┘
|
|
||||||
```
|
|
||||||
|
|
||||||
新增 AI 平台只需实现 `PlatformAdapter` 接口,无需修改现有代码。
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
*本文档描述 GEO 平台的整体系统架构设计,详细模块设计请参考各子系统设计文档。*
|
|
||||||
|
|
@ -1,44 +0,0 @@
|
||||||
# GEO 平台 - 更新日志
|
|
||||||
|
|
||||||
## 概述
|
|
||||||
|
|
||||||
本文档记录 GEO 平台的所有版本更新内容,按时间倒序排列。
|
|
||||||
|
|
||||||
> **TODO**: 本文档为占位文件,待补充完整内容。
|
|
||||||
|
|
||||||
## 版本规范
|
|
||||||
|
|
||||||
采用 [语义化版本](https://semver.org/lang/zh-CN/) 规范:`主版本号.次版本号.修订号`
|
|
||||||
|
|
||||||
- **主版本号**:不兼容的 API 修改
|
|
||||||
- **次版本号**:向下兼容的功能性新增
|
|
||||||
- **修订号**:向下兼容的问题修正
|
|
||||||
|
|
||||||
## 更新记录
|
|
||||||
|
|
||||||
### [Unreleased]
|
|
||||||
|
|
||||||
#### 新增
|
|
||||||
- [ ] 项目文档体系建立
|
|
||||||
- [ ] 基础架构搭建
|
|
||||||
|
|
||||||
#### 变更
|
|
||||||
- 无
|
|
||||||
|
|
||||||
#### 修复
|
|
||||||
- 无
|
|
||||||
|
|
||||||
### [0.1.0] - 待发布
|
|
||||||
|
|
||||||
#### 新增
|
|
||||||
- [ ] TODO: 待填充 Phase 1 功能清单
|
|
||||||
|
|
||||||
#### 变更
|
|
||||||
- 无
|
|
||||||
|
|
||||||
#### 修复
|
|
||||||
- 无
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
*本文档待补充,每次版本发布时更新。*
|
|
||||||
|
|
@ -1,70 +0,0 @@
|
||||||
# GEO 平台 - 技术栈说明
|
|
||||||
|
|
||||||
## 概述
|
|
||||||
|
|
||||||
本文档详细说明 GEO 平台采用的技术栈及各技术选型的理由。
|
|
||||||
|
|
||||||
> **TODO**: 本文档为占位文件,待补充完整内容。
|
|
||||||
|
|
||||||
## 前端技术栈
|
|
||||||
|
|
||||||
### 待补充内容
|
|
||||||
|
|
||||||
- [ ] Next.js 15 核心特性与选型理由
|
|
||||||
- [ ] React 19 新特性应用
|
|
||||||
- [ ] TypeScript 类型系统实践
|
|
||||||
- [ ] Tailwind CSS 4 原子化样式方案
|
|
||||||
- [ ] shadcn/ui 组件库定制
|
|
||||||
- [ ] NextAuth.js v5 认证方案
|
|
||||||
- [ ] SWR 数据获取策略
|
|
||||||
- [ ] Recharts 图表实现
|
|
||||||
|
|
||||||
## 后端技术栈
|
|
||||||
|
|
||||||
### 待补充内容
|
|
||||||
|
|
||||||
- [ ] FastAPI 异步框架特性
|
|
||||||
- [ ] Python 3.12 类型提示最佳实践
|
|
||||||
- [ ] SQLAlchemy 2.0 ORM 使用
|
|
||||||
- [ ] Alembic 数据库迁移管理
|
|
||||||
- [ ] Celery 异步任务队列
|
|
||||||
- [ ] Redis 缓存与消息代理
|
|
||||||
- [ ] JWT + OAuth2 认证实现
|
|
||||||
- [ ] Pydantic 数据校验
|
|
||||||
|
|
||||||
## AI Agent 技术栈
|
|
||||||
|
|
||||||
### 待补充内容
|
|
||||||
|
|
||||||
- [ ] Agent 框架核心库选型
|
|
||||||
- [ ] LLM 模型选型与调用
|
|
||||||
- [ ] 提示工程(Prompt Engineering)框架
|
|
||||||
- [ ] 向量数据库(可选)
|
|
||||||
- [ ] 模型评估与监控
|
|
||||||
|
|
||||||
## 基础设施技术栈
|
|
||||||
|
|
||||||
### 待补充内容
|
|
||||||
|
|
||||||
- [ ] Docker 容器化方案
|
|
||||||
- [ ] Docker Compose 服务编排
|
|
||||||
- [ ] PostgreSQL 数据库配置
|
|
||||||
- [ ] Redis 集群配置(生产环境)
|
|
||||||
- [ ] Nginx 反向代理与负载均衡
|
|
||||||
- [ ] 日志收集方案
|
|
||||||
- [ ] 监控告警系统
|
|
||||||
|
|
||||||
## 开发工具链
|
|
||||||
|
|
||||||
### 待补充内容
|
|
||||||
|
|
||||||
- [ ] 代码格式化(Black / Prettier)
|
|
||||||
- [ ] 代码检查(Ruff / ESLint)
|
|
||||||
- [ ] 类型检查(mypy)
|
|
||||||
- [ ] 测试框架(pytest / Jest / Playwright)
|
|
||||||
- [ ] Git 工作流与 Hook
|
|
||||||
- [ ] CI/CD 工具
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
*本文档待补充,请参考 `README.md` 中的技术栈概要获取基本信息。*
|
|
||||||
|
|
@ -1,318 +0,0 @@
|
||||||
# GEO 业务生命周期定义
|
|
||||||
|
|
||||||
## 概述
|
|
||||||
|
|
||||||
GEO(Generative Engine Optimization)业务的运营遵循一个完整的 5 阶段生命周期。每个阶段都有明确的核心动作、输入输出和对应的功能模块。本生命周期同时适用于**自主订阅**和**代理运营**两种模式。
|
|
||||||
|
|
||||||
```
|
|
||||||
┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐
|
|
||||||
│ Stage 1 │───▶│ Stage 2 │───▶│ Stage 3 │───▶│ Stage 4 │───▶│ Stage 5 │
|
|
||||||
│ 诊断分析 │ │ 策略制定 │ │ 内容生产 │ │ 分发执行 │ │ 监测优化 │
|
|
||||||
└──────────┘ └──────────┘ └──────────┘ └──────────┘ └──────────┘
|
|
||||||
│ │ │ │ │
|
|
||||||
└───────────────┴───────────────┴───────────────┴───────────────┘
|
|
||||||
持续迭代优化
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Stage 1:诊断分析
|
|
||||||
|
|
||||||
### 目标
|
|
||||||
|
|
||||||
全面了解品牌在主流 AI 搜索引擎中的现状,识别品牌被引用的情况、未被引用的机会点以及竞品的引用表现。
|
|
||||||
|
|
||||||
### 核心动作
|
|
||||||
|
|
||||||
| 动作 | 说明 | 执行方式 |
|
|
||||||
|------|------|----------|
|
|
||||||
| 品牌查询执行 | 向多个 AI 平台发送与品牌相关的查询 | 系统自动执行 |
|
|
||||||
| 响应结果收集 | 收集各 AI 平台返回的生成内容 | 系统自动收集 |
|
|
||||||
| 引用检测分析 | 分析生成内容中是否引用品牌、引用方式和位置 | AI Agent 自动分析 |
|
|
||||||
| 竞品对比分析 | 分析竞品在相同查询下的引用情况 | AI Agent 自动分析 |
|
|
||||||
| 诊断报告生成 | 汇总分析结果,生成诊断报告 | 系统自动生成 |
|
|
||||||
|
|
||||||
### 输入
|
|
||||||
|
|
||||||
- 品牌名称、品牌关键词
|
|
||||||
- 目标 AI 平台列表(ChatGPT、Kimi、Perplexity 等)
|
|
||||||
- 核心查询模板(行业问题、产品对比、品牌认知等)
|
|
||||||
- 主要竞品列表
|
|
||||||
|
|
||||||
### 输出
|
|
||||||
|
|
||||||
- 各平台查询响应原文
|
|
||||||
- 引用检测结果(引用/未引用/竞品引用)
|
|
||||||
- 引用置信度评分
|
|
||||||
- 竞品引用对比数据
|
|
||||||
- 诊断分析报告(PDF/HTML)
|
|
||||||
|
|
||||||
### 对应功能模块
|
|
||||||
|
|
||||||
- **查询管理系统**:查询创建、执行、状态跟踪
|
|
||||||
- **引用检测引擎**:AI 响应解析、引用识别、置信度评估
|
|
||||||
- **竞品分析模块**:竞品引用数据对比、SWOT 分析
|
|
||||||
- **诊断报告模块**:数据聚合、报告模板渲染、多格式导出
|
|
||||||
|
|
||||||
### 关键指标
|
|
||||||
|
|
||||||
| 指标 | 说明 | 计算方式 |
|
|
||||||
|------|------|----------|
|
|
||||||
| 引用率 | 品牌被引用的查询占比 | 被引用查询数 / 总查询数 |
|
|
||||||
| 平均置信度 | 引用检测的平均置信度分数 | 所有引用记录置信度之和 / 引用记录数 |
|
|
||||||
| 竞品引用率 | 竞品被引用的查询占比 | 竞品被引用查询数 / 总查询数 |
|
|
||||||
| 平台覆盖率 | 已检测的 AI 平台数量占比 | 已检测平台数 / 目标平台数 |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Stage 2:策略制定
|
|
||||||
|
|
||||||
### 目标
|
|
||||||
|
|
||||||
基于 Stage 1 的诊断结果,制定针对性的 GEO 优化策略,明确优化方向、内容生产计划和执行时间表。
|
|
||||||
|
|
||||||
### 核心动作
|
|
||||||
|
|
||||||
| 动作 | 说明 | 执行方式 |
|
|
||||||
|------|------|----------|
|
|
||||||
| 诊断结果解读 | 分析诊断报告,识别关键问题和机会点 | AI Agent + 人工确认 |
|
|
||||||
| 策略自动生成 | 基于诊断结果生成 GEO 优化策略 | AI Agent 自动生成 |
|
|
||||||
| 规则库制定 | 制定内容生成和优化的规则约束 | 系统模板 + 人工调整 |
|
|
||||||
| 内容计划排期 | 制定内容生产的主题、数量和排期 | 系统自动排期 |
|
|
||||||
| 目标设定 | 设定可量化的优化目标 | 人工设定 + 系统建议 |
|
|
||||||
|
|
||||||
### 输入
|
|
||||||
|
|
||||||
- Stage 1 诊断分析报告
|
|
||||||
- 品牌定位与核心信息
|
|
||||||
- 目标受众画像
|
|
||||||
- 预算与资源约束
|
|
||||||
- 行业基准数据
|
|
||||||
|
|
||||||
### 输出
|
|
||||||
|
|
||||||
- GEO 优化策略文档
|
|
||||||
- 内容生产计划表
|
|
||||||
- 规则库配置
|
|
||||||
- 优化目标与 KPI
|
|
||||||
- 策略执行时间表
|
|
||||||
|
|
||||||
### 对应功能模块
|
|
||||||
|
|
||||||
- **策略生成引擎**:基于诊断数据自动生成优化策略
|
|
||||||
- **规则管理系统**:规则创建、版本管理、生效控制
|
|
||||||
- **计划排期系统**:内容主题规划、发布时间排期
|
|
||||||
- **目标管理系统**:KPI 设定、目标追踪、偏差预警
|
|
||||||
|
|
||||||
### 策略类型
|
|
||||||
|
|
||||||
| 策略类型 | 适用场景 | 预期效果 |
|
|
||||||
|----------|----------|----------|
|
|
||||||
| 内容补充策略 | 品牌未被引用的查询场景 | 提升引用覆盖率 |
|
|
||||||
| 内容优化策略 | 品牌被引用但质量不高 | 提升引用质量和排名 |
|
|
||||||
| 竞品对抗策略 | 竞品引用率高于品牌 | 抢夺竞品引用份额 |
|
|
||||||
| 平台专攻策略 | 特定平台表现不佳 | 提升该平台表现 |
|
|
||||||
| 全面优化策略 | 整体表现均需提升 | 系统性提升各指标 |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Stage 3:内容生产
|
|
||||||
|
|
||||||
### 目标
|
|
||||||
|
|
||||||
根据 Stage 2 制定的策略和计划,生产符合 GEO 优化标准的高质量内容资产。
|
|
||||||
|
|
||||||
### 核心动作
|
|
||||||
|
|
||||||
| 动作 | 说明 | 执行方式 |
|
|
||||||
|------|------|----------|
|
|
||||||
| 内容主题生成 | 根据策略和计划生成具体的内容主题 | AI Agent 自动生成 |
|
|
||||||
| 内容草稿生成 | 生成符合 GEO 标准的文章/问答/说明 | AI Agent 自动生成 |
|
|
||||||
| 质量规则检查 | 检查内容是否符合规则库要求 | RuleChecker Agent |
|
|
||||||
| 人工审核确认 | 关键内容经人工审核后发布 | 人工操作 |
|
|
||||||
| 素材资源准备 | 准备配图、数据图表等辅助素材 | 系统生成 + 人工上传 |
|
|
||||||
|
|
||||||
### 输入
|
|
||||||
|
|
||||||
- Stage 2 策略文档和内容计划
|
|
||||||
- 规则库配置
|
|
||||||
- 品牌素材库(Logo、产品介绍、案例等)
|
|
||||||
- 行业知识和参考资料
|
|
||||||
|
|
||||||
### 输出
|
|
||||||
|
|
||||||
- GEO 优化内容资产(文章、问答、白皮书等)
|
|
||||||
- 内容质量评分
|
|
||||||
- 内容元数据(关键词、目标平台、适用查询等)
|
|
||||||
- 素材资源包
|
|
||||||
|
|
||||||
### 对应功能模块
|
|
||||||
|
|
||||||
- **内容生成引擎**:基于主题和规则自动生成内容
|
|
||||||
- **规则检查引擎**:内容质量校验、规则合规性检查
|
|
||||||
- **素材管理系统**:素材上传、管理、智能匹配
|
|
||||||
- **内容版本管理**:内容草稿、版本对比、发布控制
|
|
||||||
|
|
||||||
### 内容类型
|
|
||||||
|
|
||||||
| 内容类型 | 说明 | 适用平台 |
|
|
||||||
|----------|------|----------|
|
|
||||||
| 问答型内容 | 针对常见问题的结构化回答 | ChatGPT、Kimi |
|
|
||||||
| 文章型内容 | 深度行业文章或品牌故事 | Perplexity、搜索引擎 |
|
|
||||||
| 数据型内容 | 研究报告、数据统计 | 所有平台 |
|
|
||||||
| 对比型内容 | 产品与竞品的客观对比 | 所有平台 |
|
|
||||||
| 指南型内容 | 操作指南、使用教程 | 所有平台 |
|
|
||||||
|
|
||||||
### GEO 内容质量标准
|
|
||||||
|
|
||||||
- **权威性**:内容需体现品牌专业度和行业权威性
|
|
||||||
- **结构化**:使用清晰的标题层级和结构化格式
|
|
||||||
- **数据支撑**:包含统计数据、案例研究等可信来源
|
|
||||||
- **关键词优化**:自然融入目标关键词和长尾词
|
|
||||||
- **平台适配**:根据不同 AI 平台的偏好调整内容形式
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Stage 4:分发执行
|
|
||||||
|
|
||||||
### 目标
|
|
||||||
|
|
||||||
将 Stage 3 生产的内容分发至目标渠道,执行 GEO 优化策略,并追踪分发效果。
|
|
||||||
|
|
||||||
### 核心动作
|
|
||||||
|
|
||||||
| 动作 | 说明 | 执行方式 |
|
|
||||||
|------|------|----------|
|
|
||||||
| 渠道配置管理 | 管理内容分发的目标渠道和平台 | 系统配置 |
|
|
||||||
| 内容发布执行 | 将内容发布至目标渠道 | 系统自动/半自动 |
|
|
||||||
| 分发效果追踪 | 追踪内容发布后的表现数据 | 系统自动追踪 |
|
|
||||||
| 链接建设 | 建立品牌内容与权威来源的链接关系 | 系统辅助 |
|
|
||||||
| 社媒同步 | 将内容同步至社交媒体扩大影响 | 系统对接 |
|
|
||||||
|
|
||||||
### 输入
|
|
||||||
|
|
||||||
- Stage 3 生产的内容资产
|
|
||||||
- 渠道配置信息
|
|
||||||
- 发布时间表
|
|
||||||
- 分发策略参数
|
|
||||||
|
|
||||||
### 输出
|
|
||||||
|
|
||||||
- 内容发布记录
|
|
||||||
- 分发效果数据(曝光、点击、引用)
|
|
||||||
- 渠道表现报告
|
|
||||||
- 链接建设报告
|
|
||||||
|
|
||||||
### 对应功能模块
|
|
||||||
|
|
||||||
- **渠道管理系统**:渠道配置、状态监控、接入管理
|
|
||||||
- **内容分发引擎**:发布执行、状态同步、失败重试
|
|
||||||
- **效果追踪系统**:UTM 追踪、回源分析、数据收集
|
|
||||||
- **链接管理系统**:链接生成、追踪、失效检测
|
|
||||||
|
|
||||||
### 分发渠道
|
|
||||||
|
|
||||||
| 渠道类型 | 具体渠道 | 内容形式 |
|
|
||||||
|----------|----------|----------|
|
|
||||||
| 自有媒体 | 官网博客、知识库、帮助中心 | 文章、指南 |
|
|
||||||
| 第三方平台 | 知乎、百家号、今日头条 | 文章、问答 |
|
|
||||||
| 社媒平台 | 微信公众号、LinkedIn、Twitter | 短文、图文 |
|
|
||||||
| 行业平台 | 行业论坛、垂直社区 | 专业文章 |
|
|
||||||
| 权威来源 | 维基百科、行业报告库 | 数据、引用源 |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Stage 5:监测优化
|
|
||||||
|
|
||||||
### 目标
|
|
||||||
|
|
||||||
持续监测品牌在 AI 搜索引擎中的表现变化,评估优化效果,并基于数据反馈迭代优化策略。
|
|
||||||
|
|
||||||
### 核心动作
|
|
||||||
|
|
||||||
| 动作 | 说明 | 执行方式 |
|
|
||||||
|------|------|----------|
|
|
||||||
| 定期查询执行 | 按设定周期重新执行品牌查询 | 系统自动执行 |
|
|
||||||
| 引用变化追踪 | 对比历史数据,追踪引用变化趋势 | 系统自动追踪 |
|
|
||||||
| 性能指标计算 | 计算引用率、排名、覆盖率等 KPI | 系统自动计算 |
|
|
||||||
| 效果报告生成 | 生成周期性效果报告 | 系统自动生成 |
|
|
||||||
| 策略迭代优化 | 基于效果数据调整优化策略 | AI Agent + 人工决策 |
|
|
||||||
|
|
||||||
### 输入
|
|
||||||
|
|
||||||
- Stage 1-4 的历史数据和执行记录
|
|
||||||
- 设定的 KPI 目标和基准值
|
|
||||||
- 定期查询执行结果
|
|
||||||
- 竞品最新表现数据
|
|
||||||
|
|
||||||
### 输出
|
|
||||||
|
|
||||||
- 实时监测仪表盘
|
|
||||||
- 周期性效果报告(周报/月报/季报)
|
|
||||||
- 趋势分析图表
|
|
||||||
- 策略调整建议
|
|
||||||
- 告警通知
|
|
||||||
|
|
||||||
### 对应功能模块
|
|
||||||
|
|
||||||
- **性能追踪系统**:KPI 计算、趋势分析、告警触发
|
|
||||||
- **报告生成系统**:报告模板、自动渲染、多格式导出
|
|
||||||
- **仪表盘系统**:可视化图表、实时数据、下钻分析
|
|
||||||
- **告警系统**:阈值设置、多渠道通知、告警升级
|
|
||||||
|
|
||||||
### 监测周期
|
|
||||||
|
|
||||||
| 周期 | 监测内容 | 报告类型 |
|
|
||||||
|------|----------|----------|
|
|
||||||
| 实时 | 查询执行状态、系统健康 | 系统监控 |
|
|
||||||
| 每日 | 引用变化、异常检测 | 日报 |
|
|
||||||
| 每周 | 综合表现、趋势分析 | 周报 |
|
|
||||||
| 每月 | 完整效果评估、策略建议 | 月报 |
|
|
||||||
| 每季 | 深度分析、竞品对比、策略迭代 | 季报 |
|
|
||||||
|
|
||||||
### 核心 KPI 体系
|
|
||||||
|
|
||||||
| KPI | 说明 | 目标值 |
|
|
||||||
|-----|------|--------|
|
|
||||||
| 引用率提升 | 品牌引用率的增长幅度 | 月度 ≥ 5% |
|
|
||||||
| 平均排名 | 品牌在引用中的平均位置 | 争取前 3 |
|
|
||||||
| 平台覆盖率 | 已优化平台的引用占比 | ≥ 80% |
|
|
||||||
| 内容转化率 | 内容分发后的引用转化 | ≥ 30% |
|
|
||||||
| 竞品相对优势 | 品牌引用率 vs 竞品引用率 | 保持领先 |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 生命周期迭代闭环
|
|
||||||
|
|
||||||
```
|
|
||||||
┌─────────────────────────────────┐
|
|
||||||
│ Stage 5 监测优化 │
|
|
||||||
│ - 追踪效果、生成报告、策略迭代 │
|
|
||||||
└───────────────┬─────────────────┘
|
|
||||||
│
|
|
||||||
┌───────────────▼─────────────────┐
|
|
||||||
│ Stage 1 诊断分析 │
|
|
||||||
│ - 查询执行、引用检测、竞品分析 │
|
|
||||||
└───────────────┬─────────────────┘
|
|
||||||
│
|
|
||||||
┌───────────────▼─────────────────┐
|
|
||||||
│ Stage 2 策略制定 │
|
|
||||||
│ - 策略生成、规则制定、计划排期 │
|
|
||||||
└───────────────┬─────────────────┘
|
|
||||||
│
|
|
||||||
┌───────────────▼─────────────────┐
|
|
||||||
│ Stage 3 内容生产 │
|
|
||||||
│ - 内容生成、质量检查、素材准备 │
|
|
||||||
└───────────────┬─────────────────┘
|
|
||||||
│
|
|
||||||
┌───────────────▼─────────────────┐
|
|
||||||
│ Stage 4 分发执行 │
|
|
||||||
│ - 渠道管理、内容发布、效果追踪 │
|
|
||||||
└─────────────────────────────────┘
|
|
||||||
```
|
|
||||||
|
|
||||||
GEO 业务生命周期是一个持续迭代的闭环系统。每次完成 Stage 5 后,诊断分析(Stage 1)将在新的基准上重新执行,形成数据驱动的持续优化循环。
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
*本文档定义了 GEO 平台的完整业务生命周期,各阶段的详细功能设计请参考功能清单文档。*
|
|
||||||
|
|
@ -1,337 +0,0 @@
|
||||||
# GEO 平台 - 完整功能清单
|
|
||||||
|
|
||||||
## 概述
|
|
||||||
|
|
||||||
本文档按 GEO 业务生命周期的 5 个阶段(Stage 1-5)+ 通用模块,列出 GEO 平台的所有功能项。每个功能项包含功能名称、功能描述、适用模式和优先级。
|
|
||||||
|
|
||||||
**优先级说明**:
|
|
||||||
- **P0(Critical)**:MVP 核心功能,必须实现
|
|
||||||
- **P1(High)**:重要功能,尽快实现
|
|
||||||
- **P2(Medium)**:增强功能,后续迭代
|
|
||||||
- **P3(Low)**:优化功能,长期规划
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Stage 1:诊断分析
|
|
||||||
|
|
||||||
### 1.1 查询管理
|
|
||||||
|
|
||||||
| ID | 功能名称 | 功能描述 | 自主订阅 | 代理运营 | 优先级 |
|
|
||||||
|----|----------|----------|----------|----------|--------|
|
|
||||||
| S1-F01 | 查询模板管理 | 创建、编辑、删除品牌查询模板,支持变量替换 | 有 | 有 | P0 |
|
|
||||||
| S1-F02 | 批量查询创建 | 基于模板批量生成查询,支持 Excel/CSV 导入 | 有 | 有 | P0 |
|
|
||||||
| S1-F03 | 多平台查询执行 | 向多个 AI 平台(ChatGPT/Kimi/Perplexity)并行发送查询 | 有 | 有 | P0 |
|
|
||||||
| S1-F04 | 查询状态跟踪 | 实时跟踪查询执行状态(排队中/执行中/已完成/失败) | 有 | 有 | P0 |
|
|
||||||
| S1-F05 | 查询结果查看 | 查看各平台返回的原始响应内容 | 有 | 有 | P0 |
|
|
||||||
| S1-F06 | 查询历史管理 | 查看历史查询记录,支持筛选、搜索、分页 | 有 | 有 | P0 |
|
|
||||||
| S1-F07 | 查询重试机制 | 对失败的查询自动/手动重试 | 有 | 有 | P0 |
|
|
||||||
| S1-F08 | 查询调度配置 | 配置查询的执行时间和频率(即时/定时/周期) | 有 | 有 | P1 |
|
|
||||||
| S1-F09 | 查询结果导出 | 导出查询结果为 JSON/Excel/PDF 格式 | 有 | 有 | P1 |
|
|
||||||
| S1-F10 | 查询队列管理 | 查看和管理查询执行队列,支持优先级调整 | 无 | 有 | P2 |
|
|
||||||
|
|
||||||
### 1.2 引用检测
|
|
||||||
|
|
||||||
| ID | 功能名称 | 功能描述 | 自主订阅 | 代理运营 | 优先级 |
|
|
||||||
|----|----------|----------|----------|----------|--------|
|
|
||||||
| S1-F11 | 引用自动识别 | 自动识别 AI 响应中是否包含品牌引用 | 有 | 有 | P0 |
|
|
||||||
| S1-F12 | 引用内容提取 | 提取引用片段的具体内容和上下文 | 有 | 有 | P0 |
|
|
||||||
| S1-F13 | 引用位置标记 | 标记引用在响应中的具体位置 | 有 | 有 | P0 |
|
|
||||||
| S1-F14 | 引用置信度评估 | 评估引用检测结果的置信度(高/中/低) | 有 | 有 | P0 |
|
|
||||||
| S1-F15 | 引用类型分类 | 分类引用类型(直接引用/间接引用/未引用/竞品引用) | 有 | 有 | P0 |
|
|
||||||
| S1-F16 | 引用记录管理 | 查看和管理所有引用检测记录 | 有 | 有 | P0 |
|
|
||||||
| S1-F17 | 引用统计分析 | 按平台、时间、类型等维度统计引用数据 | 有 | 有 | P1 |
|
|
||||||
| S1-F18 | 引用趋势分析 | 分析引用率随时间的变化趋势 | 有 | 有 | P1 |
|
|
||||||
| S1-F19 | 引用对比分析 | 对比不同查询或不同时间点的引用差异 | 有 | 有 | P1 |
|
|
||||||
| S1-F20 | 引用详情查看 | 查看单次引用的完整上下文和关联查询 | 有 | 有 | P0 |
|
|
||||||
|
|
||||||
### 1.3 竞品分析
|
|
||||||
|
|
||||||
| ID | 功能名称 | 功能描述 | 自主订阅 | 代理运营 | 优先级 |
|
|
||||||
|----|----------|----------|----------|----------|--------|
|
|
||||||
| S1-F21 | 竞品列表管理 | 管理竞品品牌列表,支持增删改查 | 有 | 有 | P0 |
|
|
||||||
| S1-F22 | 竞品查询执行 | 对竞品执行相同的查询,获取引用数据 | 有 | 有 | P0 |
|
|
||||||
| S1-F23 | 竞品引用对比 | 对比品牌和竞品在相同查询下的引用表现 | 有 | 有 | P0 |
|
|
||||||
| S1-F24 | 竞品引用排名 | 生成品牌在竞品中的引用排名 | 有 | 有 | P1 |
|
|
||||||
| S1-F25 | 竞品 SWOT 分析 | 基于引用数据生成竞品的 SWOT 分析 | 有 | 有 | P1 |
|
|
||||||
| S1-F26 | 竞品动态监测 | 监测竞品引用表现的变化动态 | 有 | 有 | P2 |
|
|
||||||
| S1-F27 | 竞品策略推测 | 基于竞品引用模式推测其 GEO 策略 | 有 | 有 | P3 |
|
|
||||||
|
|
||||||
### 1.4 诊断报告
|
|
||||||
|
|
||||||
| ID | 功能名称 | 功能描述 | 自主订阅 | 代理运营 | 优先级 |
|
|
||||||
|----|----------|----------|----------|----------|--------|
|
|
||||||
| S1-F28 | 诊断报告生成 | 基于查询和引用数据自动生成诊断报告 | 有 | 有 | P0 |
|
|
||||||
| S1-F29 | 报告模板管理 | 管理诊断报告的模板和样式 | 有 | 有 | P1 |
|
|
||||||
| S1-F30 | 报告多格式导出 | 支持 PDF/HTML/Excel 格式导出 | 有 | 有 | P0 |
|
|
||||||
| S1-F31 | 报告在线查看 | 在线查看诊断报告的渲染效果 | 有 | 有 | P0 |
|
|
||||||
| S1-F32 | 报告分享功能 | 生成分享链接或发送邮件分享报告 | 有 | 有 | P2 |
|
|
||||||
| S1-F33 | 白标报告(代理) | 生成去除 GEO 品牌标识的白标报告 | 无 | 有 | P1 |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Stage 2:策略制定
|
|
||||||
|
|
||||||
### 2.1 策略生成
|
|
||||||
|
|
||||||
| ID | 功能名称 | 功能描述 | 自主订阅 | 代理运营 | 优先级 |
|
|
||||||
|----|----------|----------|----------|----------|--------|
|
|
||||||
| S2-F01 | 策略自动推荐 | 基于诊断结果自动推荐 GEO 优化策略 | 有 | 有 | P0 |
|
|
||||||
| S2-F02 | 策略编辑确认 | 编辑和确认 AI 生成的优化策略 | 有 | 有 | P0 |
|
|
||||||
| S2-F03 | 策略版本管理 | 管理策略的历史版本,支持对比和回滚 | 有 | 有 | P1 |
|
|
||||||
| S2-F04 | 策略库管理 | 管理策略模板库,支持复用和参考 | 有 | 有 | P1 |
|
|
||||||
| S2-F05 | 策略效果预测 | 基于历史数据预测策略预期效果 | 有 | 有 | P2 |
|
|
||||||
|
|
||||||
### 2.2 规则管理
|
|
||||||
|
|
||||||
| ID | 功能名称 | 功能描述 | 自主订阅 | 代理运营 | 优先级 |
|
|
||||||
|----|----------|----------|----------|----------|--------|
|
|
||||||
| S2-F06 | 规则库创建 | 创建内容生成的规则约束(关键词、格式、风格等) | 有 | 有 | P0 |
|
|
||||||
| S2-F07 | 规则库编辑 | 编辑和修改已有规则库配置 | 有 | 有 | P0 |
|
|
||||||
| S2-F08 | 规则库版本 | 管理规则库的版本历史 | 有 | 有 | P1 |
|
|
||||||
| S2-F09 | 规则库启用/禁用 | 控制规则库的生效状态 | 有 | 有 | P0 |
|
|
||||||
| S2-F10 | 规则库导入/导出 | 支持规则库配置的导入和导出 | 有 | 有 | P2 |
|
|
||||||
|
|
||||||
### 2.3 计划排期
|
|
||||||
|
|
||||||
| ID | 功能名称 | 功能描述 | 自主订阅 | 代理运营 | 优先级 |
|
|
||||||
|----|----------|----------|----------|----------|--------|
|
|
||||||
| S2-F11 | 内容主题规划 | 基于策略生成内容生产的主题列表 | 有 | 有 | P0 |
|
|
||||||
| S2-F12 | 排期日历视图 | 以日历形式展示内容发布计划 | 有 | 有 | P0 |
|
|
||||||
| S2-F13 | 排期调整 | 拖拽调整内容发布的时间和顺序 | 有 | 有 | P1 |
|
|
||||||
| S2-F14 | 排期冲突检测 | 自动检测排期中的时间和资源冲突 | 有 | 有 | P1 |
|
|
||||||
| S2-F15 | 计划与实际对比 | 对比计划排期和实际执行情况 | 有 | 有 | P2 |
|
|
||||||
|
|
||||||
### 2.4 目标管理
|
|
||||||
|
|
||||||
| ID | 功能名称 | 功能描述 | 自主订阅 | 代理运营 | 优先级 |
|
|
||||||
|----|----------|----------|----------|----------|--------|
|
|
||||||
| S2-F16 | KPI 设定 | 设定可量化的优化目标(引用率、排名等) | 有 | 有 | P0 |
|
|
||||||
| S2-F17 | 目标追踪 | 实时追踪目标达成进度 | 有 | 有 | P1 |
|
|
||||||
| S2-F18 | 偏差预警 | 当实际表现偏离目标时触发预警 | 有 | 有 | P1 |
|
|
||||||
| S2-F19 | 目标调整 | 根据实际情况调整目标值 | 有 | 有 | P2 |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Stage 3:内容生产
|
|
||||||
|
|
||||||
### 3.1 内容生成
|
|
||||||
|
|
||||||
| ID | 功能名称 | 功能描述 | 自主订阅 | 代理运营 | 优先级 |
|
|
||||||
|----|----------|----------|----------|----------|--------|
|
|
||||||
| S3-F01 | 内容自动生成 | 基于主题和规则自动生成 GEO 优化内容 | 有 | 有 | P0 |
|
|
||||||
| S3-F02 | 生成参数配置 | 配置内容生成的参数(长度、风格、语气等) | 有 | 有 | P0 |
|
|
||||||
| S3-F03 | 多版本生成 | 同一主题生成多个版本供选择 | 有 | 有 | P1 |
|
|
||||||
| S3-F04 | 内容编辑 | 对生成的内容进行在线编辑 | 有 | 有 | P0 |
|
|
||||||
| S3-F05 | 内容预览 | 预览内容的最终呈现效果 | 有 | 有 | P0 |
|
|
||||||
| S3-F06 | 内容保存草稿 | 将内容保存为草稿状态 | 有 | 有 | P0 |
|
|
||||||
| S3-F07 | 内容类型选择 | 选择生成内容的类型(文章/问答/指南等) | 有 | 有 | P0 |
|
|
||||||
| S3-F08 | 批量内容生成 | 批量生成多个主题的内容 | 有 | 有 | P1 |
|
|
||||||
|
|
||||||
### 3.2 质量检查
|
|
||||||
|
|
||||||
| ID | 功能名称 | 功能描述 | 自主订阅 | 代理运营 | 优先级 |
|
|
||||||
|----|----------|----------|----------|----------|--------|
|
|
||||||
| S3-F09 | 规则合规检查 | 检查内容是否符合规则库要求 | 有 | 有 | P0 |
|
|
||||||
| S3-F10 | 质量评分 | 对内容进行质量评分(可读性、SEO、GEO 等维度) | 有 | 有 | P0 |
|
|
||||||
| S3-F11 | 问题高亮 | 高亮显示内容中的问题和改进建议 | 有 | 有 | P0 |
|
|
||||||
| S3-F12 | 一键修复 | 对检测到的问题提供一键修复建议 | 有 | 有 | P1 |
|
|
||||||
| S3-F13 | 人工审核标记 | 标记需要人工审核的内容 | 有 | 有 | P0 |
|
|
||||||
| S3-F14 | 审核工作流 | 定义内容审核的审批流程 | 无 | 有 | P1 |
|
|
||||||
|
|
||||||
### 3.3 素材管理
|
|
||||||
|
|
||||||
| ID | 功能名称 | 功能描述 | 自主订阅 | 代理运营 | 优先级 |
|
|
||||||
|----|----------|----------|----------|----------|--------|
|
|
||||||
| S3-F15 | 素材上传 | 上传图片、视频、文档等素材 | 有 | 有 | P0 |
|
|
||||||
| S3-F16 | 素材库管理 | 管理素材的分类、标签、搜索 | 有 | 有 | P0 |
|
|
||||||
| S3-F17 | 智能配图 | 基于内容主题推荐或生成配图 | 有 | 有 | P2 |
|
|
||||||
| S3-F18 | 素材版权检查 | 检查素材的版权和使用权 | 有 | 有 | P3 |
|
|
||||||
|
|
||||||
### 3.4 内容版本
|
|
||||||
|
|
||||||
| ID | 功能名称 | 功能描述 | 自主订阅 | 代理运营 | 优先级 |
|
|
||||||
|----|----------|----------|----------|----------|--------|
|
|
||||||
| S3-F19 | 版本历史 | 查看内容的所有历史版本 | 有 | 有 | P1 |
|
|
||||||
| S3-F20 | 版本对比 | 对比不同版本的内容差异 | 有 | 有 | P1 |
|
|
||||||
| S3-F21 | 版本回滚 | 恢复到内容的某个历史版本 | 有 | 有 | P1 |
|
|
||||||
| S3-F22 | 发布控制 | 控制内容的发布状态和可见性 | 有 | 有 | P0 |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Stage 4:分发执行
|
|
||||||
|
|
||||||
### 4.1 渠道管理
|
|
||||||
|
|
||||||
| ID | 功能名称 | 功能描述 | 自主订阅 | 代理运营 | 优先级 |
|
|
||||||
|----|----------|----------|----------|----------|--------|
|
|
||||||
| S4-F01 | 渠道配置 | 配置内容分发的目标渠道信息 | 有 | 有 | P0 |
|
|
||||||
| S4-F02 | 渠道状态监控 | 监控各渠道的连接和可用状态 | 有 | 有 | P0 |
|
|
||||||
| S4-F03 | 渠道分组管理 | 对渠道进行分组管理 | 有 | 有 | P1 |
|
|
||||||
| S4-F04 | 渠道授权管理 | 管理渠道 API 授权和 Token | 有 | 有 | P0 |
|
|
||||||
| S4-F05 | 渠道接入扩展 | 支持添加新的分发渠道 | 有 | 有 | P2 |
|
|
||||||
|
|
||||||
### 4.2 内容发布
|
|
||||||
|
|
||||||
| ID | 功能名称 | 功能描述 | 自主订阅 | 代理运营 | 优先级 |
|
|
||||||
|----|----------|----------|----------|----------|--------|
|
|
||||||
| S4-F06 | 单内容发布 | 将单条内容发布至指定渠道 | 有 | 有 | P0 |
|
|
||||||
| S4-F07 | 批量发布 | 批量发布多条内容至多个渠道 | 有 | 有 | P1 |
|
|
||||||
| S4-F08 | 定时发布 | 设定未来的发布时间自动执行 | 有 | 有 | P1 |
|
|
||||||
| S4-F09 | 发布预览 | 预览内容在各渠道的最终呈现 | 有 | 有 | P0 |
|
|
||||||
| S4-F10 | 发布撤回 | 撤回已发布的内容 | 有 | 有 | P1 |
|
|
||||||
| S4-F11 | 发布状态追踪 | 追踪内容发布的实时状态 | 有 | 有 | P0 |
|
|
||||||
|
|
||||||
### 4.3 效果追踪
|
|
||||||
|
|
||||||
| ID | 功能名称 | 功能描述 | 自主订阅 | 代理运营 | 优先级 |
|
|
||||||
|----|----------|----------|----------|----------|--------|
|
|
||||||
| S4-F12 | UTM 参数生成 | 为分发内容生成 UTM 追踪参数 | 有 | 有 | P1 |
|
|
||||||
| S4-F13 | 点击数据收集 | 收集内容的点击和访问数据 | 有 | 有 | P1 |
|
|
||||||
| S4-F14 | 回源分析 | 分析内容分发后的回源和转化 | 有 | 有 | P2 |
|
|
||||||
| S4-F15 | 渠道效果对比 | 对比不同渠道的分发效果 | 有 | 有 | P1 |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Stage 5:监测优化
|
|
||||||
|
|
||||||
### 5.1 性能监测
|
|
||||||
|
|
||||||
| ID | 功能名称 | 功能描述 | 自主订阅 | 代理运营 | 优先级 |
|
|
||||||
|----|----------|----------|----------|----------|--------|
|
|
||||||
| S5-F01 | KPI 仪表盘 | 展示核心 KPI 的实时数据仪表盘 | 有 | 有 | P0 |
|
|
||||||
| S5-F02 | 引用率趋势 | 展示品牌引用率的历史趋势图表 | 有 | 有 | P0 |
|
|
||||||
| S5-F03 | 平台表现对比 | 对比品牌在各 AI 平台的表现 | 有 | 有 | P0 |
|
|
||||||
| S5-F04 | 排名变化追踪 | 追踪品牌在引用中的排名变化 | 有 | 有 | P1 |
|
|
||||||
| S5-F05 | 竞品动态对比 | 实时对比品牌和竞品的最新表现 | 有 | 有 | P1 |
|
|
||||||
| S5-F06 | 异常检测 | 自动检测数据中的异常波动 | 有 | 有 | P1 |
|
|
||||||
| S5-F07 | 告警通知 | 数据异常或目标达成时发送通知 | 有 | 有 | P1 |
|
|
||||||
| S5-F08 | 数据下钻 | 从汇总数据下钻到详细记录 | 有 | 有 | P1 |
|
|
||||||
|
|
||||||
### 5.2 报告生成
|
|
||||||
|
|
||||||
| ID | 功能名称 | 功能描述 | 自主订阅 | 代理运营 | 优先级 |
|
|
||||||
|----|----------|----------|----------|----------|--------|
|
|
||||||
| S5-F09 | 日报生成 | 自动生成每日效果摘要报告 | 有 | 有 | P1 |
|
|
||||||
| S5-F10 | 周报生成 | 自动生成周度综合分析报告 | 有 | 有 | P0 |
|
|
||||||
| S5-F11 | 月报生成 | 自动生成月度深度分析报告 | 有 | 有 | P0 |
|
|
||||||
| S5-F12 | 季报生成 | 自动生成季度战略评估报告 | 有 | 有 | P1 |
|
|
||||||
| S5-F13 | 自定义报告 | 自定义报告的时间范围和数据维度 | 有 | 有 | P1 |
|
|
||||||
| S5-F14 | 报告自动发送 | 定时自动发送报告至指定邮箱 | 有 | 有 | P2 |
|
|
||||||
| S5-F15 | 白标报告 | 生成无 GEO 标识的品牌化报告 | 无 | 有 | P1 |
|
|
||||||
|
|
||||||
### 5.3 策略迭代
|
|
||||||
|
|
||||||
| ID | 功能名称 | 功能描述 | 自主订阅 | 代理运营 | 优先级 |
|
|
||||||
|----|----------|----------|----------|----------|--------|
|
|
||||||
| S5-F16 | 效果归因分析 | 分析优化动作与效果变化的关联 | 有 | 有 | P2 |
|
|
||||||
| S5-F17 | 策略调整建议 | 基于效果数据自动生成策略调整建议 | 有 | 有 | P1 |
|
|
||||||
| S5-F18 | A/B 测试管理 | 管理不同策略的 A/B 测试 | 有 | 有 | P2 |
|
|
||||||
| S5-F19 | 优化方案记录 | 记录每次优化的方案和结果 | 有 | 有 | P1 |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 通用模块
|
|
||||||
|
|
||||||
### 6.1 用户与认证
|
|
||||||
|
|
||||||
| ID | 功能名称 | 功能描述 | 自主订阅 | 代理运营 | 优先级 |
|
|
||||||
|----|----------|----------|----------|----------|--------|
|
|
||||||
| G-F01 | 用户注册 | 邮箱/手机号注册账号 | 有 | 有 | P0 |
|
|
||||||
| G-F02 | 用户登录 | 邮箱密码/验证码登录 | 有 | 有 | P0 |
|
|
||||||
| G-F03 | 密码找回 | 通过邮箱/手机找回密码 | 有 | 有 | P0 |
|
|
||||||
| G-F04 | 个人中心 | 管理个人信息、头像、联系方式 | 有 | 有 | P0 |
|
|
||||||
| G-F05 | 修改密码 | 修改登录密码 | 有 | 有 | P0 |
|
|
||||||
| G-F06 | OAuth 登录 | 支持 Google/微信等第三方登录 | 有 | 有 | P2 |
|
|
||||||
| G-F07 | 单点登录(SSO) | 企业级 SSO 集成 | 无 | 有 | P3 |
|
|
||||||
|
|
||||||
### 6.2 团队管理(Enterprise / 代理)
|
|
||||||
|
|
||||||
| ID | 功能名称 | 功能描述 | 自主订阅 | 代理运营 | 优先级 |
|
|
||||||
|----|----------|----------|----------|----------|--------|
|
|
||||||
| G-F08 | 团队成员管理 | 邀请、添加、移除团队成员 | Enterprise | 有 | P1 |
|
|
||||||
| G-F09 | 角色权限分配 | 为团队成员分配角色和权限 | Enterprise | 有 | P1 |
|
|
||||||
| G-F10 | 团队资源管理 | 管理团队共享的查询、内容、报告 | Enterprise | 有 | P2 |
|
|
||||||
| G-F11 | 操作日志 | 查看团队成员的操作记录 | Enterprise | 有 | P2 |
|
|
||||||
|
|
||||||
### 6.3 代理运营(代理模式专用)
|
|
||||||
|
|
||||||
| ID | 功能名称 | 功能描述 | 自主订阅 | 代理运营 | 优先级 |
|
|
||||||
|----|----------|----------|----------|----------|--------|
|
|
||||||
| G-F12 | 客户管理 | 管理代理运营的客户列表 | 无 | 有 | P0 |
|
|
||||||
| G-F13 | 客户项目分配 | 为客户分配专属的查询和优化项目 | 无 | 有 | P0 |
|
|
||||||
| G-F14 | 客户数据隔离 | 确保客户间数据完全隔离 | 无 | 有 | P0 |
|
|
||||||
| G-F15 | 白标配置 | 配置客户品牌标识(Logo、颜色等) | 无 | 有 | P1 |
|
|
||||||
| G-F16 | 客户报告中心 | 为客户生成和发送专属报告 | 无 | 有 | P1 |
|
|
||||||
| G-F17 | 客户账单管理 | 管理客户的账单和结算 | 无 | 有 | P2 |
|
|
||||||
|
|
||||||
### 6.4 订阅与计费
|
|
||||||
|
|
||||||
| ID | 功能名称 | 功能描述 | 自主订阅 | 代理运营 | 优先级 |
|
|
||||||
|----|----------|----------|----------|----------|--------|
|
|
||||||
| G-F18 | 订阅计划查看 | 查看当前订阅计划的功能和限额 | 有 | 有 | P0 |
|
|
||||||
| G-F19 | 订阅升级/降级 | 升级或降级订阅等级 | 有 | 有 | P0 |
|
|
||||||
| G-F20 | 配额使用查询 | 查看当前周期的配额使用情况 | 有 | 有 | P0 |
|
|
||||||
| G-F21 | 支付管理 | 管理支付方式和账单历史 | 有 | 有 | P0 |
|
|
||||||
| G-F22 | 发票管理 | 申请和下载发票 | 有 | 有 | P1 |
|
|
||||||
| G-F23 | 代理计费 | 代理运营模式的计费结算 | 无 | 有 | P2 |
|
|
||||||
|
|
||||||
### 6.5 系统管理(Admin)
|
|
||||||
|
|
||||||
| ID | 功能名称 | 功能描述 | 自主订阅 | 代理运营 | 优先级 |
|
|
||||||
|----|----------|----------|----------|----------|--------|
|
|
||||||
| G-F24 | 用户管理 | 管理全平台用户(查看、禁用、删除) | 无 | 无 | P0 |
|
|
||||||
| G-F25 | 订阅计划管理 | 配置和管理订阅计划 | 无 | 无 | P1 |
|
|
||||||
| G-F26 | 系统配置 | 配置系统参数和全局设置 | 无 | 无 | P1 |
|
|
||||||
| G-F27 | 平台适配器管理 | 管理 AI 平台适配器的接入和配置 | 无 | 无 | P1 |
|
|
||||||
| G-F28 | 系统监控 | 查看系统运行状态和性能指标 | 无 | 无 | P1 |
|
|
||||||
| G-F29 | 日志查看 | 查看系统日志和操作日志 | 无 | 无 | P2 |
|
|
||||||
| G-F30 | 公告管理 | 发布全平台公告和通知 | 无 | 无 | P3 |
|
|
||||||
|
|
||||||
### 6.6 通知与消息
|
|
||||||
|
|
||||||
| ID | 功能名称 | 功能描述 | 自主订阅 | 代理运营 | 优先级 |
|
|
||||||
|----|----------|----------|----------|----------|--------|
|
|
||||||
| G-F31 | 站内消息 | 接收系统通知和任务完成提醒 | 有 | 有 | P0 |
|
|
||||||
| G-F32 | 邮件通知 | 通过邮件接收重要通知 | 有 | 有 | P1 |
|
|
||||||
| G-F33 | 通知设置 | 配置通知的类型和接收方式 | 有 | 有 | P1 |
|
|
||||||
| G-F34 | 消息中心 | 查看历史通知和消息 | 有 | 有 | P0 |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 功能优先级汇总
|
|
||||||
|
|
||||||
### P0(MVP 核心功能)- 必须在第一阶段实现
|
|
||||||
|
|
||||||
**Stage 1**:查询模板管理、批量查询、多平台执行、状态跟踪、结果查看、历史管理、引用识别、内容提取、位置标记、置信度评估、类型分类、记录管理、竞品列表、竞品查询、竞品对比、诊断报告生成、多格式导出、在线查看
|
|
||||||
|
|
||||||
**Stage 2**:策略推荐、策略编辑、规则库创建/编辑、主题规划、日历视图、KPI 设定
|
|
||||||
|
|
||||||
**Stage 3**:内容自动生成、参数配置、内容编辑、内容预览、保存草稿、类型选择、规则检查、质量评分、问题高亮、人工审核、素材上传/管理、发布控制
|
|
||||||
|
|
||||||
**Stage 4**:渠道配置、状态监控、授权管理、单内容发布、发布预览、状态追踪
|
|
||||||
|
|
||||||
**Stage 5**:KPI 仪表盘、引用率趋势、平台对比、周报/月报生成
|
|
||||||
|
|
||||||
**通用**:用户注册/登录/找回密码、个人中心、修改密码、订阅查看、升级降级、配额查询、支付管理、站内消息、消息中心
|
|
||||||
|
|
||||||
### P1(重要功能)- 第二阶段实现
|
|
||||||
|
|
||||||
**Stage 1**:查询调度、结果导出、队列管理、统计分析、趋势分析、对比分析、报告模板、分享功能、白标报告、竞品排名、SWOT 分析
|
|
||||||
|
|
||||||
**Stage 2**:策略版本、策略库、规则版本、启用禁用、排期调整、冲突检测、目标追踪、偏差预警
|
|
||||||
|
|
||||||
**Stage 3**:多版本生成、批量生成、一键修复、审核工作流、版本历史/对比/回滚
|
|
||||||
|
|
||||||
**Stage 4**:渠道分组、批量发布、定时发布、发布撤回、UTM 生成、渠道对比
|
|
||||||
|
|
||||||
**Stage 5**:排名追踪、竞品动态、异常检测、告警通知、数据下钻、日报、季报、自定义报告、策略调整建议、优化方案记录
|
|
||||||
|
|
||||||
**通用**:团队成员管理、角色权限、白标配置、客户报告、发票管理、订阅计划管理、系统配置、适配器管理、邮件通知、通知设置
|
|
||||||
|
|
||||||
### P2/P3(增强/优化功能)- 后续迭代规划
|
|
||||||
|
|
||||||
包括:竞品动态监测、策略推测、效果预测、计划对比、智能配图、版权检查、渠道扩展、回源分析、报告自动发送、效果归因、A/B 测试、团队资源、操作日志、客户账单、代理计费、系统监控、日志查看、公告管理、SSO 等。
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
*本文档为 GEO 平台的完整功能清单,具体实现时可根据迭代计划分阶段落地。*
|
|
||||||
|
|
@ -0,0 +1,14 @@
|
||||||
|
# 项目概览
|
||||||
|
|
||||||
|
本目录包含GEO平台的项目概述、架构设计、技术栈说明和变更日志。
|
||||||
|
|
||||||
|
## 目录内容
|
||||||
|
|
||||||
|
- [README](./README.md) - 项目简介
|
||||||
|
- [系统架构](./architecture.md) - 系统架构设计
|
||||||
|
- [技术栈](./tech-stack.md) - 技术栈说明
|
||||||
|
- [变更日志](./changelog.md) - 版本变更记录
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
*GEO平台 - 让品牌在AI时代被看见。*
|
||||||
|
|
@ -0,0 +1,121 @@
|
||||||
|
# 系统架构
|
||||||
|
|
||||||
|
## 整体架构
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────────────────────────────────────────────────────┐
|
||||||
|
│ 前端 (Next.js) │
|
||||||
|
├─────────────────────────────────────────────────────────────┤
|
||||||
|
│ Dashboard │ 诊断中心 │ 内容管理 │ 知识库 │ 监控面板 │
|
||||||
|
└─────────────────────────────────────────────────────────────┘
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌─────────────────────────────────────────────────────────────┐
|
||||||
|
│ API网关 (FastAPI) │
|
||||||
|
├─────────────────────────────────────────────────────────────┤
|
||||||
|
│ Auth │ Brands │ Diagnosis │ Content │ Knowledge │ Monitoring│
|
||||||
|
└─────────────────────────────────────────────────────────────┘
|
||||||
|
│
|
||||||
|
┌────────────────────┼────────────────────┐
|
||||||
|
▼ ▼ ▼
|
||||||
|
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
|
||||||
|
│ PostgreSQL │ │ Redis │ │ LLM Providers │
|
||||||
|
│ (数据存储) │ │ (缓存/队列) │ │ (AI服务) │
|
||||||
|
└─────────────────┘ └─────────────────┘ └─────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
## 诊断架构
|
||||||
|
|
||||||
|
### SEO诊断
|
||||||
|
|
||||||
|
```
|
||||||
|
网站URL输入 → 爬虫抓取 → 技术分析 → 内容分析 → 外链分析 → 生成SEO诊断报告
|
||||||
|
```
|
||||||
|
|
||||||
|
**诊断维度:**
|
||||||
|
- 技术SEO(索引、爬取、Core Web Vitals)
|
||||||
|
- 页面SEO(Title/Meta、H标签、内链)
|
||||||
|
- 内容质量(E-E-A-T、新鲜度、重复内容)
|
||||||
|
- 外链分析(质量、毒性、锚文本)
|
||||||
|
- 用户体验(移动适配、页面速度)
|
||||||
|
|
||||||
|
### GEO诊断
|
||||||
|
|
||||||
|
```
|
||||||
|
品牌信息输入 → 内容可提取性检测 → 实体清晰度检测 → E-E-A-T信号检测
|
||||||
|
→ Schema标记检测 → 主题权威检测 → AI平台引用检测 → 生成GEO诊断报告
|
||||||
|
```
|
||||||
|
|
||||||
|
**诊断维度:**
|
||||||
|
- 内容可提取性(直接回答块、问答式标题、列表表格)
|
||||||
|
- 实体清晰度(品牌定义、目标受众、差异化价值)
|
||||||
|
- E-E-A-T信号(作者资质、专业认证、数据来源)
|
||||||
|
- Schema标记(Organization、Product、FAQPage等)
|
||||||
|
- 主题权威(内容深度、话题覆盖度、实体信号一致性)
|
||||||
|
- 引用就绪度(引用频率、引用质量、AI声量占比)
|
||||||
|
|
||||||
|
## Agent Framework
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────────────────────────────────────────────────────┐
|
||||||
|
│ Agent Framework │
|
||||||
|
├─────────────────────────────────────────────────────────────┤
|
||||||
|
│ │
|
||||||
|
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
|
||||||
|
│ │ Dispatcher │→│ Registry │→│ Monitor │ │
|
||||||
|
│ └─────────────┘ └─────────────┘ └─────────────┘ │
|
||||||
|
│ │
|
||||||
|
│ ┌─────────────────────────────────────────────────┐ │
|
||||||
|
│ │ Agents │ │
|
||||||
|
│ ├───────────┬───────────┬───────────┬────────────┤ │
|
||||||
|
│ │ Citation │ Content │ DeAI │ GEO │ │
|
||||||
|
│ │ Detector │ Generator │ Agent │ Optimizer │ │
|
||||||
|
│ └───────────┴───────────┴───────────┴────────────┘ │
|
||||||
|
│ │
|
||||||
|
│ ┌─────────────────────────────────────────────────┐ │
|
||||||
|
│ │ Services │ │
|
||||||
|
│ ├───────────┬───────────┬───────────┬────────────┤ │
|
||||||
|
│ │ RuleValid │ SEOOptim │ Sensitive │ HTMLGen │ │
|
||||||
|
│ └───────────┴───────────┴───────────┴────────────┘ │
|
||||||
|
│ │
|
||||||
|
└─────────────────────────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
**Agent职责说明:**
|
||||||
|
|
||||||
|
| Agent | 职责 | 输入 | 输出 |
|
||||||
|
|-------|------|------|------|
|
||||||
|
| CitationDetector | 引用检测 | AI平台响应、品牌名称 | 引用检测结果 |
|
||||||
|
| ContentGenerator | 内容生成 | 主题、规则库、品牌素材 | GEO优化内容 |
|
||||||
|
| DeAIAgent | 去AI化 | AI生成内容 | 自然化内容 |
|
||||||
|
| GEOOptimizer | GEO优化 | 原始内容、关键词策略 | 优化后内容 |
|
||||||
|
|
||||||
|
**注意:** 当前项目中的`SEOOptimizer`实际执行的是GEO优化(内容结构化、实体优化),而非传统SEO优化(技术SEO)。建议在后续版本中明确区分:
|
||||||
|
- **SEOOptimizer** → 技术SEO优化(网站技术层面)
|
||||||
|
- **GEOOptimizer** → 内容实体优化(AI引用层面)
|
||||||
|
|
||||||
|
## 内容生成Pipeline
|
||||||
|
|
||||||
|
```
|
||||||
|
用户输入 → 母题选择 → 内容生成 → 去AI化 → GEO优化 → HTML生成 → 输出
|
||||||
|
```
|
||||||
|
|
||||||
|
**注意:** 原Pipeline中的"SEO优化"实际是GEO优化(内容结构化),不是传统SEO优化。
|
||||||
|
|
||||||
|
## 知识库系统
|
||||||
|
|
||||||
|
```
|
||||||
|
文档上传 → 文本分块 → 向量化 → RAG检索 → LLM增强生成
|
||||||
|
```
|
||||||
|
|
||||||
|
## 监控体系
|
||||||
|
|
||||||
|
```
|
||||||
|
API请求 → Prometheus指标 → Grafana可视化
|
||||||
|
↓
|
||||||
|
健康检查 → 告警通知
|
||||||
|
```
|
||||||
|
|
||||||
|
## 数据库设计
|
||||||
|
|
||||||
|
核心表:users, organizations, brands, competitors, queries, citations, alerts, contents, knowledge_bases, knowledge_entities, knowledge_relations
|
||||||
|
|
@ -0,0 +1,71 @@
|
||||||
|
# 变更日志
|
||||||
|
|
||||||
|
## v2.0.0 (当前版本)
|
||||||
|
|
||||||
|
### 新增功能
|
||||||
|
- Agent Framework架构
|
||||||
|
- 内容生成Pipeline
|
||||||
|
- 知识库RAG系统
|
||||||
|
- 知识图谱
|
||||||
|
- 平台规则中心
|
||||||
|
- 监控模块
|
||||||
|
- GEO母题库
|
||||||
|
- 图片生成服务
|
||||||
|
|
||||||
|
### 核心模块
|
||||||
|
- [x] 诊断分析 (100%)
|
||||||
|
- [x] 竞品分析 (100%)
|
||||||
|
- [x] 内容生产 (100%)
|
||||||
|
- [x] 知识库 (100%)
|
||||||
|
- [x] 监控 (100%)
|
||||||
|
|
||||||
|
### Agent Framework
|
||||||
|
- [x] TaskDispatcher - 任务分发器
|
||||||
|
- [x] AgentRegistry - 注册中心
|
||||||
|
- [x] CitationDetector - 引用检测
|
||||||
|
- [x] ContentGenerator - 内容生成
|
||||||
|
- [x] DeAIAgent - 去AI化
|
||||||
|
- [x] GEOOptimizer - GEO优化
|
||||||
|
- [x] PipelineEngine - Pipeline编排引擎
|
||||||
|
|
||||||
|
### 内容Pipeline
|
||||||
|
- [x] RuleValidator - 规则校验
|
||||||
|
- [x] SensitiveFilter - 敏感词过滤
|
||||||
|
- [x] SEOOptimizer - SEO优化
|
||||||
|
- [x] HTMLGenerator - HTML生成
|
||||||
|
|
||||||
|
### 知识库
|
||||||
|
- [x] RAG服务
|
||||||
|
- [x] 文档解析器 (PDF/DOCX/Markdown/TXT)
|
||||||
|
- [x] 分块策略 (Recursive/Semantic/Fixed)
|
||||||
|
- [x] 嵌入服务
|
||||||
|
- [x] 混合检索器
|
||||||
|
- [x] 增强检索 (重排序/上下文压缩)
|
||||||
|
|
||||||
|
### 知识图谱
|
||||||
|
- [x] 实体抽取
|
||||||
|
- [x] 关系抽取
|
||||||
|
- [x] 图谱构建
|
||||||
|
- [x] 图谱查询
|
||||||
|
|
||||||
|
### 平台规则
|
||||||
|
- [x] 10个平台规则 (知乎/微信/百家号/头条/微博/小红书/B站/简书/掘金/抖音)
|
||||||
|
- [x] AI敏感度配置
|
||||||
|
- [x] 敏感词过滤
|
||||||
|
- [x] HTML规则
|
||||||
|
|
||||||
|
### 测试覆盖
|
||||||
|
- 后端单元测试: ~480个
|
||||||
|
- 前端测试: 75个
|
||||||
|
- E2E测试: 65个
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## v1.0.0 (初始版本)
|
||||||
|
|
||||||
|
### 初始版本
|
||||||
|
- 项目初始化
|
||||||
|
- 基础架构搭建
|
||||||
|
- 基础查询功能
|
||||||
|
- 引用检测
|
||||||
|
- 基础Dashboard
|
||||||
|
|
@ -0,0 +1,64 @@
|
||||||
|
# 技术栈
|
||||||
|
|
||||||
|
## 前端技术栈
|
||||||
|
|
||||||
|
| 组件 | 技术 | 版本 |
|
||||||
|
|------|------|------|
|
||||||
|
| 框架 | Next.js | 14+ |
|
||||||
|
| UI库 | React | 18+ |
|
||||||
|
| 语言 | TypeScript | 5.x |
|
||||||
|
| 样式 | Tailwind CSS | 4.x |
|
||||||
|
| 组件库 | shadcn/ui | - |
|
||||||
|
| 图表 | Recharts | - |
|
||||||
|
| 状态管理 | SWR | - |
|
||||||
|
| 认证 | NextAuth.js | v5 |
|
||||||
|
|
||||||
|
## 后端技术栈
|
||||||
|
|
||||||
|
| 组件 | 技术 | 版本 |
|
||||||
|
|------|------|------|
|
||||||
|
| 框架 | FastAPI | 0.109+ |
|
||||||
|
| 语言 | Python | 3.12+ |
|
||||||
|
| ORM | SQLAlchemy | 2.0+ |
|
||||||
|
| 数据库 | PostgreSQL | 15+ |
|
||||||
|
| 缓存 | Redis | 7+ |
|
||||||
|
| 任务队列 | Celery | 5+ |
|
||||||
|
| 认证 | JWT + OAuth2 | - |
|
||||||
|
|
||||||
|
## AI Agent框架
|
||||||
|
|
||||||
|
| 组件 | 技术 |
|
||||||
|
|------|------|
|
||||||
|
| Agent基础 | 自研模块化框架 |
|
||||||
|
| 消息队列 | Redis Pub/Sub |
|
||||||
|
| 注册中心 | Redis Hash |
|
||||||
|
| 任务分发 | Dispatcher + Registry |
|
||||||
|
|
||||||
|
## 基础设施
|
||||||
|
|
||||||
|
| 组件 | 技术 |
|
||||||
|
|------|------|
|
||||||
|
| 容器化 | Docker + Docker Compose |
|
||||||
|
| 反向代理 | Nginx |
|
||||||
|
| 监控 | Prometheus + Grafana |
|
||||||
|
|
||||||
|
## 项目目录结构
|
||||||
|
|
||||||
|
```
|
||||||
|
geo/
|
||||||
|
├── backend/ # FastAPI 后端
|
||||||
|
│ ├── app/
|
||||||
|
│ │ ├── api/ # API路由
|
||||||
|
│ │ ├── agent_framework/ # Agent框架
|
||||||
|
│ │ ├── models/ # 数据模型
|
||||||
|
│ │ ├── schemas/ # Pydantic模型
|
||||||
|
│ │ ├── services/ # 业务逻辑
|
||||||
|
│ │ ├── workers/ # 异步任务
|
||||||
|
│ │ └── monitoring/ # 监控模块
|
||||||
|
│ └── requirements.txt
|
||||||
|
├── frontend/ # Next.js 前端
|
||||||
|
│ ├── app/ # 页面
|
||||||
|
│ ├── components/ # 组件
|
||||||
|
│ └── lib/ # 工具函数
|
||||||
|
└── docs/ # 文档
|
||||||
|
```
|
||||||
|
|
@ -1,502 +0,0 @@
|
||||||
# GEO 平台 - AI Agent 框架设计
|
|
||||||
|
|
||||||
## 概述
|
|
||||||
|
|
||||||
GEO 平台的 AI Agent 框架是系统的核心智能层,负责执行各种需要 AI 能力的业务任务。框架采用模块化、可插拔的设计,支持 Agent 的动态注册、任务分发和状态管理。
|
|
||||||
|
|
||||||
本文档定义 Agent 框架的整体架构、核心组件和通信机制。
|
|
||||||
|
|
||||||
## Agent 列表
|
|
||||||
|
|
||||||
### CitationDetector(引用检测 Agent)
|
|
||||||
|
|
||||||
| 属性 | 说明 |
|
|
||||||
|------|------|
|
|
||||||
| **职责** | 解析 AI 平台响应,识别其中是否包含品牌引用 |
|
|
||||||
| **输入** | AI 平台原始响应文本、品牌名称列表、竞品名称列表 |
|
|
||||||
| **输出** | 引用检测结果(引用/未引用/竞品引用)、引用片段、置信度评分 |
|
|
||||||
| **核心能力** | 自然语言理解、实体识别、引用关系判定 |
|
|
||||||
| **触发方式** | 查询任务完成后自动触发 |
|
|
||||||
| **优先级** | P0 |
|
|
||||||
|
|
||||||
**处理流程**:
|
|
||||||
```
|
|
||||||
接收响应文本 ──▶ 文本预处理 ──▶ 品牌实体识别 ──▶ 引用关系判定
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
竞品引用检测 ──▶ 置信度评估
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
结果格式化 ──▶ 上报结果
|
|
||||||
```
|
|
||||||
|
|
||||||
**引用类型判定**:
|
|
||||||
| 类型 | 判定标准 | 置信度权重 |
|
|
||||||
|------|----------|------------|
|
|
||||||
| `direct_quote` | 品牌名被直接提及并作为信息来源 | 高 |
|
|
||||||
| `indirect_reference` | 品牌信息被提及但未明确标注来源 | 中 |
|
|
||||||
| `no_reference` | 品牌完全未被提及 | 确定 |
|
|
||||||
| `competitor_reference` | 竞品被引用而品牌未被引用 | 高 |
|
|
||||||
|
|
||||||
### ContentGenerator(内容生成 Agent)
|
|
||||||
|
|
||||||
| 属性 | 说明 |
|
|
||||||
|------|------|
|
|
||||||
| **职责** | 根据策略和规则生成 GEO 优化的内容资产 |
|
|
||||||
| **输入** | 内容主题、规则库配置、品牌素材、参考数据 |
|
|
||||||
| **输出** | GEO 优化内容(文章/问答/指南等)、内容元数据 |
|
|
||||||
| **核心能力** | 创意写作、SEO/GEO 优化、多风格适配 |
|
|
||||||
| **触发方式** | 用户手动触发或按排期自动触发 |
|
|
||||||
| **优先级** | P0 |
|
|
||||||
|
|
||||||
**处理流程**:
|
|
||||||
```
|
|
||||||
接收生成任务 ──▶ 需求分析 ──▶ 素材收集 ──▶ 大纲生成
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
内容撰写 ──▶ GEO 优化处理
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
自检校验 ──▶ 输出结果
|
|
||||||
```
|
|
||||||
|
|
||||||
**内容类型支持**:
|
|
||||||
| 类型 | 说明 | 适用场景 |
|
|
||||||
|------|------|----------|
|
|
||||||
| `article` | 长篇文章(800-2000 字) | 深度行业内容 |
|
|
||||||
| `qa` | 问答对 | 常见品牌问题 |
|
|
||||||
| `guide` | 操作指南 | 产品使用场景 |
|
|
||||||
| `comparison` | 对比分析 | 竞品对抗 |
|
|
||||||
| `data_report` | 数据报告 | 行业研究 |
|
|
||||||
|
|
||||||
### RuleChecker(规则检查 Agent)
|
|
||||||
|
|
||||||
| 属性 | 说明 |
|
|
||||||
|------|------|
|
|
||||||
| **职责** | 检查内容是否符合规则库中的约束条件 |
|
|
||||||
| **输入** | 待检查内容、规则库配置 |
|
|
||||||
| **输出** | 检查结果(通过/不通过)、问题列表、改进建议 |
|
|
||||||
| **核心能力** | 规则引擎、内容质量评估、语义分析 |
|
|
||||||
| **触发方式** | ContentGenerator 生成后自动触发,或手动触发 |
|
|
||||||
| **优先级** | P0 |
|
|
||||||
|
|
||||||
**检查维度**:
|
|
||||||
| 维度 | 说明 | 示例规则 |
|
|
||||||
|------|------|----------|
|
|
||||||
| `keyword_coverage` | 关键词覆盖度 | 必须包含核心关键词 |
|
|
||||||
| `readability` | 可读性评分 | Flesch 阅读 ease 分数 |
|
|
||||||
| `tone_consistency` | 语气一致性 | 保持专业、客观语气 |
|
|
||||||
| `length_compliance` | 长度合规 | 符合目标平台偏好长度 |
|
|
||||||
| `fact_accuracy` | 事实准确性 | 品牌信息准确无误 |
|
|
||||||
| `geo_optimization` | GEO 优化度 | 符合 GEO 最佳实践 |
|
|
||||||
|
|
||||||
### CompetitorAnalyzer(竞品分析 Agent)
|
|
||||||
|
|
||||||
| 属性 | 说明 |
|
|
||||||
|------|------|
|
|
||||||
| **职责** | 分析竞品的 GEO 表现和策略 |
|
|
||||||
| **输入** | 竞品名称、查询结果、引用数据 |
|
|
||||||
| **输出** | 竞品分析报告、SWOT 分析、策略建议 |
|
|
||||||
| **核心能力** | 竞品数据对比、模式识别、策略推断 |
|
|
||||||
| **触发方式** | 诊断分析流程中自动触发 |
|
|
||||||
| **优先级** | P1 |
|
|
||||||
|
|
||||||
**分析维度**:
|
|
||||||
| 维度 | 说明 |
|
|
||||||
|------|------|
|
|
||||||
| `citation_share` | 竞品引用份额占比 |
|
|
||||||
| `platform_presence` | 各平台出现频率 |
|
|
||||||
| `content_pattern` | 竞品被引用的内容模式 |
|
|
||||||
| `strength_weakness` | 竞品的优劣势分析 |
|
|
||||||
| `strategy_inference` | 推测竞品的 GEO 策略 |
|
|
||||||
|
|
||||||
### PerformanceTracker(性能追踪 Agent)
|
|
||||||
|
|
||||||
| 属性 | 说明 |
|
|
||||||
|------|------|
|
|
||||||
| **职责** | 持续追踪品牌 GEO 表现,计算 KPI,检测异常 |
|
|
||||||
| **输入** | 历史引用数据、当前查询结果、KPI 配置 |
|
|
||||||
| **输出** | KPI 计算结果、趋势分析、异常告警、优化建议 |
|
|
||||||
| **核心能力** | 时间序列分析、趋势预测、异常检测 |
|
|
||||||
| **触发方式** | 按设定周期自动执行 |
|
|
||||||
| **优先级** | P1 |
|
|
||||||
|
|
||||||
**追踪指标**:
|
|
||||||
| 指标 | 计算方式 | 更新频率 |
|
|
||||||
|------|----------|----------|
|
|
||||||
| `citation_rate` | 被引用查询数 / 总查询数 | 每日 |
|
|
||||||
| `avg_confidence` | 引用记录置信度平均值 | 每次查询 |
|
|
||||||
| `platform_coverage` | 已覆盖平台数 / 目标平台数 | 每周 |
|
|
||||||
| `ranking_position` | 品牌在引用中的平均排名 | 每周 |
|
|
||||||
| `competitor_gap` | 品牌引用率 - 竞品平均引用率 | 每周 |
|
|
||||||
|
|
||||||
## 通信协议
|
|
||||||
|
|
||||||
### Agent 间通信模型
|
|
||||||
|
|
||||||
GEO 平台 Agent 之间采用 **基于 Redis 消息队列的异步通信协议**。
|
|
||||||
|
|
||||||
```
|
|
||||||
┌─────────────┐ ┌─────────────┐ ┌─────────────┐
|
|
||||||
│ Agent A │◀────▶│ Redis │◀────▶│ Agent B │
|
|
||||||
│ (Producer) │ │ Queue │ │ (Consumer) │
|
|
||||||
└─────────────┘ └──────┬──────┘ └─────────────┘
|
|
||||||
│
|
|
||||||
┌──────▼──────┐
|
|
||||||
│ Registry │
|
|
||||||
│ (Hash Map) │
|
|
||||||
└─────────────┘
|
|
||||||
```
|
|
||||||
|
|
||||||
### 消息格式规范
|
|
||||||
|
|
||||||
```python
|
|
||||||
class AgentMessage:
|
|
||||||
"""Agent 间通信消息标准格式"""
|
|
||||||
|
|
||||||
message_id: str # UUID,消息唯一标识
|
|
||||||
timestamp: datetime # ISO 8601 格式时间戳
|
|
||||||
sender: str # 发送方 Agent 名称
|
|
||||||
recipient: str # 接收方 Agent 名称,或 "broadcast"
|
|
||||||
message_type: str # 消息类型
|
|
||||||
correlation_id: str # 关联 ID,用于追踪请求-响应链路
|
|
||||||
reply_to: str # 回复目标 Agent
|
|
||||||
payload: dict # 消息负载数据
|
|
||||||
ttl: int # 消息存活时间(秒),默认 300
|
|
||||||
```
|
|
||||||
|
|
||||||
**消息类型(message_type)**:
|
|
||||||
| 类型 | 说明 | 使用场景 |
|
|
||||||
|------|------|----------|
|
|
||||||
| `task_request` | 任务请求 | Agent A 请求 Agent B 执行任务 |
|
|
||||||
| `task_result` | 任务结果 | Agent B 返回任务执行结果 |
|
|
||||||
| `status_update` | 状态更新 | Agent 上报自身状态变化 |
|
|
||||||
| `heartbeat` | 心跳 | Agent 定期发送健康信号 |
|
|
||||||
| `config_update` | 配置更新 | 通知 Agent 配置变更 |
|
|
||||||
| `error_report` | 错误报告 | Agent 报告执行错误 |
|
|
||||||
|
|
||||||
### 消息负载(payload)规范
|
|
||||||
|
|
||||||
```python
|
|
||||||
class TaskPayload:
|
|
||||||
"""任务请求/结果负载"""
|
|
||||||
|
|
||||||
task_id: str # 任务唯一标识
|
|
||||||
task_type: str # 任务类型
|
|
||||||
parameters: dict # 任务参数
|
|
||||||
data: dict # 输入/输出数据
|
|
||||||
priority: int # 优先级 1-10,默认 5
|
|
||||||
deadline: datetime # 截止时间(可选)
|
|
||||||
retry_count: int # 已重试次数
|
|
||||||
max_retries: int # 最大重试次数
|
|
||||||
```
|
|
||||||
|
|
||||||
### 通信模式
|
|
||||||
|
|
||||||
#### 1. 点对点模式(Point-to-Point)
|
|
||||||
|
|
||||||
```
|
|
||||||
Agent A ──task_request──▶ Queue ──▶ Agent B
|
|
||||||
Agent A ◀──task_result─── Queue ◀── Agent B
|
|
||||||
```
|
|
||||||
|
|
||||||
- 每个 Agent 有专属的输入队列
|
|
||||||
- 任务按轮询方式分配
|
|
||||||
- 适用于明确的任务分配场景
|
|
||||||
|
|
||||||
#### 2. 发布/订阅模式(Pub/Sub)
|
|
||||||
|
|
||||||
```
|
|
||||||
Agent A ──config_update──▶ Topic
|
|
||||||
│
|
|
||||||
┌───────────────┼───────────────┐
|
|
||||||
▼ ▼ ▼
|
|
||||||
Agent B Agent C Agent D
|
|
||||||
```
|
|
||||||
|
|
||||||
- 使用 Redis Pub/Sub 频道
|
|
||||||
- 适用于配置变更、状态广播
|
|
||||||
|
|
||||||
#### 3. 工作流模式(Workflow)
|
|
||||||
|
|
||||||
```
|
|
||||||
┌────────────┐ ┌────────────┐ ┌────────────┐
|
|
||||||
│ QueryTask │───▶│ Citation │───▶│ Report │
|
|
||||||
│ Completed │ │ Detector │ │ Generator │
|
|
||||||
└────────────┘ └────────────┘ └────────────┘
|
|
||||||
```
|
|
||||||
|
|
||||||
- 基于事件驱动的工作流编排
|
|
||||||
- 使用 correlation_id 串联整个工作流
|
|
||||||
- 支持并行和串行执行
|
|
||||||
|
|
||||||
## 注册机制
|
|
||||||
|
|
||||||
### Agent 注册流程
|
|
||||||
|
|
||||||
```
|
|
||||||
Agent 启动
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
读取配置(能力声明、资源需求、版本信息)
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
连接 Redis
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
向 Registry 注册
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
创建专属输入队列
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
启动心跳定时器
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
开始监听任务
|
|
||||||
```
|
|
||||||
|
|
||||||
### Registry 数据结构
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Redis Hash: agent:registry:{agent_name}
|
|
||||||
{
|
|
||||||
"name": "CitationDetector",
|
|
||||||
"version": "1.0.0",
|
|
||||||
"capabilities": "[\"citation_detect\", \"entity_recognition\"]",
|
|
||||||
"status": "online", # online / busy / offline
|
|
||||||
"queue_name": "agent:citation_detector",
|
|
||||||
"registered_at": "2026-01-01T00:00:00Z",
|
|
||||||
"last_heartbeat": "2026-01-01T00:01:00Z",
|
|
||||||
"current_task": "task_123", # 当前执行任务 ID,空闲时为 null
|
|
||||||
"concurrency": 5, # 最大并发数
|
|
||||||
"load": 0.6 # 当前负载 0-1
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### 服务发现
|
|
||||||
|
|
||||||
```python
|
|
||||||
# 获取所有在线 Agent
|
|
||||||
agents = redis.hgetall("agent:registry:*")
|
|
||||||
online_agents = [a for a in agents if a.status == "online"]
|
|
||||||
|
|
||||||
# 按能力筛选 Agent
|
|
||||||
capable_agents = [a for a in online_agents if "citation_detect" in a.capabilities]
|
|
||||||
|
|
||||||
# 选择负载最低的 Agent
|
|
||||||
selected = min(capable_agents, key=lambda a: a.load)
|
|
||||||
```
|
|
||||||
|
|
||||||
## 任务分发
|
|
||||||
|
|
||||||
### 分发策略
|
|
||||||
|
|
||||||
| 策略 | 说明 | 适用场景 |
|
|
||||||
|------|------|----------|
|
|
||||||
| `round_robin` | 轮询分配 | 同类型 Agent 负载均衡 |
|
|
||||||
| `least_load` | 最低负载优先 | 不同 Agent 性能差异时 |
|
|
||||||
| `capability_match` | 能力匹配 | 需要特定能力的任务 |
|
|
||||||
| `priority_queue` | 优先级队列 | 任务有明确优先级差异 |
|
|
||||||
| `sticky` | 粘性分配 | 需要上下文连续性的任务 |
|
|
||||||
|
|
||||||
### 任务分发流程
|
|
||||||
|
|
||||||
```
|
|
||||||
任务提交
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
解析任务类型 ──▶ 确定所需能力
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
查询 Registry ──▶ 筛选可用 Agent
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
应用分发策略 ──▶ 选择目标 Agent
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
序列化消息 ──▶ 写入目标队列
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
更新任务状态 ──▶ 等待结果
|
|
||||||
```
|
|
||||||
|
|
||||||
### 任务状态机
|
|
||||||
|
|
||||||
```
|
|
||||||
┌──────────┐ 提交任务 ┌──────────┐
|
|
||||||
│ PENDING │───────────────▶│ QUEUED │
|
|
||||||
└──────────┘ └─────┬────┘
|
|
||||||
│ 入队
|
|
||||||
▼
|
|
||||||
┌──────────┐
|
|
||||||
开始执行│ RUNNING │
|
|
||||||
┌───────┤ │
|
|
||||||
│ └─────┬────┘
|
|
||||||
│ │
|
|
||||||
执行成功 执行失败
|
|
||||||
│ │
|
|
||||||
▼ ▼
|
|
||||||
┌──────────┐ ┌──────────┐
|
|
||||||
│COMPLETED │ │ FAILED │
|
|
||||||
└──────────┘ └─────┬────┘
|
|
||||||
│
|
|
||||||
重试次数 < 最大重试
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
┌──────────┐
|
|
||||||
│ RETRY │
|
|
||||||
└────┬─────┘
|
|
||||||
│ 重新入队
|
|
||||||
└────▶ QUEUED
|
|
||||||
```
|
|
||||||
|
|
||||||
## 状态上报
|
|
||||||
|
|
||||||
### 上报内容
|
|
||||||
|
|
||||||
每个 Agent 需要定期上报以下状态信息:
|
|
||||||
|
|
||||||
| 字段 | 说明 | 上报频率 |
|
|
||||||
|------|------|----------|
|
|
||||||
| `status` | 当前状态 | 每次变化 + 心跳 |
|
|
||||||
| `current_task` | 正在执行的任务 | 每次变化 |
|
|
||||||
| `load` | 当前负载 0-1 | 心跳时 |
|
|
||||||
| `queue_depth` | 待处理任务数 | 心跳时 |
|
|
||||||
| `processed_count` | 已处理任务总数 | 心跳时 |
|
|
||||||
| `error_count` | 错误次数 | 心跳时 |
|
|
||||||
| `avg_process_time` | 平均处理时间 | 心跳时 |
|
|
||||||
|
|
||||||
### 心跳机制
|
|
||||||
|
|
||||||
```python
|
|
||||||
# 心跳间隔:30 秒
|
|
||||||
HEARTBEAT_INTERVAL = 30
|
|
||||||
|
|
||||||
# 心跳超时判定:90 秒(3 个心跳周期)
|
|
||||||
HEARTBEAT_TIMEOUT = 90
|
|
||||||
|
|
||||||
# 心跳消息格式
|
|
||||||
heartbeat_message = {
|
|
||||||
"message_type": "heartbeat",
|
|
||||||
"sender": "CitationDetector",
|
|
||||||
"timestamp": "2026-01-01T00:01:00Z",
|
|
||||||
"payload": {
|
|
||||||
"status": "online",
|
|
||||||
"load": 0.4,
|
|
||||||
"queue_depth": 2,
|
|
||||||
"processed_count": 150,
|
|
||||||
"error_count": 3,
|
|
||||||
"avg_process_time": 2.5
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### 状态监控
|
|
||||||
|
|
||||||
- Registry 定期检查 Agent 心跳超时
|
|
||||||
- 超时 Agent 自动标记为 `offline`
|
|
||||||
- 正在执行的任务自动重新分发
|
|
||||||
- 管理员可通过 Dashboard 查看 Agent 状态
|
|
||||||
|
|
||||||
## 配置管理
|
|
||||||
|
|
||||||
### 配置层级
|
|
||||||
|
|
||||||
| 层级 | 说明 | 优先级 |
|
|
||||||
|------|------|--------|
|
|
||||||
| `system_default` | 系统默认配置 | 最低 |
|
|
||||||
| `tenant_config` | 租户级配置(覆盖默认) | 中 |
|
|
||||||
| `agent_override` | Agent 本地配置(覆盖租户) | 最高 |
|
|
||||||
|
|
||||||
### 配置项
|
|
||||||
|
|
||||||
```python
|
|
||||||
class AgentConfig:
|
|
||||||
"""Agent 配置项"""
|
|
||||||
|
|
||||||
# 通用配置
|
|
||||||
max_concurrency: int = 5 # 最大并发数
|
|
||||||
task_timeout: int = 300 # 任务超时时间(秒)
|
|
||||||
retry_policy: str = "exponential" # 重试策略
|
|
||||||
max_retries: int = 3 # 最大重试次数
|
|
||||||
|
|
||||||
# CitationDetector 专用
|
|
||||||
confidence_threshold: float = 0.7 # 置信度阈值
|
|
||||||
enable_semantic_analysis: bool = True
|
|
||||||
|
|
||||||
# ContentGenerator 专用
|
|
||||||
default_content_length: int = 1500
|
|
||||||
content_styles: List[str] = ["professional", "casual", "technical"]
|
|
||||||
|
|
||||||
# RuleChecker 专用
|
|
||||||
strict_mode: bool = False
|
|
||||||
custom_rules: List[dict] = []
|
|
||||||
```
|
|
||||||
|
|
||||||
### 动态配置更新
|
|
||||||
|
|
||||||
```
|
|
||||||
管理员更新配置
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
写入 Redis 配置存储
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
发布 config_update 消息
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
各 Agent 接收消息
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
Agent 热重载配置(无需重启)
|
|
||||||
```
|
|
||||||
|
|
||||||
## 管理功能需求
|
|
||||||
|
|
||||||
### Agent Dashboard
|
|
||||||
|
|
||||||
管理员需要通过 Dashboard 监控和管理所有 Agent:
|
|
||||||
|
|
||||||
| 功能 | 说明 | 优先级 |
|
|
||||||
|------|------|--------|
|
|
||||||
| Agent 列表 | 查看所有已注册 Agent 的状态 | P0 |
|
|
||||||
| 实时状态看板 | 展示 Agent 在线状态、负载、任务数 | P0 |
|
|
||||||
| 任务队列监控 | 查看各队列的待处理任务数 | P0 |
|
|
||||||
| 任务执行日志 | 查看任务执行的详细日志 | P1 |
|
|
||||||
| 错误告警 | Agent 异常时发送告警通知 | P1 |
|
|
||||||
| 配置管理 | 在线修改 Agent 配置 | P1 |
|
|
||||||
| Agent 启停 | 远程启动或停止 Agent | P2 |
|
|
||||||
| 版本管理 | 查看和升级 Agent 版本 | P2 |
|
|
||||||
| 性能统计 | 查看 Agent 的性能指标和趋势 | P2 |
|
|
||||||
|
|
||||||
### Agent 日志规范
|
|
||||||
|
|
||||||
```python
|
|
||||||
# 日志格式
|
|
||||||
{
|
|
||||||
"timestamp": "2026-01-01T00:00:00Z",
|
|
||||||
"level": "INFO", # DEBUG / INFO / WARNING / ERROR
|
|
||||||
"agent": "CitationDetector",
|
|
||||||
"task_id": "task_123",
|
|
||||||
"message": "开始执行引用检测",
|
|
||||||
"metadata": {
|
|
||||||
"query_id": "q_456",
|
|
||||||
"brand": "ExampleBrand"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### 错误处理
|
|
||||||
|
|
||||||
| 错误类型 | 处理方式 | 告警级别 |
|
|
||||||
|----------|----------|----------|
|
|
||||||
| 任务执行失败 | 自动重试,达到最大重试后标记失败 | Warning |
|
|
||||||
| Agent 心跳超时 | 标记离线,任务重新分发 | Error |
|
|
||||||
| 队列积压 | 触发弹性扩容或告警 | Warning |
|
|
||||||
| Agent 崩溃 | 自动重启(配合 supervisor) | Critical |
|
|
||||||
| 配置错误 | 拒绝加载,保持上次有效配置 | Error |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
*本文档定义 GEO 平台的 AI Agent 框架设计,各 Agent 的具体实现算法和模型选型在模块指南中详细说明。*
|
|
||||||
|
|
@ -1,27 +0,0 @@
|
||||||
# GEO 平台 - 组件库设计
|
|
||||||
|
|
||||||
## 概述
|
|
||||||
|
|
||||||
本文档定义 GEO 平台的 UI 组件库设计方案,基于 shadcn/ui 进行扩展和定制。
|
|
||||||
|
|
||||||
> **TODO**: 本文档为占位文件,待补充完整内容。
|
|
||||||
|
|
||||||
## 待补充内容
|
|
||||||
|
|
||||||
- [ ] 组件库架构设计
|
|
||||||
- [ ] 基础组件清单(Button、Input、Card 等)
|
|
||||||
- [ ] 业务组件清单(QueryCard、CitationBadge、ReportPreview 等)
|
|
||||||
- [ ] 组件 Props 接口定义
|
|
||||||
- [ ] 组件使用示例
|
|
||||||
- [ ] 组件主题定制方案
|
|
||||||
- [ ] 组件开发规范
|
|
||||||
- [ ] 组件文档生成方案(Storybook)
|
|
||||||
|
|
||||||
## 参考
|
|
||||||
|
|
||||||
- [UI 风格指南](./ui-style-guide.md)
|
|
||||||
- [前端系统架构](../.qoder/repowiki/zh/content/前端系统架构/UI组件库.md)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
*本文档待补充。*
|
|
||||||
|
|
@ -1,31 +0,0 @@
|
||||||
# GEO 平台 - 数据库 Schema 设计
|
|
||||||
|
|
||||||
## 概述
|
|
||||||
|
|
||||||
本文档定义 GEO 平台的完整数据库 Schema 设计,包括表结构、字段定义、索引设计和关系图。
|
|
||||||
|
|
||||||
> **TODO**: 本文档为占位文件,待补充完整内容。
|
|
||||||
|
|
||||||
## 待补充内容
|
|
||||||
|
|
||||||
- [ ] 实体关系图(ER Diagram)
|
|
||||||
- [ ] 用户相关表(users、user_profiles、teams 等)
|
|
||||||
- [ ] 查询相关表(queries、query_tasks、query_results 等)
|
|
||||||
- [ ] 引用检测相关表(citation_records、citation_analysis 等)
|
|
||||||
- [ ] 订阅相关表(subscriptions、plans、payments 等)
|
|
||||||
- [ ] 内容相关表(contents、content_versions、rules 等)
|
|
||||||
- [ ] 报告相关表(reports、report_templates 等)
|
|
||||||
- [ ] 渠道相关表(channels、channel_configs 等)
|
|
||||||
- [ ] 代理运营相关表(clients、projects 等)
|
|
||||||
- [ ] 索引设计说明
|
|
||||||
- [ ] 数据迁移策略
|
|
||||||
|
|
||||||
## 参考
|
|
||||||
|
|
||||||
- [数据模型设计](../.qoder/repowiki/zh/content/后端系统架构/数据模型设计.md)
|
|
||||||
- [数据库设计](../.qoder/repowiki/zh/content/数据库设计/数据库设计.md)
|
|
||||||
- [表结构设计](../.qoder/repowiki/zh/content/数据库设计/表结构设计.md)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
*本文档待补充。*
|
|
||||||
|
|
@ -1,443 +0,0 @@
|
||||||
# GEO 平台 - UI 风格指南
|
|
||||||
|
|
||||||
## 概述
|
|
||||||
|
|
||||||
本文档定义 GEO 平台的前端 UI 风格规范,确保全平台视觉一致性。所有前端开发人员、UI 设计师和产品经理应遵循本指南进行界面设计和开发。
|
|
||||||
|
|
||||||
GEO 平台的 UI 设计遵循以下核心原则:
|
|
||||||
- **专业可信**:体现企业级 SaaS 平台的专业感和可信度
|
|
||||||
- **简洁高效**:减少视觉噪音,让用户聚焦于数据和操作
|
|
||||||
- **层次分明**:通过清晰的信息层级引导用户注意力
|
|
||||||
- **响应式适配**:适配桌面端、平板和移动端设备
|
|
||||||
|
|
||||||
## 风格特征总结
|
|
||||||
|
|
||||||
### 设计语言
|
|
||||||
|
|
||||||
| 特征 | 描述 | 应用场景 |
|
|
||||||
|------|------|----------|
|
|
||||||
| 扁平化设计 | 去除冗余装饰,强调内容和功能 | 全局 |
|
|
||||||
| 卡片化布局 | 使用卡片容器组织信息模块 | 仪表盘、列表页 |
|
|
||||||
| 留白呼吸感 | 充足的内边距和外边距 | 全局 |
|
|
||||||
| 微妙的层次感 | 通过阴影和边框区分层级 | 卡片、弹窗、下拉 |
|
|
||||||
| 数据可视化 | 丰富的图表和数据展示 | 仪表盘、报告页 |
|
|
||||||
| 状态色彩编码 | 使用颜色直观表达状态 | 状态标签、进度条 |
|
|
||||||
|
|
||||||
### 整体氛围
|
|
||||||
|
|
||||||
- **主色调**:深蓝/靛蓝色系,传达专业、科技、可信
|
|
||||||
- **辅助色**:翠绿色表示成功/正向,琥珀色表示警告,红色表示错误
|
|
||||||
- **中性色**:丰富的灰度层级用于文字和边框
|
|
||||||
- **背景**:浅色背景为主,深色模式作为可选项
|
|
||||||
|
|
||||||
## Design Tokens
|
|
||||||
|
|
||||||
### 颜色体系
|
|
||||||
|
|
||||||
#### 主色调(Primary)
|
|
||||||
|
|
||||||
| Token | 色值 | 用途 |
|
|
||||||
|-------|------|------|
|
|
||||||
| `--color-primary-50` | `#EEF2FF` | 极浅背景、hover 状态 |
|
|
||||||
| `--color-primary-100` | `#E0E7FF` | 轻色背景、选中状态 |
|
|
||||||
| `--color-primary-200` | `#C7D2FE` | 边框高亮 |
|
|
||||||
| `--color-primary-300` | `#A5B4FC` | 禁用状态主色 |
|
|
||||||
| `--color-primary-400` | `#818CF8` | 次要强调 |
|
|
||||||
| `--color-primary-500` | `#6366F1` | **主色** - 按钮、链接、图标 |
|
|
||||||
| `--color-primary-600` | `#4F46E5` | 按钮 hover、深色调 |
|
|
||||||
| `--color-primary-700` | `#4338CA` | 按钮 active |
|
|
||||||
| `--color-primary-800` | `#3730A3` | 深色背景文字 |
|
|
||||||
| `--color-primary-900` | `#312E81` | 标题、深色元素 |
|
|
||||||
|
|
||||||
#### 语义色(Semantic)
|
|
||||||
|
|
||||||
| Token | 色值 | 用途 |
|
|
||||||
|-------|------|------|
|
|
||||||
| `--color-success-50` | `#F0FDF4` | 成功状态轻背景 |
|
|
||||||
| `--color-success-500` | `#22C55E` | **成功** - 成功提示、正向指标 |
|
|
||||||
| `--color-success-600` | `#16A34A` | 成功 hover |
|
|
||||||
| `--color-warning-50` | `#FFFBEB` | 警告状态轻背景 |
|
|
||||||
| `--color-warning-500` | `#F59E0B` | **警告** - 警告提示、待处理 |
|
|
||||||
| `--color-warning-600` | `#D97706` | 警告 hover |
|
|
||||||
| `--color-error-50` | `#FEF2F2` | 错误状态轻背景 |
|
|
||||||
| `--color-error-500` | `#EF4444` | **错误** - 错误提示、失败状态 |
|
|
||||||
| `--color-error-600` | `#DC2626` | 错误 hover |
|
|
||||||
| `--color-info-50` | `#EFF6FF` | 信息状态轻背景 |
|
|
||||||
| `--color-info-500` | `#3B82F6` | **信息** - 信息提示 |
|
|
||||||
|
|
||||||
#### 中性色(Neutral)
|
|
||||||
|
|
||||||
| Token | 色值 | 用途 |
|
|
||||||
|-------|------|------|
|
|
||||||
| `--color-gray-50` | `#F9FAFB` | 页面背景、hover 背景 |
|
|
||||||
| `--color-gray-100` | `#F3F4F6` | 卡片背景、分隔背景 |
|
|
||||||
| `--color-gray-200` | `#E5E7EB` | 边框、分隔线 |
|
|
||||||
| `--color-gray-300` | `#D1D5DB` | 禁用边框、占位符 |
|
|
||||||
| `--color-gray-400` | `#9CA3AF` | 次要文字、图标 |
|
|
||||||
| `--color-gray-500` | `#6B7280` | 辅助文字 |
|
|
||||||
| `--color-gray-600` | `#4B5563` | 正文文字 |
|
|
||||||
| `--color-gray-700` | `#374151` | 标题文字 |
|
|
||||||
| `--color-gray-800` | `#1F2937` | 深色标题 |
|
|
||||||
| `--color-gray-900` | `#111827` | 主标题、深色文字 |
|
|
||||||
|
|
||||||
### 圆角(Border Radius)
|
|
||||||
|
|
||||||
| Token | 值 | 用途 |
|
|
||||||
|-------|-----|------|
|
|
||||||
| `--radius-none` | `0` | 直角(表格、特殊场景) |
|
|
||||||
| `--radius-sm` | `4px` | 小元素(标签、徽章) |
|
|
||||||
| `--radius-md` | `6px` | 默认(按钮、输入框) |
|
|
||||||
| `--radius-lg` | `8px` | 卡片、容器 |
|
|
||||||
| `--radius-xl` | `12px` | 大卡片、模态框 |
|
|
||||||
| `--radius-2xl` | `16px` | 页面级容器 |
|
|
||||||
| `--radius-full` | `9999px` | 圆形(头像、状态点) |
|
|
||||||
|
|
||||||
### 阴影(Shadow)
|
|
||||||
|
|
||||||
| Token | 值 | 用途 |
|
|
||||||
|-------|-----|------|
|
|
||||||
| `--shadow-sm` | `0 1px 2px 0 rgba(0,0,0,0.05)` | 轻微提升(按钮 hover) |
|
|
||||||
| `--shadow-md` | `0 4px 6px -1px rgba(0,0,0,0.1), 0 2px 4px -2px rgba(0,0,0,0.1)` | 卡片默认 |
|
|
||||||
| `--shadow-lg` | `0 10px 15px -3px rgba(0,0,0,0.1), 0 4px 6px -4px rgba(0,0,0,0.1)` | 下拉菜单、浮层 |
|
|
||||||
| `--shadow-xl` | `0 20px 25px -5px rgba(0,0,0,0.1), 0 8px 10px -6px rgba(0,0,0,0.1)` | 模态框、抽屉 |
|
|
||||||
| `--shadow-inner` | `inset 0 2px 4px 0 rgba(0,0,0,0.05)` | 内嵌效果 |
|
|
||||||
|
|
||||||
### 字体(Typography)
|
|
||||||
|
|
||||||
#### 字体族
|
|
||||||
|
|
||||||
| Token | 值 | 用途 |
|
|
||||||
|-------|-----|------|
|
|
||||||
| `--font-sans` | `Inter, "PingFang SC", "Microsoft YaHei", sans-serif` | 正文、UI 文字 |
|
|
||||||
| `--font-mono` | `"JetBrains Mono", "Fira Code", monospace` | 代码、数据值 |
|
|
||||||
|
|
||||||
#### 字号层级
|
|
||||||
|
|
||||||
| Token | 大小 | 字重 | 行高 | 用途 |
|
|
||||||
|-------|------|------|------|------|
|
|
||||||
| `--text-xs` | `12px` | 400 | `16px` | 辅助说明、时间戳 |
|
|
||||||
| `--text-sm` | `14px` | 400 | `20px` | 次要文字、描述 |
|
|
||||||
| `--text-base` | `16px` | 400 | `24px` | **正文默认** |
|
|
||||||
| `--text-lg` | `18px` | 500 | `28px` | 小标题、强调文字 |
|
|
||||||
| `--text-xl` | `20px` | 600 | `28px` | 卡片标题 |
|
|
||||||
| `--text-2xl` | `24px` | 600 | `32px` | 页面小标题 |
|
|
||||||
| `--text-3xl` | `30px` | 700 | `36px` | 页面标题 |
|
|
||||||
| `--text-4xl` | `36px` | 700 | `40px` | 大标题、数字展示 |
|
|
||||||
|
|
||||||
### 间距(Spacing)
|
|
||||||
|
|
||||||
| Token | 值 | 用途 |
|
|
||||||
|-------|-----|------|
|
|
||||||
| `--space-1` | `4px` | 极小间距 |
|
|
||||||
| `--space-2` | `8px` | 紧凑间距(图标与文字) |
|
|
||||||
| `--space-3` | `12px` | 小间距 |
|
|
||||||
| `--space-4` | `16px` | **默认间距**(内边距基础) |
|
|
||||||
| `--space-5` | `20px` | 中等间距 |
|
|
||||||
| `--space-6` | `24px` | 卡片内边距 |
|
|
||||||
| `--space-8` | `32px` | 区块间距 |
|
|
||||||
| `--space-10` | `40px` | 大区块间距 |
|
|
||||||
| `--space-12` | `48px` | 页面级间距 |
|
|
||||||
| `--space-16` | `64px` | 大页面间距 |
|
|
||||||
|
|
||||||
## 组件改造计划
|
|
||||||
|
|
||||||
GEO 平台基于 shadcn/ui 组件库进行构建,以下是对核心组件的定制化规范。
|
|
||||||
|
|
||||||
### Button 按钮
|
|
||||||
|
|
||||||
| 变体 | 背景 | 文字 | 边框 | Hover | 用途 |
|
|
||||||
|------|------|------|------|-------|------|
|
|
||||||
| `default` | `primary-500` | white | none | `primary-600` | 主要操作 |
|
|
||||||
| `secondary` | `gray-100` | `gray-900` | none | `gray-200` | 次要操作 |
|
|
||||||
| `outline` | transparent | `gray-700` | `gray-200` | `gray-50` | 辅助操作 |
|
|
||||||
| `ghost` | transparent | `gray-700` | none | `gray-100` | 极简操作 |
|
|
||||||
| `destructive` | `error-500` | white | none | `error-600` | 危险操作 |
|
|
||||||
| `link` | transparent | `primary-500` | none | underline | 链接样式 |
|
|
||||||
|
|
||||||
**尺寸规范**:
|
|
||||||
- `sm`:`h-8 px-3 text-sm`
|
|
||||||
- `md`(默认):`h-10 px-4 text-base`
|
|
||||||
- `lg`:`h-12 px-6 text-lg`
|
|
||||||
|
|
||||||
### Card 卡片
|
|
||||||
|
|
||||||
```
|
|
||||||
默认样式:
|
|
||||||
- 背景:white
|
|
||||||
- 圆角:`radius-lg` (8px)
|
|
||||||
- 阴影:`shadow-md`
|
|
||||||
- 内边距:`space-6` (24px)
|
|
||||||
- 边框:`1px solid gray-200`
|
|
||||||
|
|
||||||
Hover 状态:
|
|
||||||
- 阴影提升:`shadow-lg`
|
|
||||||
- 过渡:`transition-shadow duration-200`
|
|
||||||
```
|
|
||||||
|
|
||||||
**卡片变体**:
|
|
||||||
| 变体 | 说明 |
|
|
||||||
|------|------|
|
|
||||||
| `default` | 标准卡片,用于信息展示 |
|
|
||||||
| `stats` | 统计卡片,大号数字 + 描述 |
|
|
||||||
| `interactive` | 可交互卡片,hover 有明显反馈 |
|
|
||||||
| `flat` | 扁平卡片,无阴影,用于列表项 |
|
|
||||||
|
|
||||||
### Input 输入框
|
|
||||||
|
|
||||||
```
|
|
||||||
默认样式:
|
|
||||||
- 高度:`h-10`
|
|
||||||
- 圆角:`radius-md` (6px)
|
|
||||||
- 边框:`1px solid gray-300`
|
|
||||||
- 内边距:`px-3`
|
|
||||||
- 字体:`text-base`
|
|
||||||
- 背景:white
|
|
||||||
|
|
||||||
Focus 状态:
|
|
||||||
- 边框:`primary-500`
|
|
||||||
- 阴影:`0 0 0 2px primary-100`
|
|
||||||
- 过渡:`transition-all duration-200`
|
|
||||||
|
|
||||||
错误状态:
|
|
||||||
- 边框:`error-500`
|
|
||||||
- 阴影:`0 0 0 2px error-50`
|
|
||||||
```
|
|
||||||
|
|
||||||
### Badge 标签
|
|
||||||
|
|
||||||
| 变体 | 背景 | 文字 | 用途 |
|
|
||||||
|------|------|------|------|
|
|
||||||
| `default` | `primary-100` | `primary-700` | 默认标签 |
|
|
||||||
| `secondary` | `gray-100` | `gray-700` | 次要标签 |
|
|
||||||
| `success` | `success-50` | `success-600` | 成功状态 |
|
|
||||||
| `warning` | `warning-50` | `warning-600` | 警告状态 |
|
|
||||||
| `error` | `error-50` | `error-600` | 错误状态 |
|
|
||||||
| `info` | `info-50` | `info-600` | 信息状态 |
|
|
||||||
|
|
||||||
### Table 表格
|
|
||||||
|
|
||||||
```
|
|
||||||
默认样式:
|
|
||||||
- 表头背景:`gray-50`
|
|
||||||
- 表头文字:`gray-700` font-medium text-sm
|
|
||||||
- 行高:`h-12`
|
|
||||||
- 行边框:`border-b border-gray-200`
|
|
||||||
- 行 hover:`bg-gray-50`
|
|
||||||
- 单元格内边距:`px-4 py-3`
|
|
||||||
- 空状态:居中显示插图 + 提示文字
|
|
||||||
```
|
|
||||||
|
|
||||||
**特殊行样式**:
|
|
||||||
| 状态 | 样式 |
|
|
||||||
|------|------|
|
|
||||||
| 选中行 | `bg-primary-50` |
|
|
||||||
| 禁用行 | `opacity-50` |
|
|
||||||
| 高亮行 | `bg-warning-50` |
|
|
||||||
|
|
||||||
### Modal / Dialog 模态框
|
|
||||||
|
|
||||||
```
|
|
||||||
默认样式:
|
|
||||||
- 遮罩:`bg-black/50`
|
|
||||||
- 容器背景:white
|
|
||||||
- 圆角:`radius-xl` (12px)
|
|
||||||
- 阴影:`shadow-xl`
|
|
||||||
- 最大宽度:`max-w-lg` (512px)
|
|
||||||
- 内边距:`p-6`
|
|
||||||
- 动画:fade-in + scale-in
|
|
||||||
|
|
||||||
头部:
|
|
||||||
- 标题:`text-lg font-semibold`
|
|
||||||
- 关闭按钮:右上角 ghost 按钮
|
|
||||||
|
|
||||||
底部操作区:
|
|
||||||
- 居右排列
|
|
||||||
- 主按钮在右,次按钮在左
|
|
||||||
- 间距:`gap-3`
|
|
||||||
```
|
|
||||||
|
|
||||||
### Toast 通知
|
|
||||||
|
|
||||||
| 类型 | 图标 | 背景 | 边框 |
|
|
||||||
|------|------|------|------|
|
|
||||||
| `success` | CheckCircle | white | `success-200` |
|
|
||||||
| `warning` | AlertTriangle | white | `warning-200` |
|
|
||||||
| `error` | XCircle | white | `error-200` |
|
|
||||||
| `info` | Info | white | `info-200` |
|
|
||||||
|
|
||||||
```
|
|
||||||
默认样式:
|
|
||||||
- 位置:右上角固定
|
|
||||||
- 最大宽度:`400px`
|
|
||||||
- 圆角:`radius-lg`
|
|
||||||
- 阴影:`shadow-lg`
|
|
||||||
- 内边距:`p-4`
|
|
||||||
- 自动关闭:`5秒`
|
|
||||||
- 动画:slide-in-right + fade-out
|
|
||||||
```
|
|
||||||
|
|
||||||
### Navigation 导航
|
|
||||||
|
|
||||||
**侧边栏导航**:
|
|
||||||
```
|
|
||||||
默认样式:
|
|
||||||
- 宽度:`w-64` (256px)
|
|
||||||
- 背景:`gray-900`
|
|
||||||
- 文字:`gray-300`
|
|
||||||
|
|
||||||
选中状态:
|
|
||||||
- 背景:`primary-600`
|
|
||||||
- 文字:white
|
|
||||||
- 左侧边框指示:`3px solid primary-400`
|
|
||||||
|
|
||||||
Hover 状态:
|
|
||||||
- 背景:`gray-800`
|
|
||||||
- 文字:white
|
|
||||||
```
|
|
||||||
|
|
||||||
**顶部导航**:
|
|
||||||
```
|
|
||||||
默认样式:
|
|
||||||
- 高度:`h-16`
|
|
||||||
- 背景:white
|
|
||||||
- 边框:`border-b border-gray-200`
|
|
||||||
- 阴影:`shadow-sm`
|
|
||||||
```
|
|
||||||
|
|
||||||
## 交互规范
|
|
||||||
|
|
||||||
### 过渡动画
|
|
||||||
|
|
||||||
| 场景 | 时长 | 缓动函数 | 属性 |
|
|
||||||
|------|------|----------|------|
|
|
||||||
| 按钮 hover | `150ms` | `ease-in-out` | `background-color, border-color, box-shadow` |
|
|
||||||
| 卡片 hover | `200ms` | `ease-out` | `box-shadow, transform` |
|
|
||||||
| 模态框出现 | `200ms` | `cubic-bezier(0.16, 1, 0.3, 1)` | `opacity, transform` |
|
|
||||||
| 下拉菜单 | `150ms` | `ease-out` | `opacity, transform` |
|
|
||||||
| Toast 出现/消失 | `300ms` | `ease-in-out` | `opacity, transform` |
|
|
||||||
| 页面切换 | `200ms` | `ease-in-out` | `opacity` |
|
|
||||||
| 数据加载骨架屏 | `pulse 2s infinite` | - | `opacity` |
|
|
||||||
|
|
||||||
### 加载状态
|
|
||||||
|
|
||||||
| 场景 | 样式 |
|
|
||||||
|------|------|
|
|
||||||
| 页面加载 | 全屏骨架屏 + Logo |
|
|
||||||
| 数据加载 | 卡片/表格骨架屏 |
|
|
||||||
| 按钮加载 | 按钮内 Spinner,禁用点击 |
|
|
||||||
| 无限滚动 | 底部 Spinner |
|
|
||||||
| 文件上传 | 进度条 + 百分比 |
|
|
||||||
|
|
||||||
### 空状态
|
|
||||||
|
|
||||||
```
|
|
||||||
统一空状态样式:
|
|
||||||
- 居中布局
|
|
||||||
- 插图/图标:`primary-200` 色调
|
|
||||||
- 标题:`text-lg font-medium text-gray-700`
|
|
||||||
- 描述:`text-sm text-gray-500`
|
|
||||||
- 操作按钮(可选):`primary` 按钮
|
|
||||||
```
|
|
||||||
|
|
||||||
### 错误状态
|
|
||||||
|
|
||||||
```
|
|
||||||
统一错误状态样式:
|
|
||||||
- 错误图标:`error-500` 色的 AlertTriangle
|
|
||||||
- 标题:`text-lg font-medium text-gray-900`
|
|
||||||
- 描述:`text-sm text-gray-600`
|
|
||||||
- 重试按钮:`outline` 按钮
|
|
||||||
- 错误详情(可选):可展开的折叠面板
|
|
||||||
```
|
|
||||||
|
|
||||||
### 表单校验
|
|
||||||
|
|
||||||
```
|
|
||||||
校验提示样式:
|
|
||||||
- 错误信息:`text-xs text-error-500 mt-1`
|
|
||||||
- 错误图标:输入框右侧 `error-500` 图标
|
|
||||||
- 成功图标:输入框右侧 `success-500` Check 图标
|
|
||||||
- 实时校验:输入后 `300ms` 延迟触发
|
|
||||||
```
|
|
||||||
|
|
||||||
## 布局规范
|
|
||||||
|
|
||||||
### 页面结构
|
|
||||||
|
|
||||||
```
|
|
||||||
┌─────────────────────────────────────────────┐
|
|
||||||
│ 顶部导航栏 (h-16, fixed) │
|
|
||||||
├──────────┬──────────────────────────────────┤
|
|
||||||
│ │ 页面标题区 │
|
|
||||||
│ 侧边栏 │ ┌────────────────────────────┐ │
|
|
||||||
│ (w-64) │ │ 内容区 │ │
|
|
||||||
│ fixed │ │ ┌──────┐ ┌──────┐ ┌──────┐ │ │
|
|
||||||
│ │ │ │ 卡片 │ │ 卡片 │ │ 卡片 │ │ │
|
|
||||||
│ │ │ └──────┘ └──────┘ └──────┘ │ │
|
|
||||||
│ │ │ ┌────────────────────────┐ │ │
|
|
||||||
│ │ │ │ 表格/列表 │ │ │
|
|
||||||
│ │ │ └────────────────────────┘ │ │
|
|
||||||
│ │ └────────────────────────────┘ │
|
|
||||||
└──────────┴──────────────────────────────────┘
|
|
||||||
```
|
|
||||||
|
|
||||||
### 响应式断点
|
|
||||||
|
|
||||||
| 断点 | 宽度 | 布局调整 |
|
|
||||||
|------|------|----------|
|
|
||||||
| `sm` | `640px` | 移动端基础适配 |
|
|
||||||
| `md` | `768px` | 侧边栏可折叠 |
|
|
||||||
| `lg` | `1024px` | 完整桌面布局 |
|
|
||||||
| `xl` | `1280px` | 宽屏扩展 |
|
|
||||||
| `2xl` | `1536px` | 超宽屏 |
|
|
||||||
|
|
||||||
### 网格系统
|
|
||||||
|
|
||||||
- 采用 CSS Grid + Flexbox 混合布局
|
|
||||||
- 仪表盘卡片网格:`grid-cols-1 md:grid-cols-2 lg:grid-cols-3 xl:grid-cols-4`
|
|
||||||
- 卡片间距:`gap-6`
|
|
||||||
- 页面最大宽度:`max-w-7xl` (1280px),居中显示
|
|
||||||
|
|
||||||
## 图表规范
|
|
||||||
|
|
||||||
GEO 平台大量使用数据可视化图表,基于 Recharts 实现。
|
|
||||||
|
|
||||||
### 配色方案
|
|
||||||
|
|
||||||
```
|
|
||||||
图表主色系:
|
|
||||||
- 系列 1:`#6366F1` (primary-500)
|
|
||||||
- 系列 2:`#22C55E` (success-500)
|
|
||||||
- 系列 3:`#F59E0B` (warning-500)
|
|
||||||
- 系列 4:`#EF4444` (error-500)
|
|
||||||
- 系列 5:`#8B5CF6` (violet-500)
|
|
||||||
- 系列 6:`#06B6D4` (cyan-500)
|
|
||||||
```
|
|
||||||
|
|
||||||
### 图表样式
|
|
||||||
|
|
||||||
```
|
|
||||||
通用样式:
|
|
||||||
- 背景:透明
|
|
||||||
- 网格线:`gray-200`,虚线
|
|
||||||
- 坐标轴文字:`gray-500`,text-xs
|
|
||||||
- 图例:底部居中
|
|
||||||
- Tooltip:白色背景,圆角,阴影
|
|
||||||
|
|
||||||
柱状图:
|
|
||||||
- 圆角:`radius-sm` 顶部
|
|
||||||
- 间距:`barGap=4`
|
|
||||||
|
|
||||||
折线图:
|
|
||||||
- 线条宽度:`2px`
|
|
||||||
- 数据点:`r=4`,hover `r=6`
|
|
||||||
- 填充区域:`fillOpacity=0.1`
|
|
||||||
|
|
||||||
饼图/环形图:
|
|
||||||
- 内半径:`60%`
|
|
||||||
- 标签:外部引导线 + 百分比
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
*本文档定义 GEO 平台的 UI 风格规范,所有前端实现应严格遵循以上 Design Tokens 和组件规范。*
|
|
||||||
|
|
@ -0,0 +1,24 @@
|
||||||
|
# 模块说明
|
||||||
|
|
||||||
|
本目录包含GEO平台各核心模块的详细说明。
|
||||||
|
|
||||||
|
## 目录内容
|
||||||
|
|
||||||
|
- [诊断模块](./diagnosis.md) - SEO诊断和GEO诊断 ✅ 已完成
|
||||||
|
- [Agent框架](./agent-framework.md) - AI Agent框架设计
|
||||||
|
- [内容Pipeline](./content-pipeline.md) - 内容生成流水线
|
||||||
|
- [知识库](./knowledge-base.md) - RAG知识库系统
|
||||||
|
- [知识图谱](./knowledge-graph.md) - 知识图谱构建
|
||||||
|
- [平台规则](./platform-rules.md) - 平台规则中心
|
||||||
|
- [监控模块](./monitoring.md) - 系统监控
|
||||||
|
- [母题库](./topic-templates.md) - 母题模板管理
|
||||||
|
|
||||||
|
## 诊断模块API端点概览
|
||||||
|
|
||||||
|
| 端点 | 描述 | 状态 |
|
||||||
|
|------|------|------|
|
||||||
|
| `GET /api/v1/diagnosis/seo/{brand_id}` | 获取品牌的SEO诊断结果 | ✅ 已完成 |
|
||||||
|
| `GET /api/v1/diagnosis/geo/{brand_id}` | 获取品牌的GEO诊断结果 | ✅ 已完成 |
|
||||||
|
| `GET /api/v1/diagnosis/combined/{brand_id}` | 获取品牌的SEO+GEO综合诊断结果 | ✅ 已完成 |
|
||||||
|
|
||||||
|
详细说明请查看 [诊断模块文档](./diagnosis.md)。
|
||||||
|
|
@ -0,0 +1,167 @@
|
||||||
|
# AI Agent框架
|
||||||
|
|
||||||
|
## 概述
|
||||||
|
|
||||||
|
GEO平台的AI Agent框架是系统的核心智能层,采用模块化、可插拔的设计,支持Agent的动态注册、任务分发和状态管理。
|
||||||
|
|
||||||
|
## 架构
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────────────────────────────────────────────────────┐
|
||||||
|
│ Agent Framework │
|
||||||
|
├─────────────────────────────────────────────────────────────┤
|
||||||
|
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────────┐ │
|
||||||
|
│ │ Dispatcher │→│ Registry │→│ PipelineEngine │ │
|
||||||
|
│ └─────────────┘ └─────────────┘ └─────────────────┘ │
|
||||||
|
│ │
|
||||||
|
│ ┌─────────────────────────────────────────────────┐ │
|
||||||
|
│ │ Agents │ │
|
||||||
|
│ ├───────────┬───────────┬───────────┬────────────┤ │
|
||||||
|
│ │ Citation │ Content │ DeAI │ GEO │ │
|
||||||
|
│ │ Detector │ Generator │ Agent │ Optimizer │ │
|
||||||
|
│ └───────────┴───────────┴───────────┴────────────┘ │
|
||||||
|
└─────────────────────────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
## 核心组件
|
||||||
|
|
||||||
|
### Dispatcher (调度器)
|
||||||
|
|
||||||
|
负责接收任务请求并分发到合适的Agent。
|
||||||
|
|
||||||
|
- 位置:`backend/app/agent_framework/dispatcher.py`
|
||||||
|
- 职责:任务路由、负载均衡、优先级处理、结果汇总
|
||||||
|
|
||||||
|
### Registry (注册中心)
|
||||||
|
|
||||||
|
管理所有Agent的注册信息和状态。
|
||||||
|
|
||||||
|
- 位置:`backend/app/agent_framework/registry.py`
|
||||||
|
- 存储:PostgreSQL数据库 (AgentRegistryModel)
|
||||||
|
- 功能:Agent注册、健康检查、服务发现、心跳更新
|
||||||
|
|
||||||
|
### PipelineEngine (编排引擎)
|
||||||
|
|
||||||
|
编排多阶段Agent任务链的执行。
|
||||||
|
|
||||||
|
- 位置:`backend/app/agent_framework/pipeline/engine.py`
|
||||||
|
- 职责:DAG执行、变量传递、超时控制、条件执行
|
||||||
|
|
||||||
|
## Agent列表
|
||||||
|
|
||||||
|
### CitationDetector (引用检测Agent)
|
||||||
|
|
||||||
|
| 属性 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| 文件 | `backend/app/agent_framework/agents/citation_detector.py` |
|
||||||
|
| 职责 | 解析AI平台响应,识别品牌引用情况 |
|
||||||
|
| 输入 | AI平台原始响应、品牌名称、竞品名称 |
|
||||||
|
| 输出 | 引用检测结果(引用/未引用/竞品引用) |
|
||||||
|
|
||||||
|
**引用类型**:
|
||||||
|
- `direct_quote` - 品牌名直接提及
|
||||||
|
- `indirect_reference` - 品牌信息被提及
|
||||||
|
- `no_reference` - 品牌未被提及
|
||||||
|
- `competitor_reference` - 竞品被引用
|
||||||
|
|
||||||
|
### ContentGenerator (内容生成Agent)
|
||||||
|
|
||||||
|
| 属性 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| 文件 | `backend/app/agent_framework/agents/content_generator_agent.py` |
|
||||||
|
| 职责 | 生成GEO优化的内容资产 |
|
||||||
|
| 输入 | 内容主题、规则库配置、品牌素材 |
|
||||||
|
| 输出 | GEO优化内容(文章/问答/指南等) |
|
||||||
|
|
||||||
|
**支持内容类型**:
|
||||||
|
- `article` - 长篇文章(800-2000字)
|
||||||
|
- `qa` - 问答对
|
||||||
|
- `guide` - 操作指南
|
||||||
|
- `comparison` - 对比分析
|
||||||
|
- `data_report` - 数据报告
|
||||||
|
|
||||||
|
### DeAI Agent (去AI化Agent)
|
||||||
|
|
||||||
|
| 属性 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| 文件 | `backend/app/agent_framework/agents/deai_agent.py` |
|
||||||
|
| 职责 | 去除AI生成内容的痕迹 |
|
||||||
|
| 输入 | AI生成的内容 |
|
||||||
|
| 输出 | 自然化处理后的内容 |
|
||||||
|
|
||||||
|
### GEOOptimizer (GEO优化Agent)
|
||||||
|
|
||||||
|
| 属性 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| 文件 | `backend/app/agent_framework/agents/geo_optimizer_agent.py` |
|
||||||
|
| 职责 | 对内容进行GEO优化 |
|
||||||
|
| 输入 | 原始内容、关键词策略 |
|
||||||
|
| 输出 | 优化后的内容 |
|
||||||
|
|
||||||
|
## 通信协议
|
||||||
|
|
||||||
|
基于数据库+Redis Queue的异步通信:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class TaskMessage:
|
||||||
|
task_id: str # UUID
|
||||||
|
agent_name: str # 目标Agent名称
|
||||||
|
task_type: str # 任务类型
|
||||||
|
priority: int # 优先级(0-9)
|
||||||
|
input_data: dict # 输入参数
|
||||||
|
callback_url: str # 回调URL
|
||||||
|
created_at: datetime # 创建时间
|
||||||
|
timeout_seconds: int # 超时时间
|
||||||
|
```
|
||||||
|
|
||||||
|
### 通信流程
|
||||||
|
|
||||||
|
```
|
||||||
|
1. 外部请求 → Dispatcher.dispatch(TaskMessage) → Redis Queue
|
||||||
|
2. Redis Queue → Agent消费 → 执行任务
|
||||||
|
3. Agent执行完成 → TaskResult → Dispatcher.handle_result
|
||||||
|
4. Dispatcher更新数据库 → 触发回调(如有)
|
||||||
|
```
|
||||||
|
|
||||||
|
## 注册机制
|
||||||
|
|
||||||
|
Agent启动时向Registry注册:
|
||||||
|
|
||||||
|
```python
|
||||||
|
{
|
||||||
|
"name": "CitationDetector",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"status": "online", # online/busy/offline
|
||||||
|
"endpoint": "http://...",
|
||||||
|
"capabilities": {...},
|
||||||
|
"last_heartbeat": "2024-01-01T00:00:00Z"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 心跳机制
|
||||||
|
|
||||||
|
- 心跳超时阈值:90秒
|
||||||
|
- Registry.check_health() 定期检查并标记超时Agent为OFFLINE
|
||||||
|
|
||||||
|
## 任务状态机
|
||||||
|
|
||||||
|
```
|
||||||
|
PENDING → RUNNING → COMPLETED
|
||||||
|
↓
|
||||||
|
FAILED → RETRY → RUNNING
|
||||||
|
↓
|
||||||
|
CANCELLED
|
||||||
|
```
|
||||||
|
|
||||||
|
## Prompt模板
|
||||||
|
|
||||||
|
Agent使用统一的Prompt模板系统:
|
||||||
|
|
||||||
|
| 文件 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| `prompts/base_template.py` | 基础模板结构 |
|
||||||
|
| `prompts/content_generator.py` | 内容生成模板 |
|
||||||
|
| `prompts/deai_agent.py` | 去AI化模板 |
|
||||||
|
| `prompts/geo_optimizer.py` | GEO优化模板 |
|
||||||
|
| `prompts/rule_checker.py` | 规则检查模板 |
|
||||||
|
| `prompts/topic_selector.py` | 母题选择模板 |
|
||||||
|
|
@ -0,0 +1,109 @@
|
||||||
|
# Agent间通信协议
|
||||||
|
|
||||||
|
## 概述
|
||||||
|
|
||||||
|
本文档详细描述Agent Framework中各Agent之间的通信协议和数据结构。
|
||||||
|
|
||||||
|
## 消息类型
|
||||||
|
|
||||||
|
### TaskMessage (任务消息)
|
||||||
|
|
||||||
|
从调度器发往Agent的任务消息。
|
||||||
|
|
||||||
|
| 字段 | 类型 | 必填 | 说明 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| task_id | str | 是 | UUID格式的任务ID |
|
||||||
|
| agent_name | str | 是 | 目标Agent名称 |
|
||||||
|
| task_type | str | 是 | 任务类型 |
|
||||||
|
| priority | int | 否 | 优先级(0-9,9最高) |
|
||||||
|
| input_data | dict | 是 | 输入参数 |
|
||||||
|
| callback_url | str | 否 | 回调URL |
|
||||||
|
| created_at | datetime | 是 | 创建时间 |
|
||||||
|
| timeout_seconds | int | 否 | 超时时间(默认300秒) |
|
||||||
|
|
||||||
|
### TaskResult (任务结果)
|
||||||
|
|
||||||
|
从Agent返回的结果消息。
|
||||||
|
|
||||||
|
| 字段 | 类型 | 必填 | 说明 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| task_id | str | 是 | 对应的任务ID |
|
||||||
|
| agent_name | str | 是 | 执行任务的Agent名称 |
|
||||||
|
| status | str | 是 | 任务状态 (completed/failed/cancelled) |
|
||||||
|
| output_data | dict | 否 | 输出数据 |
|
||||||
|
| error_message | str | 否 | 错误信息 |
|
||||||
|
| started_at | datetime | 是 | 开始时间 |
|
||||||
|
| completed_at | datetime | 是 | 完成时间 |
|
||||||
|
| metrics | dict | 否 | 执行指标(耗时、token消耗等) |
|
||||||
|
|
||||||
|
### TaskProgress (进度上报)
|
||||||
|
|
||||||
|
Agent执行过程中上报的进度信息。
|
||||||
|
|
||||||
|
| 字段 | 类型 | 必填 | 说明 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| task_id | str | 是 | 任务ID |
|
||||||
|
| agent_name | str | 是 | Agent名称 |
|
||||||
|
| progress | float | 是 | 进度 (0.0-1.0) |
|
||||||
|
| message | str | 是 | 进度描述 |
|
||||||
|
| updated_at | datetime | 是 | 更新时间 |
|
||||||
|
|
||||||
|
## 状态枚举
|
||||||
|
|
||||||
|
### TaskStatus (任务状态)
|
||||||
|
|
||||||
|
```python
|
||||||
|
class TaskStatus(str, Enum):
|
||||||
|
PENDING = "pending" # 等待执行
|
||||||
|
RUNNING = "running" # 执行中
|
||||||
|
COMPLETED = "completed" # 已完成
|
||||||
|
FAILED = "failed" # 执行失败
|
||||||
|
CANCELLED = "cancelled" # 已取消
|
||||||
|
```
|
||||||
|
|
||||||
|
### AgentStatus (Agent状态)
|
||||||
|
|
||||||
|
```python
|
||||||
|
class AgentStatus(str, Enum):
|
||||||
|
ONLINE = "online" # 在线
|
||||||
|
OFFLINE = "offline" # 离线
|
||||||
|
BUSY = "busy" # 忙碌中
|
||||||
|
```
|
||||||
|
|
||||||
|
## 通信流程
|
||||||
|
|
||||||
|
```
|
||||||
|
1. 外部请求 → Dispatcher.dispatch(TaskMessage) → Redis Queue
|
||||||
|
2. Redis Queue → Agent消费 → 执行任务
|
||||||
|
3. Agent执行完成 → TaskResult → Dispatcher.handle_result
|
||||||
|
4. Dispatcher更新数据库 → 触发回调(如有)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Agent类型
|
||||||
|
|
||||||
|
| 类型 | 说明 | 职责 |
|
||||||
|
|------|------|------|
|
||||||
|
| CITATION_DETECTOR | 引用检测Agent | 解析AI平台响应,识别品牌引用 |
|
||||||
|
| CONTENT_GENERATOR | 内容生成Agent | 生成GEO优化内容 |
|
||||||
|
| DEAI_AGENT | 去AI化Agent | 去除AI生成痕迹 |
|
||||||
|
| GEO_OPTIMIZER | GEO优化Agent | SEO和GEO优化 |
|
||||||
|
| RULE_CHECKER | 规则检查Agent | 内容合规审核 |
|
||||||
|
| COMPETITOR_ANALYZER | 竞品分析Agent | 竞品数据收集分析 |
|
||||||
|
| PERFORMANCE_TRACKER | 性能追踪Agent | 追踪内容表现 |
|
||||||
|
|
||||||
|
## 任务类型
|
||||||
|
|
||||||
|
| Agent | task_type | 说明 |
|
||||||
|
|-------|-----------|------|
|
||||||
|
| CitationDetector | citation_detect | 检测品牌引用 |
|
||||||
|
| ContentGenerator | generate | 生成内容 |
|
||||||
|
| DeAIAgent | humanize | 去AI化处理 |
|
||||||
|
| GEOOptimizer | optimize | GEO优化 |
|
||||||
|
|
||||||
|
## 心跳机制
|
||||||
|
|
||||||
|
Agent通过定时更新心跳来维持在线状态:
|
||||||
|
|
||||||
|
- 心跳超时阈值:90秒
|
||||||
|
- 超时后Agent被标记为OFFLINE
|
||||||
|
- Registry.check_health() 定期检查所有Agent状态
|
||||||
|
|
@ -0,0 +1,87 @@
|
||||||
|
# 内容生成Pipeline
|
||||||
|
|
||||||
|
## 概述
|
||||||
|
|
||||||
|
内容生成Pipeline是GEO平台的核心功能之一,负责将用户输入转化为符合GEO优化标准的最终内容。
|
||||||
|
|
||||||
|
## Pipeline流程
|
||||||
|
|
||||||
|
```
|
||||||
|
用户输入 → 母题选择 → 内容生成 → 去AI化 → SEO优化 → HTML生成 → 输出
|
||||||
|
```
|
||||||
|
|
||||||
|
## 各阶段说明
|
||||||
|
|
||||||
|
### 1. 用户输入
|
||||||
|
|
||||||
|
接收用户的内容生成请求,包括:
|
||||||
|
- 品牌信息
|
||||||
|
- 目标平台
|
||||||
|
- 内容主题
|
||||||
|
- 关键词策略
|
||||||
|
|
||||||
|
### 2. 母题选择
|
||||||
|
|
||||||
|
从母题库中选择合适的模板:
|
||||||
|
|
||||||
|
- 位置:`backend/app/agent_framework/pipeline/loader.py`
|
||||||
|
- 功能:根据内容类型和行业匹配母题
|
||||||
|
|
||||||
|
### 3. 内容生成
|
||||||
|
|
||||||
|
调用ContentGenerator Agent生成初稿:
|
||||||
|
|
||||||
|
- 使用LLM生成内容
|
||||||
|
- 应用品牌风格指南
|
||||||
|
- 遵守规则库约束
|
||||||
|
|
||||||
|
### 4. 去AI化
|
||||||
|
|
||||||
|
使用DeAI Agent处理内容:
|
||||||
|
|
||||||
|
- 重写机械化的句式
|
||||||
|
- 增加语言多样性
|
||||||
|
- 保持语义一致性
|
||||||
|
|
||||||
|
### 5. SEO优化
|
||||||
|
|
||||||
|
使用GEOOptimizer Agent优化:
|
||||||
|
|
||||||
|
- 关键词密度调整
|
||||||
|
- 语义相关性提升
|
||||||
|
- 结构化数据添加
|
||||||
|
|
||||||
|
### 6. HTML生成
|
||||||
|
|
||||||
|
将内容转换为HTML格式:
|
||||||
|
|
||||||
|
- 响应式设计
|
||||||
|
- SEO友好的标签结构
|
||||||
|
- 平台适配
|
||||||
|
|
||||||
|
## Pipeline引擎
|
||||||
|
|
||||||
|
位置:`backend/app/agent_framework/pipeline/engine.py`
|
||||||
|
|
||||||
|
```python
|
||||||
|
class PipelineEngine:
|
||||||
|
"""内容生成Pipeline引擎"""
|
||||||
|
|
||||||
|
async def run(self, request: ContentRequest) -> ContentResult:
|
||||||
|
# 1. 加载母题
|
||||||
|
template = await self.loader.load(request.template_id)
|
||||||
|
|
||||||
|
# 2. 生成内容
|
||||||
|
draft = await self.generator.generate(request, template)
|
||||||
|
|
||||||
|
# 3. 去AI化
|
||||||
|
naturalized = await self.deai.process(draft)
|
||||||
|
|
||||||
|
# 4. SEO优化
|
||||||
|
optimized = await self.optimizer.optimize(naturalized)
|
||||||
|
|
||||||
|
# 5. HTML生成
|
||||||
|
html = await self.html_generator.generate(optimized)
|
||||||
|
|
||||||
|
return ContentResult(html=html, metadata={})
|
||||||
|
```
|
||||||
|
|
@ -0,0 +1,431 @@
|
||||||
|
# 诊断模块
|
||||||
|
|
||||||
|
## 概述
|
||||||
|
|
||||||
|
诊断模块是GEO平台的核心功能,提供**传统SEO诊断**和**GEO诊断**两种能力,帮助用户全面了解网站和品牌在AI时代的搜索可见性。
|
||||||
|
|
||||||
|
## SEO诊断 vs GEO诊断
|
||||||
|
|
||||||
|
### 核心区别
|
||||||
|
|
||||||
|
| 维度 | SEO诊断 | GEO诊断 |
|
||||||
|
|------|---------|---------|
|
||||||
|
| **优化目标** | 网页排名 | 品牌被AI引用 |
|
||||||
|
| **诊断对象** | 网站 | 品牌实体+内容 |
|
||||||
|
| **成功指标** | 排名、流量、点击率 | 引用频率、AI声量占比 |
|
||||||
|
| **用户路径** | 点击链接访问网站 | AI直接推荐品牌 |
|
||||||
|
| **见效周期** | 3-6个月 | 2周-1个月 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## SEO诊断
|
||||||
|
|
||||||
|
传统SEO诊断评估网站对搜索引擎的优化程度,包含以下维度:
|
||||||
|
|
||||||
|
### 1. 技术SEO诊断
|
||||||
|
|
||||||
|
| 诊断项 | 说明 | 工具/方法 |
|
||||||
|
|--------|------|----------|
|
||||||
|
| 索引状态 | 网站是否被搜索引擎正确索引 | site:domain.com、Search Console |
|
||||||
|
| 爬取错误 | 404、5xx、重定向链 | 爬虫工具 |
|
||||||
|
| Core Web Vitals | LCP<2.5s、FID<100ms、CLS<0.1 | PageSpeed Insights |
|
||||||
|
| URL结构 | 规范化、重复URL | 爬虫分析 |
|
||||||
|
| robots.txt | 是否阻止重要页面 | 文件检查 |
|
||||||
|
| sitemap | 站点地图完整性 | XML验证 |
|
||||||
|
|
||||||
|
### 2. 页面SEO诊断
|
||||||
|
|
||||||
|
| 诊断项 | 说明 |
|
||||||
|
|--------|------|
|
||||||
|
| Title/Meta标签 | 是否完整、是否关键词堆砌 |
|
||||||
|
| H标签结构 | 层级是否清晰 |
|
||||||
|
| 关键词密度 | 是否合理分布 |
|
||||||
|
| 内链结构 | 是否有死链、锚文本是否相关 |
|
||||||
|
| 图片Alt | 是否添加描述性Alt文本 |
|
||||||
|
|
||||||
|
### 3. 内容质量诊断
|
||||||
|
|
||||||
|
| 诊断项 | 说明 |
|
||||||
|
|--------|------|
|
||||||
|
| 可读性 | 内容是否易于理解 |
|
||||||
|
| 信息深度 | 是否全面覆盖主题 |
|
||||||
|
| E-E-A-T | 经验、专业性、权威性、可信度 |
|
||||||
|
| 内容新鲜度 | 是否定期更新 |
|
||||||
|
| 重复内容 | 是否有大量重复页面 |
|
||||||
|
|
||||||
|
### 4. 外链分析
|
||||||
|
|
||||||
|
| 诊断项 | 说明 |
|
||||||
|
|--------|------|
|
||||||
|
| 反向链接质量 | 链接来源权威性 |
|
||||||
|
| 毒性信号 | 是否有垃圾链接 |
|
||||||
|
| 锚文本分布 | 是否自然多样 |
|
||||||
|
|
||||||
|
### 5. 用户体验诊断
|
||||||
|
|
||||||
|
| 诊断项 | 说明 |
|
||||||
|
|--------|------|
|
||||||
|
| 移动适配 | 移动端显示是否正常 |
|
||||||
|
| 页面速度 | 加载时间是否达标 |
|
||||||
|
| 转化路径 | 用户操作是否顺畅 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## GEO诊断
|
||||||
|
|
||||||
|
GEO诊断评估品牌在AI生成式引擎中的被引用能力,包含以下维度:
|
||||||
|
|
||||||
|
### 1. 内容可提取性诊断
|
||||||
|
|
||||||
|
AI需要能够轻松提取和理解内容:
|
||||||
|
|
||||||
|
| 诊断项 | 说明 | 优先级 |
|
||||||
|
|--------|------|--------|
|
||||||
|
| 直接回答块 | 页面首段是否有简洁明确的答案 | P0 |
|
||||||
|
| 问答式标题 | H2/H3是否采用问题形式 | P0 |
|
||||||
|
| 列表和表格 | 是否使用结构化数据展示 | P0 |
|
||||||
|
| 内链到子意图页 | 是否链接到相关深度内容 | P1 |
|
||||||
|
| 内容新鲜度 | 是否有更新日期和作者信息 | P1 |
|
||||||
|
|
||||||
|
### 2. 实体清晰度诊断
|
||||||
|
|
||||||
|
AI需要能够理解品牌是什么:
|
||||||
|
|
||||||
|
| 诊断项 | 说明 | 验证标准 |
|
||||||
|
|--------|------|----------|
|
||||||
|
| 品牌定义 | 是否清晰说明品牌做什么 | AI理解准确率≥95% |
|
||||||
|
| 目标受众 | 是否明确服务谁 | 实体识别准确率≥90% |
|
||||||
|
| 差异化价值 | 为什么选择这个品牌 | 独特性评分≥80 |
|
||||||
|
| 行业分类 | 品牌属于什么行业 | 分类准确率≥95% |
|
||||||
|
|
||||||
|
### 3. E-E-A-T信号诊断
|
||||||
|
|
||||||
|
AI需要验证品牌的可信度:
|
||||||
|
|
||||||
|
| 诊断项 | 说明 | 验证标准 |
|
||||||
|
|--------|------|----------|
|
||||||
|
| 作者资质 | 内容作者是否有专业背景 | 作者简介完整度≥90% |
|
||||||
|
| 专业认证 | 是否有行业认证/奖项 | 认证展示率≥80% |
|
||||||
|
| 数据来源 | 是否引用可靠数据 | 引用权威源≥70% |
|
||||||
|
| 专家背书 | 是否有行业专家认可 | 背书数量≥3 |
|
||||||
|
|
||||||
|
### 4. Schema标记诊断
|
||||||
|
|
||||||
|
结构化数据帮助AI理解内容:
|
||||||
|
|
||||||
|
| Schema类型 | 适用场景 | 优先级 | 实施难度 |
|
||||||
|
|-----------|---------|--------|---------|
|
||||||
|
| Organization | 企业主页 | P0必须 | ⭐ 简单 |
|
||||||
|
| Product | 产品页 | P0必须 | ⭐⭐ 中等 |
|
||||||
|
| Article/BlogPosting | 博客文章 | P0必须 | ⭐ 简单 |
|
||||||
|
| FAQPage | 常见问题 | P1推荐 | ⭐ 简单 |
|
||||||
|
| HowTo | 操作指南 | P1推荐 | ⭐⭐ 中等 |
|
||||||
|
| BreadcrumbList | 导航结构 | P1推荐 | ⭐ 简单 |
|
||||||
|
| Review/Rating | 评价评分 | P2可选 | ⭐⭐ 中等 |
|
||||||
|
|
||||||
|
### 5. 主题权威诊断
|
||||||
|
|
||||||
|
AI需要验证品牌在特定领域的权威性:
|
||||||
|
|
||||||
|
| 诊断项 | 说明 | 验证标准 |
|
||||||
|
|--------|------|----------|
|
||||||
|
| 内容深度 | 是否全面覆盖主题 | 内容质量QScore≥4.6/5 |
|
||||||
|
| 话题覆盖度 | 是否覆盖相关子话题 | 话题覆盖率≥80% |
|
||||||
|
| 实体信号一致性 | 各页面实体信号是否一致 | 一致性评分≥85% |
|
||||||
|
| 内链网络 | 是否形成主题内容集群 | 集群完整度≥70% |
|
||||||
|
|
||||||
|
### 6. 引用就绪度诊断
|
||||||
|
|
||||||
|
评估品牌在AI回答中被引用的可能性:
|
||||||
|
|
||||||
|
| 诊断项 | 说明 | 验证标准 |
|
||||||
|
|--------|------|----------|
|
||||||
|
| 引用频率 | 品牌在AI回答中被提及的频率 | AOR(Answer Ownership Rate)≥50% |
|
||||||
|
| 引用质量 | 引用内容是否准确完整 | 引用准确率≥90% |
|
||||||
|
| AI声量占比 | 品牌在AI回答中的占比 | AI SOV≥30% |
|
||||||
|
| 竞品对比 | 与竞品在AI回答中的表现 | 差距≤10pp |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 诊断流程
|
||||||
|
|
||||||
|
### SEO诊断流程
|
||||||
|
|
||||||
|
```
|
||||||
|
1. 网站爬取 → 2. 技术分析 → 3. 内容分析 → 4. 外链分析 → 5. 生成报告
|
||||||
|
```
|
||||||
|
|
||||||
|
### GEO诊断流程
|
||||||
|
|
||||||
|
```
|
||||||
|
1. 品牌信息输入 → 2. 内容可提取性检测 → 3. 实体清晰度检测 → 4. E-E-A-T信号检测
|
||||||
|
→ 5. Schema标记检测 → 6. 主题权威检测 → 7. AI平台引用检测 → 8. 生成诊断报告
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 诊断报告输出
|
||||||
|
|
||||||
|
### SEO诊断报告
|
||||||
|
|
||||||
|
包含以下内容:
|
||||||
|
- 技术SEO评分
|
||||||
|
- 页面SEO评分
|
||||||
|
- 内容质量评分
|
||||||
|
- 外链质量评分
|
||||||
|
- 用户体验评分
|
||||||
|
- 综合评分
|
||||||
|
- 优先修复建议
|
||||||
|
|
||||||
|
### GEO诊断报告
|
||||||
|
|
||||||
|
包含以下内容:
|
||||||
|
- 内容可提取性评分
|
||||||
|
- 实体清晰度评分
|
||||||
|
- E-E-A-T信号评分
|
||||||
|
- Schema标记完整性
|
||||||
|
- 主题权威评分
|
||||||
|
- AI平台引用率
|
||||||
|
- 综合评分
|
||||||
|
- 优先优化建议
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 技术实现
|
||||||
|
|
||||||
|
### API端点
|
||||||
|
|
||||||
|
| 端点 | 方法 | 描述 | 状态 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| `GET /api/v1/diagnosis/seo/{brand_id}` | GET | 获取品牌的SEO诊断结果 | ✅ 已完成 |
|
||||||
|
| `GET /api/v1/diagnosis/geo/{brand_id}` | GET | 获取品牌的GEO诊断结果 | ✅ 已完成 |
|
||||||
|
| `GET /api/v1/diagnosis/combined/{brand_id}` | GET | 获取品牌的SEO+GEO综合诊断结果 | ✅ 已完成 |
|
||||||
|
|
||||||
|
### API响应示例
|
||||||
|
|
||||||
|
#### SEO诊断响应
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"success": true,
|
||||||
|
"data": {
|
||||||
|
"brand_id": "brand_123",
|
||||||
|
"diagnosis_type": "seo",
|
||||||
|
"overall_score": 78,
|
||||||
|
"dimensions": {
|
||||||
|
"technical_seo": {
|
||||||
|
"score": 85,
|
||||||
|
"status": "good",
|
||||||
|
"issues": []
|
||||||
|
},
|
||||||
|
"page_seo": {
|
||||||
|
"score": 72,
|
||||||
|
"status": "needs_improvement",
|
||||||
|
"issues": [
|
||||||
|
{
|
||||||
|
"type": "missing_meta_description",
|
||||||
|
"severity": "medium",
|
||||||
|
"description": "部分页面缺少Meta描述",
|
||||||
|
"affected_pages": 12
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"content_quality": {
|
||||||
|
"score": 80,
|
||||||
|
"status": "good",
|
||||||
|
"issues": []
|
||||||
|
},
|
||||||
|
"backlinks": {
|
||||||
|
"score": 65,
|
||||||
|
"status": "needs_improvement",
|
||||||
|
"issues": []
|
||||||
|
},
|
||||||
|
"user_experience": {
|
||||||
|
"score": 88,
|
||||||
|
"status": "excellent",
|
||||||
|
"issues": []
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"recommendations": [
|
||||||
|
{
|
||||||
|
"priority": "high",
|
||||||
|
"action": "补充缺失的Meta描述",
|
||||||
|
"impact": "提升页面相关性评分"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"diagnosed_at": "2024-01-15T10:30:00Z"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### GEO诊断响应
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"success": true,
|
||||||
|
"data": {
|
||||||
|
"brand_id": "brand_123",
|
||||||
|
"diagnosis_type": "geo",
|
||||||
|
"overall_score": 72,
|
||||||
|
"dimensions": {
|
||||||
|
"content_extractability": {
|
||||||
|
"score": 75,
|
||||||
|
"status": "good",
|
||||||
|
"metrics": {
|
||||||
|
"direct_answer_blocks": 85,
|
||||||
|
"qa_headings": 70,
|
||||||
|
"structured_data": 80
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"entity_clarity": {
|
||||||
|
"score": 80,
|
||||||
|
"status": "good",
|
||||||
|
"metrics": {
|
||||||
|
"brand_definition_accuracy": 95,
|
||||||
|
"target_audience_clarity": 85,
|
||||||
|
"differentiation_score": 75
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"eeat_signals": {
|
||||||
|
"score": 68,
|
||||||
|
"status": "needs_improvement",
|
||||||
|
"metrics": {
|
||||||
|
"author_credentials": 60,
|
||||||
|
"certifications": 70,
|
||||||
|
"data_sources": 75,
|
||||||
|
"expert_endorsements": 50
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"schema_markup": {
|
||||||
|
"score": 70,
|
||||||
|
"status": "good",
|
||||||
|
"coverage": {
|
||||||
|
"organization": true,
|
||||||
|
"product": true,
|
||||||
|
"article": true,
|
||||||
|
"faq": false,
|
||||||
|
"howto": false
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"topic_authority": {
|
||||||
|
"score": 74,
|
||||||
|
"status": "good",
|
||||||
|
"metrics": {
|
||||||
|
"content_depth": 80,
|
||||||
|
"topic_coverage": 72,
|
||||||
|
"entity_consistency": 85,
|
||||||
|
"internal_linking": 65
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"citation_readiness": {
|
||||||
|
"score": 65,
|
||||||
|
"status": "needs_improvement",
|
||||||
|
"metrics": {
|
||||||
|
"citation_frequency": 60,
|
||||||
|
"citation_accuracy": 75,
|
||||||
|
"ai_sov": 55,
|
||||||
|
"competitor_gap": -8
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"recommendations": [
|
||||||
|
{
|
||||||
|
"priority": "high",
|
||||||
|
"action": "增强E-E-A-T信号",
|
||||||
|
"impact": "提升AI对品牌可信度的评估"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"diagnosed_at": "2024-01-15T10:30:00Z"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 综合诊断响应
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"success": true,
|
||||||
|
"data": {
|
||||||
|
"brand_id": "brand_123",
|
||||||
|
"diagnosis_type": "combined",
|
||||||
|
"seo_score": 78,
|
||||||
|
"geo_score": 72,
|
||||||
|
"combined_score": 75,
|
||||||
|
"seo_summary": {
|
||||||
|
"strengths": ["技术SEO基础良好", "用户体验优秀"],
|
||||||
|
"weaknesses": ["外链质量需提升", "部分页面Meta描述缺失"]
|
||||||
|
},
|
||||||
|
"geo_summary": {
|
||||||
|
"strengths": ["实体清晰度高", "内容结构化良好"],
|
||||||
|
"weaknesses": ["E-E-A-T信号不足", "AI引用率偏低"]
|
||||||
|
},
|
||||||
|
"priority_actions": [
|
||||||
|
{
|
||||||
|
"type": "seo",
|
||||||
|
"action": "建设高质量外链",
|
||||||
|
"expected_impact": "提升域名权威性"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "geo",
|
||||||
|
"action": "添加专家背书和认证",
|
||||||
|
"expected_impact": "提升AI引用率"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"diagnosed_at": "2024-01-15T10:30:00Z"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 后端实现
|
||||||
|
|
||||||
|
| 组件 | 文件 | 职责 |
|
||||||
|
|------|------|------|
|
||||||
|
| SEO诊断服务 | `backend/app/services/seo_diagnosis.py` | 执行SEO诊断分析 |
|
||||||
|
| GEO诊断服务 | `backend/app/services/geo_diagnosis.py` | 执行GEO诊断分析 |
|
||||||
|
| 引用检测Agent | `backend/app/agent_framework/agents/citation_detector.py` | 检测AI平台引用情况 |
|
||||||
|
| 诊断报告生成 | `backend/app/services/diagnosis_report.py` | 生成诊断报告 |
|
||||||
|
| 诊断路由 | `backend/app/api/routes/diagnosis.py` | API端点定义 |
|
||||||
|
|
||||||
|
### 前端实现
|
||||||
|
|
||||||
|
| 页面 | 路径 | 说明 | 状态 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| SEO诊断 | `/dashboard/seo-diagnosis` | SEO诊断入口和报告展示 | ✅ 已完成 |
|
||||||
|
| GEO诊断 | `/dashboard/geo-diagnosis` | GEO诊断入口和报告展示 | ✅ 已完成 |
|
||||||
|
| 综合诊断 | `/dashboard/diagnosis` | SEO+GEO综合诊断报告 | ✅ 已完成 |
|
||||||
|
|
||||||
|
前端实现包括:
|
||||||
|
- 诊断任务触发和进度展示
|
||||||
|
- 多维度评分可视化(雷达图、进度条)
|
||||||
|
- 问题列表和修复建议
|
||||||
|
- 历史诊断记录对比
|
||||||
|
- 诊断报告导出功能
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 改进建议
|
||||||
|
|
||||||
|
### 当前问题
|
||||||
|
|
||||||
|
| 问题 | 说明 | 优先级 | 状态 |
|
||||||
|
|------|------|--------|------|
|
||||||
|
| 诊断定义不完整 | 仅实现引用检测,缺少完整GEO诊断 | P0 | ✅ 已解决 |
|
||||||
|
| SEO诊断缺失 | 未实现传统SEO诊断能力 | P1 | ✅ 已解决 |
|
||||||
|
| 前后端断裂 | 诊断页面是占位页面 | P0 | ✅ 已解决 |
|
||||||
|
|
||||||
|
### 改进计划
|
||||||
|
|
||||||
|
| 阶段 | 目标 | 时间 | 状态 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| Phase 1 | 完善GEO诊断(补充6大维度) | 2周 | ✅ 已完成 |
|
||||||
|
| Phase 2 | 新增SEO诊断能力 | 2周 | ✅ 已完成 |
|
||||||
|
| Phase 3 | 整合SEO+GEO诊断报告 | 1周 | ✅ 已完成 |
|
||||||
|
| Phase 4 | 前端诊断页面实现 | 1周 | ✅ 已完成 |
|
||||||
|
|
||||||
|
### 后续优化
|
||||||
|
|
||||||
|
| 阶段 | 目标 | 预计时间 |
|
||||||
|
|------|------|----------|
|
||||||
|
| Phase 5 | 诊断报告导出(PDF/Excel) | 1周 |
|
||||||
|
| Phase 6 | 历史诊断趋势分析 | 2周 |
|
||||||
|
| Phase 7 | 竞品诊断对比功能 | 2周 |
|
||||||
|
|
@ -0,0 +1,242 @@
|
||||||
|
# 知识图谱实体类型定义
|
||||||
|
|
||||||
|
## 概述
|
||||||
|
|
||||||
|
本文档描述知识图谱中使用的实体类型和关系类型。
|
||||||
|
|
||||||
|
## 实体类型 (EntityType)
|
||||||
|
|
||||||
|
位置:`backend/app/models/knowledge_graph.py`
|
||||||
|
|
||||||
|
### 类型定义
|
||||||
|
|
||||||
|
| 类型 | 枚举值 | 说明 | 示例 |
|
||||||
|
|------|--------|------|------|
|
||||||
|
| ORGANIZATION | organization | 公司/组织 | 腾讯、阿里巴巴 |
|
||||||
|
| PRODUCT | product | 产品 | 微信、王者荣耀 |
|
||||||
|
| PERSON | person | 人物 | 马化腾、张小龙 |
|
||||||
|
| LOCATION | location | 地点 | 深圳、广州 |
|
||||||
|
| TECHNOLOGY | technology | 技术 | AI、区块链 |
|
||||||
|
| BRAND | brand | 品牌 | 微信支付、腾讯云 |
|
||||||
|
| EVENT | event | 事件 | 2020腾讯年会 |
|
||||||
|
| CONCEPT | concept | 概念 | 数字化转型 |
|
||||||
|
|
||||||
|
### 属性说明
|
||||||
|
|
||||||
|
```python
|
||||||
|
class Entity:
|
||||||
|
id: str # UUID
|
||||||
|
name: str # 实体名称
|
||||||
|
entity_type: str # 实体类型
|
||||||
|
description: str # 描述
|
||||||
|
properties: dict # 自定义属性
|
||||||
|
confidence: str # 置信度 (high/medium/low)
|
||||||
|
source_chunk_id: str # 来源Chunk ID
|
||||||
|
```
|
||||||
|
|
||||||
|
## 关系类型 (RelationType)
|
||||||
|
|
||||||
|
位置:`backend/app/models/knowledge_graph.py`
|
||||||
|
|
||||||
|
### 类型定义
|
||||||
|
|
||||||
|
| 类型 | 枚举值 | 说明 | 示例 |
|
||||||
|
|------|--------|------|------|
|
||||||
|
| COMPETES_WITH | competes_with | 竞争对手 | 微信 ↔ 支付宝 |
|
||||||
|
| PARTNERS_WITH | partners_with | 合作伙伴 | 腾讯 ↔ 京东 |
|
||||||
|
| PRODUCES | produces | 生产 | 苹果 ↔ iPhone |
|
||||||
|
| USES_TECHNOLOGY | uses_technology | 使用技术 | 微信 ↔ AI |
|
||||||
|
| LOCATED_IN | located_in | 位于 | 腾讯 ↔ 深圳 |
|
||||||
|
| FOUNDED_IN | founded_in | 成立于 | 腾讯 ↔ 1998 |
|
||||||
|
| CEO_OF | ceo_of | CEO | 马化腾 ↔ 腾讯 |
|
||||||
|
| FOUNDER_OF | founder_of | 创始人 | 马化腾 ↔ 腾讯 |
|
||||||
|
| RELATED_TO | related_to | 相关 | AI ↔ 机器学习 |
|
||||||
|
| PART_OF | part_of | 属于 | iPhone ↔ 苹果 |
|
||||||
|
|
||||||
|
### 属性说明
|
||||||
|
|
||||||
|
```python
|
||||||
|
class Relation:
|
||||||
|
id: str # UUID
|
||||||
|
source_entity_id: str # 源实体ID
|
||||||
|
target_entity_id: str # 目标实体ID
|
||||||
|
relation_type: str # 关系类型
|
||||||
|
properties: dict # 自定义属性
|
||||||
|
confidence: str # 置信度 (high/medium/low)
|
||||||
|
source_chunk_id: str # 来源Chunk ID
|
||||||
|
```
|
||||||
|
|
||||||
|
## 抽取流程
|
||||||
|
|
||||||
|
位置:`backend/app/services/knowledge/entity_extractor.py`
|
||||||
|
|
||||||
|
### EntityExtractor
|
||||||
|
|
||||||
|
```python
|
||||||
|
class EntityExtractor:
|
||||||
|
ENTITY_TYPES = [
|
||||||
|
"ORGANIZATION",
|
||||||
|
"PRODUCT",
|
||||||
|
"PERSON",
|
||||||
|
"LOCATION",
|
||||||
|
"TECHNOLOGY",
|
||||||
|
"BRAND",
|
||||||
|
"EVENT",
|
||||||
|
"CONCEPT",
|
||||||
|
]
|
||||||
|
|
||||||
|
RELATION_TYPES = [
|
||||||
|
"COMPETES_WITH",
|
||||||
|
"PARTNERS_WITH",
|
||||||
|
"PRODUCES",
|
||||||
|
"USES_TECHNOLOGY",
|
||||||
|
"LOCATED_IN",
|
||||||
|
"FOUNDED_IN",
|
||||||
|
"CEO_OF",
|
||||||
|
"FOUNDER_OF",
|
||||||
|
"RELATED_TO",
|
||||||
|
"PART_OF",
|
||||||
|
]
|
||||||
|
|
||||||
|
async def extract(text: str, context: str) -> ExtractionResult:
|
||||||
|
"""
|
||||||
|
从文本中抽取实体和关系
|
||||||
|
1. 构建抽取Prompt
|
||||||
|
2. 调用LLM
|
||||||
|
3. 解析返回结果
|
||||||
|
"""
|
||||||
|
```
|
||||||
|
|
||||||
|
### 抽取Prompt示例
|
||||||
|
|
||||||
|
```
|
||||||
|
从以下文本中抽取知识图谱的实体和关系。
|
||||||
|
|
||||||
|
实体类型:
|
||||||
|
- ORGANIZATION (公司/组织)
|
||||||
|
- PRODUCT (产品)
|
||||||
|
- PERSON (人物)
|
||||||
|
...
|
||||||
|
|
||||||
|
关系类型:
|
||||||
|
- COMPETES_WITH (竞争对手)
|
||||||
|
- PARTNERS_WITH (合作伙伴)
|
||||||
|
...
|
||||||
|
|
||||||
|
文本内容:
|
||||||
|
{text}
|
||||||
|
|
||||||
|
请以JSON格式返回结果:
|
||||||
|
{
|
||||||
|
"entities": [
|
||||||
|
{"name": "实体名称", "entity_type": "类型", "confidence": "high/medium/low"}
|
||||||
|
],
|
||||||
|
"relations": [
|
||||||
|
{"source_entity": "源", "target_entity": "目标", "relation_type": "类型", "confidence": "high/medium/low"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 图谱构建
|
||||||
|
|
||||||
|
位置:`backend/app/services/knowledge/graph_builder.py`
|
||||||
|
|
||||||
|
### GraphBuilder
|
||||||
|
|
||||||
|
```python
|
||||||
|
class GraphBuilder:
|
||||||
|
async def build_from_chunk(
|
||||||
|
session: AsyncSession,
|
||||||
|
chunk_id: str,
|
||||||
|
context: str = None,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
从Chunk构建知识图谱
|
||||||
|
1. 获取Chunk内容
|
||||||
|
2. 调用EntityExtractor抽取
|
||||||
|
3. 存储到图谱
|
||||||
|
"""
|
||||||
|
```
|
||||||
|
|
||||||
|
### 构建统计
|
||||||
|
|
||||||
|
返回统计信息:
|
||||||
|
|
||||||
|
```python
|
||||||
|
{
|
||||||
|
"entities_created": 5, # 新建实体数
|
||||||
|
"entities_existing": 2, # 已存在实体数
|
||||||
|
"relations_created": 3, # 新建关系数
|
||||||
|
"relations_existing": 1, # 已存在关系数
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 图谱查询
|
||||||
|
|
||||||
|
### 实体查询
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 查询实体
|
||||||
|
GET /api/v1/knowledge-graph/entities?type=brand&limit=20
|
||||||
|
|
||||||
|
# 响应
|
||||||
|
{
|
||||||
|
"entities": [
|
||||||
|
{
|
||||||
|
"id": "uuid",
|
||||||
|
"name": "微信",
|
||||||
|
"type": "brand",
|
||||||
|
"description": "...",
|
||||||
|
"properties": {...}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 关系查询
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 查询关系
|
||||||
|
GET /api/v1/knowledge-graph/relations?source_id=xxx&type=competes_with
|
||||||
|
|
||||||
|
# 响应
|
||||||
|
{
|
||||||
|
"relations": [
|
||||||
|
{
|
||||||
|
"id": "uuid",
|
||||||
|
"source_id": "xxx",
|
||||||
|
"target_id": "yyy",
|
||||||
|
"type": "competes_with",
|
||||||
|
"properties": {...}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 语义搜索
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 语义搜索
|
||||||
|
GET /api/v1/knowledge-graph/search?query=微信的竞争对手
|
||||||
|
|
||||||
|
# 响应
|
||||||
|
{
|
||||||
|
"results": [
|
||||||
|
{"entity": {...}, "score": 0.95},
|
||||||
|
{"entity": {...}, "score": 0.87}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 置信度评估
|
||||||
|
|
||||||
|
| 级别 | 说明 | 使用场景 |
|
||||||
|
|------|------|----------|
|
||||||
|
| high | 高置信度 | 直接使用 |
|
||||||
|
| medium | 中置信度 | 建议人工审核 |
|
||||||
|
| low | 低置信度 | 仅供参考 |
|
||||||
|
|
||||||
|
评估因素:
|
||||||
|
- 文本中提及的明确性
|
||||||
|
- 上下文的支持程度
|
||||||
|
- LLM模型的确定性
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,50 @@
|
||||||
|
# 知识库
|
||||||
|
|
||||||
|
## 概述
|
||||||
|
|
||||||
|
GEO平台的知识库系统基于RAG(Retrieval-Augmented Generation)架构,为内容生成提供领域知识支持。
|
||||||
|
|
||||||
|
## 系统架构
|
||||||
|
|
||||||
|
```
|
||||||
|
文档上传 → 文本分块 → 向量化 → RAG检索 → LLM增强生成
|
||||||
|
```
|
||||||
|
|
||||||
|
## 核心功能
|
||||||
|
|
||||||
|
### 文档管理
|
||||||
|
|
||||||
|
- 支持多种文档格式(PDF、TXT、Markdown、DOCX)
|
||||||
|
- 文档版本管理
|
||||||
|
- 分类和标签
|
||||||
|
|
||||||
|
### 文本分块
|
||||||
|
|
||||||
|
- 智能分块策略
|
||||||
|
- 重叠窗口机制
|
||||||
|
- 元数据保留
|
||||||
|
|
||||||
|
### 向量化
|
||||||
|
|
||||||
|
- 支持多种嵌入模型
|
||||||
|
- 向量数据库存储
|
||||||
|
- 相似度检索
|
||||||
|
|
||||||
|
## API接口
|
||||||
|
|
||||||
|
知识库相关API位于 `backend/app/api/knowledge.py`
|
||||||
|
|
||||||
|
### 主要端点
|
||||||
|
|
||||||
|
| 方法 | 路径 | 说明 |
|
||||||
|
|------|------|------|
|
||||||
|
| GET | /api/v1/knowledge/bases | 获取知识库列表 |
|
||||||
|
| POST | /api/v1/knowledge/bases | 创建知识库 |
|
||||||
|
| POST | /api/v1/knowledge/bases/{id}/documents | 上传文档 |
|
||||||
|
| GET | /api/v1/knowledge/search | 搜索知识 |
|
||||||
|
|
||||||
|
## 配置
|
||||||
|
|
||||||
|
环境变量:
|
||||||
|
- `KNOWLEDGE_EMBEDDING_MODEL` - 嵌入模型
|
||||||
|
- `KNOWLEDGE_VECTOR_DB` - 向量数据库类型
|
||||||
|
|
@ -0,0 +1,58 @@
|
||||||
|
# 知识图谱
|
||||||
|
|
||||||
|
## 概述
|
||||||
|
|
||||||
|
知识图谱模块用于构建和管理品牌相关的实体关系图谱,支持语义搜索和智能推理。
|
||||||
|
|
||||||
|
## 核心功能
|
||||||
|
|
||||||
|
### 实体管理
|
||||||
|
|
||||||
|
- 品牌实体
|
||||||
|
- 产品实体
|
||||||
|
- 竞品实体
|
||||||
|
- 行业概念实体
|
||||||
|
|
||||||
|
### 关系构建
|
||||||
|
|
||||||
|
- 品牌-产品关系
|
||||||
|
- 品牌-竞品关系
|
||||||
|
- 产品-行业关系
|
||||||
|
- 概念上下位关系
|
||||||
|
|
||||||
|
## API接口
|
||||||
|
|
||||||
|
知识图谱API位于 `backend/app/api/knowledge_graph.py`
|
||||||
|
|
||||||
|
### 主要端点
|
||||||
|
|
||||||
|
| 方法 | 路径 | 说明 |
|
||||||
|
|------|------|------|
|
||||||
|
| GET | /api/v1/knowledge-graph/entities | 获取实体列表 |
|
||||||
|
| POST | /api/v1/knowledge-graph/entities | 创建实体 |
|
||||||
|
| GET | /api/v1/knowledge-graph/relations | 获取关系列表 |
|
||||||
|
| POST | /api/v1/knowledge-graph/relations | 创建关系 |
|
||||||
|
| GET | /api/v1/knowledge-graph/search | 语义搜索 |
|
||||||
|
|
||||||
|
## 数据模型
|
||||||
|
|
||||||
|
### Entity (实体)
|
||||||
|
|
||||||
|
```python
|
||||||
|
class Entity:
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
type: str # brand/product/competitor/concept
|
||||||
|
properties: dict
|
||||||
|
```
|
||||||
|
|
||||||
|
### Relation (关系)
|
||||||
|
|
||||||
|
```python
|
||||||
|
class Relation:
|
||||||
|
id: str
|
||||||
|
source_id: str
|
||||||
|
target_id: str
|
||||||
|
relation_type: str # produces/competes_with/belongs_to
|
||||||
|
properties: dict
|
||||||
|
```
|
||||||
|
|
@ -0,0 +1,208 @@
|
||||||
|
# 监控指标定义
|
||||||
|
|
||||||
|
## 概述
|
||||||
|
|
||||||
|
本文档详细描述监控系统的指标定义和LLM成本追踪。
|
||||||
|
|
||||||
|
## Prometheus指标
|
||||||
|
|
||||||
|
位置:`backend/app/monitoring/metrics.py`
|
||||||
|
|
||||||
|
### API层指标
|
||||||
|
|
||||||
|
| 指标名 | 类型 | 标签 | 说明 |
|
||||||
|
|--------|------|------|------|
|
||||||
|
| geo_api_requests_total | Counter | method, endpoint, status | HTTP请求总数 |
|
||||||
|
| geo_api_request_duration_seconds | Histogram | method, endpoint | 请求延迟分布 |
|
||||||
|
| geo_api_requests_in_progress | Gauge | method, endpoint | 当前处理中的请求数 |
|
||||||
|
|
||||||
|
### Agent层指标
|
||||||
|
|
||||||
|
| 指标名 | 类型 | 标签 | 说明 |
|
||||||
|
|--------|------|------|------|
|
||||||
|
| geo_agent_executions_total | Counter | agent_name, status | Agent执行总数 |
|
||||||
|
| geo_agent_execution_duration_seconds | Histogram | agent_name | Agent执行耗时 |
|
||||||
|
| geo_agent_running_tasks | Gauge | agent_name | 当前运行的任务数 |
|
||||||
|
|
||||||
|
### LLM层指标
|
||||||
|
|
||||||
|
| 指标名 | 类型 | 标签 | 说明 |
|
||||||
|
|--------|------|------|------|
|
||||||
|
| geo_llm_requests_total | Counter | provider, model, status | LLM请求总数 |
|
||||||
|
| geo_llm_request_duration_seconds | Histogram | provider, model | LLM请求耗时 |
|
||||||
|
| geo_llm_tokens_total | Counter | provider, model, token_type | Token消耗总量 |
|
||||||
|
| geo_llm_cost_estimated | Gauge | provider, model | 预估成本(USD) |
|
||||||
|
|
||||||
|
### 业务层指标
|
||||||
|
|
||||||
|
| 指标名 | 类型 | 标签 | 说明 |
|
||||||
|
|--------|------|------|------|
|
||||||
|
| geo_brands_total | Gauge | - | 品牌总数 |
|
||||||
|
| geo_queries_total | Counter | platform, status | 查询总数 |
|
||||||
|
| geo_content_generated_total | Counter | - | 生成内容总数 |
|
||||||
|
| geo_citations_detected_total | Counter | platform | 引用检测总数 |
|
||||||
|
|
||||||
|
## LLM成本追踪
|
||||||
|
|
||||||
|
位置:`backend/app/monitoring/llm_metrics.py`
|
||||||
|
|
||||||
|
### 成本估算表
|
||||||
|
|
||||||
|
```python
|
||||||
|
LLM_COST_PER_TOKEN = {
|
||||||
|
# OpenAI
|
||||||
|
("openai", "gpt-4o"): {
|
||||||
|
"prompt": 0.000005, # $5/1M tokens
|
||||||
|
"completion": 0.000015, # $15/1M tokens
|
||||||
|
},
|
||||||
|
("openai", "gpt-4o-mini"): {
|
||||||
|
"prompt": 0.00000015, # $0.15/1M tokens
|
||||||
|
"completion": 0.0000006, # $0.60/1M tokens
|
||||||
|
},
|
||||||
|
("openai", "gpt-4-turbo"): {
|
||||||
|
"prompt": 0.00001, # $10/1M tokens
|
||||||
|
"completion": 0.00003, # $30/1M tokens
|
||||||
|
},
|
||||||
|
# DeepSeek
|
||||||
|
("deepseek", "deepseek-chat"): {
|
||||||
|
"prompt": 0.00000014, # $0.14/1M tokens
|
||||||
|
"completion": 0.00000028, # $0.28/1M tokens
|
||||||
|
},
|
||||||
|
("deepseek", "deepseek-coder"): {
|
||||||
|
"prompt": 0.00000014,
|
||||||
|
"completion": 0.00000028,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 成本计算公式
|
||||||
|
|
||||||
|
```
|
||||||
|
总成本 = prompt_tokens * prompt_price + completion_tokens * completion_price
|
||||||
|
```
|
||||||
|
|
||||||
|
### LLMMetricsWrapper
|
||||||
|
|
||||||
|
```python
|
||||||
|
class LLMMetricsWrapper:
|
||||||
|
def record_request(
|
||||||
|
self,
|
||||||
|
status: str,
|
||||||
|
duration: float,
|
||||||
|
prompt_tokens: int = None,
|
||||||
|
completion_tokens: int = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
记录LLM请求指标
|
||||||
|
1. 记录请求数和耗时
|
||||||
|
2. 记录Token消耗
|
||||||
|
3. 估算并记录成本
|
||||||
|
"""
|
||||||
|
```
|
||||||
|
|
||||||
|
## Histogram buckets
|
||||||
|
|
||||||
|
### API延迟
|
||||||
|
|
||||||
|
```python
|
||||||
|
buckets=(0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0)
|
||||||
|
# 单位:秒
|
||||||
|
# P50/P90/P99 估算
|
||||||
|
```
|
||||||
|
|
||||||
|
### Agent执行耗时
|
||||||
|
|
||||||
|
```python
|
||||||
|
buckets=(0.1, 0.5, 1.0, 5.0, 10.0, 30.0, 60.0, 120.0)
|
||||||
|
# 单位:秒
|
||||||
|
```
|
||||||
|
|
||||||
|
### LLM请求耗时
|
||||||
|
|
||||||
|
```python
|
||||||
|
buckets=(0.1, 0.5, 1.0, 2.0, 5.0, 10.0, 30.0, 60.0)
|
||||||
|
# 单位:秒
|
||||||
|
```
|
||||||
|
|
||||||
|
## 健康检查
|
||||||
|
|
||||||
|
位置:`backend/app/services/health_checker.py`
|
||||||
|
|
||||||
|
### 检查端点
|
||||||
|
|
||||||
|
| 路径 | 说明 | 检查项 |
|
||||||
|
|------|------|--------|
|
||||||
|
| GET /health | 综合健康检查 | 所有检查项 |
|
||||||
|
| GET /health/ready | 就绪检查 | 数据库、Redis |
|
||||||
|
| GET /health/live | 存活检查 | 服务运行状态 |
|
||||||
|
|
||||||
|
### 检查项
|
||||||
|
|
||||||
|
| 检查项 | 说明 | 超时 |
|
||||||
|
|--------|------|------|
|
||||||
|
| database | 数据库连接 | 5s |
|
||||||
|
| redis | Redis连接 | 5s |
|
||||||
|
| disk | 磁盘空间 | - |
|
||||||
|
| memory | 内存使用率 | - |
|
||||||
|
|
||||||
|
## 告警规则
|
||||||
|
|
||||||
|
### 告警条件
|
||||||
|
|
||||||
|
| 规则 | 条件 | 级别 | 说明 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| API响应超时 | p99 > 5s | Warning | 99分位响应时间超过5秒 |
|
||||||
|
| API错误率高 | error_rate > 5% | Error | 错误率超过5% |
|
||||||
|
| 队列积压 | queue_depth > 1000 | Warning | 队列深度超过1000 |
|
||||||
|
| Agent离线 | heartbeat_timeout | Critical | 心跳超时 |
|
||||||
|
|
||||||
|
### 告警级别
|
||||||
|
|
||||||
|
| 级别 | 说明 | 通知方式 |
|
||||||
|
|------|------|----------|
|
||||||
|
| Warning | 警告 | 日志 |
|
||||||
|
| Error | 错误 | 日志+邮件 |
|
||||||
|
| Critical | 严重 | 日志+邮件+短信 |
|
||||||
|
|
||||||
|
## 监控集成
|
||||||
|
|
||||||
|
### Prometheus端点
|
||||||
|
|
||||||
|
```
|
||||||
|
GET /metrics
|
||||||
|
```
|
||||||
|
|
||||||
|
返回Prometheus格式的指标数据。
|
||||||
|
|
||||||
|
### Grafana仪表板
|
||||||
|
|
||||||
|
推荐配置的仪表板:
|
||||||
|
|
||||||
|
1. **API Dashboard**
|
||||||
|
- QPS
|
||||||
|
- 延迟分布 (P50/P90/P99)
|
||||||
|
- 错误率
|
||||||
|
|
||||||
|
2. **Agent Dashboard**
|
||||||
|
- 各Agent执行次数
|
||||||
|
- 执行耗时
|
||||||
|
- 成功率
|
||||||
|
|
||||||
|
3. **LLM Dashboard**
|
||||||
|
- 请求数
|
||||||
|
- Token消耗
|
||||||
|
- 成本趋势
|
||||||
|
|
||||||
|
4. **业务 Dashboard**
|
||||||
|
- 品牌数量
|
||||||
|
- 内容生成量
|
||||||
|
- 引用检测量
|
||||||
|
|
||||||
|
## 指标收集流程
|
||||||
|
|
||||||
|
```
|
||||||
|
1. API请求 → 中间件记录 (method, endpoint, status, duration)
|
||||||
|
2. Agent执行 → AgentHooks记录 (agent_name, status, duration)
|
||||||
|
3. LLM调用 → LLMMetricsWrapper记录 (provider, model, tokens, cost)
|
||||||
|
4. 后台任务 → 定时汇总写入数据库
|
||||||
|
```
|
||||||
|
|
@ -0,0 +1,64 @@
|
||||||
|
# 监控模块
|
||||||
|
|
||||||
|
## 概述
|
||||||
|
|
||||||
|
监控系统用于追踪系统运行状态、收集性能指标、检测异常并发送告警。
|
||||||
|
|
||||||
|
## 监控范围
|
||||||
|
|
||||||
|
### 系统监控
|
||||||
|
|
||||||
|
- CPU、内存、磁盘使用率
|
||||||
|
- 网络流量
|
||||||
|
- 进程状态
|
||||||
|
|
||||||
|
### 应用监控
|
||||||
|
|
||||||
|
- API响应时间
|
||||||
|
- 请求成功率
|
||||||
|
- 错误率
|
||||||
|
- 并发连接数
|
||||||
|
|
||||||
|
### 业务监控
|
||||||
|
|
||||||
|
- Agent运行状态
|
||||||
|
- 任务队列深度
|
||||||
|
- 知识库检索延迟
|
||||||
|
|
||||||
|
## 监控指标
|
||||||
|
|
||||||
|
位置:`backend/app/monitoring/`
|
||||||
|
|
||||||
|
| 指标 | 类型 | 说明 |
|
||||||
|
|------|------|------|
|
||||||
|
| http_requests_total | Counter | HTTP请求总数 |
|
||||||
|
| http_request_duration_seconds | Histogram | 请求延迟分布 |
|
||||||
|
| agent_tasks_total | Counter | Agent任务总数 |
|
||||||
|
| agent_task_duration_seconds | Histogram | 任务执行时间 |
|
||||||
|
| queue_depth | Gauge | 队列深度 |
|
||||||
|
|
||||||
|
## 健康检查
|
||||||
|
|
||||||
|
### 端点
|
||||||
|
|
||||||
|
| 路径 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| GET /health | 服务健康检查 |
|
||||||
|
| GET /health/ready | 就绪检查 |
|
||||||
|
| GET /health/live | 存活检查 |
|
||||||
|
|
||||||
|
### 检查项
|
||||||
|
|
||||||
|
- 数据库连接
|
||||||
|
- Redis连接
|
||||||
|
- 磁盘空间
|
||||||
|
- 内存使用率
|
||||||
|
|
||||||
|
## 告警规则
|
||||||
|
|
||||||
|
| 规则 | 条件 | 级别 |
|
||||||
|
|------|------|------|
|
||||||
|
| API响应超时 | p99 > 5s | Warning |
|
||||||
|
| API错误率高 | error_rate > 5% | Error |
|
||||||
|
| 队列积压 | queue_depth > 1000 | Warning |
|
||||||
|
| Agent离线 | heartbeat_timeout | Critical |
|
||||||
|
|
@ -0,0 +1,207 @@
|
||||||
|
# 内容生成Pipeline编排逻辑
|
||||||
|
|
||||||
|
## 概述
|
||||||
|
|
||||||
|
本文档描述内容生成Pipeline的完整编排逻辑。
|
||||||
|
|
||||||
|
## Pipeline架构
|
||||||
|
|
||||||
|
```
|
||||||
|
用户请求 → PipelineEngine.execute()
|
||||||
|
│
|
||||||
|
├─ 1. 加载Pipeline定义 (YAML)
|
||||||
|
│
|
||||||
|
├─ 2. 构建执行上下文
|
||||||
|
│
|
||||||
|
├─ 3. 拓扑排序确定执行顺序
|
||||||
|
│
|
||||||
|
└─ 4. 逐阶段执行
|
||||||
|
│
|
||||||
|
├─ Stage 1 (无依赖)
|
||||||
|
├─ Stage 2 (依赖Stage 1)
|
||||||
|
├─ Stage 3 (依赖Stage 1, 2)
|
||||||
|
└─ ...
|
||||||
|
```
|
||||||
|
|
||||||
|
## Pipeline定义结构
|
||||||
|
|
||||||
|
Pipeline通过YAML文件定义,位于 `backend/pipelines/` 目录:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
name: content_production
|
||||||
|
version: "1.0"
|
||||||
|
description: 内容生产Pipeline
|
||||||
|
|
||||||
|
variables:
|
||||||
|
brand_name: ""
|
||||||
|
keywords: ""
|
||||||
|
platform: "zhihu"
|
||||||
|
|
||||||
|
stages:
|
||||||
|
- name: topic_selection
|
||||||
|
agent: content_generator
|
||||||
|
action: select_topic
|
||||||
|
inputs:
|
||||||
|
brand: "${brand_name}"
|
||||||
|
keywords: "${keywords}"
|
||||||
|
outputs: [selected_topic, topic_score]
|
||||||
|
|
||||||
|
- name: content_generation
|
||||||
|
agent: content_generator
|
||||||
|
action: generate
|
||||||
|
depends_on: [topic_selection]
|
||||||
|
inputs:
|
||||||
|
topic: "${stages.topic_selection.outputs.selected_topic}"
|
||||||
|
brand: "${brand_name}"
|
||||||
|
outputs: [content]
|
||||||
|
|
||||||
|
- name: deai_processing
|
||||||
|
agent: deai_agent
|
||||||
|
action: humanize
|
||||||
|
depends_on: [content_generation]
|
||||||
|
inputs:
|
||||||
|
content: "${stages.content_generation.outputs.content}"
|
||||||
|
outputs: [natural_content]
|
||||||
|
|
||||||
|
- name: seo_optimization
|
||||||
|
agent: geo_optimizer
|
||||||
|
action: optimize
|
||||||
|
depends_on: [deai_processing]
|
||||||
|
inputs:
|
||||||
|
content: "${stages.deai_processing.outputs.natural_content}"
|
||||||
|
keywords: "${keywords}"
|
||||||
|
outputs: [optimized_content]
|
||||||
|
```
|
||||||
|
|
||||||
|
## 核心组件
|
||||||
|
|
||||||
|
### PipelineEngine
|
||||||
|
|
||||||
|
位置:`backend/app/agent_framework/pipeline/engine.py`
|
||||||
|
|
||||||
|
| 方法 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| execute(pipeline, context) | 执行完整Pipeline |
|
||||||
|
| _execute_stage(stage, exec_context, stages_context) | 执行单个阶段 |
|
||||||
|
| _dispatch_and_wait(stage, inputs) | 分发任务并等待结果 |
|
||||||
|
| _topological_sort(stages) | 拓扑排序 |
|
||||||
|
| _resolve_variables(template, context) | 解析变量引用 |
|
||||||
|
|
||||||
|
### PipelineLoader
|
||||||
|
|
||||||
|
位置:`backend/app/agent_framework/pipeline/loader.py`
|
||||||
|
|
||||||
|
| 方法 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| load(pipeline_name) | 从YAML加载Pipeline |
|
||||||
|
| load_from_yaml(yaml_content) | 从字符串加载 |
|
||||||
|
| validate_dag(stages) | 验证DAG无环 |
|
||||||
|
| resolve_variables(template, context) | 解析${...}变量 |
|
||||||
|
|
||||||
|
### PipelineSchema
|
||||||
|
|
||||||
|
位置:`backend/app/agent_framework/pipeline/schema.py`
|
||||||
|
|
||||||
|
| 类 | 说明 |
|
||||||
|
|----|------|
|
||||||
|
| Pipeline | Pipeline定义 |
|
||||||
|
| PipelineStage | 单个阶段定义 |
|
||||||
|
| StageResult | 阶段执行结果 |
|
||||||
|
| PipelineResult | Pipeline执行结果 |
|
||||||
|
| StageStatus | 阶段状态枚举 |
|
||||||
|
|
||||||
|
## 变量解析
|
||||||
|
|
||||||
|
支持 `${var.path}` 格式的变量引用:
|
||||||
|
|
||||||
|
| 格式 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| `${brand_name}` | 全局变量 |
|
||||||
|
| `${stages.step1.outputs.result}` | 上游阶段输出 |
|
||||||
|
| `${stages.step1.outputs.result.substring(0,10)}` | 字符串截取 |
|
||||||
|
|
||||||
|
## 条件执行
|
||||||
|
|
||||||
|
Stage支持 `condition` 字段进行条件执行:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
- name: optional_stage
|
||||||
|
agent: some_agent
|
||||||
|
action: do_something
|
||||||
|
condition: "${enable_optional} == true"
|
||||||
|
```
|
||||||
|
|
||||||
|
支持的比较操作:
|
||||||
|
- `${var}` - 变量存在且非空
|
||||||
|
- `${var} == 'value'` - 等于
|
||||||
|
- `${var} != 'value'` - 不等于
|
||||||
|
|
||||||
|
## 重试机制
|
||||||
|
|
||||||
|
每个Stage支持 `retry_count` 配置:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
- name: unstable_stage
|
||||||
|
agent: some_agent
|
||||||
|
retry_count: 3 # 失败后重试3次
|
||||||
|
```
|
||||||
|
|
||||||
|
## 错误处理
|
||||||
|
|
||||||
|
| 配置 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| continue_on_failure: false | 阶段失败则Pipeline失败(默认) |
|
||||||
|
| continue_on_failure: true | 阶段失败仍继续下游 |
|
||||||
|
|
||||||
|
## 内容生成服务层Pipeline
|
||||||
|
|
||||||
|
除了Agent Pipeline,还有服务层的内容处理Pipeline:
|
||||||
|
|
||||||
|
位置:`backend/app/services/content/content_pipeline.py`
|
||||||
|
|
||||||
|
```
|
||||||
|
用户内容 → RuleValidator (规则校验)
|
||||||
|
├─ 高严重问题 → 中断Pipeline
|
||||||
|
└─ 通过 → SensitiveFilter (敏感词过滤)
|
||||||
|
└─ SEOOptimizer (SEO优化)
|
||||||
|
└─ HTMLGenerator (HTML生成)
|
||||||
|
└─ 输出 (html/markdown/plain)
|
||||||
|
```
|
||||||
|
|
||||||
|
### ContentPipeline类
|
||||||
|
|
||||||
|
```python
|
||||||
|
class ContentPipeline:
|
||||||
|
async def run(self, request: dict) -> PipelineResponse:
|
||||||
|
"""
|
||||||
|
request = {
|
||||||
|
"content": "原始内容",
|
||||||
|
"title": "标题",
|
||||||
|
"platform": "目标平台",
|
||||||
|
"optimize_for": ["validation", "sensitive", "seo"],
|
||||||
|
"output_formats": ["html", "markdown", "plain"]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
```
|
||||||
|
|
||||||
|
### 阶段说明
|
||||||
|
|
||||||
|
| 阶段 | 组件 | 说明 |
|
||||||
|
|------|------|------|
|
||||||
|
| validation | RuleValidator | 校验标题长度、内容长度、AI模式检测 |
|
||||||
|
| sensitive_filter | SensitiveFilter | 敏感词过滤替换 |
|
||||||
|
| seo_optimization | SEOOptimizer | 关键词密度调整、位置检查 |
|
||||||
|
| html_generation | HTMLGenerator | 生成HTML/ Markdown/纯文本 |
|
||||||
|
|
||||||
|
## Dry-Run模式
|
||||||
|
|
||||||
|
当PipelineEngine初始化时未传入dispatcher,进入dry-run模式:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Dry-run模式(测试/开发)
|
||||||
|
engine = PipelineEngine(dispatcher=None)
|
||||||
|
result = await engine.execute(pipeline, context)
|
||||||
|
# 返回模拟结果用于测试Pipeline定义
|
||||||
|
```
|
||||||
|
|
||||||
|
生产环境若触发dry-run会记录ERROR级别日志。
|
||||||
|
|
@ -0,0 +1,44 @@
|
||||||
|
# 平台规则中心
|
||||||
|
|
||||||
|
## 概述
|
||||||
|
|
||||||
|
平台规则中心管理各AI平台的内容规则和发布要求,确保生成的内容符合各平台标准。
|
||||||
|
|
||||||
|
## 功能特性
|
||||||
|
|
||||||
|
### 规则管理
|
||||||
|
|
||||||
|
- 规则分类(关键词、长度、格式、敏感词)
|
||||||
|
- 规则版本控制
|
||||||
|
- 规则优先级
|
||||||
|
|
||||||
|
### 规则检查
|
||||||
|
|
||||||
|
- 实时检查
|
||||||
|
- 批量检查
|
||||||
|
- 检查结果反馈
|
||||||
|
|
||||||
|
## API接口
|
||||||
|
|
||||||
|
平台规则API位于 `backend/app/api/platform_rules.py`
|
||||||
|
|
||||||
|
### 主要端点
|
||||||
|
|
||||||
|
| 方法 | 路径 | 说明 |
|
||||||
|
|------|------|------|
|
||||||
|
| GET | /api/v1/platform-rules | 获取规则列表 |
|
||||||
|
| POST | /api/v1/platform-rules | 创建规则 |
|
||||||
|
| PUT | /api/v1/platform-rules/{id} | 更新规则 |
|
||||||
|
| DELETE | /api/v1/platform-rules/{id} | 删除规则 |
|
||||||
|
| POST | /api/v1/platform-rules/check | 检查内容 |
|
||||||
|
|
||||||
|
### 检查维度
|
||||||
|
|
||||||
|
| 维度 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| keyword_coverage | 关键词覆盖度 |
|
||||||
|
| readability | 可读性评分 |
|
||||||
|
| tone_consistency | 语气一致性 |
|
||||||
|
| length_compliance | 长度合规 |
|
||||||
|
| fact_accuracy | 事实准确性 |
|
||||||
|
| geo_optimization | GEO优化度 |
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue