feat: P0-P2功能实现 + GEO workflow分析与规划

P0 紧急修复:
- 实现诊断分析页面(SEO+GEO 5+6维度)
- 修复E2E测试: dashboard标题统一为'品牌健康中心'
- 修复建议模块API路径不一致
- 修复告警模块HTTP方法不匹配(POST→PATCH)

P1 功能实现:
- 实现监测优化页面(告警列表+配置)
- 实现组织管理页面(成员/角色/邀请)
- 实现SEO诊断5维度后端服务(63测试)
- 实现GEO诊断6维度后端服务(40测试)
- 实现诊断API端点(TDD, 6测试)
- 前端诊断页面连接真实API

P2 功能实现:
- 实现告警设置API端点(TDD, 11测试)
- 实现套餐额度预警服务(TDD, 37测试)
- 实现邮件通知服务(TDD, 30测试)

GEO Workflow分析:
- 10步闭环流程设计(替代原7步)
- 7个缺失节点技术方案
- 4阶段12周开发计划
- 完整技术架构设计
This commit is contained in:
chiguyong 2026-05-25 09:45:18 +08:00
parent cbedb09383
commit 65e2f3c380
179 changed files with 23476 additions and 14877 deletions

16
.codegraph/.gitignore vendored Normal file
View File

@ -0,0 +1,16 @@
# CodeGraph data files
# These are local to each machine and should not be committed
# Database
*.db
*.db-wal
*.db-shm
# Cache
cache/
# Logs
*.log
# Hook markers
.dirty

View File

@ -36,6 +36,12 @@ ZHIPU_API_KEY=
# 通义千问 (可选) # 通义千问 (可选)
TONGYI_API_KEY= TONGYI_API_KEY=
# ============================================================
# 阿里云百炼(图片生成)
# ============================================================
# 万相-文生图V1 API Key
ALIYUN_DASHSCOPE_API_KEY=
# ============================================================ # ============================================================
# LLM Provider 配置 # LLM Provider 配置
# ============================================================ # ============================================================

6
backend/.env.test Normal file
View File

@ -0,0 +1,6 @@
DATABASE_URL=sqlite+aiosqlite:///./test.db
REDIS_URL=redis://localhost:6379
ENVIRONMENT=testing
LOG_LEVEL=info
SECRET_KEY=test-secret-key-for-testing-only
CORS_ORIGINS=http://localhost:3000

View File

@ -0,0 +1,143 @@
"""Add knowledge graph tables
Revision ID: f7a8b9c0de56
Revises: e5f7g9h1cd45
Create Date: 2026-05-24 12:00:00.000000
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
import enum
# revision identifiers, used by Alembic.
revision: str = "f7a8b9c0de56"
down_revision: Union[str, None] = "e5f7g9h1cd45"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ------------------------------------------------------------------ #
# 1. 创建实体类型枚举
# ------------------------------------------------------------------ #
entity_type_enum = enum.Enum(
'EntityType',
[
'ORGANIZATION', 'PRODUCT', 'PERSON', 'LOCATION',
'TECHNOLOGY', 'BRAND', 'EVENT', 'CONCEPT', 'OTHER'
]
)
# ------------------------------------------------------------------ #
# 2. 创建关系类型枚举
# ------------------------------------------------------------------ #
relation_type_enum = enum.Enum(
'RelationType',
[
'COMPETES_WITH', 'PARTNERS_WITH', 'ACQUIRES', 'SUBSIDIARY_OF',
'PRODUCES', 'USES_TECHNOLOGY', 'PART_OF',
'LOCATED_IN', 'FOUNDED_IN',
'CEO_OF', 'FOUNDER_OF',
'RELATED_TO', 'MENTIONED_IN', 'ALSO_KNOWN_AS'
]
)
# ------------------------------------------------------------------ #
# 3. knowledge_entities 表
# ------------------------------------------------------------------ #
op.create_table(
"knowledge_entities",
sa.Column(
"id",
postgresql.UUID(as_uuid=True),
primary_key=True,
nullable=False,
),
sa.Column(
"knowledge_base_id",
postgresql.UUID(as_uuid=True),
sa.ForeignKey("knowledge_bases.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("name", sa.String(500), nullable=False),
sa.Column("entity_type", sa.Enum(entity_type_enum, name="entitytype"), nullable=False),
sa.Column("description", sa.Text, nullable=True),
sa.Column("properties", postgresql.JSONB(astext_type=sa.Text()), nullable=True, server_default="{}"),
sa.Column(
"source_chunk_id",
postgresql.UUID(as_uuid=True),
sa.ForeignKey("knowledge_chunks.id", ondelete="SET NULL"),
nullable=True,
),
sa.Column("confidence", sa.String(20), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
)
op.create_index("ix_entities_name", "knowledge_entities", ["name"])
op.create_index("ix_entities_kb_name", "knowledge_entities", ["knowledge_base_id", "name"])
op.create_index("ix_entities_kb_type", "knowledge_entities", ["knowledge_base_id", "entity_type"])
# ------------------------------------------------------------------ #
# 4. knowledge_relations 表
# ------------------------------------------------------------------ #
op.create_table(
"knowledge_relations",
sa.Column(
"id",
postgresql.UUID(as_uuid=True),
primary_key=True,
nullable=False,
),
sa.Column(
"source_entity_id",
postgresql.UUID(as_uuid=True),
sa.ForeignKey("knowledge_entities.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column(
"target_entity_id",
postgresql.UUID(as_uuid=True),
sa.ForeignKey("knowledge_entities.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("relation_type", sa.Enum(relation_type_enum, name="relationtype"), nullable=False),
sa.Column("properties", postgresql.JSONB(astext_type=sa.Text()), nullable=True, server_default="{}"),
sa.Column(
"source_chunk_id",
postgresql.UUID(as_uuid=True),
sa.ForeignKey("knowledge_chunks.id", ondelete="SET NULL"),
nullable=True,
),
sa.Column("confidence", sa.String(20), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
)
op.create_index("ix_relations_source", "knowledge_relations", ["source_entity_id"])
op.create_index("ix_relations_target", "knowledge_relations", ["target_entity_id"])
op.create_index("ix_relations_type", "knowledge_relations", ["relation_type"])
def downgrade() -> None:
# 删除表(注意外键约束会自动处理)
op.drop_table("knowledge_relations")
op.drop_table("knowledge_entities")
# 删除枚举类型
op.execute("DROP TYPE IF EXISTS relationtype")
op.execute("DROP TYPE IF EXISTS entitytype")

View File

@ -1,4 +1,5 @@
"""Alerts API endpoints - 告警通知接口""" """Alerts API endpoints - 告警通知接口"""
import logging
import uuid import uuid
from fastapi import APIRouter, Depends, HTTPException, Query, status from fastapi import APIRouter, Depends, HTTPException, Query, status
@ -9,6 +10,7 @@ from app.api.deps import get_current_user
from app.database import get_db from app.database import get_db
from app.models.alert import Alert from app.models.alert import Alert
from app.models.alert_setting import AlertSetting from app.models.alert_setting import AlertSetting
from app.models.brand import Brand
from app.models.user import User from app.models.user import User
from app.schemas.alert import ( from app.schemas.alert import (
AlertResponse, AlertResponse,
@ -24,9 +26,35 @@ from app.schemas.alert import (
) )
from app.services.alert_engine import AlertEngine from app.services.alert_engine import AlertEngine
logger = logging.getLogger(__name__)
router = APIRouter() router = APIRouter()
async def verify_brand_ownership(
brand_id: uuid.UUID,
current_user: User,
db: AsyncSession,
) -> Brand:
"""验证品牌属于当前用户"""
stmt = select(Brand).where(
and_(
Brand.id == brand_id,
Brand.user_id == current_user.id,
)
)
result = await db.execute(stmt)
brand = result.scalar_one_or_none()
if not brand:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"品牌 {brand_id} 不存在或不属于当前用户",
)
return brand
# ============================================================ # ============================================================
# 告警接口 # 告警接口
# ============================================================ # ============================================================
@ -207,21 +235,7 @@ async def update_alert_settings(
for item in data.settings: for item in data.settings:
# 验证品牌属于当前用户 # 验证品牌属于当前用户
from app.models.brand import Brand await verify_brand_ownership(item.brand_id, current_user, db)
brand_stmt = select(Brand).where(
and_(
Brand.id == item.brand_id,
Brand.user_id == current_user.id,
)
)
brand_result = await db.execute(brand_stmt)
brand = brand_result.scalar_one_or_none()
if not brand:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"品牌 {item.brand_id} 不存在或不属于当前用户",
)
# 查找现有设置 # 查找现有设置
existing_stmt = select(AlertSetting).where( existing_stmt = select(AlertSetting).where(
@ -257,6 +271,7 @@ async def update_alert_settings(
for setting in updated_settings: for setting in updated_settings:
await db.refresh(setting) await db.refresh(setting)
logger.info(f"批量更新告警设置: user={current_user.id}, count={len(updated_settings)}")
return {"items": updated_settings, "total": len(updated_settings)} return {"items": updated_settings, "total": len(updated_settings)}
@ -292,3 +307,74 @@ async def update_single_setting(
await db.refresh(setting) await db.refresh(setting)
return setting return setting
@router.post("/settings", response_model=AlertSettingResponse, status_code=status.HTTP_201_CREATED)
async def create_alert_setting(
data: AlertSettingCreate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""创建告警设置"""
# 验证品牌属于当前用户
await verify_brand_ownership(data.brand_id, current_user, db)
# 检查是否已存在相同品牌+类型的设置
existing_stmt = select(AlertSetting).where(
and_(
AlertSetting.brand_id == data.brand_id,
AlertSetting.alert_type == data.alert_type,
)
)
existing_result = await db.execute(existing_stmt)
existing = existing_result.scalar_one_or_none()
if existing:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"品牌 {data.brand_id} 的告警类型 {data.alert_type} 已存在",
)
# 创建新设置
setting = AlertSetting(
brand_id=data.brand_id,
user_id=current_user.id,
alert_type=data.alert_type,
enabled=data.enabled,
threshold=data.threshold,
)
db.add(setting)
await db.commit()
await db.refresh(setting)
logger.info(f"创建告警设置: user={current_user.id}, brand={data.brand_id}, type={data.alert_type}")
return setting
@router.delete("/settings/{setting_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_alert_setting(
setting_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""删除告警设置"""
stmt = select(AlertSetting).where(
and_(
AlertSetting.id == setting_id,
AlertSetting.user_id == current_user.id,
)
)
result = await db.execute(stmt)
setting = result.scalar_one_or_none()
if not setting:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="告警设置不存在",
)
await db.delete(setting)
await db.commit()
logger.info(f"删除告警设置: user={current_user.id}, setting={setting_id}")
return None

View File

@ -237,4 +237,165 @@ async def generate_topics(
return {"status": "success", "topics": topics} return {"status": "success", "topics": topics}
except LLMError as e: except LLMError as e:
raise HTTPException(status_code=502, detail=str(e)) raise HTTPException(status_code=502, detail=str(e))
# ==================== 母题库接口 ====================
class TopicGenerateRequest(BaseModel):
"""母题生成请求"""
params: dict # 母题模板参数
platform: str = "通用"
style: str = "专业严谨"
@router.get("/topics")
async def list_topics():
"""获取所有母题库列表"""
from app.services.content.topic_templates import list_topic_templates
templates = list_topic_templates()
return [
{
"id": t.id,
"name": t.name,
"description": t.description,
"icon": t.icon,
"recommended_platforms": t.recommended_platforms,
"word_count_range": list(t.word_count_range),
"required_params": t.required_params,
"optional_params": t.optional_params,
}
for t in templates
]
@router.get("/topics/{topic_id}")
async def get_topic(topic_id: str):
"""获取母题详情"""
from app.services.content.topic_templates import get_topic_template
template = get_topic_template(topic_id)
if not template:
raise HTTPException(status_code=404, detail="Topic not found")
return {
"id": template.id,
"name": template.name,
"description": template.description,
"icon": template.icon,
"prompt_template": template.prompt_template,
"seo_tips": template.seo_tips,
"recommended_platforms": template.recommended_platforms,
"word_count_range": list(template.word_count_range),
"required_params": template.required_params,
"optional_params": template.optional_params,
}
@router.post("/topics/{topic_id}/generate")
async def generate_with_topic(
topic_id: str,
request: TopicGenerateRequest,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""使用母题生成内容"""
from app.services.content.topic_templates import get_topic_template, render_topic_prompt
from app.services.llm import LLMError, LLMFactory
from app.agent_framework.prompts import DEAI_TEMPLATE, GEO_OPTIMIZER_TEMPLATE
template = get_topic_template(topic_id)
if not template:
raise HTTPException(status_code=404, detail="Topic not found")
# 验证必填参数
for param in template.required_params:
if param not in request.params:
raise HTTPException(
status_code=400,
detail=f"Missing required parameter: {param}"
)
org_id = getattr(current_user, "organization_id", None)
if not org_id:
raise HTTPException(status_code=403, detail="用户未关联组织")
try:
provider = LLMFactory.get_default()
# 渲染Prompt
prompt = render_topic_prompt(topic_id, request.params)
# 调用内容生成
response = await provider.chat(
[{"role": "user", "content": prompt}],
temperature=0.7,
max_tokens=4000
)
content = response.content
# 去AI化处理
deai_variables = {
"original_content": content,
"target_style": "自然流畅",
"preserve_structure": "",
}
messages = DEAI_TEMPLATE.render(deai_variables)
response = await provider.chat(messages, temperature=0.9, max_tokens=len(content) * 2)
content = response.content
# GEO优化
geo_variables = {
"original_content": content,
"target_keywords": request.params.get("keywords", ""),
"target_platform": request.platform,
"optimization_level": "moderate",
}
messages = GEO_OPTIMIZER_TEMPLATE.render(geo_variables)
response = await provider.chat(messages, temperature=0.5, max_tokens=len(content) * 2)
optimized = response.content
# 存入数据库
content_obj = Content(
organization_id=org_id,
title=request.params.get("product_name") or request.params.get("topic") or topic_id,
content_type="article",
body=optimized,
status="draft",
target_platforms=[request.platform],
keywords=[request.params.get("keywords", "")],
extra_metadata={
"original_content": content,
"topic_id": topic_id,
"topic_name": template.name,
"brand_name": request.params.get("brand_name", ""),
"content_style": request.style,
},
created_by=current_user.id,
current_version=1,
)
db.add(content_obj)
await db.flush()
version = ContentVersion(
content_id=content_obj.id,
version_number=1,
title=content_obj.title,
body=optimized,
change_summary="母题库自动生成",
created_by=current_user.id,
)
db.add(version)
await db.commit()
await db.refresh(content_obj)
return {
"topic_id": topic_id,
"content": content,
"optimized_content": optimized,
"content_id": str(content_obj.id),
"seo_tips": template.seo_tips,
}
except LLMError as e:
raise HTTPException(status_code=502, detail=f"LLM调用失败: {str(e)}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"内容生成异常: {str(e)}")

View File

@ -0,0 +1,150 @@
"""诊断API端点 - 提供SEO和GEO诊断功能"""
import logging
import uuid
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.database import get_db
from app.models.user import User
from app.models.brand import Brand
from app.services.seo_diagnosis import SEODiagnosisService
from app.services.geo_diagnosis import GEODiagnosisService, GEODiagnosisInput
logger = logging.getLogger(__name__)
router = APIRouter()
@router.get("/seo/{brand_id}")
async def get_seo_diagnosis(
brand_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""
获取品牌的SEO诊断结果
返回5维度SEO诊断:
- 技术SEO (25)
- 页面SEO (20)
- 内容质量 (20)
- 外链分析 (15)
- 用户体验 (20)
"""
brand = await _get_brand_or_404(brand_id, current_user, db)
try:
service = SEODiagnosisService()
result = service.diagnose()
logger.info(f"SEO诊断完成: brand_id={brand_id}, brand={brand.name}, score={result.overall_score}")
return result.to_dict()
except Exception as e:
logger.error(f"SEO诊断失败: brand_id={brand_id}, error={str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="SEO诊断服务异常,请稍后重试",
)
@router.get("/geo/{brand_id}")
async def get_geo_diagnosis(
brand_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""
获取品牌的GEO诊断结果
返回6维度GEO诊断:
- 内容可提取性 (20)
- 实体清晰度 (15)
- E-E-A-T信号 (20)
- Schema标记 (15)
- 主题权威 (15)
- 引用就绪度 (15)
"""
brand = await _get_brand_or_404(brand_id, current_user, db)
try:
input_data = GEODiagnosisInput()
service = GEODiagnosisService()
result = service.diagnose(input_data)
logger.info(f"GEO诊断完成: brand_id={brand_id}, brand={brand.name}, score={result.overall_score}")
return result.to_dict()
except Exception as e:
logger.error(f"GEO诊断失败: brand_id={brand_id}, error={str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="GEO诊断服务异常,请稍后重试",
)
@router.get("/combined/{brand_id}")
async def get_combined_diagnosis(
brand_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""
获取品牌的综合诊断结果
结合SEO和GEO诊断,返回综合评分和详细诊断结果
"""
brand = await _get_brand_or_404(brand_id, current_user, db)
try:
seo_service = SEODiagnosisService()
seo_result = seo_service.diagnose()
geo_service = GEODiagnosisService()
geo_result = geo_service.diagnose(GEODiagnosisInput())
combined_score = round((seo_result.overall_score + geo_result.overall_score) / 2, 2)
logger.info(
f"综合诊断完成: brand_id={brand_id}, brand={brand.name}, "
f"seo_score={seo_result.overall_score}, "
f"geo_score={geo_result.overall_score}, "
f"combined_score={combined_score}"
)
return {
"seo_score": seo_result.overall_score,
"geo_score": geo_result.overall_score,
"combined_score": combined_score,
"seo_diagnosis": seo_result.to_dict(),
"geo_diagnosis": geo_result.to_dict(),
}
except Exception as e:
logger.error(f"综合诊断失败: brand_id={brand_id}, error={str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="综合诊断服务异常,请稍后重试",
)
async def _get_brand_or_404(
brand_id: uuid.UUID,
current_user: User,
db: AsyncSession,
) -> Brand:
"""获取品牌或抛出404异常"""
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == current_user.id)
result = await db.execute(stmt)
brand = result.scalar_one_or_none()
if not brand:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="品牌不存在",
)
return brand

182
backend/app/api/image.py Normal file
View File

@ -0,0 +1,182 @@
"""图片生成 API 路由"""
from typing import Optional
from fastapi import APIRouter, HTTPException, status
from pydantic import BaseModel, Field
from app.services.image_generator import (
IMAGE_STYLES,
LAYOUT_OPTIONS,
ImageGenerator,
ImageGenerationError,
PLATFORM_IMAGE_SPECS,
)
router = APIRouter(prefix="/image", tags=["图片生成"])
class GenerateCoverRequest(BaseModel):
"""生成封面图请求"""
title: str = Field(..., description="文章标题")
platform: str = Field(..., description="目标平台")
image_type: str = Field(default="cover", description="图片类型: cover/inline")
style: str = Field(default="modern", description="风格选项")
layout: str = Field(default="centered", description="排版选项")
custom_prompt: Optional[str] = Field(default=None, description="自定义提示词")
class ImageResultResponse(BaseModel):
"""图片生成结果响应"""
url: str
width: int
height: int
prompt: str
platform: str
task_id: str
class ImageSpecsResponse(BaseModel):
"""平台图片规格响应"""
platform: str
specs: dict
class StyleOption(BaseModel):
"""风格选项"""
value: str
name: str
class LayoutOption(BaseModel):
"""排版选项"""
value: str
name: str
class ImageConfigResponse(BaseModel):
"""图片生成配置响应"""
platforms: list[str]
styles: list[StyleOption]
layouts: list[LayoutOption]
@router.post("/generate-cover", response_model=ImageResultResponse)
async def generate_cover(request: GenerateCoverRequest):
"""生成封面图
基于阿里云百炼万相-文生图V1生成封面图自动适配目标平台的尺寸规格
Args:
request: 生成请求参数
Returns:
ImageResultResponse: 包含图片URL和元数据
Raises:
HTTPException: 当平台不支持或API调用失败时
"""
# 验证平台是否支持
supported_platforms = ImageGenerator.get_supported_platforms()
if request.platform not in supported_platforms:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"不支持的平台: {request.platform},支持的平台: {', '.join(supported_platforms)}",
)
# 验证风格选项
if request.style not in IMAGE_STYLES:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"不支持的风格: {request.style},支持的风格: {', '.join(IMAGE_STYLES.keys())}",
)
# 验证排版选项
if request.layout not in LAYOUT_OPTIONS:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"不支持的排版: {request.layout},支持的排版: {', '.join(LAYOUT_OPTIONS.keys())}",
)
# 验证图片类型
if request.image_type not in ("cover", "inline"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="图片类型只支持: cover, inline",
)
generator = ImageGenerator()
try:
result = await generator.generate_cover(
title=request.title,
platform=request.platform,
image_type=request.image_type,
style=request.style,
layout=request.layout,
custom_prompt=request.custom_prompt,
)
return ImageResultResponse(
url=result.url,
width=result.width,
height=result.height,
prompt=result.prompt,
platform=result.platform,
task_id=result.task_id,
)
except ImageGenerationError as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"图片生成失败: {str(e)}",
)
@router.get("/platforms", response_model=list[str])
async def get_supported_platforms():
"""获取支持的平台列表"""
return ImageGenerator.get_supported_platforms()
@router.get("/platforms/{platform}/specs", response_model=ImageSpecsResponse)
async def get_platform_specs(platform: str):
"""获取指定平台的图片规格
Args:
platform: 平台标识
Returns:
ImageSpecsResponse: 平台图片规格
Raises:
HTTPException: 当平台不支持时
"""
specs = ImageGenerator.get_platform_specs(platform)
if specs is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"不支持的平台: {platform}",
)
return ImageSpecsResponse(platform=platform, specs=specs)
@router.get("/config", response_model=ImageConfigResponse)
async def get_image_config():
"""获取图片生成配置(风格、排版选项等)"""
styles = [
StyleOption(value=key, name=value["name"])
for key, value in IMAGE_STYLES.items()
]
layouts = [
LayoutOption(value=key, name=value["name"])
for key, value in LAYOUT_OPTIONS.items()
]
return ImageConfigResponse(
platforms=ImageGenerator.get_supported_platforms(),
styles=styles,
layouts=layouts,
)

View File

@ -29,10 +29,15 @@ from app.schemas.knowledge import (
KnowledgeBaseCreate, KnowledgeBaseCreate,
KnowledgeBaseResponse, KnowledgeBaseResponse,
KnowledgeSearchRequest, KnowledgeSearchRequest,
RetrieveRequest,
SearchResponse, SearchResponse,
SearchResultItem, SearchResultItem,
UpdateDocumentRequest,
) )
from app.services.knowledge import MockEmbedder, RAGService from app.services.knowledge import MockEmbedder, RAGService
from app.services.knowledge.enhanced_rag import EnhancedRAG
from app.services.knowledge.incremental_index import IncrementalIndexService
from app.services.knowledge.chunker import ChunkerFactory
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
router = APIRouter() router = APIRouter()
@ -499,3 +504,151 @@ async def knowledge_search(
total=len(items), total=len(items),
latency_ms=latency_ms, latency_ms=latency_ms,
) )
@router.post("/bases/{kb_id}/chunks/preview")
async def preview_chunks(
kb_id: uuid.UUID,
text: str,
strategy: str = "recursive",
chunk_size: int = 500,
):
"""预览分块效果"""
chunker = ChunkerFactory.create(strategy)
# 临时修改chunk_size
if strategy == "recursive":
chunker.STRATEGY.chunk_size = chunk_size
elif strategy == "semantic":
chunker.STRATEGY.chunk_size = chunk_size * 1.5 # 语义块可以更大
elif strategy == "fixed":
chunker.STRATEGY.chunk_size = chunk_size
preview = chunker.preview(text, max_chunks=10)
return {
"strategy": strategy,
"chunk_count": len(preview),
"preview": preview,
"strategies": [
{
"name": s.name,
"description": s.description,
"recommended_size": s.chunk_size,
}
for s in ChunkerFactory.list_strategies()
],
}
# ---------------------------------------------------------------------------
# 增量索引 API
# ---------------------------------------------------------------------------
@router.post("/bases/{kb_id}/documents/{doc_id}/reindex")
async def reindex_document(
kb_id: uuid.UUID,
doc_id: uuid.UUID,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""重新索引单个文档"""
org_id = current_user.organization_id
if not org_id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found")
await _get_kb(db, kb_id, org_id)
index_service = IncrementalIndexService(_rag_service)
result = await index_service.add_document(
db, str(kb_id), str(doc_id)
)
return result
@router.post("/bases/{kb_id}/documents/{doc_id}/update")
async def update_document_content(
kb_id: uuid.UUID,
doc_id: uuid.UUID,
request: UpdateDocumentRequest,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""更新文档内容(增量)"""
org_id = current_user.organization_id
if not org_id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found")
await _get_kb(db, kb_id, org_id)
index_service = IncrementalIndexService(_rag_service)
result = await index_service.update_document(
db, str(doc_id), request.content
)
return result
@router.delete("/bases/{kb_id}/documents/{doc_id}")
async def delete_document_incremental(
kb_id: uuid.UUID,
doc_id: uuid.UUID,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""删除文档"""
org_id = current_user.organization_id
if not org_id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found")
await _get_kb(db, kb_id, org_id)
index_service = IncrementalIndexService(_rag_service)
result = await index_service.delete_document(db, str(doc_id))
return result
@router.post("/bases/{kb_id}/rebuild")
async def rebuild_knowledge_base(
kb_id: uuid.UUID,
force: bool = False,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""重建知识库索引"""
org_id = current_user.organization_id
if not org_id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found")
await _get_kb(db, kb_id, org_id)
index_service = IncrementalIndexService(_rag_service)
result = await index_service.rebuild_knowledge_base(
db, str(kb_id), force
)
return result
@router.post("/bases/{kb_id}/retrieve")
async def enhanced_retrieve(
kb_id: uuid.UUID,
request: RetrieveRequest,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""增强检索(支持重排序和压缩)"""
org_id = current_user.organization_id
if not org_id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found")
await _get_kb(db, kb_id, org_id)
enhanced_rag = EnhancedRAG(_rag_service, _rag_service.embedder)
results = await enhanced_rag.retrieve_with_rerank(
db,
request.query,
[str(kb_id)],
top_k=request.top_k or 5,
use_rerank=request.use_rerank,
use_compression=request.use_compression,
)
return {"results": results, "query": request.query}

View File

@ -0,0 +1,115 @@
"""知识图谱API"""
from typing import Optional
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_db, get_current_user
from app.models.user import User
from app.services.knowledge.graph_builder import GraphBuilder
from app.services.knowledge.graph_query import GraphQuery
router = APIRouter(prefix="/knowledge-bases", tags=["知识图谱"])
@router.post("/{kb_id}/graph/build")
async def build_graph(
kb_id: UUID,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
从知识库构建知识图谱
对知识库中的所有Chunks执行实体和关系抽取
"""
# TODO: 实现批量构建
# 目前先实现单个Chunk的构建
return {"message": "Use /graph/build-chunk to build from specific chunk"}
@router.post("/{kb_id}/graph/build-chunk/{chunk_id}")
async def build_graph_from_chunk(
kb_id: UUID,
chunk_id: UUID,
context: Optional[str] = None,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""从单个Chunk构建图谱"""
builder = GraphBuilder()
try:
stats = await builder.build_from_chunk(db, str(chunk_id), context)
return {
"status": "success",
"stats": stats,
}
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
@router.get("/{kb_id}/graph/statistics")
async def get_graph_statistics(
kb_id: UUID,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""获取图谱统计信息"""
query = GraphQuery()
stats = await query.get_statistics(db, str(kb_id))
return stats
@router.get("/{kb_id}/graph/entities/search")
async def search_entities(
kb_id: UUID,
q: str,
entity_type: Optional[str] = None,
limit: int = 20,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""搜索实体"""
query = GraphQuery()
entities = await query.search_entities(
db, str(kb_id), q, entity_type, limit
)
return {"entities": entities}
@router.get("/{kb_id}/graph/entities/{entity_id}")
async def get_entity(
kb_id: UUID,
entity_id: UUID,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""获取实体详情"""
query = GraphQuery()
entity = await query.get_entity(db, str(entity_id))
if not entity:
raise HTTPException(status_code=404, detail="Entity not found")
# 获取邻居
neighbors = await query.get_entity_neighbors(db, str(entity_id))
entity["neighbors"] = neighbors
return entity
@router.get("/{kb_id}/graph/path")
async def find_path(
kb_id: UUID,
source: str,
target: str,
max_hops: int = 3,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""查找两个实体之间的路径"""
query = GraphQuery()
path = await query.get_entity_path(db, source, target, max_hops)
return {"path": path, "hops": len(path)}

View File

@ -1,9 +1,23 @@
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
from sqlalchemy.orm import declarative_base from sqlalchemy.orm import declarative_base
from sqlalchemy import text from sqlalchemy import text, JSON
from sqlalchemy.types import TypeDecorator
from sqlalchemy.dialects.postgresql import JSONB
from app.config import settings from app.config import settings
class JSONType(TypeDecorator):
"""A JSON type that uses JSONB on PostgreSQL and JSON on other databases."""
impl = JSON
cache_ok = True
def load_dialect_impl(self, dialect):
if dialect.name == "postgresql":
return dialect.type_descriptor(JSONB())
return dialect.type_descriptor(JSON())
engine = create_async_engine( engine = create_async_engine(
settings.DATABASE_URL, settings.DATABASE_URL,
pool_size=10, # 连接池大小 pool_size=10, # 连接池大小

View File

@ -4,7 +4,8 @@ from datetime import datetime, timezone
from fastapi import FastAPI, HTTPException, Request, Depends from fastapi import FastAPI, HTTPException, Request, Depends
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse, Response
from prometheus_client import generate_latest, CONTENT_TYPE_LATEST
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import text from sqlalchemy import text
@ -31,9 +32,12 @@ from app.api.subscriptions import router as subscription_router
from app.api.alerts import router as alerts_router from app.api.alerts import router as alerts_router
from app.api.dashboard import router as dashboard_router from app.api.dashboard import router as dashboard_router
from app.api.brands import router as brands_router from app.api.brands import router as brands_router
from app.api.diagnosis import router as diagnosis_router
from app.api.onboarding import router as onboarding_router from app.api.onboarding import router as onboarding_router
from app.api.platforms import router as platforms_router from app.api.platforms import router as platforms_router
from app.api.platform_rules import router as platform_rules_router from app.api.platform_rules import router as platform_rules_router
from app.api.image import router as image_router
from app.api.knowledge_graph import router as knowledge_graph_router
from app.config import settings from app.config import settings
from app.database import engine, Base from app.database import engine, Base
from app.schemas.common import ErrorResponse, ErrorCode from app.schemas.common import ErrorResponse, ErrorCode
@ -41,6 +45,7 @@ from app.middleware.rate_limit import RateLimitMiddleware
from app.middleware.logging_middleware import RequestLoggingMiddleware from app.middleware.logging_middleware import RequestLoggingMiddleware
from app.middleware.request_id import RequestIdMiddleware from app.middleware.request_id import RequestIdMiddleware
from app.middleware.metrics import MetricsMiddleware from app.middleware.metrics import MetricsMiddleware
from app.monitoring.middleware import MonitoringMiddleware
from app.database import get_db from app.database import get_db
from app.workers.scheduler import query_scheduler from app.workers.scheduler import query_scheduler
@ -49,6 +54,9 @@ from app.workers.scheduler import query_scheduler
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
import app.models import app.models
# 初始化监控模块
import app.monitoring
async with engine.begin() as conn: async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all) await conn.run_sync(Base.metadata.create_all)
@ -131,6 +139,7 @@ async def add_security_headers(request, call_next):
app.add_middleware(RequestLoggingMiddleware) app.add_middleware(RequestLoggingMiddleware)
app.add_middleware(RateLimitMiddleware) app.add_middleware(RateLimitMiddleware)
app.add_middleware(MetricsMiddleware) app.add_middleware(MetricsMiddleware)
app.add_middleware(MonitoringMiddleware)
app.add_middleware(RequestIdMiddleware) app.add_middleware(RequestIdMiddleware)
app.include_router(auth_router, prefix="/api/v1/auth", tags=["认证"]) app.include_router(auth_router, prefix="/api/v1/auth", tags=["认证"])
@ -150,9 +159,12 @@ app.include_router(analytics_router, prefix="/api/v1/analytics", tags=["监测
app.include_router(alerts_router, prefix="/api/v1/alerts", tags=["告警通知"]) app.include_router(alerts_router, prefix="/api/v1/alerts", tags=["告警通知"])
app.include_router(dashboard_router, prefix="/api/v1/dashboard", tags=["仪表盘"]) app.include_router(dashboard_router, prefix="/api/v1/dashboard", tags=["仪表盘"])
app.include_router(brands_router, prefix="/api/v1/brands", tags=["品牌管理"]) app.include_router(brands_router, prefix="/api/v1/brands", tags=["品牌管理"])
app.include_router(diagnosis_router, prefix="/api/v1/diagnosis", tags=["诊断服务"])
app.include_router(onboarding_router, prefix="/api/v1") app.include_router(onboarding_router, prefix="/api/v1")
app.include_router(platforms_router, prefix="/api/v1") app.include_router(platforms_router, prefix="/api/v1")
app.include_router(platform_rules_router) app.include_router(platform_rules_router)
app.include_router(image_router, prefix="/api/v1")
app.include_router(knowledge_graph_router, prefix="/api/v1/knowledge-bases")
@app.get("/health", tags=["可观测性"]) @app.get("/health", tags=["可观测性"])
@ -203,3 +215,90 @@ async def readiness_check(db: AsyncSession = Depends(get_db)):
"timestamp": datetime.now(timezone.utc).isoformat(), "timestamp": datetime.now(timezone.utc).isoformat(),
}, },
) )
@app.get("/metrics", tags=["可观测性"])
async def metrics():
"""Prometheus指标端点"""
return Response(
content=generate_latest(),
media_type=CONTENT_TYPE_LATEST
)
# ---- 详细健康检查端点 ----
from app.services.health_checker import HealthChecker
from app.services.app_state import app_state
@app.get("/health/detailed", tags=["可观测性"])
async def detailed_health(
db: AsyncSession = Depends(get_db),
):
"""
详细健康检查
返回所有依赖组件的健康状态
- database: 数据库连接
- redis: Redis缓存
- llm_providers: LLM服务提供商
- storage: 文件存储
状态
- healthy: 所有组件正常
- degraded: 部分组件异常但仍可服务
- unhealthy: 核心组件异常
"""
checker = HealthChecker(db, settings.REDIS_URL)
health_result = await checker.check_all()
# 添加应用信息
health_result["app"] = app_state.get_info()
return health_result
@app.get("/health/liveness", tags=["可观测性"])
async def liveness():
"""
存活探针
用于Kubernetes livenessProbe
只要应用进程存活就返回200
"""
return {"status": "alive"}
@app.get("/health/readiness", tags=["可观测性"])
async def readiness(db: AsyncSession = Depends(get_db)):
"""
就绪探针
用于Kubernetes readinessProbe
检查核心依赖是否就绪
"""
checker = HealthChecker(db, settings.REDIS_URL)
# 只检查核心依赖数据库和Redis
db_result = await checker.check_database()
redis_result = await checker.check_redis()
if db_result.healthy and redis_result.healthy:
return {
"status": "ready",
"checks": {
"database": db_result.healthy,
"redis": redis_result.healthy,
}
}
else:
raise HTTPException(
status_code=503,
detail={
"status": "not_ready",
"checks": {
"database": db_result.healthy,
"redis": redis_result.healthy,
}
}
)

View File

@ -15,6 +15,12 @@ from app.models.knowledge import (
KnowledgeChunk, KnowledgeChunk,
KnowledgeSearchLog, KnowledgeSearchLog,
) )
from app.models.knowledge_graph import (
KnowledgeEntity,
KnowledgeRelation,
EntityType,
RelationType,
)
from app.models.analytics import PublishRecord, ContentMetrics, OptimizationInsight from app.models.analytics import PublishRecord, ContentMetrics, OptimizationInsight
from app.models.distribution import DistributionSchedule from app.models.distribution import DistributionSchedule
# 缺失的模型导入 - 重构后遗留 # 缺失的模型导入 - 重构后遗留
@ -48,6 +54,10 @@ __all__ = [
"KnowledgeDocument", "KnowledgeDocument",
"KnowledgeChunk", "KnowledgeChunk",
"KnowledgeSearchLog", "KnowledgeSearchLog",
"KnowledgeEntity",
"KnowledgeRelation",
"EntityType",
"RelationType",
"PublishRecord", "PublishRecord",
"ContentMetrics", "ContentMetrics",
"OptimizationInsight", "OptimizationInsight",

View File

@ -3,10 +3,9 @@ from datetime import datetime
from sqlalchemy import String, Integer, ForeignKey, Index, func, Text from sqlalchemy import String, Integer, ForeignKey, Index, func, Text
from sqlalchemy import Uuid from sqlalchemy import Uuid
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.database import Base from app.database import Base, JSONType
class AgentRegistry(Base): class AgentRegistry(Base):
@ -24,7 +23,7 @@ class AgentRegistry(Base):
version: Mapped[str | None] = mapped_column(String(20), nullable=True) version: Mapped[str | None] = mapped_column(String(20), nullable=True)
endpoint: Mapped[str | None] = mapped_column(String(500), nullable=True) endpoint: Mapped[str | None] = mapped_column(String(500), nullable=True)
status: Mapped[str] = mapped_column(String(20), server_default="offline", nullable=False) status: Mapped[str] = mapped_column(String(20), server_default="offline", nullable=False)
capabilities: Mapped[dict | None] = mapped_column(JSONB, nullable=True) capabilities: Mapped[dict | None] = mapped_column(JSONType, nullable=True)
last_heartbeat: Mapped[datetime | None] = mapped_column(nullable=True) last_heartbeat: Mapped[datetime | None] = mapped_column(nullable=True)
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
server_default=func.now(), server_default=func.now(),
@ -68,7 +67,7 @@ class AgentConfig(Base):
nullable=False, nullable=False,
) )
config_key: Mapped[str] = mapped_column(String(100), nullable=False) config_key: Mapped[str] = mapped_column(String(100), nullable=False)
config_value: Mapped[dict] = mapped_column(JSONB, nullable=False) config_value: Mapped[dict] = mapped_column(JSONType, nullable=False)
description: Mapped[str | None] = mapped_column(String(500), nullable=True) description: Mapped[str | None] = mapped_column(String(500), nullable=True)
updated_at: Mapped[datetime] = mapped_column( updated_at: Mapped[datetime] = mapped_column(
server_default=func.now(), server_default=func.now(),
@ -111,8 +110,8 @@ class AgentTask(Base):
task_type: Mapped[str] = mapped_column(String(50), nullable=False) task_type: Mapped[str] = mapped_column(String(50), nullable=False)
status: Mapped[str] = mapped_column(String(20), server_default="pending", nullable=False) status: Mapped[str] = mapped_column(String(20), server_default="pending", nullable=False)
priority: Mapped[int] = mapped_column(Integer, server_default="0", nullable=False) priority: Mapped[int] = mapped_column(Integer, server_default="0", nullable=False)
input_data: Mapped[dict | None] = mapped_column(JSONB, nullable=True) input_data: Mapped[dict | None] = mapped_column(JSONType, nullable=True)
output_data: Mapped[dict | None] = mapped_column(JSONB, nullable=True) output_data: Mapped[dict | None] = mapped_column(JSONType, nullable=True)
error_message: Mapped[str | None] = mapped_column(Text, nullable=True) error_message: Mapped[str | None] = mapped_column(Text, nullable=True)
created_by: Mapped[uuid.UUID | None] = mapped_column( created_by: Mapped[uuid.UUID | None] = mapped_column(
Uuid(as_uuid=True), Uuid(as_uuid=True),
@ -184,7 +183,7 @@ class AgentTaskLog(Base):
) )
log_level: Mapped[str] = mapped_column(String(10), nullable=False) log_level: Mapped[str] = mapped_column(String(10), nullable=False)
message: Mapped[str] = mapped_column(Text, nullable=False) message: Mapped[str] = mapped_column(Text, nullable=False)
extra_metadata: Mapped[dict | None] = mapped_column("metadata", JSONB, nullable=True) extra_metadata: Mapped[dict | None] = mapped_column("metadata", JSONType, nullable=True)
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
server_default=func.now(), server_default=func.now(),
nullable=False, nullable=False,

View File

@ -3,10 +3,9 @@ from datetime import datetime
from sqlalchemy import String, Integer, Boolean, ForeignKey, Index, func, Text from sqlalchemy import String, Integer, Boolean, ForeignKey, Index, func, Text
from sqlalchemy import Uuid from sqlalchemy import Uuid
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.database import Base from app.database import Base, JSONType
class BrandKnowledge(Base): class BrandKnowledge(Base):
@ -25,7 +24,7 @@ class BrandKnowledge(Base):
category: Mapped[str] = mapped_column(String(50), nullable=False) category: Mapped[str] = mapped_column(String(50), nullable=False)
title: Mapped[str] = mapped_column(String(200), nullable=False) title: Mapped[str] = mapped_column(String(200), nullable=False)
content: Mapped[str] = mapped_column(Text, nullable=False) content: Mapped[str] = mapped_column(Text, nullable=False)
extra_metadata: Mapped[dict | None] = mapped_column("metadata", JSONB, nullable=True) extra_metadata: Mapped[dict | None] = mapped_column("metadata", JSONType, nullable=True)
is_active: Mapped[bool] = mapped_column(Boolean, server_default="true", nullable=False) is_active: Mapped[bool] = mapped_column(Boolean, server_default="true", nullable=False)
created_by: Mapped[uuid.UUID | None] = mapped_column( created_by: Mapped[uuid.UUID | None] = mapped_column(
Uuid(as_uuid=True), Uuid(as_uuid=True),

View File

@ -3,10 +3,9 @@ from datetime import datetime
from sqlalchemy import String, Integer, ForeignKey, Index, func, Text from sqlalchemy import String, Integer, ForeignKey, Index, func, Text
from sqlalchemy import Uuid from sqlalchemy import Uuid
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.database import Base from app.database import Base, JSONType
class Content(Base): class Content(Base):
@ -31,9 +30,9 @@ class Content(Base):
content_type: Mapped[str] = mapped_column(String(50), nullable=False) content_type: Mapped[str] = mapped_column(String(50), nullable=False)
body: Mapped[str | None] = mapped_column(Text, nullable=True) body: Mapped[str | None] = mapped_column(Text, nullable=True)
status: Mapped[str] = mapped_column(String(20), server_default="draft", nullable=False) status: Mapped[str] = mapped_column(String(20), server_default="draft", nullable=False)
target_platforms: Mapped[list | None] = mapped_column(JSONB, nullable=True) target_platforms: Mapped[list | None] = mapped_column(JSONType, nullable=True)
keywords: Mapped[list | None] = mapped_column(JSONB, nullable=True) keywords: Mapped[list | None] = mapped_column(JSONType, nullable=True)
extra_metadata: Mapped[dict | None] = mapped_column("metadata", JSONB, nullable=True) extra_metadata: Mapped[dict | None] = mapped_column("metadata", JSONType, nullable=True)
created_by: Mapped[uuid.UUID | None] = mapped_column( created_by: Mapped[uuid.UUID | None] = mapped_column(
Uuid(as_uuid=True), Uuid(as_uuid=True),
ForeignKey("users.id", ondelete="SET NULL"), ForeignKey("users.id", ondelete="SET NULL"),

View File

@ -4,10 +4,9 @@ from datetime import datetime
from sqlalchemy import String, Integer, ForeignKey, Index, func, Text from sqlalchemy import String, Integer, ForeignKey, Index, func, Text
from sqlalchemy import Uuid from sqlalchemy import Uuid
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.database import Base from app.database import Base, JSONType
class DistributionSchedule(Base): class DistributionSchedule(Base):
@ -29,9 +28,9 @@ class DistributionSchedule(Base):
ForeignKey("contents.id", ondelete="SET NULL"), ForeignKey("contents.id", ondelete="SET NULL"),
nullable=True, nullable=True,
) )
platforms: Mapped[list | None] = mapped_column(JSONB, nullable=True) platforms: Mapped[list | None] = mapped_column(JSONType, nullable=True)
"""[{platform, platform_name, scheduled_time, status}]""" """[{platform, platform_name, scheduled_time, status}]"""
tips: Mapped[list | None] = mapped_column(JSONB, nullable=True) tips: Mapped[list | None] = mapped_column(JSONType, nullable=True)
status: Mapped[str] = mapped_column(String(20), server_default="pending", nullable=False) status: Mapped[str] = mapped_column(String(20), server_default="pending", nullable=False)
created_by: Mapped[uuid.UUID | None] = mapped_column( created_by: Mapped[uuid.UUID | None] = mapped_column(
Uuid(as_uuid=True), Uuid(as_uuid=True),

View File

@ -3,10 +3,9 @@ from datetime import datetime
from sqlalchemy import String, Integer, ForeignKey, Index, func, Text from sqlalchemy import String, Integer, ForeignKey, Index, func, Text
from sqlalchemy import Uuid from sqlalchemy import Uuid
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.database import Base from app.database import Base, JSONType
# pgvector Vector type - imported conditionally # pgvector Vector type - imported conditionally
try: try:
@ -92,7 +91,7 @@ class KnowledgeDocument(Base):
status: Mapped[str] = mapped_column(String(20), server_default="processing", nullable=False) # "processing" / "ready" / "failed" status: Mapped[str] = mapped_column(String(20), server_default="processing", nullable=False) # "processing" / "ready" / "failed"
error_message: Mapped[str | None] = mapped_column(Text, nullable=True) error_message: Mapped[str | None] = mapped_column(Text, nullable=True)
# mapped_column("metadata") to avoid SQLAlchemy reserved keyword conflict # mapped_column("metadata") to avoid SQLAlchemy reserved keyword conflict
extra_metadata: Mapped[dict | None] = mapped_column("metadata", JSONB, nullable=True) extra_metadata: Mapped[dict | None] = mapped_column("metadata", JSONType, nullable=True)
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
server_default=func.now(), server_default=func.now(),
nullable=False, nullable=False,
@ -152,7 +151,7 @@ class KnowledgeChunk(Base):
chunk_index: Mapped[int] = mapped_column(Integer, nullable=False) chunk_index: Mapped[int] = mapped_column(Integer, nullable=False)
token_count: Mapped[int] = mapped_column(Integer, server_default="0", nullable=False) token_count: Mapped[int] = mapped_column(Integer, server_default="0", nullable=False)
# mapped_column("metadata") to avoid SQLAlchemy reserved keyword conflict # mapped_column("metadata") to avoid SQLAlchemy reserved keyword conflict
extra_metadata: Mapped[dict | None] = mapped_column("metadata", JSONB, nullable=True) extra_metadata: Mapped[dict | None] = mapped_column("metadata", JSONType, nullable=True)
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
server_default=func.now(), server_default=func.now(),
nullable=False, nullable=False,
@ -189,7 +188,7 @@ class KnowledgeSearchLog(Base):
nullable=True, nullable=True,
) )
query: Mapped[str] = mapped_column(Text, nullable=False) query: Mapped[str] = mapped_column(Text, nullable=False)
knowledge_base_ids: Mapped[list | None] = mapped_column(JSONB, nullable=True) knowledge_base_ids: Mapped[list | None] = mapped_column(JSONType, nullable=True)
results_count: Mapped[int] = mapped_column(Integer, server_default="0", nullable=False) results_count: Mapped[int] = mapped_column(Integer, server_default="0", nullable=False)
latency_ms: Mapped[int] = mapped_column(Integer, server_default="0", nullable=False) latency_ms: Mapped[int] = mapped_column(Integer, server_default="0", nullable=False)
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(

View File

@ -0,0 +1,175 @@
"""知识图谱数据模型"""
import uuid
from datetime import datetime
import enum
from sqlalchemy import String, Text, DateTime, ForeignKey, Index, Enum, func
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.database import Base, JSONType
class EntityType(str, enum.Enum):
"""实体类型"""
ORGANIZATION = "ORGANIZATION" # 组织/公司
PRODUCT = "PRODUCT" # 产品
PERSON = "PERSON" # 人物
LOCATION = "LOCATION" # 地点
TECHNOLOGY = "TECHNOLOGY" # 技术
BRAND = "BRAND" # 品牌
EVENT = "EVENT" # 事件
CONCEPT = "CONCEPT" # 概念
OTHER = "OTHER" # 其他
class RelationType(str, enum.Enum):
"""关系类型"""
# 组织关系
COMPETES_WITH = "COMPETES_WITH" # 竞争对手
PARTNERS_WITH = "PARTNERS_WITH" # 合作伙伴
ACQUIRES = "ACQUIRES" # 收购
SUBSIDIARY_OF = "SUBSIDIARY_OF" # 子公司
# 产品关系
PRODUCES = "PRODUCES" # 生产
USES_TECHNOLOGY = "USES_TECHNOLOGY" # 使用技术
PART_OF = "PART_OF" # 属于(产品线)
# 地点关系
LOCATED_IN = "LOCATED_IN" # 位于
FOUNDED_IN = "FOUNDED_IN" # 成立于
# 人物关系
CEO_OF = "CEO_OF" # CEO
FOUNDER_OF = "FOUNDER_OF" # 创始人
# 通用关系
RELATED_TO = "RELATED_TO" # 相关
MENTIONED_IN = "MENTIONED_IN" # 提及于
ALSO_KNOWN_AS = "ALSO_KNOWN_AS" # 又名
class KnowledgeEntity(Base):
"""知识实体"""
__tablename__ = "knowledge_entities"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
primary_key=True,
default=uuid.uuid4,
)
knowledge_base_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("knowledge_bases.id", ondelete="CASCADE"),
nullable=False,
)
# 实体信息
name: Mapped[str] = mapped_column(String(500), nullable=False, index=True)
entity_type: Mapped[EntityType] = mapped_column(Enum(EntityType), nullable=False, index=True)
description: Mapped[str | None] = mapped_column(Text, nullable=True)
# 扩展属性JSON
properties: Mapped[dict | None] = mapped_column(JSONType, default=dict)
# 来源信息
source_chunk_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True),
ForeignKey("knowledge_chunks.id", ondelete="SET NULL"),
nullable=True,
)
confidence: Mapped[str | None] = mapped_column(String(20), nullable=True) # 置信度描述high/medium/low
# 元数据
created_at: Mapped[datetime] = mapped_column(
server_default=func.now(),
nullable=False,
)
updated_at: Mapped[datetime] = mapped_column(
server_default=func.now(),
onupdate=func.now(),
nullable=False,
)
# 索引
__table_args__ = (
Index("ix_entities_kb_name", "knowledge_base_id", "name"),
Index("ix_entities_kb_type", "knowledge_base_id", "entity_type"),
)
# 关系
outgoing_relations: Mapped[list["KnowledgeRelation"]] = relationship(
"KnowledgeRelation",
foreign_keys="KnowledgeRelation.source_entity_id",
back_populates="source_entity",
cascade="all, delete-orphan",
)
incoming_relations: Mapped[list["KnowledgeRelation"]] = relationship(
"KnowledgeRelation",
foreign_keys="KnowledgeRelation.target_entity_id",
back_populates="target_entity",
cascade="all, delete-orphan",
)
class KnowledgeRelation(Base):
"""知识关系"""
__tablename__ = "knowledge_relations"
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
primary_key=True,
default=uuid.uuid4,
)
# 关系两端
source_entity_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("knowledge_entities.id", ondelete="CASCADE"),
nullable=False,
)
target_entity_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("knowledge_entities.id", ondelete="CASCADE"),
nullable=False,
)
# 关系信息
relation_type: Mapped[RelationType] = mapped_column(Enum(RelationType), nullable=False, index=True)
# 扩展属性
properties: Mapped[dict | None] = mapped_column(JSONType, default=dict)
# 来源信息
source_chunk_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True),
ForeignKey("knowledge_chunks.id", ondelete="SET NULL"),
nullable=True,
)
confidence: Mapped[str | None] = mapped_column(String(20), nullable=True)
# 元数据
created_at: Mapped[datetime] = mapped_column(
server_default=func.now(),
nullable=False,
)
# 关系
source_entity: Mapped["KnowledgeEntity"] = relationship(
"KnowledgeEntity",
foreign_keys=[source_entity_id],
back_populates="outgoing_relations",
)
target_entity: Mapped["KnowledgeEntity"] = relationship(
"KnowledgeEntity",
foreign_keys=[target_entity_id],
back_populates="incoming_relations",
)
# 索引
__table_args__ = (
Index("ix_relations_source", "source_entity_id"),
Index("ix_relations_target", "target_entity_id"),
Index("ix_relations_type", "relation_type"),
)

View File

@ -3,10 +3,9 @@ from datetime import datetime
from sqlalchemy import String, Integer, ForeignKey, Index, func, Text from sqlalchemy import String, Integer, ForeignKey, Index, func, Text
from sqlalchemy import Uuid from sqlalchemy import Uuid
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.database import Base from app.database import Base, JSONType
class LifecycleProject(Base): class LifecycleProject(Base):
@ -23,7 +22,7 @@ class LifecycleProject(Base):
nullable=False, nullable=False,
) )
brand_name: Mapped[str] = mapped_column(String(100), nullable=False) brand_name: Mapped[str] = mapped_column(String(100), nullable=False)
brand_aliases: Mapped[list] = mapped_column(JSONB, server_default="[]", nullable=False) brand_aliases: Mapped[list] = mapped_column(JSONType, server_default="[]", nullable=False)
current_stage: Mapped[int] = mapped_column(Integer, server_default="1", nullable=False) current_stage: Mapped[int] = mapped_column(Integer, server_default="1", nullable=False)
status: Mapped[str] = mapped_column(String(20), server_default="active", nullable=False) status: Mapped[str] = mapped_column(String(20), server_default="active", nullable=False)
created_by: Mapped[uuid.UUID] = mapped_column( created_by: Mapped[uuid.UUID] = mapped_column(
@ -77,7 +76,7 @@ class ProjectStage(Base):
started_at: Mapped[datetime | None] = mapped_column(nullable=True) started_at: Mapped[datetime | None] = mapped_column(nullable=True)
completed_at: Mapped[datetime | None] = mapped_column(nullable=True) completed_at: Mapped[datetime | None] = mapped_column(nullable=True)
notes: Mapped[str | None] = mapped_column(Text, nullable=True) notes: Mapped[str | None] = mapped_column(Text, nullable=True)
metrics: Mapped[dict | None] = mapped_column(JSONB, nullable=True) metrics: Mapped[dict | None] = mapped_column(JSONType, nullable=True)
# Relationships # Relationships
project: Mapped["LifecycleProject"] = relationship( project: Mapped["LifecycleProject"] = relationship(

View File

@ -3,10 +3,9 @@ from datetime import datetime
from sqlalchemy import String, Boolean, ForeignKey, Index, func, Text from sqlalchemy import String, Boolean, ForeignKey, Index, func, Text
from sqlalchemy import Uuid from sqlalchemy import Uuid
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.database import Base from app.database import Base, JSONType
class PlatformRule(Base): class PlatformRule(Base):
@ -21,7 +20,7 @@ class PlatformRule(Base):
rule_category: Mapped[str] = mapped_column(String(50), nullable=False) rule_category: Mapped[str] = mapped_column(String(50), nullable=False)
rule_name: Mapped[str] = mapped_column(String(200), nullable=False) rule_name: Mapped[str] = mapped_column(String(200), nullable=False)
description: Mapped[str | None] = mapped_column(Text, nullable=True) description: Mapped[str | None] = mapped_column(Text, nullable=True)
check_criteria: Mapped[dict | None] = mapped_column(JSONB, nullable=True) check_criteria: Mapped[dict | None] = mapped_column(JSONType, nullable=True)
severity: Mapped[str] = mapped_column(String(20), nullable=False) severity: Mapped[str] = mapped_column(String(20), nullable=False)
is_active: Mapped[bool] = mapped_column(Boolean, server_default="true", nullable=False) is_active: Mapped[bool] = mapped_column(Boolean, server_default="true", nullable=False)
updated_at: Mapped[datetime] = mapped_column( updated_at: Mapped[datetime] = mapped_column(

View File

@ -0,0 +1,13 @@
"""监控模块"""
import os
from app.monitoring.metrics import *
from app.monitoring.middleware import MonitoringMiddleware
from app.monitoring.agent_hooks import agent_execution_context, record_agent_execution
from app.monitoring.llm_metrics import get_llm_metrics, LLMMetricsWrapper
# 设置服务信息
SERVICE_INFO.info({
"version": "1.0.0",
"environment": os.getenv("ENVIRONMENT", "development"),
})

View File

@ -0,0 +1,57 @@
"""Agent执行指标钩子"""
import time
from contextlib import asynccontextmanager
from typing import Optional
from app.monitoring.metrics import (
AGENT_EXECUTIONS_TOTAL,
AGENT_EXECUTION_DURATION_SECONDS,
AGENT_RUNNING_TASKS,
)
@asynccontextmanager
async def agent_execution_context(agent_name: str):
"""Agent执行上下文管理器 - 自动记录指标"""
# 增加运行任务计数
AGENT_RUNNING_TASKS.labels(agent_name=agent_name).inc()
start_time = time.perf_counter()
status = "success"
try:
yield
except Exception as e:
status = "failure"
raise
finally:
# 记录执行时间和状态
duration = time.perf_counter() - start_time
AGENT_EXECUTIONS_TOTAL.labels(
agent_name=agent_name,
status=status
).inc()
AGENT_EXECUTION_DURATION_SECONDS.labels(
agent_name=agent_name
).observe(duration)
# 减少运行任务计数
AGENT_RUNNING_TASKS.labels(agent_name=agent_name).dec()
def record_agent_execution(
agent_name: str,
status: str,
duration: float
):
"""手动记录Agent执行指标"""
AGENT_EXECUTIONS_TOTAL.labels(
agent_name=agent_name,
status=status
).inc()
AGENT_EXECUTION_DURATION_SECONDS.labels(
agent_name=agent_name
).observe(duration)

View File

@ -0,0 +1,102 @@
"""LLM调用指标包装"""
import time
from typing import Optional
from app.monitoring.metrics import (
LLM_REQUESTS_TOTAL,
LLM_REQUEST_DURATION_SECONDS,
LLM_TOKENS_TOTAL,
LLM_COST_ESTIMATED,
)
# LLM成本估算USD/token
LLM_COST_PER_TOKEN = {
# OpenAI
("openai", "gpt-4o"): {"prompt": 0.000005, "completion": 0.000015},
("openai", "gpt-4o-mini"): {"prompt": 0.00000015, "completion": 0.0000006},
("openai", "gpt-4-turbo"): {"prompt": 0.00001, "completion": 0.00003},
# DeepSeek
("deepseek", "deepseek-chat"): {"prompt": 0.00000014, "completion": 0.00000028},
("deepseek", "deepseek-coder"): {"prompt": 0.00000014, "completion": 0.00000028},
}
class LLMMetricsWrapper:
"""LLM调用指标包装器"""
def __init__(self, provider: str, model: str):
self.provider = provider
self.model = model
def record_request(
self,
status: str,
duration: float,
prompt_tokens: Optional[int] = None,
completion_tokens: Optional[int] = None,
):
"""记录LLM请求指标"""
# 记录请求数和耗时
LLM_REQUESTS_TOTAL.labels(
provider=self.provider,
model=self.model,
status=status
).inc()
LLM_REQUEST_DURATION_SECONDS.labels(
provider=self.provider,
model=self.model
).observe(duration)
# 记录Token消耗
if prompt_tokens is not None:
LLM_TOKENS_TOTAL.labels(
provider=self.provider,
model=self.model,
token_type="prompt"
).inc(prompt_tokens)
if completion_tokens is not None:
LLM_TOKENS_TOTAL.labels(
provider=self.provider,
model=self.model,
token_type="completion"
).inc(completion_tokens)
# 估算成本
cost = self._estimate_cost(prompt_tokens, completion_tokens)
if cost > 0:
LLM_COST_ESTIMATED.labels(
provider=self.provider,
model=self.model
).inc(cost)
def _estimate_cost(
self,
prompt_tokens: Optional[int],
completion_tokens: Optional[int]
) -> float:
"""估算请求成本"""
cost_info = LLM_COST_PER_TOKEN.get((self.provider, self.model))
if not cost_info:
return 0.0
total = 0.0
if prompt_tokens:
total += prompt_tokens * cost_info["prompt"]
if completion_tokens:
total += completion_tokens * cost_info["completion"]
return total
# 全局LLM指标记录器
_llm_metrics_cache: dict[str, LLMMetricsWrapper] = {}
def get_llm_metrics(provider: str, model: str) -> LLMMetricsWrapper:
"""获取LLM指标包装器带缓存"""
key = f"{provider}:{model}"
if key not in _llm_metrics_cache:
_llm_metrics_cache[key] = LLMMetricsWrapper(provider, model)
return _llm_metrics_cache[key]

View File

@ -0,0 +1,97 @@
"""Prometheus指标定义"""
from prometheus_client import Counter, Histogram, Gauge, Info
# API层指标
API_REQUESTS_TOTAL = Counter(
"geo_api_requests_total",
"Total API requests",
["method", "endpoint", "status"]
)
API_REQUEST_DURATION_SECONDS = Histogram(
"geo_api_request_duration_seconds",
"API request duration in seconds",
["method", "endpoint"],
buckets=(0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0)
)
API_REQUESTS_IN_PROGRESS = Gauge(
"geo_api_requests_in_progress",
"Number of requests currently being processed",
["method", "endpoint"]
)
# Agent层指标
AGENT_EXECUTIONS_TOTAL = Counter(
"geo_agent_executions_total",
"Total agent executions",
["agent_name", "status"]
)
AGENT_EXECUTION_DURATION_SECONDS = Histogram(
"geo_agent_execution_duration_seconds",
"Agent execution duration in seconds",
["agent_name"],
buckets=(0.1, 0.5, 1.0, 5.0, 10.0, 30.0, 60.0, 120.0)
)
AGENT_RUNNING_TASKS = Gauge(
"geo_agent_running_tasks",
"Number of tasks currently running",
["agent_name"]
)
# LLM层指标
LLM_REQUESTS_TOTAL = Counter(
"geo_llm_requests_total",
"Total LLM requests",
["provider", "model", "status"]
)
LLM_REQUEST_DURATION_SECONDS = Histogram(
"geo_llm_request_duration_seconds",
"LLM request duration in seconds",
["provider", "model"],
buckets=(0.1, 0.5, 1.0, 2.0, 5.0, 10.0, 30.0, 60.0)
)
LLM_TOKENS_TOTAL = Counter(
"geo_llm_tokens_total",
"Total LLM tokens used",
["provider", "model", "token_type"]
)
LLM_COST_ESTIMATED = Gauge(
"geo_llm_cost_estimated",
"Estimated LLM cost in USD",
["provider", "model"]
)
# 业务层指标
BRAND_COUNT = Gauge(
"geo_brands_total",
"Total number of brands"
)
QUERY_COUNT_TOTAL = Counter(
"geo_queries_total",
"Total number of queries executed",
["platform", "status"]
)
CONTENT_GENERATED_TOTAL = Counter(
"geo_content_generated_total",
"Total content generated"
)
CITATION_DETECTED_TOTAL = Counter(
"geo_citations_detected_total",
"Total citations detected",
["platform"]
)
# 系统信息
SERVICE_INFO = Info(
"geo_service",
"GEO service information"
)

View File

@ -0,0 +1,86 @@
"""监控中间件"""
import time
from typing import Callable
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from app.monitoring.metrics import (
API_REQUESTS_TOTAL,
API_REQUEST_DURATION_SECONDS,
API_REQUESTS_IN_PROGRESS,
)
# 需要排除的路径(不记录指标)
EXCLUDED_PATHS = {"/health", "/ready", "/metrics", "/docs", "/openapi.json"}
class MonitoringMiddleware(BaseHTTPMiddleware):
"""API监控中间件"""
async def dispatch(self, request: Request, call_next: Callable) -> Response:
# 跳过排除路径
if request.url.path in EXCLUDED_PATHS:
return await call_next(request)
# 提取端点标识(用于指标标签)
endpoint = self._get_endpoint_label(request)
# 增加活跃请求计数
API_REQUESTS_IN_PROGRESS.labels(
method=request.method,
endpoint=endpoint
).inc()
# 记录开始时间
start_time = time.perf_counter()
try:
# 执行请求
response = await call_next(request)
status_code = response.status_code
except Exception as e:
status_code = 500
raise
finally:
# 计算耗时
duration = time.perf_counter() - start_time
# 记录指标
API_REQUESTS_TOTAL.labels(
method=request.method,
endpoint=endpoint,
status=str(status_code)
).inc()
API_REQUEST_DURATION_SECONDS.labels(
method=request.method,
endpoint=endpoint
).observe(duration)
# 减少活跃请求计数
API_REQUESTS_IN_PROGRESS.labels(
method=request.method,
endpoint=endpoint
).dec()
return response
def _get_endpoint_label(self, request: Request) -> str:
"""提取端点标签"""
path = request.url.path
# 规范化路径替换ID等参数
parts = path.strip("/").split("/")
# 处理常见模式:/api/v1/resources/{id}
if len(parts) >= 4 and parts[0] == "api":
resource = parts[2] if len(parts) > 2 else "unknown"
action = parts[3] if len(parts) > 3 else "list"
# 映射到规范标签
if action.isdigit():
return f"{resource}_detail"
return f"{resource}_{action}"
return "other"

View File

@ -74,3 +74,27 @@ class ChunkPreview(BaseModel):
token_count: int token_count: int
model_config = {"from_attributes": True} model_config = {"from_attributes": True}
# ---------- 增量索引 Schemas ----------
class UpdateDocumentRequest(BaseModel):
"""更新文档内容请求"""
content: str
class RetrieveRequest(BaseModel):
"""增强检索请求"""
query: str = Field(..., min_length=1)
top_k: Optional[int] = Field(default=5, ge=1, le=50)
use_rerank: Optional[bool] = Field(default=True)
use_compression: Optional[bool] = Field(default=False)
class RebuildResponse(BaseModel):
"""重建索引响应"""
total: int
processed: int
skipped: int
failed: int
errors: list[dict]

View File

@ -0,0 +1,54 @@
"""应用状态管理"""
import os
import platform
import time
class AppState:
"""应用状态"""
def __init__(self):
self.start_time = time.time()
self.platform = platform.system()
self.python_version = platform.python_version()
self.version = os.getenv("APP_VERSION", "1.0.0")
self.environment = os.getenv("ENVIRONMENT", "development")
def get_uptime_seconds(self) -> float:
"""获取运行时间(秒)"""
return time.time() - self.start_time
def get_uptime_formatted(self) -> str:
"""获取格式化的运行时间"""
seconds = self.get_uptime_seconds()
days = int(seconds // 86400)
hours = int((seconds % 86400) // 3600)
minutes = int((seconds % 3600) // 60)
secs = int(seconds % 60)
parts = []
if days > 0:
parts.append(f"{days}")
if hours > 0:
parts.append(f"{hours}小时")
if minutes > 0:
parts.append(f"{minutes}分钟")
parts.append(f"{secs}")
return "".join(parts)
def get_info(self) -> dict:
"""获取应用信息"""
return {
"version": self.version,
"environment": self.environment,
"platform": self.platform,
"python_version": self.python_version,
"uptime_seconds": round(self.get_uptime_seconds(), 2),
"uptime_formatted": self.get_uptime_formatted(),
}
# 全局实例
app_state = AppState()

View File

@ -0,0 +1,362 @@
"""GEO内容8大母题库定义"""
from dataclasses import dataclass
from typing import Optional
@dataclass
class TopicTemplate:
"""母题模板"""
id: str # product_comparison
name: str # 产品对比
description: str # 描述
icon: str # emoji图标
prompt_template: str # Prompt模板
seo_tips: list[str] # SEO技巧
recommended_platforms: list[str] # 推荐平台
word_count_range: tuple[int, int] # 推荐字数范围
required_params: list[str] # 必填参数
optional_params: list[str] # 可选参数
TOPIC_TEMPLATES = {
# 1. 产品对比
"product_comparison": TopicTemplate(
id="product_comparison",
name="产品对比",
description="品牌与竞品的功能、价格、体验对比",
icon="⚖️",
prompt_template="""
请基于以下信息生成一篇产品对比文章
你的品牌{brand_name}
对比品牌{competitor_name}
对比维度{comparison_dimensions}
目标关键词{keywords}
内容风格{content_style}
要求
1. 客观中立不刻意贬低竞品
2. 突出自身优势但有理有据
3. 适合AI平台引用结构清晰数据支撑
4. 包含对比表格或图表
5. 字数{word_count}
文章结构
1. 引言简要介绍两个产品
2. 外观设计对比
3. 功能特性对比
4. 性价比对比
5. 适用场景分析
6. 结论总结优劣给出建议
""",
seo_tips=[
"标题包含品牌名+对比+关键词",
"使用对比表格增强可读性",
"自然嵌入目标关键词",
],
recommended_platforms=["知乎", "百家号", "公众号"],
word_count_range=(1500, 3000),
required_params=["brand_name", "competitor_name", "comparison_dimensions", "keywords"],
optional_params=["content_style", "word_count"],
),
# 2. 使用指南
"how_to_guide": TopicTemplate(
id="how_to_guide",
name="使用指南",
description="产品使用方法教程,帮助用户解决问题",
icon="📖",
prompt_template="""
请基于以下信息生成一篇使用指南文章
产品名称{product_name}
核心功能{core_features}
目标用户{target_audience}
目标关键词{keywords}
内容风格{content_style}
要求
1. 步骤清晰操作性强
2. 图文并茂描述需要的图片类型
3. 适合搜索引擎收录
4. 适合AI平台引用
5. 字数{word_count}
文章结构
1. 引言介绍产品价值
2. 基础设置/准备工作
3. 核心功能使用步骤分点详述
4. 进阶技巧/常见问题
5. 总结与进阶学习建议
""",
seo_tips=[
"标题包含产品名+使用方法/教程",
"使用有序列表标注步骤",
"包含FAQ解答常见问题",
],
recommended_platforms=["知乎", "小红书", "简书"],
word_count_range=(1000, 2500),
required_params=["product_name", "core_features", "keywords"],
optional_params=["target_audience", "content_style", "word_count"],
),
# 3. 行业趋势
"industry_trends": TopicTemplate(
id="industry_trends",
name="行业趋势",
description="行业动态、发展方向和未来预测分析",
icon="📈",
prompt_template="""
请基于以下信息生成一篇行业趋势分析文章
行业名称{industry_name}
品牌视角{brand_perspective}
分析维度{analysis_dimensions}
目标关键词{keywords}
内容风格{content_style}
要求
1. 数据支撑引用权威来源
2. 趋势分析有理有据
3. 适合AI平台引用观点鲜明数据翔实
4. 字数{word_count}
文章结构
1. 引言行业现状概述
2. 核心趋势分析3-5个趋势
3. 趋势背后的驱动因素
4. 品牌如何把握趋势
5. 展望与建议
""",
seo_tips=[
"标题包含行业名+趋势/预测",
"引用数据增加可信度",
"使用时间线展示演变",
],
recommended_platforms=["知乎", "百家号", "公众号"],
word_count_range=(2000, 4000),
required_params=["industry_name", "brand_perspective", "keywords"],
optional_params=["analysis_dimensions", "content_style", "word_count"],
),
# 4. 专家观点
"expert_opinion": TopicTemplate(
id="expert_opinion",
name="专家观点",
description="行业专家见解和专业分析",
icon="💡",
prompt_template="""
请基于以下信息生成一篇专家观点文章
主题{topic}
专家身份{expert_identity}
核心观点{core_opinion}
目标关键词{keywords}
内容风格{content_style}
要求
1. 观点鲜明论证充分
2. 结合实际案例
3. 适合AI平台引用
4. 字数{word_count}
文章结构
1. 引言抛出核心观点
2. 观点详细阐述
3. 案例支撑
4. 对行业的影响
5. 结论与建议
""",
seo_tips=[
"标题体现专家身份+观点",
"使用引用格式突出核心观点",
"增加互动性问题",
],
recommended_platforms=["知乎", "公众号", "微博"],
word_count_range=(1500, 3000),
required_params=["topic", "expert_identity", "core_opinion", "keywords"],
optional_params=["content_style", "word_count"],
),
# 5. 案例研究
"case_study": TopicTemplate(
id="case_study",
name="案例研究",
description="成功案例深度分析",
icon="🏆",
prompt_template="""
请基于以下信息生成一篇案例研究文章
案例名称{case_name}
企业背景{company_background}
核心成果{core_results}
成功要素{success_factors}
目标关键词{keywords}
内容风格{content_style}
要求
1. 故事性强有代入感
2. 数据支撑成果
3. 可复制的方法论
4. 适合AI平台引用
5. 字数{word_count}
文章结构
1. 案例背景介绍
2. 面临的挑战
3. 解决方案详述
4. 成果数据展示
5. 成功经验总结
6. 可复制建议
""",
seo_tips=[
"标题包含行业+案例+成果",
"使用数据图表展示成果",
"突出可复制的关键步骤",
],
recommended_platforms=["知乎", "百家号", "公众号"],
word_count_range=(2000, 4000),
required_params=["case_name", "company_background", "core_results", "keywords"],
optional_params=["success_factors", "content_style", "word_count"],
),
# 6. 数据报告
"data_report": TopicTemplate(
id="data_report",
name="数据报告",
description="行业数据和分析报告",
icon="📊",
prompt_template="""
请基于以下信息生成一篇数据报告文章
报告主题{report_topic}
数据范围{data_scope}
核心发现{core_findings}
目标关键词{keywords}
内容风格{content_style}
要求
1. 数据翔实来源可靠
2. 图表丰富直观易懂
3. 深度分析洞察本质
4. 适合AI平台引用
5. 字数{word_count}
文章结构
1. 研究背景与方法
2. 核心数据发现
3. 深度分析
4. 趋势解读
5. 建议与展望
""",
seo_tips=[
"标题包含数据+报告+关键词",
"使用图表增强说服力",
"包含数据来源说明",
],
recommended_platforms=["知乎", "百家号", "公众号"],
word_count_range=(3000, 5000),
required_params=["report_topic", "data_scope", "core_findings", "keywords"],
optional_params=["content_style", "word_count"],
),
# 7. 问题解答
"faq": TopicTemplate(
id="faq",
name="问题解答",
description="常见问题解答",
icon="",
prompt_template="""
请基于以下信息生成一篇FAQ文章
主题{topic}
目标用户{target_audience}
问题列表{questions}
目标关键词{keywords}
内容风格{content_style}
要求
1. 问题覆盖面广击中痛点
2. 回答简洁明了实操性强
3. 适合搜索引擎长尾词
4. 适合AI平台引用
5. 字数{word_count}
文章结构
1. 引言为什么需要了解这些
2. 核心问题解答每个问题独立成段
3. 延伸问题
4. 总结下一步建议
""",
seo_tips=[
"标题包含主题+常见问题/FAQ",
"每个问题使用H2标签",
"覆盖长尾搜索词",
],
recommended_platforms=["知乎", "简书", "小程序"],
word_count_range=(1000, 2000),
required_params=["topic", "questions", "keywords"],
optional_params=["target_audience", "content_style", "word_count"],
),
# 8. 评测报告
"review": TopicTemplate(
id="review",
name="评测报告",
description="产品深度评测:优缺点全面分析",
icon="🔬",
prompt_template="""
请基于以下信息生成一篇评测报告
产品名称{product_name}
评测维度{review_dimensions}
竞品参照{competitor_reference}
目标关键词{keywords}
内容风格{content_style}
要求
1. 客观公正优点不夸大
2. 缺点实事求是
3. 有具体数据和场景支撑
4. 适合AI平台引用
5. 字数{word_count}
文章结构
1. 评测背景与方法
2. 外观设计评价
3. 核心功能评测
4. 性能测试数据
5. 优缺点总结
6. 适合人群分析
7. 购买建议
""",
seo_tips=[
"标题包含产品名+评测/测评",
"使用评分表格直观展示",
"包含具体测试数据",
],
recommended_platforms=["知乎", "小红书", "百家号"],
word_count_range=(2000, 4000),
required_params=["product_name", "review_dimensions", "keywords"],
optional_params=["competitor_reference", "content_style", "word_count"],
),
}
def get_topic_template(topic_id: str) -> Optional[TopicTemplate]:
"""获取母题模板"""
return TOPIC_TEMPLATES.get(topic_id)
def list_topic_templates() -> list[TopicTemplate]:
"""列出所有母题模板"""
return list(TOPIC_TEMPLATES.values())
def render_topic_prompt(topic_id: str, params: dict) -> str:
"""渲染母题Prompt"""
template = get_topic_template(topic_id)
if not template:
raise ValueError(f"Unknown topic: {topic_id}")
# 设置默认参数
params.setdefault("content_style", "专业严谨")
params.setdefault("word_count", 2000)
return template.prompt_template.format(**params)

View File

@ -106,6 +106,11 @@ PLATFORM_RULES: dict[str, dict] = {
"不得使用AI水文内容需有信息增量", "不得使用AI水文内容需有信息增量",
"专业背书内容需有据可查", "专业背书内容需有据可查",
], ],
"image_rules": {
"cover": {"width": 690, "height": 280, "ratio": "2.5:1"},
"inline": {"width": 500, "height": 375, "ratio": "4:3"},
"max_size_mb": 5,
},
"seo_tips": [ "seo_tips": [
"回答开头直接给出结论", "回答开头直接给出结论",
"使用数据和案例支撑", "使用数据和案例支撑",
@ -186,6 +191,11 @@ PLATFORM_RULES: dict[str, dict] = {
"正文不含外部链接(仅支持公众号链接和小程序)", "正文不含外部链接(仅支持公众号链接和小程序)",
"不得包含未经授权的商标/品牌名称", "不得包含未经授权的商标/品牌名称",
], ],
"image_rules": {
"cover": {"width": 900, "height": 383, "ratio": "2.35:1"},
"inline": {"width": 800, "height": 600, "ratio": "4:3"},
"max_size_mb": 10,
},
"seo_tips": [ "seo_tips": [
"首段包含核心关键词", "首段包含核心关键词",
"使用小标题分段(适配搜一搜)", "使用小标题分段(适配搜一搜)",
@ -262,6 +272,11 @@ PLATFORM_RULES: dict[str, dict] = {
"正文需含至少1张配图", "正文需含至少1张配图",
"不得搬运/洗稿", "不得搬运/洗稿",
], ],
"image_rules": {
"cover": {"width": 600, "height": 400, "ratio": "3:2"},
"inline": {"width": 800, "height": 600, "ratio": "4:3"},
"max_size_mb": 5,
},
"seo_tips": [ "seo_tips": [
"标题包含百度搜索热词", "标题包含百度搜索热词",
"文章结构化H2小标题", "文章结构化H2小标题",
@ -341,6 +356,11 @@ PLATFORM_RULES: dict[str, dict] = {
"首发原创优先推荐", "首发原创优先推荐",
"配图清晰不模糊", "配图清晰不模糊",
], ],
"image_rules": {
"cover": {"width": 1024, "height": 678, "ratio": "1.5:1"},
"inline": {"width": 800, "height": 600, "ratio": "4:3"},
"max_size_mb": 5,
},
"seo_tips": [ "seo_tips": [
"标题含核心关键词", "标题含核心关键词",
"文章1500字以上推荐更高", "文章1500字以上推荐更高",
@ -417,6 +437,11 @@ PLATFORM_RULES: dict[str, dict] = {
"话题标签有助于曝光", "话题标签有助于曝光",
"配图有助于转发", "配图有助于转发",
], ],
"image_rules": {
"cover": {"width": 980, "height": 560, "ratio": "1.75:1"},
"inline": {"width": 800, "height": 600, "ratio": "4:3"},
"max_size_mb": 5,
},
"seo_tips": [ "seo_tips": [
"热门话题可增加曝光", "热门话题可增加曝光",
"短句更易阅读", "短句更易阅读",
@ -493,6 +518,11 @@ PLATFORM_RULES: dict[str, dict] = {
"不得出现其他平台引流信息", "不得出现其他平台引流信息",
"图片不含水印", "图片不含水印",
], ],
"image_rules": {
"cover": {"width": 1080, "height": 1080, "ratio": "1:1"},
"inline": {"width": 1242, "height": 1660, "ratio": "3:4"},
"max_size_mb": 10,
},
"seo_tips": [ "seo_tips": [
"标题含数字更吸引点击", "标题含数字更吸引点击",
"正文用短句+emoji分段", "正文用短句+emoji分段",
@ -572,6 +602,11 @@ PLATFORM_RULES: dict[str, dict] = {
"封面和标题很重要", "封面和标题很重要",
"互动有助于推荐", "互动有助于推荐",
], ],
"image_rules": {
"cover": {"width": 1920, "height": 1080, "ratio": "16:9"},
"inline": {"width": 800, "height": 600, "ratio": "4:3"},
"max_size_mb": 5,
},
"seo_tips": [ "seo_tips": [
"标题包含关键词", "标题包含关键词",
"封面吸引人", "封面吸引人",
@ -646,6 +681,11 @@ PLATFORM_RULES: dict[str, dict] = {
"文艺风格更受欢迎", "文艺风格更受欢迎",
"配图有助于阅读", "配图有助于阅读",
], ],
"image_rules": {
"cover": {"width": 800, "height": 600, "ratio": "4:3"},
"inline": {"width": 800, "height": 600, "ratio": "4:3"},
"max_size_mb": 5,
},
"seo_tips": [ "seo_tips": [
"标题包含关键词", "标题包含关键词",
"合理使用专题", "合理使用专题",
@ -721,6 +761,11 @@ PLATFORM_RULES: dict[str, dict] = {
"鼓励原创技术文章", "鼓励原创技术文章",
"禁止低质量搬运", "禁止低质量搬运",
], ],
"image_rules": {
"cover": {"width": 1024, "height": 768, "ratio": "4:3"},
"inline": {"width": 800, "height": 600, "ratio": "4:3"},
"max_size_mb": 5,
},
"seo_tips": [ "seo_tips": [
"标题包含技术关键词", "标题包含技术关键词",
"代码块有助于阅读", "代码块有助于阅读",
@ -796,6 +841,11 @@ PLATFORM_RULES: dict[str, dict] = {
"话题标签2-5个", "话题标签2-5个",
"文案简短有吸引力", "文案简短有吸引力",
], ],
"image_rules": {
"cover": {"width": 1080, "height": 1920, "ratio": "9:16"},
"inline": {"width": 1080, "height": 1920, "ratio": "9:16"},
"max_size_mb": 10,
},
"seo_tips": [ "seo_tips": [
"前3秒决定完播率", "前3秒决定完播率",
"标题含热点关键词", "标题含热点关键词",

View File

@ -0,0 +1,382 @@
"""
邮件通知服务
支持发送告警通知额度预警等邮件
功能:
- 邮件模板引擎: 变量替换渲染邮件内容
- 邮件内容生成: 告警通知额度预警邮件生成
- 邮件发送: 支持真实SMTP和模拟模式
- 邮件队列管理: 批量添加和发送
- 错误处理和重试: 自动重试机制
"""
from __future__ import annotations
import logging
import re
import smtplib
import time
import uuid
from dataclasses import dataclass, field
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from email.mime.base import MIMEBase
from email import encoders
from typing import Any
logger = logging.getLogger(__name__)
EMAIL_TEMPLATES = {
"alert_notification": {
"subject": "[GEO平台] 告警通知:{alert_type}",
"body_html": """
<h2>告警通知</h2>
<p>品牌<strong>{brand_name}</strong></p>
<p>告警类型{alert_type}</p>
<p>严重程度{severity}</p>
<p>详情{description}</p>
<p>时间{timestamp}</p>
""",
"body_text": "告警通知 - 品牌:{brand_name}, 类型:{alert_type}, severity{severity}"
},
"quota_warning": {
"subject": "[GEO平台] 额度预警:{quota_type}",
"body_html": """
<h2>额度预警</h2>
<p>您的{quota_type}使用量已达到{usage_percentage}%</p>
<p>已使用{used} / 总额度{limit}</p>
<p>建议操作{recommended_action}</p>
""",
"body_text": "额度预警 - {quota_type}使用量:{usage_percentage}%"
}
}
@dataclass
class EmailMessage:
to: str
subject: str
body_html: str
body_text: str
attachments: list[dict] = field(default_factory=list)
metadata: dict = field(default_factory=dict)
@dataclass
class EmailSendResult:
success: bool
message_id: str | None
error: str | None
retry_count: int
class EmailService:
"""邮件通知服务
提供邮件模板渲染内容生成发送支持模拟模式队列管理等功能
"""
def __init__(
self,
simulate_mode: bool = True,
smtp_host: str = "localhost",
smtp_port: int = 587,
smtp_user: str = "",
smtp_password: str = "",
max_retries: int = 3,
):
self.simulate_mode = simulate_mode
self.smtp_host = smtp_host
self.smtp_port = smtp_port
self.smtp_user = smtp_user
self.smtp_password = smtp_password
self.max_retries = max_retries
self._queue: list[EmailMessage] = []
def validate_email(self, email: str) -> bool:
"""验证邮箱地址格式
Args:
email: 邮箱地址
Returns:
是否为有效邮箱地址
"""
if not email:
return False
pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
return bool(re.match(pattern, email))
def _safe_format(self, template: str, variables: dict[str, Any]) -> str:
"""安全格式化模板,缺失变量时保留原占位符
Args:
template: 模板字符串
variables: 变量字典
Returns:
格式化后的字符串
"""
try:
return template.format(**variables)
except KeyError:
result = template
for key, value in variables.items():
result = result.replace("{" + key + "}", str(value))
return result
def render_template(
self,
template_name: str,
to: str,
variables: dict[str, Any],
) -> EmailMessage:
"""渲染邮件模板
Args:
template_name: 模板名称
to: 收件人邮箱
variables: 模板变量
Returns:
EmailMessage邮件消息对象
Raises:
ValueError: 模板不存在
"""
if template_name not in EMAIL_TEMPLATES:
raise ValueError(f"模板不存在: {template_name}")
template = EMAIL_TEMPLATES[template_name]
subject = self._safe_format(template["subject"], variables)
body_html = self._safe_format(template["body_html"], variables)
body_text = self._safe_format(template["body_text"], variables)
return EmailMessage(
to=to,
subject=subject,
body_html=body_html,
body_text=body_text,
metadata=variables,
)
def generate_alert_email(
self,
to: str,
alert_type: str,
brand_name: str,
severity: str,
description: str,
timestamp: str,
) -> EmailMessage:
"""生成告警通知邮件
Args:
to: 收件人邮箱
alert_type: 告警类型
brand_name: 品牌名称
severity: 严重程度
description: 告警详情
timestamp: 时间戳
Returns:
EmailMessage邮件消息对象
"""
variables = {
"alert_type": alert_type,
"brand_name": brand_name,
"severity": severity,
"description": description,
"timestamp": timestamp,
}
return self.render_template("alert_notification", to, variables)
def generate_quota_warning_email(
self,
to: str,
quota_type: str,
usage_percentage: int,
used: int,
limit: int,
recommended_action: str,
) -> EmailMessage:
"""生成额度预警邮件
Args:
to: 收件人邮箱
quota_type: 额度类型
usage_percentage: 使用百分比
used: 已使用量
limit: 总额度
recommended_action: 建议操作
Returns:
EmailMessage邮件消息对象
"""
variables = {
"quota_type": quota_type,
"usage_percentage": usage_percentage,
"used": used,
"limit": limit,
"recommended_action": recommended_action,
}
return self.render_template("quota_warning", to, variables)
def add_attachment(self, msg: EmailMessage, filename: str, content: bytes) -> None:
"""添加附件到邮件
Args:
msg: 邮件消息对象
filename: 附件文件名
content: 附件内容
"""
msg.attachments.append({
"filename": filename,
"content": content,
})
def _create_mime_message(self, msg: EmailMessage) -> MIMEMultipart:
"""创建MIME邮件对象
Args:
msg: EmailMessage对象
Returns:
MIMEMultipart MIME邮件对象
"""
mime_msg = MIMEMultipart()
mime_msg["From"] = self.smtp_user
mime_msg["To"] = msg.to
mime_msg["Subject"] = msg.subject
mime_msg.attach(MIMEText(msg.body_html, "html", "utf-8"))
mime_msg.attach(MIMEText(msg.body_text, "plain", "utf-8"))
for attachment in msg.attachments:
part = MIMEBase("application", "octet-stream")
part.set_payload(attachment["content"])
encoders.encode_base64(part)
part.add_header(
"Content-Disposition",
f'attachment; filename="{attachment["filename"]}"',
)
mime_msg.attach(part)
return mime_msg
def send_email(self, msg: EmailMessage) -> EmailSendResult:
"""发送邮件
Args:
msg: EmailMessage邮件消息对象
Returns:
EmailSendResult发送结果对象
"""
if not self.validate_email(msg.to):
logger.warning(f"无效的邮箱地址: {msg.to}")
return EmailSendResult(
success=False,
message_id=None,
error=f"无效的邮箱地址: {msg.to}",
retry_count=0,
)
if self.simulate_mode:
message_id = f"sim_{uuid.uuid4().hex[:8]}"
logger.info(f"[模拟模式] 邮件发送成功: {msg.to}, ID: {message_id}")
return EmailSendResult(
success=True,
message_id=message_id,
error=None,
retry_count=0,
)
retry_count = 0
last_error = None
while retry_count <= self.max_retries:
try:
logger.info(f"尝试发送邮件到 {msg.to} (尝试 {retry_count + 1}/{self.max_retries + 1})")
server = smtplib.SMTP(self.smtp_host, self.smtp_port)
server.starttls()
server.login(self.smtp_user, self.smtp_password)
mime_msg = self._create_mime_message(msg)
server.send_message(mime_msg)
server.quit()
message_id = f"smtp_{uuid.uuid4().hex[:8]}"
logger.info(f"邮件发送成功: {msg.to}, ID: {message_id}")
return EmailSendResult(
success=True,
message_id=message_id,
error=None,
retry_count=retry_count,
)
except smtplib.SMTPException as e:
last_error = f"SMTP错误: {str(e)}"
logger.error(f"SMTP错误 (尝试 {retry_count + 1}/{self.max_retries + 1}): {e}")
retry_count += 1
if retry_count <= self.max_retries:
time.sleep(1 * retry_count)
except Exception as e:
last_error = str(e)
logger.error(f"邮件发送失败 (尝试 {retry_count + 1}/{self.max_retries + 1}): {e}")
retry_count += 1
if retry_count <= self.max_retries:
time.sleep(1 * retry_count)
logger.error(f"邮件发送最终失败,已重试 {retry_count} 次: {msg.to}, 错误: {last_error}")
return EmailSendResult(
success=False,
message_id=None,
error=last_error,
retry_count=retry_count,
)
def add_to_queue(self, msg: EmailMessage) -> None:
"""添加邮件到队列
Args:
msg: EmailMessage邮件消息对象
"""
self._queue.append(msg)
logger.debug(f"邮件已添加到队列: {msg.to}, 队列长度: {len(self._queue)}")
def get_queue(self) -> list[EmailMessage]:
"""获取队列中的邮件
Returns:
邮件消息列表
"""
return self._queue.copy()
def clear_queue(self) -> None:
"""清空队列"""
count = len(self._queue)
self._queue.clear()
logger.info(f"队列已清空,移除了 {count} 封邮件")
def send_queue(self) -> list[EmailSendResult]:
"""发送队列中的所有邮件
Returns:
发送结果列表
"""
results = []
messages = self._queue.copy()
self._queue.clear()
logger.info(f"开始发送队列中的 {len(messages)} 封邮件")
for msg in messages:
result = self.send_email(msg)
results.append(result)
success_count = sum(1 for r in results if r.success)
logger.info(f"队列发送完成: 成功 {success_count}/{len(results)}")
return results

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,186 @@
"""详细健康检查服务"""
import time
from dataclasses import dataclass
from typing import Optional
import redis.asyncio as aioredis
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
@dataclass
class HealthCheckResult:
"""健康检查结果"""
name: str
healthy: bool
latency_ms: Optional[float] = None
message: Optional[str] = None
details: Optional[dict] = None
class HealthChecker:
"""健康检查服务"""
def __init__(self, db: AsyncSession, redis_url: str):
self.db = db
self.redis_url = redis_url
async def check_database(self) -> HealthCheckResult:
"""检查数据库连接"""
start = time.perf_counter()
try:
await self.db.execute(text("SELECT 1"))
latency = (time.perf_counter() - start) * 1000
return HealthCheckResult(
name="database",
healthy=True,
latency_ms=round(latency, 2),
message="Connection OK",
)
except Exception as e:
latency = (time.perf_counter() - start) * 1000
return HealthCheckResult(
name="database",
healthy=False,
latency_ms=round(latency, 2),
message=f"Connection failed: {str(e)}",
)
async def check_redis(self) -> HealthCheckResult:
"""检查Redis连接"""
start = time.perf_counter()
try:
redis = aioredis.from_url(
self.redis_url,
socket_connect_timeout=2,
)
await redis.ping()
await redis.aclose()
latency = (time.perf_counter() - start) * 1000
return HealthCheckResult(
name="redis",
healthy=True,
latency_ms=round(latency, 2),
message="Connection OK",
)
except Exception as e:
latency = (time.perf_counter() - start) * 1000
return HealthCheckResult(
name="redis",
healthy=False,
latency_ms=round(latency, 2),
message=f"Connection failed: {str(e)}",
)
async def check_llm_providers(self) -> HealthCheckResult:
"""检查LLM服务提供商"""
from app.config import settings
from app.services.llm.factory import LLMFactory
providers = {}
all_healthy = True
# 检查默认provider
try:
provider_name = getattr(settings, 'DEFAULT_LLM_PROVIDER', 'openai')
provider = LLMFactory.create(provider_name)
providers[provider_name] = {
"healthy": True,
"available": True,
}
except Exception as e:
providers[getattr(settings, 'DEFAULT_LLM_PROVIDER', 'openai')] = {
"healthy": False,
"error": str(e),
}
all_healthy = False
# 检查所有已注册的provider
for name in LLMFactory.list_providers():
if name not in providers:
try:
provider = LLMFactory.create(name)
providers[name] = {
"healthy": True,
"available": True,
}
except Exception as e:
providers[name] = {
"healthy": False,
"error": str(e),
}
all_healthy = False
return HealthCheckResult(
name="llm_providers",
healthy=all_healthy,
message="All providers healthy" if all_healthy else "Some providers unhealthy",
details={"providers": providers},
)
async def check_storage(self) -> HealthCheckResult:
"""检查存储(本地文件系统)"""
import os
storage_path = "/data/documents"
try:
if os.path.exists(storage_path):
# 检查读写权限
test_file = os.path.join(storage_path, ".health_check")
with open(test_file, "w") as f:
f.write("ok")
os.remove(test_file)
return HealthCheckResult(
name="storage",
healthy=True,
message=f"Storage path {storage_path} is writable",
details={"path": storage_path},
)
else:
return HealthCheckResult(
name="storage",
healthy=True,
message=f"Storage path {storage_path} does not exist (will be created)",
details={"path": storage_path, "created": True},
)
except Exception as e:
return HealthCheckResult(
name="storage",
healthy=False,
message=f"Storage check failed: {str(e)}",
)
async def check_all(self) -> dict:
"""执行所有健康检查"""
import asyncio
# 并行执行所有检查
checks = [
self.check_database(),
self.check_redis(),
self.check_llm_providers(),
self.check_storage(),
]
results = await asyncio.gather(*checks)
# 汇总结果
all_healthy = all(r.healthy for r in results)
return {
"status": "healthy" if all_healthy else "degraded",
"timestamp": time.time(),
"checks": {
r.name: {
"healthy": r.healthy,
"latency_ms": r.latency_ms,
"message": r.message,
"details": r.details,
}
for r in results
},
}

View File

@ -0,0 +1,312 @@
"""阿里云百炼图片生成服务"""
import os
from dataclasses import dataclass
from typing import Optional
import httpx
# 平台尺寸适配
PLATFORM_IMAGE_SPECS = {
"zhihu": {
"cover": {"width": 690, "height": 280, "ratio": "2.5:1"},
"inline": {"width": 500, "height": 375, "ratio": "4:3"},
},
"wechat": {
"cover": {"width": 900, "height": 383, "ratio": "2.35:1"},
"inline": {"width": 800, "height": 600, "ratio": "4:3"},
},
"xiaohongshu": {
"cover": {"width": 1080, "height": 1080, "ratio": "1:1"}, # 方版
"inline": {"width": 1242, "height": 1660, "ratio": "3:4"}, # 竖版
},
"toutiao": {
"cover": {"width": 1024, "height": 678, "ratio": "1.5:1"},
},
"baijiahao": {
"cover": {"width": 600, "height": 400, "ratio": "3:2"},
},
"weibo": {
"cover": {"width": 980, "height": 560, "ratio": "1.75:1"},
},
"bilibili": {
"cover": {"width": 1920, "height": 1080, "ratio": "16:9"},
},
"jianshu": {
"cover": {"width": 800, "height": 600, "ratio": "4:3"},
},
"juejin": {
"cover": {"width": 1024, "height": 768, "ratio": "4:3"},
},
"douyin": {
"cover": {"width": 1080, "height": 1920, "ratio": "9:16"}, # 竖版短视频封面
},
}
# 风格选项
IMAGE_STYLES = {
"modern": {"name": "现代简约", "prompt": "modern minimalist style, clean design, professional"},
"tech": {"name": "科技感", "prompt": "tech style, futuristic, digital, blue tones"},
"elegant": {"name": "优雅商务", "prompt": "elegant business style, sophisticated, premium"},
"creative": {"name": "创意活力", "prompt": "creative vibrant style, colorful, dynamic"},
"minimal": {"name": "极简主义", "prompt": "ultra minimal, white space, typography focus"},
}
# 排版选项
LAYOUT_OPTIONS = {
"centered": {"name": "居中排版", "prompt": "centered composition, text in middle"},
"left_text": {"name": "左文右图", "prompt": "left side text, right side visual"},
"top_text": {"name": "上文下图", "prompt": "text on top, visual below"},
"text_overlay": {"name": "文字叠加", "prompt": "text overlay on background image"},
}
@dataclass
class ImageResult:
"""图片生成结果"""
url: str
width: int
height: int
prompt: str
platform: str
task_id: str
class ImageGenerationError(Exception):
"""图片生成异常"""
pass
class ImageGenerator:
"""阿里云百炼图片生成服务(万相-文生图V1"""
def __init__(self):
self.api_key = os.getenv("ALIYUN_DASHSCOPE_API_KEY")
self.base_url = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text2image/image-synthesis"
self.timeout = 120.0 # 异步任务等待超时时间(秒)
async def generate_cover(
self,
title: str,
platform: str,
image_type: str = "cover",
style: str = "modern",
layout: str = "centered",
custom_prompt: str = None,
) -> ImageResult:
"""生成封面图
Args:
title: 文章标题
platform: 目标平台
image_type: 图片类型 (cover/inline)
style: 风格选项
layout: 排版选项
custom_prompt: 自定义提示词可选
Returns:
ImageResult: 包含生成结果的 dataclass
"""
# 1. 获取平台尺寸
specs = PLATFORM_IMAGE_SPECS.get(platform, PLATFORM_IMAGE_SPECS["zhihu"])
size_spec = specs.get(image_type, specs["cover"])
# 2. 构建提示词
if custom_prompt:
prompt = custom_prompt
else:
prompt = self._build_prompt(title, platform, style, layout)
# 3. 调用百炼API异步
task_id = await self._create_task(prompt, size_spec)
# 4. 轮询结果
result = await self._wait_for_result(task_id)
return ImageResult(
url=result["image_url"],
width=size_spec["width"],
height=size_spec["height"],
prompt=prompt,
platform=platform,
task_id=task_id,
)
def _build_prompt(self, title: str, platform: str, style: str, layout: str) -> str:
"""构建AI提示词
Args:
title: 文章标题
platform: 目标平台
style: 风格选项
layout: 排版选项
Returns:
str: 构造的英文提示词
"""
style_prompt = IMAGE_STYLES.get(style, IMAGE_STYLES["modern"])["prompt"]
layout_prompt = LAYOUT_OPTIONS.get(layout, LAYOUT_OPTIONS["centered"])["prompt"]
# 平台特定要求
platform_notes = {
"xiaohongshu": "warm tones, lifestyle, lifestyle photography",
"wechat": "professional, clean, suitable for WeChat article cover",
"zhihu": "intellectual, professional, suitable for long-form content",
"toutiao": "eye-catching, clear hierarchy, news style",
"baijiahao": "clear, professional, suitable for news media",
"weibo": "social media friendly, engaging",
"bilibili": "anime style friendly, vibrant, suitable for video platform",
"jianshu": "literary, elegant, clean layout",
"juejin": "tech blog style, developer friendly",
"douyin": "vertical video style, eye-catching, short form content",
}
platform_note = platform_notes.get(platform, "")
return f"{title}, {style_prompt}, {layout_prompt}, {platform_note}, high quality, 4K"
async def _create_task(self, prompt: str, size_spec: dict) -> str:
"""创建异步任务
Args:
prompt: 英文提示词
size_spec: 尺寸规格
Returns:
str: 任务ID
Raises:
ImageGenerationError: API调用失败时
"""
if not self.api_key:
raise ImageGenerationError("ALIYUN_DASHSCOPE_API_KEY 未设置")
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
payload = {
"model": "wanx-v1", # 万相-文生图V1
"input": {
"prompt": prompt,
},
"parameters": {
"size": f"{size_spec['width']}x{size_spec['height']}",
"n": 1,
},
}
async with httpx.AsyncClient(timeout=30.0) as client:
try:
response = await client.post(
self.base_url,
headers=headers,
json=payload,
)
response.raise_for_status()
data = response.json()
if data.get("code"):
raise ImageGenerationError(f"API错误: {data.get('message', data.get('code'))}")
# 返回任务ID
task_id = data.get("output", {}).get("task_id")
if not task_id:
raise ImageGenerationError("未获取到任务ID")
return task_id
except httpx.HTTPStatusError as e:
raise ImageGenerationError(f"HTTP错误: {e.response.status_code}")
except httpx.RequestError as e:
raise ImageGenerationError(f"请求错误: {e}")
async def _wait_for_result(self, task_id: str) -> dict:
"""轮询等待结果
Args:
task_id: 任务ID
Returns:
dict: 包含 image_url 的结果
Raises:
ImageGenerationError: 任务失败或超时
"""
status_url = f"{self.base_url}/task/{task_id}"
headers = {
"Authorization": f"Bearer {self.api_key}",
}
async with httpx.AsyncClient(timeout=self.timeout) as client:
import asyncio
start_time = asyncio.get_event_loop().time()
poll_interval = 2.0 # 轮询间隔(秒)
while True:
try:
response = await client.get(status_url, headers=headers)
response.raise_for_status()
data = response.json()
status = data.get("output", {}).get("task_status")
if status == "succeeded":
# 任务成功
images = data.get("output", {}).get("results", [])
if images:
return {"image_url": images[0].get("url")}
raise ImageGenerationError("未获取到生成图片URL")
elif status == "failed":
error_msg = data.get("output", {}).get("message", "任务失败")
raise ImageGenerationError(f"图片生成失败: {error_msg}")
elif status == "pending":
# 等待中,继续轮询
pass
else:
# 未知状态,继续等待
pass
# 检查超时
elapsed = asyncio.get_event_loop().time() - start_time
if elapsed > self.timeout:
raise ImageGenerationError("图片生成超时")
# 等待后继续轮询
await asyncio.sleep(poll_interval)
except httpx.HTTPStatusError as e:
raise ImageGenerationError(f"HTTP错误: {e.response.status_code}")
except httpx.RequestError as e:
raise ImageGenerationError(f"请求错误: {e}")
@staticmethod
def get_platform_specs(platform: str) -> Optional[dict]:
"""获取平台的图片规格
Args:
platform: 平台标识
Returns:
Optional[dict]: 平台图片规格如果没有则返回None
"""
return PLATFORM_IMAGE_SPECS.get(platform)
@staticmethod
def get_supported_platforms() -> list[str]:
"""获取支持的平台列表"""
return list(PLATFORM_IMAGE_SPECS.keys())
@staticmethod
def get_styles() -> dict:
"""获取所有风格选项"""
return IMAGE_STYLES
@staticmethod
def get_layouts() -> dict:
"""获取所有排版选项"""
return LAYOUT_OPTIONS

View File

@ -2,5 +2,21 @@ from .rag_service import RAGService
from .chunker import RecursiveChunker from .chunker import RecursiveChunker
from .embedder import EmbeddingService, OpenAIEmbedder, MockEmbedder from .embedder import EmbeddingService, OpenAIEmbedder, MockEmbedder
from .retriever import HybridRetriever from .retriever import HybridRetriever
from .entity_extractor import EntityExtractor, ExtractionResult, ExtractedEntity, ExtractedRelation
from .graph_builder import GraphBuilder
from .graph_query import GraphQuery
__all__ = ["RAGService", "RecursiveChunker", "EmbeddingService", "OpenAIEmbedder", "MockEmbedder", "HybridRetriever"] __all__ = [
"RAGService",
"RecursiveChunker",
"EmbeddingService",
"OpenAIEmbedder",
"MockEmbedder",
"HybridRetriever",
"EntityExtractor",
"ExtractionResult",
"ExtractedEntity",
"ExtractedRelation",
"GraphBuilder",
"GraphQuery",
]

View File

@ -1,178 +1,212 @@
""" """
RecursiveChunker: 递归语义分块器 分块策略 - 支持多种分块方式
按优先级分隔符段落句子将文档切割为适合embedding的块
""" """
import re import re
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional from typing import Optional
@dataclass
class ChunkStrategy:
"""分块策略配置"""
name: str
description: str
chunk_size: int # 字符数
chunk_overlap: int # 重叠字符数
min_chunk_size: int
class RecursiveChunker: class BaseChunker(ABC):
"""递归语义分块器""" """分块器基类"""
def __init__( STRATEGY: ChunkStrategy = None
self,
chunk_size: int = 512, @abstractmethod
chunk_overlap: int = 50,
min_chunk_size: int = 100,
):
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.min_chunk_size = min_chunk_size
# 分隔符优先级:段落 > 句子 > 词
self.separators = ["\n\n", "\n", "", ".", "", "!", "", "?", "", ";", " "]
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def chunk(self, text: str, metadata: Optional[dict] = None) -> list[dict]: def chunk(self, text: str, metadata: Optional[dict] = None) -> list[dict]:
""" """执行分块"""
将文本递归分块 pass
def preview(self, text: str, max_chunks: int = 5) -> list[str]:
"""预览分块结果"""
chunks = self.chunk(text)
return [c["content"][:200] + "..." if len(c["content"]) > 200 else c["content"]
for c in chunks[:max_chunks]]
Returns: class RecursiveChunker(BaseChunker):
list of dicts: """递归分块器(现有实现)"""
{
"content": str, STRATEGY = ChunkStrategy(
"chunk_index": int, name="recursive",
"token_count": int, description="优先按段落分割,过长时按句子分割",
"metadata": dict, chunk_size=500,
} chunk_overlap=50,
""" min_chunk_size=50,
if not text or not text.strip(): )
return []
# 分割模式(按优先级)
raw_chunks = self._split_recursive(text.strip(), self.separators) SEPARATORS = [
r"\n\n+", # 双换行(段落)
# 合并过短的块 & 添加重叠 r"\n", # 单换行
merged = self._merge_small_chunks(raw_chunks) r"[。!?!?]\s*", # 句子结束
result = [] r"[,;]\s*", # 分句
for idx, content in enumerate(merged): r"\s+", # 空格
result.append( ]
{
"content": content, def chunk(self, text: str, metadata: Optional[dict] = None) -> list[dict]:
"chunk_index": idx,
"token_count": self._estimate_tokens(content),
"metadata": metadata or {},
}
)
return result
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _split_recursive(self, text: str, separators: list[str]) -> list[str]:
"""
递归分割尝试当前分隔符若块太大则用下一级分隔符继续
"""
if not separators:
# 最后手段:按字符强制截断
return self._hard_split(text)
sep = separators[0]
remaining_seps = separators[1:]
# 先用当前分隔符分割
splits = text.split(sep)
# 去掉空串,但保留分隔符语义(拼回去)
splits = [s for s in splits if s.strip()]
if len(splits) <= 1:
# 该分隔符无法分割,尝试下一级
return self._split_recursive(text, remaining_seps)
chunks: list[str] = []
current_buffer = ""
for piece in splits:
candidate = (current_buffer + sep + piece).strip() if current_buffer else piece.strip()
if self._estimate_tokens(candidate) <= self.chunk_size:
current_buffer = candidate
else:
# 当前 buffer 达到 chunk_size
if current_buffer:
# buffer 本身是否太大?若是,递归细分
if self._estimate_tokens(current_buffer) > self.chunk_size:
chunks.extend(self._split_recursive(current_buffer, remaining_seps))
else:
chunks.append(current_buffer)
# piece 单独处理
if self._estimate_tokens(piece) > self.chunk_size:
chunks.extend(self._split_recursive(piece, remaining_seps))
current_buffer = ""
else:
current_buffer = piece.strip()
if current_buffer:
if self._estimate_tokens(current_buffer) > self.chunk_size:
chunks.extend(self._split_recursive(current_buffer, remaining_seps))
else:
chunks.append(current_buffer)
return [c for c in chunks if c.strip()]
def _merge_small_chunks(self, chunks: list[str]) -> list[str]:
"""
合并过短的块< min_chunk_size token并在相邻块间加入重叠文本
"""
if not chunks:
return []
merged: list[str] = []
buffer = chunks[0]
for chunk in chunks[1:]:
if self._estimate_tokens(buffer) < self.min_chunk_size:
buffer = buffer + "\n" + chunk
else:
merged.append(buffer)
# 添加重叠:取上一块末尾若干 token 作为前缀
overlap_text = self._get_overlap_prefix(buffer)
buffer = (overlap_text + "\n" + chunk).strip() if overlap_text else chunk
merged.append(buffer)
return [c for c in merged if c.strip()]
def _get_overlap_prefix(self, text: str) -> str:
"""截取文本末尾作为下一块的重叠前缀(按 token 估算)。"""
if self.chunk_overlap <= 0:
return ""
# 简单实现:按字符比例截取
words = text.split()
if not words:
return ""
# 估算每个词约 1.5 token中英混合
token_per_word = 1.5
overlap_words = int(self.chunk_overlap / token_per_word)
overlap_words = max(1, min(overlap_words, len(words)))
return " ".join(words[-overlap_words:])
def _hard_split(self, text: str) -> list[str]:
"""按字符强制截断(最后手段)。"""
# 粗略1 token ≈ 2 字符(中文)
char_limit = self.chunk_size * 2
chunks = [] chunks = []
start = 0 metadata = metadata or {}
while start < len(text):
end = start + char_limit # 按段落分割
chunks.append(text[start:end]) segments = re.split(r"\n\n+", text)
start = end - self.chunk_overlap * 2 # 加入字符级重叠
if start <= 0: current_chunk = ""
start = end for segment in segments:
if len(current_chunk) + len(segment) <= self.STRATEGY.chunk_size:
current_chunk += segment + "\n\n"
else:
# 当前块足够大,保存
if len(current_chunk.strip()) >= self.STRATEGY.min_chunk_size:
chunks.append({
"content": current_chunk.strip(),
"chunk_index": len(chunks),
"metadata": metadata,
})
# 处理过长段落
if len(segment) > self.STRATEGY.chunk_size:
current_chunk = segment
else:
# 保留重叠
overlap = current_chunk[-self.STRATEGY.chunk_overlap:]
current_chunk = overlap + segment + "\n\n"
# 处理最后一个块
if len(current_chunk.strip()) >= self.STRATEGY.min_chunk_size:
chunks.append({
"content": current_chunk.strip(),
"chunk_index": len(chunks),
"metadata": metadata,
})
return chunks return chunks
def _estimate_tokens(self, text: str) -> int: class SemanticChunker(BaseChunker):
""" """语义分块器 - 按语义边界分割"""
估算 token
规则中文字符每字计 1 token英文单词计 1.3 tokenBPE 碎片系数 STRATEGY = ChunkStrategy(
""" name="semantic",
if not text: description="根据语义边界(标题、段落)自动分块",
return 0 chunk_size=800,
chunk_overlap=100,
min_chunk_size=100,
)
# 语义边界模式
SEMANTIC_PATTERNS = [
(r"^#{1,6}\s+(.+)$", "heading"), # Markdown标题
(r"^【(.+?)】\s*$", "heading"), # 中文标题
(r"^第[一二三四五六七八九十百]+[章节条]", "heading"), # 章节标题
(r"^(\d+\.)+\s+", "heading"), # 数字编号
]
def chunk(self, text: str, metadata: Optional[dict] = None) -> list[dict]:
chunks = []
metadata = metadata or {}
lines = text.split("\n")
current_chunk = ""
current_section = None
for line in lines:
# 检查是否是语义边界
is_boundary = False
for pattern, boundary_type in self.SEMANTIC_PATTERNS:
if re.match(pattern, line.strip()):
is_boundary = True
current_section = line.strip()
break
# 如果是边界且当前块不为空,保存
if is_boundary and current_chunk.strip():
chunks.append({
"content": current_chunk.strip(),
"chunk_index": len(chunks),
"section": current_section,
"metadata": metadata,
})
# 保留重叠
overlap = current_chunk[-self.STRATEGY.chunk_overlap:]
current_chunk = overlap + line + "\n"
else:
current_chunk += line + "\n"
# 检查块大小
if len(current_chunk) >= self.STRATEGY.chunk_size:
chunks.append({
"content": current_chunk.strip(),
"chunk_index": len(chunks),
"section": current_section,
"metadata": metadata,
})
overlap = current_chunk[-self.STRATEGY.chunk_overlap:]
current_chunk = overlap
# 处理最后一个块
if current_chunk.strip():
chunks.append({
"content": current_chunk.strip(),
"chunk_index": len(chunks),
"section": current_section,
"metadata": metadata,
})
return chunks
# 中文字符计数 class FixedLengthChunker(BaseChunker):
chinese_chars = len(re.findall(r"[\u4e00-\u9fff\u3400-\u4dbf]", text)) """固定长度分块器"""
# 去掉中文后,计算英文单词数
non_chinese = re.sub(r"[\u4e00-\u9fff\u3400-\u4dbf]", " ", text) STRATEGY = ChunkStrategy(
english_words = len(non_chinese.split()) name="fixed",
description="按固定长度强制分块",
chunk_size=300,
chunk_overlap=30,
min_chunk_size=50,
)
def chunk(self, text: str, metadata: Optional[dict] = None) -> list[dict]:
chunks = []
metadata = metadata or {}
# 移除多余空白
text = re.sub(r"\s+", " ", text)
for i in range(0, len(text), self.STRATEGY.chunk_size - self.STRATEGY.chunk_overlap):
chunk_text = text[i:i + self.STRATEGY.chunk_size]
if len(chunk_text.strip()) >= self.STRATEGY.min_chunk_size:
chunks.append({
"content": chunk_text.strip(),
"chunk_index": len(chunks),
"metadata": metadata,
})
return chunks
return int(chinese_chars + english_words * 1.3) class ChunkerFactory:
"""分块策略工厂"""
STRATEGIES = {
"recursive": RecursiveChunker,
"semantic": SemanticChunker,
"fixed": FixedLengthChunker,
}
@classmethod
def create(cls, strategy: str = "recursive") -> BaseChunker:
"""创建分块器"""
chunker_cls = cls.STRATEGIES.get(strategy, RecursiveChunker)
return chunker_cls()
@classmethod
def list_strategies(cls) -> list[ChunkStrategy]:
"""列出所有策略"""
return [chunker_cls.STRATEGY for chunker_cls in cls.STRATEGIES.values()]

View File

@ -0,0 +1,149 @@
"""
增强版RAG服务 - 包含重排序和上下文压缩
"""
import re
from typing import Optional
from app.services.llm.factory import LLMFactory
class EnhancedRAG:
"""增强版RAG检索服务"""
def __init__(self, rag_service, embedder):
self.rag = rag_service
self.embedder = embedder
self.llm = LLMFactory.create()
async def retrieve_with_rerank(
self,
session,
query: str,
kb_ids: list[str],
top_k: int = 5,
use_rerank: bool = True,
use_compression: bool = False,
) -> list[dict]:
"""
增强检索流程
1. 初始检索扩大候选集
2. 可选重排序
3. 可选上下文压缩
"""
# Step 1: 初始检索
initial_k = top_k * 4 if use_rerank else top_k
candidates = await self.rag.search(
session, query, kb_ids, top_k=initial_k
)
if not candidates:
return []
# Step 2: 可选重排序
if use_rerank and len(candidates) > top_k:
candidates = await self._rerank(query, candidates, top_k)
# Step 3: 可选上下文压缩
if use_compression:
candidates = await self._compress(candidates, query)
return candidates[:top_k]
async def _rerank(
self,
query: str,
candidates: list[dict],
top_k: int,
) -> list[dict]:
"""
使用LLM进行相关性重排序
对每个候选计算与查询的相关性分数然后排序
"""
reranked = []
for item in candidates:
# 提取候选内容片段
content = item.get("content", "")[:500] # 限制长度
# 构建评估Prompt
prompt = f"""评估以下查询与文档片段的相关性。
查询{query}
文档片段
{content}
请只返回一个0到1之间的小数表示相关性分数0表示完全不相关1表示完全相关只返回数字"""
try:
response = await self.llm.generate(prompt)
# 提取数字
match = re.search(r'0?\.\d+', response)
relevance = float(match.group()) if match else 0.5
except Exception:
relevance = item.get("score", 0.5)
item["relevance_score"] = relevance
reranked.append(item)
# 按相关性分数降序排序
reranked.sort(key=lambda x: x["relevance_score"], reverse=True)
return reranked
async def _compress(
self,
candidates: list[dict],
query: str,
max_context_tokens: int = 2000,
) -> list[dict]:
"""
上下文压缩
从每个chunk中提取与query相关的内容减少token消耗
"""
compressed = []
current_tokens = 0
for chunk in candidates:
content = chunk.get("content", "")
# 估算token中文约1.5字符/token
est_tokens = len(content) // 2
if current_tokens + est_tokens <= max_context_tokens:
chunk["compressed"] = False
compressed.append(chunk)
current_tokens += est_tokens
else:
# 尝试压缩这个chunk
compressed_chunk = await self._compress_chunk(content, query)
chunk["compressed"] = True
chunk["compressed_content"] = compressed_chunk
compressed.append(chunk)
break
return compressed
async def _compress_chunk(
self,
content: str,
query: str,
) -> str:
"""压缩单个chunk保留与query相关的内容"""
prompt = f"""从以下文本中提取与问题最相关的部分,保持原文的表达方式,不要总结或改写。
问题{query}
原文
{content}
直接返回提取的内容不要解释"""
try:
result = await self.llm.generate(prompt)
return result.strip()
except Exception:
# 压缩失败时返回原文
return content[:500] + "..."

View File

@ -0,0 +1,168 @@
"""实体和关系抽取服务"""
import re
import json
from typing import Optional
from dataclasses import dataclass, field
from app.services.llm.factory import LLMFactory
@dataclass
class ExtractedEntity:
"""抽取的实体"""
name: str
entity_type: str
description: Optional[str] = None
properties: dict = field(default_factory=dict)
@dataclass
class ExtractedRelation:
"""抽取的关系"""
source_entity: str
target_entity: str
relation_type: str
properties: dict = field(default_factory=dict)
@dataclass
class ExtractionResult:
"""抽取结果"""
entities: list[ExtractedEntity] = field(default_factory=list)
relations: list[ExtractedRelation] = field(default_factory=list)
class EntityExtractor:
"""实体和关系抽取服务"""
# 实体类型映射
ENTITY_TYPES = [
"ORGANIZATION", # 公司/组织
"PRODUCT", # 产品
"PERSON", # 人物
"LOCATION", # 地点
"TECHNOLOGY", # 技术
"BRAND", # 品牌
"EVENT", # 事件
"CONCEPT", # 概念
]
# 关系类型映射
RELATION_TYPES = [
"COMPETES_WITH", # 竞争对手
"PARTNERS_WITH", # 合作伙伴
"PRODUCES", # 生产
"USES_TECHNOLOGY", # 使用技术
"LOCATED_IN", # 位于
"FOUNDED_IN", # 成立于
"CEO_OF", # CEO
"FOUNDER_OF", # 创始人
"RELATED_TO", # 相关
"PART_OF", # 属于
]
def __init__(self):
self.llm = LLMFactory.create()
async def extract(self, text: str, context: Optional[str] = None) -> ExtractionResult:
"""
从文本中抽取实体和关系
Args:
text: 待处理的文本
context: 可选的上下文信息如品牌名行业等
Returns:
ExtractionResult: 包含实体和关系的抽取结果
"""
# 构建抽取Prompt
prompt = self._build_extraction_prompt(text, context)
# 调用LLM
response = await self.llm.generate(prompt)
# 解析结果
return self._parse_response(response)
def _build_extraction_prompt(self, text: str, context: Optional[str] = None) -> str:
"""构建抽取Prompt"""
entity_types = "\n".join([f"- {t}" for t in self.ENTITY_TYPES])
relation_types = "\n".join([f"- {t}" for t in self.RELATION_TYPES])
context_hint = f"\n\n附加上下文:{context}" if context else ""
return f"""从以下文本中抽取知识图谱的实体和关系。
要求
1. 实体必须从文本中明确提及不能臆造
2. 关系必须有文本依据不能臆造
3. 每个实体和关系都要有置信度说明high/medium/low
实体类型
{entity_types}
关系类型
{relation_types}
{context_hint}
文本内容
{text}
请以JSON格式返回结果
{{
"entities": [
{{
"name": "实体名称",
"entity_type": "实体类型",
"description": "实体描述(可选)",
"confidence": "high/medium/low"
}}
],
"relations": [
{{
"source_entity": "源实体名称",
"target_entity": "目标实体名称",
"relation_type": "关系类型",
"confidence": "high/medium/low"
}}
]
}}
只返回JSON不要有其他内容"""
def _parse_response(self, response: str) -> ExtractionResult:
"""解析LLM返回的结果"""
# 提取JSON
json_match = re.search(r'\{[\s\S]*\}', response)
if not json_match:
return ExtractionResult(entities=[], relations=[])
try:
data = json.loads(json_match.group())
except json.JSONDecodeError:
return ExtractionResult(entities=[], relations=[])
# 解析实体
entities = [
ExtractedEntity(
name=e["name"],
entity_type=e["entity_type"],
description=e.get("description"),
properties={"confidence": e.get("confidence", "medium")},
)
for e in data.get("entities", [])
]
# 解析关系
relations = [
ExtractedRelation(
source_entity=r["source_entity"],
target_entity=r["target_entity"],
relation_type=r["relation_type"],
properties={"confidence": r.get("confidence", "medium")},
)
for r in data.get("relations", [])
]
return ExtractionResult(entities=entities, relations=relations)

View File

@ -0,0 +1,187 @@
"""知识图谱构建服务"""
from typing import Optional
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.knowledge_graph import (
KnowledgeEntity,
KnowledgeRelation,
EntityType,
RelationType,
)
from app.models.knowledge import KnowledgeChunk, KnowledgeDocument
from app.services.knowledge.entity_extractor import EntityExtractor, ExtractionResult
class GraphBuilder:
"""知识图谱构建服务"""
def __init__(self):
self.extractor = EntityExtractor()
async def build_from_chunk(
self,
session: AsyncSession,
chunk_id: str,
context: Optional[str] = None,
) -> dict:
"""
从Chunk构建知识图谱
Args:
session: 数据库会话
chunk_id: Chunk ID
context: 可选的上下文如品牌名
Returns:
构建统计信息
"""
# 1. 获取Chunk内容
chunk = await session.get(KnowledgeChunk, chunk_id)
if not chunk:
raise ValueError(f"Chunk not found: {chunk_id}")
# 2. 抽取实体和关系
result = await self.extractor.extract(chunk.content, context)
# 3. 存储到图谱
stats = await self._store_extraction(session, chunk_id, result)
return stats
async def _store_extraction(
self,
session: AsyncSession,
chunk_id: str,
result: ExtractionResult,
) -> dict:
"""存储抽取结果"""
stats = {
"entities_created": 0,
"entities_existing": 0,
"relations_created": 0,
"relations_existing": 0,
}
# 实体名称到ID的映射
entity_map = {}
# 4. 存储实体
for extracted_entity in result.entities:
# 检查是否已存在
existing, created = await self._get_or_create_entity(
session,
chunk_id,
extracted_entity,
)
entity_map[extracted_entity.name] = existing.id
if created:
stats["entities_created"] += 1
else:
stats["entities_existing"] += 1
# 5. 存储关系
for extracted_relation in result.relations:
# 查找实体ID
source_id = entity_map.get(extracted_relation.source_entity)
target_id = entity_map.get(extracted_relation.target_entity)
if not source_id or not target_id:
continue # 跳过找不到实体的情况
# 创建关系
created = await self._create_relation(
session,
chunk_id,
source_id,
target_id,
extracted_relation,
)
if created:
stats["relations_created"] += 1
else:
stats["relations_existing"] += 1
await session.commit()
return stats
async def _get_or_create_entity(
self,
session: AsyncSession,
chunk_id: str,
extracted_entity,
) -> tuple:
"""获取或创建实体"""
kb_id = await self._get_chunk_kb_id(session, chunk_id)
# 查找现有实体
stmt = select(KnowledgeEntity).where(
KnowledgeEntity.knowledge_base_id == kb_id,
KnowledgeEntity.name == extracted_entity.name,
)
result = await session.execute(stmt)
existing = result.scalar_one_or_none()
if existing:
return (existing, False)
# 创建新实体
entity = KnowledgeEntity(
knowledge_base_id=kb_id,
name=extracted_entity.name,
entity_type=EntityType(extracted_entity.entity_type),
description=extracted_entity.description,
properties=extracted_entity.properties or {},
source_chunk_id=chunk_id,
confidence=extracted_entity.properties.get("confidence") if extracted_entity.properties else None,
)
session.add(entity)
await session.flush()
return (entity, True)
async def _create_relation(
self,
session: AsyncSession,
chunk_id: str,
source_id: str,
target_id: str,
extracted_relation,
) -> bool:
"""创建关系(如果不存在)"""
# 检查是否已存在
stmt = select(KnowledgeRelation).where(
KnowledgeRelation.source_entity_id == source_id,
KnowledgeRelation.target_entity_id == target_id,
KnowledgeRelation.relation_type == RelationType(extracted_relation.relation_type),
)
result = await session.execute(stmt)
existing = result.scalar_one_or_none()
if existing:
return False
# 创建关系
relation = KnowledgeRelation(
source_entity_id=source_id,
target_entity_id=target_id,
relation_type=RelationType(extracted_relation.relation_type),
properties=extracted_relation.properties or {},
source_chunk_id=chunk_id,
confidence=extracted_relation.properties.get("confidence") if extracted_relation.properties else None,
)
session.add(relation)
return True
async def _get_chunk_kb_id(self, session: AsyncSession, chunk_id: str) -> str:
"""获取Chunk所属的知识库ID"""
chunk = await session.get(KnowledgeChunk, chunk_id)
if not chunk:
raise ValueError(f"Chunk not found: {chunk_id}")
# 通过document获取kb_id
doc = await session.get(KnowledgeDocument, chunk.document_id)
return doc.knowledge_base_id

View File

@ -0,0 +1,249 @@
"""知识图谱查询服务"""
from typing import Optional
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.models.knowledge_graph import (
KnowledgeEntity,
KnowledgeRelation,
)
class GraphQuery:
"""知识图谱查询服务"""
async def get_entity(
self,
session: AsyncSession,
entity_id: str,
) -> Optional[dict]:
"""根据ID获取实体详情"""
entity = await session.get(KnowledgeEntity, entity_id)
if not entity:
return None
return self._entity_to_dict(entity)
async def search_entities(
self,
session: AsyncSession,
kb_id: str,
query: str,
entity_type: Optional[str] = None,
limit: int = 20,
) -> list[dict]:
"""搜索实体"""
stmt = select(KnowledgeEntity).where(
KnowledgeEntity.knowledge_base_id == kb_id,
KnowledgeEntity.name.ilike(f"%{query}%"),
)
if entity_type:
stmt = stmt.where(KnowledgeEntity.entity_type == entity_type)
stmt = stmt.limit(limit)
result = await session.execute(stmt)
return [self._entity_to_dict(e) for e in result.scalars()]
async def get_entity_neighbors(
self,
session: AsyncSession,
entity_id: str,
max_depth: int = 1,
) -> dict:
"""获取实体的邻居(直接关联的实体)"""
entity = await session.get(KnowledgeEntity, entity_id)
if not entity:
return None
neighbors = {
"entity": self._entity_to_dict(entity),
"incoming": [], # 入边(别人指向我)
"outgoing": [], # 出边(我指向别人)
}
# 获取入边
incoming_stmt = (
select(KnowledgeRelation, KnowledgeEntity)
.join(KnowledgeEntity, KnowledgeRelation.source_entity_id == KnowledgeEntity.id)
.where(KnowledgeRelation.target_entity_id == entity_id)
)
incoming_result = await session.execute(incoming_stmt)
for rel, source_entity in incoming_result:
neighbors["incoming"].append({
"relation": self._relation_to_dict(rel),
"entity": self._entity_to_dict(source_entity),
})
# 获取出边
outgoing_stmt = (
select(KnowledgeRelation, KnowledgeEntity)
.join(KnowledgeEntity, KnowledgeRelation.target_entity_id == KnowledgeEntity.id)
.where(KnowledgeRelation.source_entity_id == entity_id)
)
outgoing_result = await session.execute(outgoing_stmt)
for rel, target_entity in outgoing_result:
neighbors["outgoing"].append({
"relation": self._relation_to_dict(rel),
"entity": self._entity_to_dict(target_entity),
})
return neighbors
async def get_entity_path(
self,
session: AsyncSession,
source_name: str,
target_name: str,
max_hops: int = 3,
) -> list[dict]:
"""
查找两个实体之间的路径
使用简单BFS查找路径
"""
# 获取实体ID
source_stmt = select(KnowledgeEntity).where(
KnowledgeEntity.name == source_name
)
source_result = await session.execute(source_stmt)
source_entity = source_result.scalar_one_or_none()
target_stmt = select(KnowledgeEntity).where(
KnowledgeEntity.name == target_name
)
target_result = await session.execute(target_stmt)
target_entity = target_result.scalar_one_or_none()
if not source_entity or not target_entity:
return []
# BFS查找路径
visited = {str(source_entity.id)}
queue = [(str(source_entity.id), [])]
while queue:
current_id, path = queue.pop(0)
if current_id == str(target_entity.id):
# 找到路径,返回
return await self._format_path(path, session)
if len(path) >= max_hops:
continue
# 探索邻居
neighbors_stmt = (
select(KnowledgeRelation, KnowledgeEntity)
.join(KnowledgeEntity, KnowledgeRelation.target_entity_id == KnowledgeEntity.id)
.where(KnowledgeRelation.source_entity_id == current_id)
)
neighbors_result = await session.execute(neighbors_stmt)
for rel, neighbor in neighbors_result:
neighbor_id = str(neighbor.id)
if neighbor_id not in visited:
visited.add(neighbor_id)
new_path = path + [{
"from": current_id,
"relation": rel.relation_type.value,
"to": neighbor_id,
}]
queue.append((neighbor_id, new_path))
return []
async def get_statistics(
self,
session: AsyncSession,
kb_id: str,
) -> dict:
"""获取图谱统计信息"""
# 实体数量
entity_count_stmt = select(func.count()).where(
KnowledgeEntity.knowledge_base_id == kb_id
)
entity_count_result = await session.execute(entity_count_stmt)
entity_count = entity_count_result.scalar() or 0
# 关系数量
kb_entities = select(KnowledgeEntity.id).where(
KnowledgeEntity.knowledge_base_id == kb_id
)
relation_count_stmt = select(func.count()).where(
KnowledgeRelation.source_entity_id.in_(kb_entities)
)
relation_count_result = await session.execute(relation_count_stmt)
relation_count = relation_count_result.scalar() or 0
# 实体类型分布
type_dist_stmt = (
select(
KnowledgeEntity.entity_type,
func.count()
)
.where(KnowledgeEntity.knowledge_base_id == kb_id)
.group_by(KnowledgeEntity.entity_type)
)
type_dist_result = await session.execute(type_dist_stmt)
entity_type_dist = {str(k): v for k, v in type_dist_result}
# 关系类型分布
rel_type_dist_stmt = (
select(
KnowledgeRelation.relation_type,
func.count()
)
.where(KnowledgeRelation.source_entity_id.in_(kb_entities))
.group_by(KnowledgeRelation.relation_type)
)
rel_type_dist_result = await session.execute(rel_type_dist_stmt)
relation_type_dist = {str(k): v for k, v in rel_type_dist_result}
return {
"entity_count": entity_count,
"relation_count": relation_count,
"entity_type_distribution": entity_type_dist,
"relation_type_distribution": relation_type_dist,
}
def _entity_to_dict(self, entity: KnowledgeEntity) -> dict:
"""实体转字典"""
return {
"id": str(entity.id),
"name": entity.name,
"entity_type": entity.entity_type.value,
"description": entity.description,
"properties": entity.properties,
"confidence": entity.confidence,
}
def _relation_to_dict(self, relation: KnowledgeRelation) -> dict:
"""关系转字典"""
return {
"id": str(relation.id),
"source_id": str(relation.source_entity_id),
"target_id": str(relation.target_entity_id),
"relation_type": relation.relation_type.value,
"properties": relation.properties,
"confidence": relation.confidence,
}
async def _format_path(self, path: list, session: AsyncSession) -> list[dict]:
"""格式化路径,返回实体名称"""
formatted = []
for step in path:
# 获取实体名称
from_entity = await session.get(KnowledgeEntity, step["from"])
to_entity = await session.get(KnowledgeEntity, step["to"])
formatted.append({
"from": from_entity.name if from_entity else step["from"],
"relation": step["relation"],
"to": to_entity.name if to_entity else step["to"],
})
return formatted

View File

@ -0,0 +1,242 @@
"""
增量索引服务
"""
import hashlib
from typing import Optional
from sqlalchemy import delete, func, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.knowledge import KnowledgeChunk, KnowledgeDocument
class IncrementalIndexService:
"""增量索引服务 - 支持文档的增删改"""
def __init__(self, rag_service):
self.rag = rag_service
async def add_document(
self,
session: AsyncSession,
kb_id: str,
document_id: str,
) -> dict:
"""
增量添加文档
不重建全量索引只处理单个文档
"""
# 检查是否已存在
existing = await self._get_document_status(session, document_id)
if existing and existing.get("status") == "ready":
return {"action": "skip", "reason": "already_indexed"}
# 执行增量摄入
chunk_count = await self.rag.ingest_document(session, document_id)
return {
"action": "indexed",
"document_id": document_id,
"chunk_count": chunk_count,
}
async def update_document(
self,
session: AsyncSession,
document_id: str,
new_content: str,
) -> dict:
"""
增量更新文档
策略
1. 计算新内容hash
2. 若hash未变跳过
3. 若hash改变删除旧chunks生成新的
"""
# 计算新hash
new_hash = hashlib.sha256(new_content.encode()).hexdigest()
# 获取旧hash
old_hash = await self._get_content_hash(session, document_id)
if new_hash == old_hash:
return {"action": "skip", "reason": "content_unchanged"}
# 删除旧chunks
deleted = await self._delete_document_chunks(session, document_id)
# 更新文档内容
await self._update_document_content(session, document_id, new_content, new_hash)
# 增量摄入新内容
chunk_count = await self.rag.ingest_document(session, document_id)
return {
"action": "updated",
"document_id": document_id,
"deleted_chunks": deleted,
"new_chunks": chunk_count,
}
async def delete_document(
self,
session: AsyncSession,
document_id: str,
) -> dict:
"""
删除文档
删除文档及其所有chunks
"""
# 删除chunks
deleted = await self._delete_document_chunks(session, document_id)
# 删除文档记录
await self._delete_document(session, document_id)
return {
"action": "deleted",
"document_id": document_id,
"deleted_chunks": deleted,
}
async def rebuild_knowledge_base(
self,
session: AsyncSession,
kb_id: str,
force: bool = False,
) -> dict:
"""
重建知识库索引
Args:
force: 是否强制重建即使状态是ready
"""
# 获取所有文档
stmt = select(KnowledgeDocument).where(
KnowledgeDocument.knowledge_base_id == kb_id
)
result = await session.execute(stmt)
documents = result.scalars().all()
stats = {
"total": len(documents),
"processed": 0,
"skipped": 0,
"failed": 0,
"errors": [],
}
for doc in documents:
try:
if doc.status == "ready" and not force:
stats["skipped"] += 1
continue
# 删除旧chunks
await self._delete_document_chunks(session, str(doc.id))
# 重新摄入
await self.rag.ingest_document(session, str(doc.id))
stats["processed"] += 1
except Exception as e:
stats["failed"] += 1
stats["errors"].append({
"document_id": str(doc.id),
"error": str(e),
})
return stats
# ------------------------------------------------------------------
# 辅助方法
# ------------------------------------------------------------------
async def _get_document_status(
self,
session: AsyncSession,
document_id: str,
) -> Optional[dict]:
"""获取文档状态"""
stmt = select(KnowledgeDocument).where(
KnowledgeDocument.id == document_id
)
result = await session.execute(stmt)
doc = result.scalar_one_or_none()
if not doc:
return None
return {
"status": doc.status,
"content_hash": getattr(doc, "content_hash", None),
}
async def _get_content_hash(
self,
session: AsyncSession,
document_id: str,
) -> Optional[str]:
"""获取文档内容hash"""
status = await self._get_document_status(session, document_id)
return status.get("content_hash") if status else None
async def _delete_document_chunks(
self,
session: AsyncSession,
document_id: str,
) -> int:
"""删除文档的所有chunks"""
# 统计要删除的数量
count_stmt = select(func.count()).where(
KnowledgeChunk.document_id == document_id
)
count_result = await session.execute(count_stmt)
count = count_result.scalar() or 0
# 删除
delete_stmt = delete(KnowledgeChunk).where(
KnowledgeChunk.document_id == document_id
)
await session.execute(delete_stmt)
return count
async def _update_document_content(
self,
session: AsyncSession,
document_id: str,
content: str,
content_hash: str,
):
"""更新文档内容和hash"""
stmt = (
update(KnowledgeDocument)
.where(KnowledgeDocument.id == document_id)
.values(
content=content,
content_hash=content_hash,
status="pending", # 标记为待处理
)
)
await session.execute(stmt)
async def _delete_document(
self,
session: AsyncSession,
document_id: str,
):
"""删除文档记录"""
stmt = select(KnowledgeDocument).where(
KnowledgeDocument.id == document_id
)
result = await session.execute(stmt)
doc = result.scalar_one_or_none()
if doc:
await session.delete(doc)

View File

@ -0,0 +1,165 @@
"""文档解析器 - 支持多种格式"""
import io
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional
@dataclass
class ParsedDocument:
"""解析后的文档"""
title: str
content: str
metadata: dict
class BaseParser(ABC):
"""解析器基类"""
@abstractmethod
async def parse(self, content: bytes) -> ParsedDocument:
"""解析文档内容"""
pass
class PDFParser(BaseParser):
"""PDF解析器"""
async def parse(self, content: bytes) -> ParsedDocument:
"""使用PyMuPDF解析PDF"""
import fitz
doc = fitz.open(stream=content)
text_parts = []
metadata = {}
# 提取元数据
if doc.metadata:
metadata = {
"author": doc.metadata.get("author", ""),
"title": doc.metadata.get("title", ""),
"subject": doc.metadata.get("subject", ""),
}
# 提取每页文本
for page_num, page in enumerate(doc):
text = page.get_text()
if text.strip():
text_parts.append(f"[第{page_num + 1}页]\n{text}")
# 提取目录(如果存在)
toc = doc.get_toc()
if toc:
metadata["has_toc"] = True
metadata["toc_items"] = len(toc)
doc.close()
return ParsedDocument(
title=metadata.get("title", "未命名文档") or "未命名文档",
content="\n\n".join(text_parts),
metadata=metadata,
)
class DocxParser(BaseParser):
"""Word文档解析器"""
async def parse(self, content: bytes) -> ParsedDocument:
"""使用python-docx解析Word"""
from docx import Document
doc = Document(io.BytesIO(content))
paragraphs = []
metadata = {}
# 提取核心属性
core_props = doc.core_properties
metadata = {
"author": getattr(core_props, "author", "") or "",
"title": getattr(core_props, "title", "") or "",
"subject": getattr(core_props, "subject", "") or "",
"created": str(getattr(core_props, "created", "")) or "",
"modified": str(getattr(core_props, "modified", "")) or "",
}
# 提取段落
for para in doc.paragraphs:
text = para.text.strip()
if text:
paragraphs.append(text)
# 提取表格
for i, table in enumerate(doc.tables):
table_text = []
for row in table.rows:
cells = [cell.text.strip() for cell in row.cells]
if any(cells):
table_text.append(" | ".join(cells))
if table_text:
paragraphs.append(f"[表格{i+1}]\n" + "\n".join(table_text))
return ParsedDocument(
title=metadata.get("title", "未命名文档") or "未命名文档",
content="\n\n".join(paragraphs),
metadata=metadata,
)
class MarkdownParser(BaseParser):
"""Markdown解析器"""
async def parse(self, content: bytes) -> ParsedDocument:
"""解析Markdown"""
text = content.decode("utf-8")
# 提取标题(第一个#开头的内容)
lines = text.split("\n")
title = "未命名文档"
for line in lines:
line = line.strip()
if line.startswith("# "):
title = line[2:].strip()
break
return ParsedDocument(
title=title,
content=text,
metadata={"format": "markdown"},
)
class TextParser(BaseParser):
"""纯文本解析器"""
async def parse(self, content: bytes) -> ParsedDocument:
"""解析纯文本"""
text = content.decode("utf-8")
# 使用第一行作为标题
lines = text.split("\n")
title = lines[0][:50] if lines else "未命名文档"
return ParsedDocument(
title=title,
content=text,
metadata={"format": "text"},
)
class ParserFactory:
"""解析器工厂"""
PARSERS = {
".pdf": PDFParser,
".docx": DocxParser,
".md": MarkdownParser,
".txt": TextParser,
".html": MarkdownParser, # HTML当Markdown处理
}
@classmethod
def create(cls, file_extension: str) -> BaseParser:
"""创建解析器"""
parser_cls = cls.PARSERS.get(file_extension.lower())
if not parser_cls:
raise ValueError(f"Unsupported format: {file_extension}")
return parser_cls()
@classmethod
def supported_formats(cls) -> list[str]:
"""支持的格式"""
return list(cls.PARSERS.keys())

View File

@ -0,0 +1,96 @@
"""文档上传服务"""
import hashlib
import os
from dataclasses import dataclass
from typing import Optional
import shortuuid
from app.services.knowledge.parsers import ParserFactory, ParsedDocument
@dataclass
class UploadResult:
"""上传结果"""
document_id: str
title: str
content: str
content_hash: str
file_size: int
file_type: str
metadata: dict
class DocumentUploader:
"""文档上传服务"""
SUPPORTED_FORMATS = {".pdf", ".docx", ".md", ".txt", ".html"}
MAX_SIZE_MB = 50
def __init__(self, storage_path: str = "/data/documents"):
self.storage_path = storage_path
async def upload(
self,
file_content: bytes,
filename: str,
kb_id: str,
) -> UploadResult:
"""上传并解析文档"""
# 1. 验证格式
ext = self._get_extension(filename)
if ext not in self.SUPPORTED_FORMATS:
raise ValueError(f"Unsupported format: {ext}")
# 2. 验证大小
size_mb = len(file_content) / (1024 * 1024)
if size_mb > self.MAX_SIZE_MB:
raise ValueError(f"File too large: {size_mb:.1f}MB > {self.MAX_SIZE_MB}MB")
# 3. 解析文档
parser = ParserFactory.create(ext)
parsed = await parser.parse(file_content)
# 4. 生成ID和哈希
doc_id = shortuuid.uuid()
content_hash = hashlib.sha256(parsed.content.encode()).hexdigest()
# 5. 保存文件
file_path = self._save_file(file_content, kb_id, doc_id, ext)
return UploadResult(
document_id=doc_id,
title=parsed.title,
content=parsed.content,
content_hash=content_hash,
file_size=len(file_content),
file_type=ext,
metadata={
**parsed.metadata,
"original_filename": filename,
"stored_path": file_path,
},
)
def _get_extension(self, filename: str) -> str:
"""获取文件扩展名"""
if "." not in filename:
raise ValueError("File has no extension")
return "." + filename.rsplit(".", 1)[1].lower()
def _save_file(
self,
content: bytes,
kb_id: str,
doc_id: str,
ext: str
) -> str:
"""保存文件到存储路径"""
# 创建目录
dir_path = os.path.join(self.storage_path, kb_id)
os.makedirs(dir_path, exist_ok=True)
# 保存文件
file_path = os.path.join(dir_path, f"{doc_id}{ext}")
with open(file_path, "wb") as f:
f.write(content)
return file_path

View File

@ -1,12 +1,14 @@
import asyncio import asyncio
import json import json
import os import os
import time
from typing import AsyncIterator from typing import AsyncIterator
import httpx import httpx
from .base import LLMError, LLMProvider, LLMResponse from .base import LLMError, LLMProvider, LLMResponse
from .rate_limiter import get_rate_limiter from .rate_limiter import get_rate_limiter
from app.monitoring.llm_metrics import get_llm_metrics
_DEFAULT_MODEL = "deepseek-chat" _DEFAULT_MODEL = "deepseek-chat"
_DEFAULT_MAX_CONTEXT = 64_000 _DEFAULT_MAX_CONTEXT = 64_000
@ -75,21 +77,40 @@ class DeepSeekProvider(LLMProvider):
if stop: if stop:
payload["stop"] = stop payload["stop"] = stop
data = await self._request_with_retry(payload, stream=False) start_time = time.perf_counter()
metrics = get_llm_metrics(self.provider_name, self._model)
choice = data["choices"][0] try:
content = choice["message"]["content"] data = await self._request_with_retry(payload, stream=False)
usage = data.get("usage", {})
return LLMResponse( choice = data["choices"][0]
content=content, content = choice["message"]["content"]
model=data.get("model", self._model), usage = data.get("usage", {})
usage={
"prompt_tokens": usage.get("prompt_tokens", 0), duration = time.perf_counter() - start_time
"completion_tokens": usage.get("completion_tokens", 0), metrics.record_request(
"total_tokens": usage.get("total_tokens", 0), status="success",
}, duration=duration,
) prompt_tokens=usage.get("prompt_tokens"),
completion_tokens=usage.get("completion_tokens"),
)
return LLMResponse(
content=content,
model=data.get("model", self._model),
usage={
"prompt_tokens": usage.get("prompt_tokens", 0),
"completion_tokens": usage.get("completion_tokens", 0),
"total_tokens": usage.get("total_tokens", 0),
},
)
except Exception as e:
duration = time.perf_counter() - start_time
metrics.record_request(
status="error",
duration=duration,
)
raise
async def chat_stream( async def chat_stream(
self, self,

View File

@ -1,12 +1,14 @@
import asyncio import asyncio
import json import json
import os import os
import time
from typing import AsyncIterator from typing import AsyncIterator
import httpx import httpx
from .base import LLMError, LLMProvider, LLMResponse from .base import LLMError, LLMProvider, LLMResponse
from .rate_limiter import get_rate_limiter from .rate_limiter import get_rate_limiter
from app.monitoring.llm_metrics import get_llm_metrics
# 支持的模型及其上下文长度(百炼 Coding Plan + OpenAI # 支持的模型及其上下文长度(百炼 Coding Plan + OpenAI
_OPENAI_MODELS: dict[str, int] = { _OPENAI_MODELS: dict[str, int] = {
@ -90,21 +92,40 @@ class OpenAIProvider(LLMProvider):
if stop: if stop:
payload["stop"] = stop payload["stop"] = stop
data = await self._request_with_retry(payload, stream=False) start_time = time.perf_counter()
metrics = get_llm_metrics(self.provider_name, self._model)
choice = data["choices"][0] try:
content = choice["message"]["content"] data = await self._request_with_retry(payload, stream=False)
usage = data.get("usage", {})
return LLMResponse( choice = data["choices"][0]
content=content, content = choice["message"]["content"]
model=data.get("model", self._model), usage = data.get("usage", {})
usage={
"prompt_tokens": usage.get("prompt_tokens", 0), duration = time.perf_counter() - start_time
"completion_tokens": usage.get("completion_tokens", 0), metrics.record_request(
"total_tokens": usage.get("total_tokens", 0), status="success",
}, duration=duration,
) prompt_tokens=usage.get("prompt_tokens"),
completion_tokens=usage.get("completion_tokens"),
)
return LLMResponse(
content=content,
model=data.get("model", self._model),
usage={
"prompt_tokens": usage.get("prompt_tokens", 0),
"completion_tokens": usage.get("completion_tokens", 0),
"total_tokens": usage.get("total_tokens", 0),
},
)
except Exception as e:
duration = time.perf_counter() - start_time
metrics.record_request(
status="error",
duration=duration,
)
raise
async def chat_stream( async def chat_stream(
self, self,

View File

@ -0,0 +1,324 @@
"""
套餐额度预警服务
监控用户的API调用额度查询次数等使用情况在接近限制时触发预警
功能:
- 额度使用查询: 获取用户当前额度使用情况
- 使用率计算: 计算额度使用百分比
- 预警阈值检查: 80%警告90%严重100%限制
- 预警消息生成: 生成带建议操作的预警消息
- 额度重置: 支持额度重置功能
额度类型:
- api_calls: API调用次数
- queries: 查询次数
- content_generation: 内容生成次数
- storage: 存储空间(MB)
预警阈值:
- warning: 80% - 警告
- critical: 90% - 严重
- exhausted: 100% - 耗尽
"""
from __future__ import annotations
from dataclasses import dataclass
from enum import Enum
class QuotaType(str, Enum):
"""额度类型枚举"""
API_CALLS = "api_calls"
QUERIES = "queries"
CONTENT_GENERATION = "content_generation"
STORAGE = "storage"
QUOTA_LIMITS = {
"free": {
"api_calls": 1000,
"queries": 50,
"content_generation": 10,
"storage": 100,
},
"basic": {
"api_calls": 10000,
"queries": 500,
"content_generation": 100,
"storage": 1000,
},
"pro": {
"api_calls": 100000,
"queries": 5000,
"content_generation": 1000,
"storage": 10000,
},
"unlimited": {
"api_calls": -1,
"queries": -1,
"content_generation": -1,
"storage": -1,
},
}
WARNING_THRESHOLDS = {
"warning": 0.80,
"critical": 0.90,
"exhausted": 1.00,
}
WARNING_MESSAGES = {
"warning": "{quota_type}额度已使用{percentage:.0f}%",
"critical": "{quota_type}额度已使用{percentage:.0f}%,即将耗尽",
"exhausted": "{quota_type}额度已使用{percentage:.0f}%,已耗尽",
}
RECOMMENDED_ACTIONS = {
"warning": "请关注使用情况,考虑升级套餐",
"critical": "额度即将耗尽,建议立即升级套餐",
"exhausted": "额度已耗尽,请升级套餐或等待重置",
}
@dataclass
class QuotaUsage:
"""额度使用数据结构
Attributes:
quota_type: 额度类型
used: 已使用量
limit: 额度限制(-1表示无限)
usage_percentage: 使用百分比(0-100无限时为0)
status: 状态(ok/warning/critical/exhausted/unlimited)
remaining: 剩余额度(-1表示无限)
"""
quota_type: str
used: int
limit: int
usage_percentage: float
status: str
remaining: int
@dataclass
class QuotaWarning:
"""预警数据结构
Attributes:
quota_type: 额度类型
status: 预警状态
usage_percentage: 使用百分比
message: 预警消息
recommended_action: 建议操作
"""
quota_type: str
status: str
usage_percentage: float
message: str
recommended_action: str
class QuotaService:
"""套餐额度预警服务
提供额度查询使用率计算预警检查等功能
"""
def calculate_usage_percentage(self, used: int, limit: int) -> float:
"""计算额度使用百分比
Args:
used: 已使用量
limit: 额度限制(-1表示无限)
Returns:
使用百分比(0-100)无限额度时返回0.0
"""
if limit == -1 or limit == 0:
return 0.0
return (used / limit) * 100.0
def get_quota_status(self, usage_percentage: float, limit: int = 0) -> str:
"""根据使用率获取额度状态
Args:
usage_percentage: 使用百分比
limit: 额度限制(-1表示无限)
Returns:
状态字符串: ok/warning/critical/exhausted/unlimited
"""
if limit == -1:
return "unlimited"
if usage_percentage >= 100.0:
return "exhausted"
if usage_percentage >= 90.0:
return "critical"
if usage_percentage >= 80.0:
return "warning"
return "ok"
def get_quota_limit(self, plan: str, quota_type: QuotaType) -> int:
"""获取指定套餐的额度限制
Args:
plan: 套餐名称(free/basic/pro/unlimited)
quota_type: 额度类型
Returns:
额度限制值(-1表示无限)
"""
return QUOTA_LIMITS.get(plan, {}).get(quota_type.value, 0)
def get_remaining(self, used: int, limit: int) -> int:
"""计算剩余额度
Args:
used: 已使用量
limit: 额度限制(-1表示无限)
Returns:
剩余额度(-1表示无限)
"""
if limit == -1:
return -1
return limit - used
def _format_warning_message(self, quota_type: str, status: str, percentage: float) -> str:
"""格式化预警消息
Args:
quota_type: 额度类型
status: 预警状态
percentage: 使用百分比
Returns:
格式化后的预警消息
"""
template = WARNING_MESSAGES.get(status, "{quota_type}额度使用情况")
return template.format(quota_type=quota_type, percentage=percentage)
def _get_recommended_action(self, status: str) -> str:
"""获取建议操作
Args:
status: 预警状态
Returns:
建议操作描述
"""
return RECOMMENDED_ACTIONS.get(status, "请关注使用情况")
def generate_warning(
self,
quota_type: QuotaType,
status: str,
usage_percentage: float,
) -> QuotaWarning:
"""生成预警信息
Args:
quota_type: 额度类型
status: 预警状态
usage_percentage: 使用百分比
Returns:
QuotaWarning预警数据对象
"""
return QuotaWarning(
quota_type=quota_type.value,
status=status,
usage_percentage=usage_percentage,
message=self._format_warning_message(quota_type.value, status, usage_percentage),
recommended_action=self._get_recommended_action(status),
)
def check_quota(
self,
plan: str,
quota_type: QuotaType,
used: int,
) -> QuotaUsage:
"""检查指定套餐的额度使用情况
Args:
plan: 套餐名称
quota_type: 额度类型
used: 已使用量
Returns:
QuotaUsage额度使用数据对象
"""
limit = self.get_quota_limit(plan, quota_type)
percentage = self.calculate_usage_percentage(used, limit)
status = self.get_quota_status(percentage, limit)
remaining = self.get_remaining(used, limit)
return QuotaUsage(
quota_type=quota_type.value,
used=used,
limit=limit,
usage_percentage=percentage,
status=status,
remaining=remaining,
)
def reset_quota(self, used: int, reset_to: int = 0) -> int:
"""重置额度使用量
Args:
used: 当前使用量(未使用仅为接口兼容)
reset_to: 重置到的值默认为0
Returns:
重置后的使用量
"""
return reset_to
def get_all_quota_usage(
self,
plan: str,
api_calls_used: int = 0,
queries_used: int = 0,
content_generation_used: int = 0,
storage_used: int = 0,
) -> list[QuotaUsage]:
"""获取所有额度类型的使用情况
Args:
plan: 套餐名称
api_calls_used: API调用已使用量
queries_used: 查询已使用量
content_generation_used: 内容生成已使用量
storage_used: 存储已使用量(MB)
Returns:
所有额度类型的使用情况列表
"""
return [
self.check_quota(plan, QuotaType.API_CALLS, api_calls_used),
self.check_quota(plan, QuotaType.QUERIES, queries_used),
self.check_quota(plan, QuotaType.CONTENT_GENERATION, content_generation_used),
self.check_quota(plan, QuotaType.STORAGE, storage_used),
]
def get_warnings(self, usage_list: list[QuotaUsage]) -> list[QuotaWarning]:
"""从使用情况列表中生成预警
Args:
usage_list: 额度使用情况列表
Returns:
预警信息列表(仅包含需要预警的项)
"""
warnings = []
for usage in usage_list:
if usage.status in ["warning", "critical", "exhausted"]:
warning = self.generate_warning(
quota_type=QuotaType(usage.quota_type),
status=usage.status,
usage_percentage=usage.usage_percentage,
)
warnings.append(warning)
return warnings

File diff suppressed because it is too large Load Diff

View File

@ -40,3 +40,11 @@ aiosqlite
# PDF生成 # PDF生成
fpdf2>=2.7 fpdf2>=2.7
# 监控
prometheus-client>=0.19.0
# 文档解析
PyMuPDF>=1.23.0
python-docx>=1.1.0
shortuuid>=1.0.0

View File

@ -0,0 +1,227 @@
"""Agent基类测试"""
import pytest
from datetime import datetime, timezone
from app.agent_framework.base import BaseAgent
from app.agent_framework.protocol import (
AgentCapability,
AgentStatus,
TaskMessage,
TaskResult,
TaskStatus,
)
class ConcreteTestAgent(BaseAgent):
"""用于测试的BaseAgent实现"""
def __init__(self):
super().__init__(
name="concrete_test_agent",
agent_type="test_type",
version="1.0.0",
)
self._execute_called = False
self._execute_task_data = None
def get_capabilities(self) -> AgentCapability:
return AgentCapability(
agent_name=self.name,
agent_type=self.agent_type,
version=self.version,
supported_tasks=["test_task"],
max_concurrency=3,
description="测试用Agent",
)
async def execute(self, task: TaskMessage) -> TaskResult:
self._execute_called = True
self._execute_task_data = task
return TaskResult(
task_id=task.task_id,
agent_name=self.name,
status=TaskStatus.COMPLETED,
output_data={"result": "success"},
error_message=None,
started_at=datetime.now(timezone.utc),
completed_at=datetime.now(timezone.utc),
metrics={"test": True},
)
class TestBaseAgent:
"""Agent基类测试"""
def test_agent_initialization(self):
"""测试Agent初始化"""
agent = ConcreteTestAgent()
assert agent.name == "concrete_test_agent"
assert agent.agent_type == "test_type"
assert agent.version == "1.0.0"
assert agent.status == AgentStatus.OFFLINE
def test_get_capabilities(self):
"""测试获取Agent能力"""
agent = ConcreteTestAgent()
capability = agent.get_capabilities()
assert isinstance(capability, AgentCapability)
assert capability.agent_name == "concrete_test_agent"
assert "test_task" in capability.supported_tasks
def test_agent_status_transitions(self):
"""测试Agent状态转换"""
agent = ConcreteTestAgent()
# 初始状态
assert agent.status == AgentStatus.OFFLINE
# 模拟设置状态
agent._status = AgentStatus.ONLINE
assert agent.status == AgentStatus.ONLINE
agent._status = AgentStatus.BUSY
assert agent.status == AgentStatus.BUSY
def test_agent_running_tasks_tracking(self):
"""测试运行任务跟踪"""
agent = ConcreteTestAgent()
assert len(agent._running_tasks) == 0
# 模拟添加任务
agent._running_tasks.add("task-1")
agent._running_tasks.add("task-2")
assert len(agent._running_tasks) == 2
# 模拟移除任务
agent._running_tasks.discard("task-1")
assert len(agent._running_tasks) == 1
assert "task-1" not in agent._running_tasks
assert "task-2" in agent._running_tasks
def test_agent_semaphore_initialization(self):
"""测试信号量初始化"""
agent = ConcreteTestAgent()
capability = agent.get_capabilities()
max_concurrency = capability.max_concurrency
# 初始化信号量
import asyncio
agent._semaphore = asyncio.Semaphore(max_concurrency)
assert agent._semaphore is not None
assert agent._semaphore._value == max_concurrency
def test_is_idle_property(self):
"""测试空闲状态判断"""
agent = ConcreteTestAgent()
# OFFLINE 或 ONLINE 状态应该被视为空闲
agent._status = AgentStatus.OFFLINE
# 这里假设有is_idle属性或方法
# 根据实际实现检查
# BUSY状态不是空闲
agent._status = AgentStatus.BUSY
class TestTaskMessage:
"""TaskMessage测试"""
def test_task_message_creation(self):
"""测试TaskMessage创建"""
task = TaskMessage(
task_id=str(uuid.uuid4()),
agent_name="test_agent",
task_type="test_task",
priority=5,
input_data={"key": "value"},
callback_url=None,
created_at=datetime.now(timezone.utc),
)
assert task.task_id is not None
assert task.agent_name == "test_agent"
assert task.task_type == "test_task"
assert task.priority == 5
assert task.input_data == {"key": "value"}
def test_task_message_to_dict(self):
"""测试TaskMessage序列化"""
task = TaskMessage(
task_id="test-uuid",
agent_name="test_agent",
task_type="test_task",
priority=5,
input_data={"key": "value"},
callback_url=None,
created_at=datetime.now(timezone.utc),
)
data = task.to_dict()
assert data["task_id"] == "test-uuid"
assert data["agent_name"] == "test_agent"
assert data["task_type"] == "test_task"
def test_task_message_from_dict(self):
"""测试TaskMessage反序列化"""
data = {
"task_id": "test-uuid",
"agent_name": "test_agent",
"task_type": "test_task",
"priority": 5,
"input_data": {"key": "value"},
"callback_url": None,
"created_at": datetime.now(timezone.utc).isoformat(),
}
task = TaskMessage.from_dict(data)
assert task.task_id == "test-uuid"
assert task.agent_name == "test_agent"
class TestTaskResult:
"""TaskResult测试"""
def test_task_result_creation(self):
"""测试TaskResult创建"""
result = TaskResult(
task_id="test-uuid",
agent_name="test_agent",
status=TaskStatus.COMPLETED,
output_data={"result": "success"},
error_message=None,
started_at=datetime.now(timezone.utc),
completed_at=datetime.now(timezone.utc),
metrics={"elapsed": 1.5},
)
assert result.task_id == "test-uuid"
assert result.agent_name == "test_agent"
assert result.status == TaskStatus.COMPLETED
def test_task_result_to_dict(self):
"""测试TaskResult序列化"""
result = TaskResult(
task_id="test-uuid",
agent_name="test_agent",
status=TaskStatus.COMPLETED,
output_data={"result": "success"},
error_message=None,
started_at=datetime.now(timezone.utc),
completed_at=datetime.now(timezone.utc),
metrics={"elapsed": 1.5},
)
data = result.to_dict()
assert data["task_id"] == "test-uuid"
assert data["status"] == TaskStatus.COMPLETED
assert data["output_data"] == {"result": "success"}
import uuid

View File

@ -0,0 +1,214 @@
"""任务分发器测试"""
import pytest
import uuid
from datetime import datetime, timezone
from app.agent_framework.dispatcher import TaskDispatcher
from app.agent_framework.registry import AgentRegistry
from app.agent_framework.protocol import (
AgentCapability,
TaskMessage,
)
from app.config import settings
def is_database_available():
"""检查数据库是否可用(同步方式)"""
try:
from sqlalchemy import create_engine, text
from app.config import settings
# 从URL创建同步引擎进行测试
sync_url = settings.DATABASE_URL.replace('+aiosqlite', '').replace('+asyncpg', '')
if 'sqlite' in sync_url:
engine = create_engine(sync_url)
else:
engine = create_engine(sync_url, connect_args={"connect_timeout": 1})
with engine.connect() as conn:
conn.execute(text("SELECT 1"))
engine.dispose()
return True
except Exception:
return False
# 检查服务是否可用
_db_available = None
def check_db():
global _db_available
if _db_available is None:
try:
_db_available = is_database_available()
except Exception:
_db_available = False
return _db_available
def is_redis_available():
"""检查Redis是否可用"""
import redis
try:
r = redis.Redis.from_url(settings.REDIS_URL)
r.ping()
return True
except Exception:
return False
_redis_available = None
def check_redis():
global _redis_available
if _redis_available is None:
try:
_redis_available = is_redis_available()
except Exception:
_redis_available = False
return _redis_available
class TestTaskDispatcher:
"""任务分发器测试"""
@pytest.fixture
def dispatcher(self):
"""创建分发器实例"""
return TaskDispatcher(settings.REDIS_URL)
@pytest.mark.asyncio
async def test_dispatcher_initialization(self, dispatcher):
"""测试分发器初始化"""
assert dispatcher is not None
assert dispatcher._redis_url == settings.REDIS_URL
assert dispatcher._redis is None
@pytest.mark.asyncio
async def test_get_task_status_not_found(self, dispatcher):
"""测试获取不存在的任务状态"""
if not check_db():
pytest.skip("数据库不可用,跳过此测试")
non_existent_id = str(uuid.uuid4())
from app.agent_framework.exceptions import TaskNotFoundError
with pytest.raises(TaskNotFoundError):
await dispatcher.get_task_status(non_existent_id)
@pytest.mark.asyncio
async def test_dispatch_without_agent(self, dispatcher):
"""测试分发任务到不存在的Agent"""
if not check_db():
pytest.skip("数据库不可用,跳过此测试")
task = TaskMessage(
task_id=str(uuid.uuid4()),
agent_name="non_existent_agent",
task_type="test_task",
priority=5,
input_data={},
callback_url=None,
created_at=datetime.now(timezone.utc),
)
from app.agent_framework.exceptions import TaskDispatchError
with pytest.raises(TaskDispatchError):
await dispatcher.dispatch(task)
@pytest.mark.asyncio
async def test_cancel_task_not_found(self, dispatcher):
"""测试取消不存在的任务"""
if not check_db():
pytest.skip("数据库不可用,跳过此测试")
non_existent_id = str(uuid.uuid4())
from app.agent_framework.exceptions import TaskNotFoundError
with pytest.raises(TaskNotFoundError):
await dispatcher.cancel_task(non_existent_id)
@pytest.mark.asyncio
async def test_handle_progress(self, dispatcher):
"""测试处理进度上报"""
if not check_db():
pytest.skip("数据库不可用,跳过此测试")
from app.agent_framework.protocol import TaskProgress
# 创建一个假的progress对象
progress = TaskProgress(
task_id=str(uuid.uuid4()),
agent_name="non_existent",
progress=0.5,
message="测试进度",
updated_at=datetime.now(timezone.utc),
)
# 不应抛出异常
await dispatcher.handle_progress(progress)
@pytest.mark.asyncio
async def test_close_dispatcher(self, dispatcher):
"""测试关闭分发器"""
if not check_redis():
pytest.skip("Redis不可用跳过此测试")
# 先获取redis连接
await dispatcher._get_redis()
assert dispatcher._redis is not None
# 关闭
await dispatcher.close()
assert dispatcher._redis is None
@pytest.mark.asyncio
async def test_dispatch_and_query_flow(self, dispatcher):
"""测试完整分发和查询流程"""
if not check_db() or not check_redis():
pytest.skip("数据库或Redis不可用跳过此测试")
# 1. 注册一个测试Agent
registry = AgentRegistry()
agent_name = f"test_dispatch_agent_{uuid.uuid4().hex[:8]}"
capability = AgentCapability(
agent_name=agent_name,
agent_type="test_type",
version="1.0.0",
supported_tasks=["test_task"],
max_concurrency=3,
description="测试Agent",
)
await registry.register(capability, endpoint=f"agent:{agent_name}")
# 2. 尝试分发任务
task = TaskMessage(
task_id=str(uuid.uuid4()),
agent_name=agent_name,
task_type="test_task",
priority=5,
input_data={"test": "data"},
callback_url=None,
created_at=datetime.now(timezone.utc),
)
# Agent虽然注册了但可能不在线这里只验证方法能正常执行
try:
task_id = await dispatcher.dispatch(task)
assert task_id is not None
except Exception:
# Agent可能不在线这是预期行为
pass
# 清理
await registry.unregister(agent_name)
@pytest.mark.asyncio
async def test_retry_failed_tasks_empty(self, dispatcher):
"""测试重试失败任务(无失败任务)"""
if not check_db():
pytest.skip("数据库不可用,跳过此测试")
result = await dispatcher.retry_failed_tasks(max_retries=3)
# 无失败任务时不应抛出异常
assert result is None or result == [] or isinstance(result, int)

View File

@ -0,0 +1,228 @@
"""Agent注册表测试"""
import pytest
import uuid
from app.agent_framework.registry import AgentRegistry
from app.agent_framework.protocol import AgentCapability, AgentStatus
def is_database_available():
"""检查数据库是否可用(同步方式)"""
try:
# 直接尝试同步方式检查
from sqlalchemy import create_engine, text
from app.config import settings
# 从URL创建同步引擎进行测试
sync_url = settings.DATABASE_URL.replace('+aiosqlite', '').replace('+asyncpg', '')
if 'sqlite' in sync_url:
engine = create_engine(sync_url)
else:
engine = create_engine(sync_url, connect_args={"connect_timeout": 1})
with engine.connect() as conn:
conn.execute(text("SELECT 1"))
engine.dispose()
return True
except Exception:
return False
# 检查服务是否可用
_db_available = None
def check_db():
global _db_available
if _db_available is None:
try:
_db_available = is_database_available()
except Exception:
_db_available = False
return _db_available
class TestAgentRegistry:
"""Agent注册表测试"""
@pytest.mark.asyncio
async def test_register_agent(self):
"""测试Agent注册"""
if not check_db():
pytest.skip("数据库不可用,跳过此测试")
registry = AgentRegistry()
agent_name = f"test_agent_{uuid.uuid4().hex[:8]}"
capability = AgentCapability(
agent_name=agent_name,
agent_type="citation_detector",
version="1.0.0",
supported_tasks=["citation_detect", "citation_detect_single"],
max_concurrency=3,
description="测试Agent",
)
agent_id = await registry.register(capability, endpoint=f"agent:{agent_name}")
# 验证注册成功
assert agent_id is not None
assert len(agent_id) > 0
# 清理
await registry.unregister(agent_name)
@pytest.mark.asyncio
async def test_get_registered_agent(self):
"""测试获取已注册的Agent"""
if not check_db():
pytest.skip("数据库不可用,跳过此测试")
registry = AgentRegistry()
agent_name = f"test_agent_{uuid.uuid4().hex[:8]}"
capability = AgentCapability(
agent_name=agent_name,
agent_type="citation_detector",
version="1.0.0",
supported_tasks=["citation_detect"],
max_concurrency=3,
description="测试Agent",
)
await registry.register(capability, endpoint=f"agent:{agent_name}")
retrieved = await registry.get_agent(agent_name)
assert retrieved is not None
assert retrieved["name"] == agent_name
assert retrieved["agent_type"] == "citation_detector"
# 清理
await registry.unregister(agent_name)
@pytest.mark.asyncio
async def test_list_all_agents(self):
"""测试列出所有Agent"""
if not check_db():
pytest.skip("数据库不可用,跳过此测试")
registry = AgentRegistry()
agents = await registry.list_agents()
assert isinstance(agents, list)
@pytest.mark.asyncio
async def test_unregister_agent(self):
"""测试取消注册"""
if not check_db():
pytest.skip("数据库不可用,跳过此测试")
registry = AgentRegistry()
agent_name = f"test_agent_{uuid.uuid4().hex[:8]}"
capability = AgentCapability(
agent_name=agent_name,
agent_type="citation_detector",
version="1.0.0",
supported_tasks=["citation_detect"],
max_concurrency=3,
description="测试Agent",
)
await registry.register(capability, endpoint=f"agent:{agent_name}")
await registry.unregister(agent_name)
retrieved = await registry.get_agent(agent_name)
# 注销后状态应为OFFLINE或None
assert retrieved is None or retrieved["status"] == AgentStatus.OFFLINE
@pytest.mark.asyncio
async def test_update_heartbeat(self):
"""测试心跳更新"""
if not check_db():
pytest.skip("数据库不可用,跳过此测试")
registry = AgentRegistry()
agent_name = f"test_agent_{uuid.uuid4().hex[:8]}"
capability = AgentCapability(
agent_name=agent_name,
agent_type="citation_detector",
version="1.0.0",
supported_tasks=["citation_detect"],
max_concurrency=3,
description="测试Agent",
)
await registry.register(capability, endpoint=f"agent:{agent_name}")
await registry.update_heartbeat(agent_name)
agent = await registry.get_agent(agent_name)
assert agent is not None
assert agent["last_heartbeat"] is not None
# 清理
await registry.unregister(agent_name)
@pytest.mark.asyncio
async def test_get_available_agent(self):
"""测试根据任务类型获取可用Agent"""
if not check_db():
pytest.skip("数据库不可用,跳过此测试")
registry = AgentRegistry()
agent_name = f"test_agent_{uuid.uuid4().hex[:8]}"
capability = AgentCapability(
agent_name=agent_name,
agent_type="citation_detector",
version="1.0.0",
supported_tasks=["citation_detect", "citation_detect_single"],
max_concurrency=3,
description="测试Agent",
)
await registry.register(capability, endpoint=f"agent:{agent_name}")
available = await registry.get_available_agent("citation_detect")
# 可能返回None因为Agent状态不是ONLINE
# 这里只验证方法能正常执行
assert available is None or isinstance(available, str)
# 清理
await registry.unregister(agent_name)
@pytest.mark.asyncio
async def test_agent_reregistration(self):
"""测试Agent重复注册应该更新而非报错"""
if not check_db():
pytest.skip("数据库不可用,跳过此测试")
registry = AgentRegistry()
agent_name = f"test_agent_{uuid.uuid4().hex[:8]}"
capability1 = AgentCapability(
agent_name=agent_name,
agent_type="citation_detector",
version="1.0.0",
supported_tasks=["citation_detect"],
max_concurrency=3,
description="测试AgentV1",
)
capability2 = AgentCapability(
agent_name=agent_name,
agent_type="citation_detector",
version="2.0.0",
supported_tasks=["citation_detect", "new_task"],
max_concurrency=5,
description="测试AgentV2",
)
await registry.register(capability1, endpoint=f"agent:{agent_name}")
# 验证第一次注册成功
agent_data = await registry.get_agent(agent_name)
assert agent_data is not None
# 重新注册同名Agent
agent_id2 = await registry.register(capability2, endpoint=f"agent:{agent_name}")
# 应该成功且不报错
assert agent_id2 is not None
# 清理
await registry.unregister(agent_name)

View File

@ -0,0 +1,224 @@
"""Agent集成测试"""
import pytest
import uuid
from datetime import datetime, timezone
from app.agent_framework.agents.citation_detector import CitationDetectorAgent
from app.agent_framework.agents.content_generator_agent import ContentGeneratorAgent
from app.agent_framework.protocol import (
TaskMessage,
TaskStatus,
)
class TestCitationDetectorAgent:
"""引用检测Agent测试"""
def test_agent_initialization(self):
"""测试Agent初始化"""
agent = CitationDetectorAgent()
assert agent.name == "citation_detector"
assert agent.agent_type.value == "citation_detector"
assert agent.version == "1.0.0"
def test_get_capabilities(self):
"""测试获取Agent能力"""
agent = CitationDetectorAgent()
capability = agent.get_capabilities()
assert capability.agent_name == "citation_detector"
assert "citation_detect" in capability.supported_tasks
assert "citation_detect_single" in capability.supported_tasks
assert capability.max_concurrency == 3
@pytest.mark.asyncio
async def test_execute_with_invalid_task_type(self):
"""测试执行不支持的任务类型"""
agent = CitationDetectorAgent()
task = TaskMessage(
task_id=str(uuid.uuid4()),
agent_name="citation_detector",
task_type="invalid_task_type",
priority=5,
input_data={},
callback_url=None,
created_at=datetime.now(timezone.utc),
)
result = await agent.execute(task)
assert result.status == TaskStatus.FAILED
assert result.error_message is not None
assert "Unsupported task type" in result.error_message
@pytest.mark.asyncio
async def test_execute_single_detect_missing_params(self):
"""测试单平台检测缺少必需参数"""
agent = CitationDetectorAgent()
task = TaskMessage(
task_id=str(uuid.uuid4()),
agent_name="citation_detector",
task_type="citation_detect_single",
priority=5,
input_data={}, # 缺少keyword, platform, target_brand
callback_url=None,
created_at=datetime.now(timezone.utc),
)
result = await agent.execute(task)
assert result.status == TaskStatus.FAILED
assert result.error_message is not None
@pytest.mark.asyncio
async def test_execute_full_detect_missing_query_id(self):
"""测试完整检测缺少query_id"""
agent = CitationDetectorAgent()
task = TaskMessage(
task_id=str(uuid.uuid4()),
agent_name="citation_detector",
task_type="citation_detect",
priority=5,
input_data={}, # 缺少query_id
callback_url=None,
created_at=datetime.now(timezone.utc),
)
result = await agent.execute(task)
assert result.status == TaskStatus.FAILED
assert "query_id" in result.error_message or "must contain" in result.error_message
def test_compatibility_methods_exist(self):
"""测试向后兼容方法存在"""
agent = CitationDetectorAgent()
assert hasattr(agent, 'execute_query_compat')
assert hasattr(agent, 'execute_single_platform_compat')
assert callable(agent.execute_query_compat)
assert callable(agent.execute_single_platform_compat)
class TestContentGeneratorAgent:
"""内容生成Agent测试"""
def test_agent_initialization(self):
"""测试Agent初始化"""
agent = ContentGeneratorAgent()
assert agent.name == "content_generator"
assert agent.agent_type.value == "content_generator"
assert agent.version == "1.0.0"
def test_get_capabilities(self):
"""测试获取Agent能力"""
agent = ContentGeneratorAgent()
capability = agent.get_capabilities()
assert capability.agent_name == "content_generator"
assert "generate_topics" in capability.supported_tasks
assert "generate_article" in capability.supported_tasks
assert capability.max_concurrency == 2
@pytest.mark.asyncio
async def test_execute_with_invalid_task_type(self):
"""测试执行不支持的任务类型"""
agent = ContentGeneratorAgent()
task = TaskMessage(
task_id=str(uuid.uuid4()),
agent_name="content_generator",
task_type="invalid_task_type",
priority=5,
input_data={},
callback_url=None,
created_at=datetime.now(timezone.utc),
)
result = await agent.execute(task)
assert result.status == TaskStatus.FAILED
assert "Unsupported task type" in result.error_message
@pytest.mark.asyncio
async def test_generate_topics_missing_keyword(self):
"""测试生成选题缺少关键词"""
agent = ContentGeneratorAgent()
task = TaskMessage(
task_id=str(uuid.uuid4()),
agent_name="content_generator",
task_type="generate_topics",
priority=5,
input_data={}, # 缺少target_keyword
callback_url=None,
created_at=datetime.now(timezone.utc),
)
# 由于没有真实LLM调用和知识库这个测试会调用LLM
# 我们主要验证方法能正常执行
result = await agent.execute(task)
# 结果可能是FAILED因为缺少必要参数或LLM调用失败
assert result is not None
assert result.status in [TaskStatus.COMPLETED, TaskStatus.FAILED]
@pytest.mark.asyncio
async def test_generate_article_missing_keyword(self):
"""测试生成文章缺少关键词"""
agent = ContentGeneratorAgent()
task = TaskMessage(
task_id=str(uuid.uuid4()),
agent_name="content_generator",
task_type="generate_article",
priority=5,
input_data={}, # 缺少target_keyword
callback_url=None,
created_at=datetime.now(timezone.utc),
)
result = await agent.execute(task)
assert result is not None
# 缺少必要参数可能导致失败
assert result.status in [TaskStatus.COMPLETED, TaskStatus.FAILED]
def test_extract_json_method(self):
"""测试JSON提取方法"""
agent = ContentGeneratorAgent()
# 测试普通JSON
json_text = '{"title": "测试标题", "reason": "测试原因"}'
extracted = agent._extract_json(json_text)
assert "title" in extracted
# 测试被markdown包裹的JSON
md_text = '```json\n{"title": "测试"}\n```'
extracted = agent._extract_json(md_text)
assert "title" in extracted
class TestAgentProtocol:
"""Agent协议测试"""
def test_agent_type_enum_values(self):
"""测试AgentType枚举值"""
from app.agent_framework.protocol import AgentType
assert AgentType.CITATION_DETECTOR.value == "citation_detector"
assert AgentType.CONTENT_GENERATOR.value == "content_generator"
def test_task_status_enum_values(self):
"""测试TaskStatus枚举值"""
assert TaskStatus.PENDING.value == "pending"
assert TaskStatus.RUNNING.value == "running"
assert TaskStatus.COMPLETED.value == "completed"
assert TaskStatus.FAILED.value == "failed"
assert TaskStatus.CANCELLED.value == "cancelled"
def test_agent_status_enum_values(self):
"""测试AgentStatus枚举值"""
from app.agent_framework.protocol import AgentStatus
assert AgentStatus.ONLINE.value == "online"
assert AgentStatus.OFFLINE.value == "offline"
assert AgentStatus.BUSY.value == "busy"

View File

@ -0,0 +1,329 @@
"""告警设置API测试"""
import uuid
import pytest
import pytest_asyncio
from httpx import AsyncClient, ASGITransport
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession, create_async_engine
from sqlalchemy.pool import StaticPool
from app.database import Base
from app.main import app
from app.models.user import User
from app.models.brand import Brand
from app.models.alert_setting import AlertSetting
from app.api.deps import get_current_user, get_db
from app.services.auth import hash_password, create_access_token
# ==================== Fixtures ====================
@pytest_asyncio.fixture
async def async_engine():
"""创建测试用SQLite异步引擎"""
engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest_asyncio.fixture
async def async_session(async_engine):
"""创建测试用异步数据库会话"""
async_session_maker = async_sessionmaker(
async_engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
autocommit=False,
)
async with async_session_maker() as session:
yield session
@pytest_asyncio.fixture
async def test_user(async_session):
"""创建测试用户"""
user = User(
id=uuid.uuid4(),
email="test@example.com",
password_hash=hash_password("Test@123456"),
name="Test User",
plan="free",
max_queries=5,
is_active=True,
email_verified=True,
)
async_session.add(user)
await async_session.commit()
await async_session.refresh(user)
return user
@pytest_asyncio.fixture
async def test_brand(async_session, test_user):
"""创建测试品牌"""
brand = Brand(
id=uuid.uuid4(),
user_id=test_user.id,
name="Test Brand",
aliases=["TestBrand", "TB"],
website="https://testbrand.com",
industry="technology",
platforms=["wenxin", "kimi"],
frequency="weekly",
status="active",
)
async_session.add(brand)
await async_session.commit()
await async_session.refresh(brand)
return brand
@pytest_asyncio.fixture
async def test_alert_setting(async_session, test_user, test_brand):
"""创建测试告警设置"""
setting = AlertSetting(
id=uuid.uuid4(),
brand_id=test_brand.id,
user_id=test_user.id,
alert_type="score_drop",
enabled=True,
threshold=5.0,
)
async_session.add(setting)
await async_session.commit()
await async_session.refresh(setting)
return setting
@pytest_asyncio.fixture
async def async_client(async_session, test_user):
"""创建异步HTTP客户端用于API测试"""
async def override_get_db():
yield async_session
async def override_get_current_user():
return test_user
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_get_current_user
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client
app.dependency_overrides.clear()
@pytest.fixture
def auth_headers(test_user):
"""创建认证请求头"""
token = create_access_token(data={"sub": str(test_user.id)})
return {"Authorization": f"Bearer {token}"}
# ==================== 测试类 ====================
class TestAlertSettingsAPI:
"""告警设置API测试"""
@pytest.mark.asyncio
async def test_get_alert_settings_success(self, async_client, test_alert_setting):
"""测试获取告警设置 - 成功返回设置列表"""
response = await async_client.get("/api/v1/alerts/settings")
assert response.status_code == 200
data = response.json()
assert "items" in data
assert "total" in data
assert data["total"] >= 1
assert len(data["items"]) >= 1
first_item = data["items"][0]
assert "id" in first_item
assert "brand_id" in first_item
assert "alert_type" in first_item
assert "enabled" in first_item
assert "threshold" in first_item
@pytest.mark.asyncio
async def test_get_alert_settings_by_brand(self, async_client, test_brand, test_alert_setting):
"""测试按品牌筛选告警设置"""
response = await async_client.get(
f"/api/v1/alerts/settings?brand_id={test_brand.id}"
)
assert response.status_code == 200
data = response.json()
assert data["total"] >= 1
for item in data["items"]:
assert item["brand_id"] == str(test_brand.id)
@pytest.mark.asyncio
async def test_update_alert_settings_success(self, async_client, test_brand):
"""测试更新告警设置 - 成功更新并返回新设置"""
update_data = {
"settings": [
{
"brand_id": str(test_brand.id),
"alert_type": "score_drop",
"enabled": True,
"threshold": 20.0,
}
]
}
response = await async_client.put(
"/api/v1/alerts/settings", json=update_data
)
assert response.status_code == 200
data = response.json()
assert "items" in data
assert len(data["items"]) >= 1
updated_setting = data["items"][0]
assert updated_setting["alert_type"] == "score_drop"
assert updated_setting["threshold"] == 20.0
assert updated_setting["enabled"] is True
@pytest.mark.asyncio
async def test_create_alert_setting_success(self, async_client, test_brand):
"""测试创建告警设置 - 为新品牌创建默认设置"""
create_data = {
"brand_id": str(test_brand.id),
"alert_type": "negative_sentiment",
"enabled": True,
"threshold": 1.0,
}
response = await async_client.post(
"/api/v1/alerts/settings", json=create_data
)
assert response.status_code == 201
data = response.json()
assert data["alert_type"] == "negative_sentiment"
assert data["threshold"] == 1.0
assert data["enabled"] is True
assert "id" in data
@pytest.mark.asyncio
async def test_delete_alert_setting_success(self, async_client, test_alert_setting):
"""测试删除告警设置 - 成功删除"""
response = await async_client.delete(
f"/api/v1/alerts/settings/{test_alert_setting.id}"
)
assert response.status_code == 204
get_response = await async_client.get("/api/v1/alerts/settings")
data = get_response.json()
deleted_ids = [item["id"] for item in data["items"]]
assert str(test_alert_setting.id) not in deleted_ids
@pytest.mark.asyncio
async def test_delete_alert_setting_not_found(self, async_client):
"""测试删除不存在的告警设置"""
non_existent_id = uuid.uuid4()
response = await async_client.delete(
f"/api/v1/alerts/settings/{non_existent_id}"
)
assert response.status_code == 404
class TestAlertSettingsValidation:
"""告警设置验证测试"""
@pytest.mark.asyncio
async def test_brand_not_found_returns_404(self, async_client):
"""测试品牌不存在时返回404"""
non_existent_brand_id = uuid.uuid4()
create_data = {
"brand_id": str(non_existent_brand_id),
"alert_type": "score_drop",
"enabled": True,
"threshold": 5.0,
}
response = await async_client.post(
"/api/v1/alerts/settings", json=create_data
)
assert response.status_code == 404
@pytest.mark.asyncio
async def test_unauthorized_returns_401(self, async_session):
"""测试未认证时返回401"""
async def override_get_db():
yield async_session
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
headers = {"Authorization": "Bearer invalid_token"}
response = await client.get(
"/api/v1/alerts/settings",
headers=headers
)
assert response.status_code == 401
app.dependency_overrides.clear()
@pytest.mark.asyncio
async def test_invalid_alert_type_returns_422(self, async_client, test_brand):
"""测试无效的告警类型返回422"""
create_data = {
"brand_id": str(test_brand.id),
"alert_type": "invalid_type",
"enabled": True,
"threshold": 5.0,
}
response = await async_client.post(
"/api/v1/alerts/settings", json=create_data
)
assert response.status_code == 422
@pytest.mark.asyncio
async def test_negative_threshold_returns_422(self, async_client, test_brand):
"""测试负数阈值返回422"""
create_data = {
"brand_id": str(test_brand.id),
"alert_type": "score_drop",
"enabled": True,
"threshold": -10.0,
}
response = await async_client.post(
"/api/v1/alerts/settings", json=create_data
)
assert response.status_code == 422
@pytest.mark.asyncio
async def test_threshold_over_100_returns_422(self, async_client, test_brand):
"""测试阈值超过100返回422"""
create_data = {
"brand_id": str(test_brand.id),
"alert_type": "score_drop",
"enabled": True,
"threshold": 150.0,
}
response = await async_client.post(
"/api/v1/alerts/settings", json=create_data
)
assert response.status_code == 422

View File

@ -0,0 +1,451 @@
"""认证API测试"""
import uuid
import pytest
import pytest_asyncio
from httpx import AsyncClient, ASGITransport
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession, create_async_engine
from sqlalchemy.pool import StaticPool
from app.database import Base
from app.main import app
from app.models.user import User
from app.api.deps import get_current_user, get_db
from app.services.auth import hash_password, create_access_token
# ==================== Fixtures ====================
@pytest_asyncio.fixture
async def async_engine():
"""创建测试用SQLite异步引擎"""
engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest_asyncio.fixture
async def async_session(async_engine):
"""创建测试用异步数据库会话"""
async_session_maker = async_sessionmaker(
async_engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
autocommit=False,
)
async with async_session_maker() as session:
yield session
@pytest_asyncio.fixture
async def test_user(async_session):
"""创建测试用户"""
user = User(
id=uuid.uuid4(),
email="test@example.com",
password_hash=hash_password("Test@123456"),
name="Test User",
plan="free",
max_queries=5,
is_active=True,
email_verified=True,
)
async_session.add(user)
await async_session.commit()
await async_session.refresh(user)
return user
@pytest_asyncio.fixture
async def async_client(async_session, test_user):
"""创建异步HTTP客户端用于API测试"""
async def override_get_db():
yield async_session
async def override_get_current_user():
return test_user
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_get_current_user
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client
app.dependency_overrides.clear()
@pytest.fixture
def auth_headers(test_user):
"""创建认证请求头"""
token = create_access_token(data={"sub": str(test_user.id)})
return {"Authorization": f"Bearer {token}"}
# ==================== 测试类 ====================
class TestAuthAPI:
"""认证API测试"""
@pytest.mark.asyncio
async def test_register_success(self, async_session):
"""测试用户注册成功"""
async def override_get_db():
yield async_session
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
response = await client.post(
"/api/v1/auth/register",
json={
"email": f"test_{uuid.uuid4()}@example.com",
"name": "Test User",
"password": "Test@123456"
}
)
assert response.status_code == 201
data = response.json()
assert "id" in data
assert data["email"] is not None
assert data["name"] == "Test User"
app.dependency_overrides.clear()
@pytest.mark.asyncio
async def test_register_duplicate_email(self, async_session):
"""测试重复邮箱注册失败"""
email = f"test_{uuid.uuid4()}@example.com"
async def override_get_db():
yield async_session
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
# 第一次注册
response1 = await client.post(
"/api/v1/auth/register",
json={
"email": email,
"name": "Test User 1",
"password": "Test@123456"
}
)
assert response1.status_code == 201
# 第二次使用相同邮箱注册
response2 = await client.post(
"/api/v1/auth/register",
json={
"email": email,
"name": "Test User 2",
"password": "Test@123456"
}
)
assert response2.status_code == 400
app.dependency_overrides.clear()
@pytest.mark.asyncio
async def test_login_success(self, async_session):
"""测试用户登录成功"""
email = f"test_{uuid.uuid4()}@example.com"
password = "Test@123456"
# 先创建用户
user = User(
id=uuid.uuid4(),
email=email,
password_hash=hash_password(password),
name="Test User",
plan="free",
max_queries=5,
is_active=True,
email_verified=True,
)
async_session.add(user)
await async_session.commit()
async def override_get_db():
yield async_session
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
response = await client.post(
"/api/v1/auth/login",
json={
"email": email,
"password": password
}
)
assert response.status_code == 200
data = response.json()
assert "access_token" in data
assert "refresh_token" in data
assert data["token_type"] == "bearer"
assert "user" in data
app.dependency_overrides.clear()
@pytest.mark.asyncio
async def test_login_wrong_password(self, async_session):
"""测试错误密码登录"""
email = f"test_{uuid.uuid4()}@example.com"
# 创建用户
user = User(
id=uuid.uuid4(),
email=email,
password_hash=hash_password("Correct@123456"),
name="Test User",
plan="free",
max_queries=5,
is_active=True,
email_verified=True,
)
async_session.add(user)
await async_session.commit()
async def override_get_db():
yield async_session
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
response = await client.post(
"/api/v1/auth/login",
json={
"email": email,
"password": "WrongPassword"
}
)
assert response.status_code == 401
app.dependency_overrides.clear()
@pytest.mark.asyncio
async def test_login_nonexistent_user(self, async_session):
"""测试不存在的用户登录"""
async def override_get_db():
yield async_session
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
response = await client.post(
"/api/v1/auth/login",
json={
"email": "nonexistent@example.com",
"password": "Test@123456"
}
)
# 统一错误消息,不区分用户不存在和密码错误
# 可能返回401认证失败或429速率限制都是有效响应
assert response.status_code in [401, 429]
app.dependency_overrides.clear()
@pytest.mark.asyncio
async def test_get_current_user(self, async_client, test_user):
"""测试获取当前用户信息"""
response = await async_client.get("/api/v1/auth/me")
assert response.status_code == 200
data = response.json()
assert data["id"] == str(test_user.id)
assert data["email"] == test_user.email
@pytest.mark.asyncio
async def test_change_password_success(self, async_client, async_session, test_user):
"""测试修改密码成功"""
async def override_get_db():
yield async_session
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
# 使用token获取用户
token = create_access_token(data={"sub": str(test_user.id)})
headers = {"Authorization": f"Bearer {token}"}
response = await client.put(
"/api/v1/auth/change-password",
headers=headers,
json={
"old_password": "Test@123456",
"new_password": "NewPass@123456"
}
)
assert response.status_code == 200
data = response.json()
assert data["message"] == "密码修改成功"
app.dependency_overrides.clear()
@pytest.mark.asyncio
async def test_change_password_wrong_old(self, async_client, test_user):
"""测试旧密码错误"""
token = create_access_token(data={"sub": str(test_user.id)})
headers = {"Authorization": f"Bearer {token}"}
response = await async_client.put(
"/api/v1/auth/change-password",
headers=headers,
json={
"old_password": "WrongOldPass",
"new_password": "NewPass@123456"
}
)
assert response.status_code == 400
assert "旧密码错误" in response.json()["detail"]
@pytest.mark.asyncio
async def test_update_profile(self, async_client, test_user):
"""测试更新用户资料"""
token = create_access_token(data={"sub": str(test_user.id)})
headers = {"Authorization": f"Bearer {token}"}
response = await async_client.put(
"/api/v1/auth/profile",
headers=headers,
json={
"name": "Updated Name",
"avatar_url": "https://example.com/avatar.png"
}
)
assert response.status_code == 200
data = response.json()
assert data["name"] == "Updated Name"
@pytest.mark.asyncio
async def test_refresh_token(self, async_session, test_user):
"""测试刷新令牌"""
from app.services.auth import create_refresh_token
async def override_get_db():
yield async_session
app.dependency_overrides[get_db] = override_get_db
# 创建refresh token
refresh_token = create_refresh_token(data={"sub": str(test_user.id)})
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
response = await client.post(
"/api/v1/auth/refresh",
json={
"refresh_token": refresh_token
}
)
assert response.status_code == 200
data = response.json()
assert "access_token" in data
assert "refresh_token" in data
app.dependency_overrides.clear()
@pytest.mark.asyncio
async def test_forgot_password(self, async_session):
"""测试忘记密码请求"""
async def override_get_db():
yield async_session
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
response = await client.post(
"/api/v1/auth/forgot-password",
json={
"email": "test@example.com"
}
)
# 无论邮箱是否存在都返回成功,防止用户枚举
assert response.status_code == 200
assert "message" in response.json()
app.dependency_overrides.clear()
@pytest.mark.asyncio
async def test_resend_verification(self, async_session):
"""测试重新发送验证码"""
async def override_get_db():
yield async_session
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
response = await client.post(
"/api/v1/auth/resend-verification",
json={
"email": "test@example.com"
}
)
assert response.status_code == 200
assert "message" in response.json()
app.dependency_overrides.clear()
@pytest.mark.asyncio
async def test_unauthorized_access(self, async_session):
"""测试未授权访问受保护端点"""
async def override_get_db():
yield async_session
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
# 不带token访问受保护端点
response = await client.get("/api/v1/auth/me")
assert response.status_code == 401
app.dependency_overrides.clear()
@pytest.mark.asyncio
async def test_invalid_token(self, async_session):
"""测试无效令牌访问"""
async def override_get_db():
yield async_session
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
response = await client.get(
"/api/v1/auth/me",
headers={"Authorization": "Bearer invalid_token"}
)
assert response.status_code == 401
app.dependency_overrides.clear()

View File

@ -0,0 +1,321 @@
"""品牌API测试"""
import uuid
import pytest
import pytest_asyncio
from httpx import AsyncClient, ASGITransport
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession, create_async_engine
from sqlalchemy.pool import StaticPool
from app.database import Base
from app.main import app
from app.models.user import User
from app.models.brand import Brand
from app.api.deps import get_current_user, get_db
from app.services.auth import hash_password, create_access_token
# ==================== Fixtures ====================
@pytest_asyncio.fixture
async def async_engine():
"""创建测试用SQLite异步引擎"""
engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest_asyncio.fixture
async def async_session(async_engine):
"""创建测试用异步数据库会话"""
async_session_maker = async_sessionmaker(
async_engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
autocommit=False,
)
async with async_session_maker() as session:
yield session
@pytest_asyncio.fixture
async def test_user(async_session):
"""创建测试用户"""
user = User(
id=uuid.uuid4(),
email="test@example.com",
password_hash=hash_password("Test@123456"),
name="Test User",
plan="free",
max_queries=5,
is_active=True,
email_verified=True,
)
async_session.add(user)
await async_session.commit()
await async_session.refresh(user)
return user
@pytest_asyncio.fixture
async def test_brand(async_session, test_user):
"""创建测试品牌"""
brand = Brand(
id=uuid.uuid4(),
user_id=test_user.id,
name="Test Brand",
aliases=["TestBrand", "TB"],
website="https://testbrand.com",
industry="technology",
platforms=["wenxin", "kimi"],
frequency="weekly",
status="active",
)
async_session.add(brand)
await async_session.commit()
await async_session.refresh(brand)
return brand
@pytest_asyncio.fixture
async def async_client(async_session, test_user):
"""创建异步HTTP客户端用于API测试"""
async def override_get_db():
yield async_session
async def override_get_current_user():
return test_user
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_get_current_user
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client
app.dependency_overrides.clear()
@pytest.fixture
def auth_headers(test_user):
"""创建认证请求头"""
token = create_access_token(data={"sub": str(test_user.id)})
return {"Authorization": f"Bearer {token}"}
# ==================== 测试类 ====================
class TestBrandsAPI:
"""品牌API测试"""
@pytest.mark.asyncio
async def test_list_brands_empty(self, async_client):
"""测试获取空品牌列表"""
response = await async_client.get("/api/v1/brands/")
assert response.status_code == 200
data = response.json()
assert data["items"] == []
assert data["total"] == 0
@pytest.mark.asyncio
async def test_create_brand(self, async_client, async_session, test_user):
"""测试创建品牌"""
brand_data = {
"name": "华为",
"aliases": ["Huawei", "HW"],
"website": "https://www.huawei.com",
"industry": "technology",
"platforms": ["wenxin", "kimi", "doubao"],
"frequency": "daily",
}
response = await async_client.post("/api/v1/brands/", json=brand_data)
assert response.status_code == 201
data = response.json()
assert data["name"] == "华为"
assert data["aliases"] == ["Huawei", "HW"]
assert data["website"] == "https://www.huawei.com"
assert data["industry"] == "technology"
assert data["platforms"] == ["wenxin", "kimi", "doubao"]
assert data["frequency"] == "daily"
assert data["status"] == "active"
assert data["user_id"] == str(test_user.id)
assert "id" in data
@pytest.mark.asyncio
async def test_create_brand_minimal(self, async_client):
"""测试创建品牌(最小数据)"""
brand_data = {
"name": "minimal_brand",
}
response = await async_client.post("/api/v1/brands/", json=brand_data)
assert response.status_code == 201
data = response.json()
assert data["name"] == "minimal_brand"
assert data["aliases"] == []
assert data["platforms"] == ["wenxin", "kimi"] # 默认值
@pytest.mark.asyncio
async def test_list_brands(self, async_client, async_session, test_user):
"""测试列出多个品牌"""
# 创建多个品牌
for i in range(3):
brand = Brand(
user_id=test_user.id,
name=f"Brand {i}",
platforms=["wenxin"],
)
async_session.add(brand)
await async_session.commit()
response = await async_client.get("/api/v1/brands/")
assert response.status_code == 200
data = response.json()
assert len(data["items"]) == 3
assert data["total"] == 3
@pytest.mark.asyncio
async def test_get_brand_by_id(self, async_client, test_brand):
"""测试通过ID获取品牌"""
response = await async_client.get(f"/api/v1/brands/{test_brand.id}/")
assert response.status_code == 200
data = response.json()
assert data["id"] == str(test_brand.id)
assert data["name"] == "Test Brand"
assert data["aliases"] == ["TestBrand", "TB"]
@pytest.mark.asyncio
async def test_get_brand_not_found(self, async_client):
"""测试获取不存在的品牌"""
non_existent_id = uuid.uuid4()
response = await async_client.get(f"/api/v1/brands/{non_existent_id}/")
assert response.status_code == 404
assert "品牌不存在" in response.json()["detail"]
@pytest.mark.asyncio
async def test_update_brand(self, async_client, test_brand):
"""测试更新品牌"""
# 注意BrandUpdate schema 不允许更新 name 字段
update_data = {
"aliases": ["Updated", "Alias"],
"frequency": "daily",
}
response = await async_client.put(
f"/api/v1/brands/{test_brand.id}/", json=update_data
)
assert response.status_code == 200
data = response.json()
assert data["aliases"] == ["Updated", "Alias"]
assert data["frequency"] == "daily"
assert data["name"] == "Test Brand" # name 保持不变
@pytest.mark.asyncio
async def test_update_brand_partial(self, async_client, test_brand):
"""测试部分更新品牌"""
update_data = {
"frequency": "monthly",
}
response = await async_client.put(
f"/api/v1/brands/{test_brand.id}/", json=update_data
)
assert response.status_code == 200
data = response.json()
# 只更新frequency其他字段保持不变
assert data["frequency"] == "monthly"
assert data["name"] == "Test Brand"
assert data["aliases"] == ["TestBrand", "TB"]
@pytest.mark.asyncio
async def test_update_brand_not_found(self, async_client):
"""测试更新不存在的品牌"""
non_existent_id = uuid.uuid4()
response = await async_client.put(
f"/api/v1/brands/{non_existent_id}/", json={"name": "New Name"}
)
assert response.status_code == 404
@pytest.mark.asyncio
async def test_delete_brand(self, async_client, test_brand):
"""测试删除品牌"""
response = await async_client.delete(f"/api/v1/brands/{test_brand.id}/")
assert response.status_code == 204
# 验证品牌已删除
get_response = await async_client.get(f"/api/v1/brands/{test_brand.id}/")
assert get_response.status_code == 404
@pytest.mark.asyncio
async def test_delete_brand_not_found(self, async_client):
"""测试删除不存在的品牌"""
non_existent_id = uuid.uuid4()
response = await async_client.delete(f"/api/v1/brands/{non_existent_id}/")
assert response.status_code == 404
@pytest.mark.asyncio
async def test_brand_unauthorized_access(self, async_session):
"""测试未授权访问品牌API无效token"""
# 使用无效的token访问品牌API
async def override_get_db():
yield async_session
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
# 使用无效的token
headers = {"Authorization": "Bearer invalid_token"}
# 尝试访问品牌列表
response = await client.get(
"/api/v1/brands/",
headers=headers
)
# 无效token应该返回401
assert response.status_code == 401
app.dependency_overrides.clear()
class TestBrandsValidation:
"""品牌API验证测试"""
@pytest.mark.asyncio
async def test_create_brand_name_too_short(self, async_client):
"""测试品牌名称过短"""
brand_data = {
"name": "A", # 最小长度为2
}
response = await async_client.post("/api/v1/brands/", json=brand_data)
assert response.status_code == 422 # 验证错误
@pytest.mark.asyncio
async def test_create_brand_invalid_frequency(self, async_client):
"""测试无效的更新频率"""
brand_data = {
"name": "Valid Brand",
"frequency": "invalid_frequency",
}
response = await async_client.post("/api/v1/brands/", json=brand_data)
# frequency字段在BrandCreate中没有严格验证但应该是有效值
# 这里主要测试API不会崩溃
assert response.status_code in [201, 422]

View File

@ -0,0 +1,459 @@
"""内容API测试"""
import uuid
import pytest
import pytest_asyncio
from httpx import AsyncClient, ASGITransport
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession, create_async_engine
from sqlalchemy.pool import StaticPool
from app.database import Base
from app.main import app
from app.models.user import User
from app.models.brand import Brand
from app.api.deps import get_current_user, get_db
from app.services.auth import hash_password, create_access_token
# ==================== Fixtures ====================
@pytest_asyncio.fixture
async def async_engine():
"""创建测试用SQLite异步引擎"""
engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest_asyncio.fixture
async def async_session(async_engine):
"""创建测试用异步数据库会话"""
async_session_maker = async_sessionmaker(
async_engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
autocommit=False,
)
async with async_session_maker() as session:
yield session
@pytest_asyncio.fixture
async def test_user(async_session):
"""创建测试用户"""
user = User(
id=uuid.uuid4(),
email="test@example.com",
password_hash=hash_password("Test@123456"),
name="Test User",
plan="free",
max_queries=5,
is_active=True,
email_verified=True,
organization_id=uuid.uuid4(), # 需要organization_id用于内容API
)
async_session.add(user)
await async_session.commit()
await async_session.refresh(user)
return user
@pytest_asyncio.fixture
async def test_brand(async_session, test_user):
"""创建测试品牌"""
brand = Brand(
id=uuid.uuid4(),
user_id=test_user.id,
name="Test Brand",
aliases=["TestBrand", "TB"],
website="https://testbrand.com",
industry="technology",
platforms=["wenxin", "kimi"],
frequency="weekly",
status="active",
)
async_session.add(brand)
await async_session.commit()
await async_session.refresh(brand)
return brand
@pytest_asyncio.fixture
async def async_client(async_session, test_user):
"""创建异步HTTP客户端用于API测试"""
async def override_get_db():
yield async_session
async def override_get_current_user():
return test_user
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_get_current_user
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client
app.dependency_overrides.clear()
@pytest.fixture
def auth_headers(test_user):
"""创建认证请求头"""
token = create_access_token(data={"sub": str(test_user.id)})
return {"Authorization": f"Bearer {token}"}
# ==================== 测试类 ====================
class TestContentAPI:
"""内容管理API测试"""
@pytest.mark.asyncio
async def test_list_contents_empty(self, async_client):
"""测试获取空内容列表"""
response = await async_client.get("/api/v1/contents/")
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
assert len(data) == 0
@pytest.mark.asyncio
async def test_create_content(self, async_client, test_user):
"""测试创建内容"""
content_data = {
"title": "测试文章标题",
"body": "这是测试文章的内容",
"content_type": "article",
"tags": ["测试", "文章"],
}
response = await async_client.post("/api/v1/contents/", json=content_data)
assert response.status_code == 201
data = response.json()
assert data["title"] == "测试文章标题"
assert data["body"] == "这是测试文章的内容"
assert data["content_type"] == "article"
assert data["status"] == "draft"
assert "id" in data
@pytest.mark.asyncio
async def test_get_content_by_id(self, async_client, test_user):
"""测试通过ID获取内容"""
# 先创建内容
create_response = await async_client.post(
"/api/v1/contents/",
json={
"title": "测试内容",
"body": "测试内容正文",
"content_type": "article",
}
)
content_id = create_response.json()["id"]
# 获取内容
response = await async_client.get(f"/api/v1/contents/{content_id}")
assert response.status_code == 200
data = response.json()
assert data["id"] == content_id
assert data["title"] == "测试内容"
@pytest.mark.asyncio
async def test_get_content_not_found(self, async_client):
"""测试获取不存在的内容"""
non_existent_id = uuid.uuid4()
response = await async_client.get(f"/api/v1/contents/{non_existent_id}")
assert response.status_code == 404
@pytest.mark.asyncio
async def test_update_content(self, async_client, test_user):
"""测试更新内容"""
# 先创建内容
create_response = await async_client.post(
"/api/v1/contents/",
json={
"title": "原始标题",
"body": "原始内容",
"content_type": "article",
}
)
content_id = create_response.json()["id"]
# 更新内容
update_data = {
"title": "更新后的标题",
"body": "更新后的内容",
"status": "published",
}
response = await async_client.put(
f"/api/v1/contents/{content_id}", json=update_data
)
assert response.status_code == 200
data = response.json()
assert data["title"] == "更新后的标题"
assert data["body"] == "更新后的内容"
@pytest.mark.asyncio
async def test_update_content_partial(self, async_client, test_user):
"""测试部分更新内容"""
# 先创建内容
create_response = await async_client.post(
"/api/v1/contents/",
json={
"title": "原始标题",
"body": "原始内容",
}
)
content_id = create_response.json()["id"]
# 只更新标题
response = await async_client.put(
f"/api/v1/contents/{content_id}",
json={"title": "新标题"}
)
assert response.status_code == 200
data = response.json()
assert data["title"] == "新标题"
assert data["body"] == "原始内容" # 保持不变
@pytest.mark.asyncio
async def test_delete_content(self, async_client, test_user):
"""测试删除内容"""
# 先创建内容
create_response = await async_client.post(
"/api/v1/contents/",
json={
"title": "待删除内容",
"body": "内容正文",
}
)
content_id = create_response.json()["id"]
# 删除内容
response = await async_client.delete(f"/api/v1/contents/{content_id}")
assert response.status_code == 204
# 验证已删除
get_response = await async_client.get(f"/api/v1/contents/{content_id}")
assert get_response.status_code == 404
@pytest.mark.asyncio
async def test_publish_content(self, async_client, test_user):
"""测试发布内容"""
# 先创建内容
create_response = await async_client.post(
"/api/v1/contents/",
json={
"title": "待发布内容",
"body": "内容正文",
}
)
content_id = create_response.json()["id"]
# 发布内容
response = await async_client.post(f"/api/v1/contents/{content_id}/publish")
assert response.status_code == 200
data = response.json()
assert data["status"] == "published"
assert data["published_at"] is not None
@pytest.mark.asyncio
async def test_list_contents_with_filter(self, async_client, test_user):
"""测试内容列表过滤"""
# 创建多个不同类型的内容
for i, content_type in enumerate(["article", "post", "article"]):
await async_client.post(
"/api/v1/contents/",
json={
"title": f"内容 {i}",
"body": "正文",
"content_type": content_type,
}
)
# 按类型过滤
response = await async_client.get("/api/v1/contents/?content_type=article")
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
for item in data:
assert item["content_type"] == "article"
@pytest.mark.asyncio
async def test_list_contents_with_status_filter(self, async_client, test_user):
"""测试内容列表按状态过滤"""
# 创建一个草稿和一个已发布
draft_response = await async_client.post(
"/api/v1/contents/",
json={"title": "草稿", "body": "正文", "content_type": "article"}
)
published_response = await async_client.post(
"/api/v1/contents/",
json={"title": "已发布", "body": "正文", "content_type": "article"}
)
await async_client.post(f"/api/v1/contents/{published_response.json()['id']}/publish")
# 只获取草稿
response = await async_client.get("/api/v1/contents/?status=draft")
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
for item in data:
assert item["status"] == "draft"
class TestContentGenerationAPI:
"""内容生成API测试"""
@pytest.mark.asyncio
async def test_list_topics(self, async_client):
"""测试获取母题库列表"""
response = await async_client.get("/api/v1/content/topics")
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
# 应该有多个母题
assert len(data) > 0
# 验证母题结构
for topic in data:
assert "id" in topic
assert "name" in topic
assert "description" in topic
@pytest.mark.asyncio
async def test_get_topic_detail(self, async_client):
"""测试获取母题详情"""
# 获取第一个可用的topic id
topics_response = await async_client.get("/api/v1/content/topics")
topics = topics_response.json()
if len(topics) > 0:
topic_id = topics[0]["id"]
response = await async_client.get(f"/api/v1/content/topics/{topic_id}")
assert response.status_code == 200
data = response.json()
assert data["id"] == topic_id
assert "prompt_template" in data
@pytest.mark.asyncio
async def test_get_topic_not_found(self, async_client):
"""测试获取不存在的母题"""
response = await async_client.get("/api/v1/content/topics/nonexistent_topic")
assert response.status_code == 404
class TestBrandKnowledgeAPI:
"""品牌知识库API测试"""
@pytest.mark.asyncio
async def test_list_brand_knowledge_empty(self, async_client):
"""测试获取空品牌知识库"""
response = await async_client.get("/api/v1/contents/knowledge/")
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
assert len(data) == 0
@pytest.mark.asyncio
async def test_create_brand_knowledge(self, async_client, test_user):
"""测试创建品牌知识条目"""
knowledge_data = {
"category": "product",
"title": "产品介绍",
"body": "这是产品的详细介绍",
"source": "官网",
}
response = await async_client.post(
"/api/v1/contents/knowledge/", json=knowledge_data
)
assert response.status_code == 201
data = response.json()
assert data["title"] == "产品介绍"
assert data["category"] == "product"
assert data["body"] == "这是产品的详细介绍"
assert data["source"] == "官网"
assert "id" in data
@pytest.mark.asyncio
async def test_list_brand_knowledge_with_category(self, async_client, test_user):
"""测试按分类获取品牌知识"""
# 创建不同分类的知识
categories = ["product", "technology", "product"]
for i, cat in enumerate(categories):
await async_client.post(
"/api/v1/contents/knowledge/",
json={"category": cat, "title": f"知识 {i}", "body": "正文"}
)
# 按product分类过滤
response = await async_client.get("/api/v1/contents/knowledge/?category=product")
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
for item in data:
assert item["category"] == "product"
@pytest.mark.asyncio
async def test_update_brand_knowledge(self, async_client, test_user):
"""测试更新品牌知识"""
# 先创建知识
create_response = await async_client.post(
"/api/v1/contents/knowledge/",
json={"category": "test", "title": "原始标题", "body": "原始内容"}
)
knowledge_id = create_response.json()["id"]
# 更新
response = await async_client.put(
f"/api/v1/contents/knowledge/{knowledge_id}",
json={"title": "新标题", "body": "新内容"}
)
assert response.status_code == 200
data = response.json()
assert data["title"] == "新标题"
assert data["body"] == "新内容"
@pytest.mark.asyncio
async def test_delete_brand_knowledge(self, async_client, test_user):
"""测试删除品牌知识"""
# 先创建知识
create_response = await async_client.post(
"/api/v1/contents/knowledge/",
json={"category": "test", "title": "待删除", "body": "正文"}
)
knowledge_id = create_response.json()["id"]
# 删除
response = await async_client.delete(f"/api/v1/contents/knowledge/{knowledge_id}")
assert response.status_code == 204
# 验证删除 - 通过列出知识来确认
list_response = await async_client.get("/api/v1/contents/knowledge/")
knowledge_ids = [item["id"] for item in list_response.json()]
assert knowledge_id not in knowledge_ids

View File

@ -0,0 +1,221 @@
"""诊断API测试"""
import uuid
import pytest
import pytest_asyncio
from httpx import AsyncClient, ASGITransport
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession, create_async_engine
from sqlalchemy.pool import StaticPool
from app.database import Base
from app.main import app
from app.models.user import User
from app.models.brand import Brand
from app.api.deps import get_current_user, get_db
from app.services.auth import hash_password
@pytest_asyncio.fixture
async def async_engine():
"""创建测试用SQLite异步引擎"""
engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest_asyncio.fixture
async def async_session(async_engine):
"""创建测试用异步数据库会话"""
async_session_maker = async_sessionmaker(
async_engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
autocommit=False,
)
async with async_session_maker() as session:
yield session
@pytest_asyncio.fixture
async def test_user(async_session):
"""创建测试用户"""
user = User(
id=uuid.uuid4(),
email="test@example.com",
password_hash=hash_password("Test@123456"),
name="Test User",
plan="free",
max_queries=5,
is_active=True,
email_verified=True,
)
async_session.add(user)
await async_session.commit()
await async_session.refresh(user)
return user
@pytest_asyncio.fixture
async def test_brand(async_session, test_user):
"""创建测试品牌"""
brand = Brand(
id=uuid.uuid4(),
user_id=test_user.id,
name="Test Brand",
aliases=["TestBrand", "TB"],
website="https://testbrand.com",
industry="technology",
platforms=["wenxin", "kimi"],
frequency="weekly",
status="active",
)
async_session.add(brand)
await async_session.commit()
await async_session.refresh(brand)
return brand
@pytest_asyncio.fixture
async def async_client(async_session, test_user):
"""创建异步HTTP客户端用于API测试"""
async def override_get_db():
yield async_session
async def override_get_current_user():
return test_user
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_get_current_user
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client
app.dependency_overrides.clear()
class TestDiagnosisAPI:
"""诊断API测试"""
@pytest.mark.asyncio
async def test_seo_diagnosis_success(self, async_client, test_brand):
"""测试SEO诊断端点成功返回"""
response = await async_client.get(f"/api/v1/diagnosis/seo/{test_brand.id}")
assert response.status_code == 200
data = response.json()
assert "overall_score" in data
assert "health_level" in data
assert "dimensions" in data
assert "recommendations" in data
assert isinstance(data["overall_score"], (int, float))
assert isinstance(data["dimensions"], list)
assert isinstance(data["recommendations"], list)
@pytest.mark.asyncio
async def test_geo_diagnosis_success(self, async_client, test_brand):
"""测试GEO诊断端点成功返回"""
response = await async_client.get(f"/api/v1/diagnosis/geo/{test_brand.id}")
assert response.status_code == 200
data = response.json()
assert "overall_score" in data
assert "health_level" in data
assert "dimensions" in data
assert "recommendations" in data
assert isinstance(data["overall_score"], (int, float))
assert isinstance(data["dimensions"], list)
assert isinstance(data["recommendations"], list)
@pytest.mark.asyncio
async def test_combined_diagnosis_success(self, async_client, test_brand):
"""测试综合诊断端点成功返回"""
response = await async_client.get(f"/api/v1/diagnosis/combined/{test_brand.id}")
assert response.status_code == 200
data = response.json()
assert "seo_score" in data
assert "geo_score" in data
assert "combined_score" in data
assert "seo_diagnosis" in data
assert "geo_diagnosis" in data
assert isinstance(data["seo_score"], (int, float))
assert isinstance(data["geo_score"], (int, float))
assert isinstance(data["combined_score"], (int, float))
@pytest.mark.asyncio
async def test_diagnosis_brand_not_found(self, async_client):
"""测试品牌不存在时返回404"""
non_existent_id = uuid.uuid4()
seo_response = await async_client.get(f"/api/v1/diagnosis/seo/{non_existent_id}")
assert seo_response.status_code == 404
geo_response = await async_client.get(f"/api/v1/diagnosis/geo/{non_existent_id}")
assert geo_response.status_code == 404
combined_response = await async_client.get(f"/api/v1/diagnosis/combined/{non_existent_id}")
assert combined_response.status_code == 404
@pytest.mark.asyncio
async def test_diagnosis_unauthorized_access(self, async_session):
"""测试未认证时返回401"""
async def override_get_db():
yield async_session
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
headers = {"Authorization": "Bearer invalid_token"}
seo_response = await client.get(f"/api/v1/diagnosis/seo/{uuid.uuid4()}", headers=headers)
assert seo_response.status_code == 401
geo_response = await client.get(f"/api/v1/diagnosis/geo/{uuid.uuid4()}", headers=headers)
assert geo_response.status_code == 401
combined_response = await client.get(f"/api/v1/diagnosis/combined/{uuid.uuid4()}", headers=headers)
assert combined_response.status_code == 401
app.dependency_overrides.clear()
@pytest.mark.asyncio
async def test_diagnosis_result_format(self, async_client, test_brand):
"""测试诊断结果格式正确"""
response = await async_client.get(f"/api/v1/diagnosis/seo/{test_brand.id}")
assert response.status_code == 200
data = response.json()
assert 0 <= data["overall_score"] <= 100
assert data["health_level"] in ["excellent", "good", "pass", "danger"]
for dimension in data["dimensions"]:
assert "name" in dimension
assert "score" in dimension
assert "max_score" in dimension
assert "percentage" in dimension
assert "status" in dimension
assert "items" in dimension
assert isinstance(dimension["items"], list)
for item in dimension["items"]:
assert "name" in item
assert "status" in item
assert "description" in item
assert "suggestion" in item
assert "score" in item
for rec in data["recommendations"]:
assert "priority" in rec
assert "dimension" in rec
assert "description" in rec

View File

@ -0,0 +1,204 @@
"""健康检查API测试"""
import pytest
import pytest_asyncio
from httpx import AsyncClient, ASGITransport
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession, create_async_engine
from sqlalchemy.pool import StaticPool
from app.database import Base
from app.main import app
from app.api.deps import get_db
# ==================== Fixtures ====================
@pytest_asyncio.fixture
async def async_engine():
"""创建测试用SQLite异步引擎"""
engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest_asyncio.fixture
async def async_session(async_engine):
"""创建测试用异步数据库会话"""
async_session_maker = async_sessionmaker(
async_engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
autocommit=False,
)
async with async_session_maker() as session:
yield session
@pytest_asyncio.fixture
async def async_client(async_session):
"""创建异步HTTP客户端用于API测试"""
async def override_get_db():
yield async_session
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client
app.dependency_overrides.clear()
# ==================== 测试类 ====================
class TestHealthAPI:
"""健康检查API测试"""
@pytest.mark.asyncio
async def test_basic_health(self):
"""测试基本健康检查"""
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
response = await client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "healthy"
assert "timestamp" in data
@pytest.mark.asyncio
async def test_liveness(self):
"""测试存活探针"""
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
response = await client.get("/health/liveness")
assert response.status_code == 200
data = response.json()
assert data["status"] == "alive"
@pytest.mark.asyncio
async def test_ready_endpoint(self):
"""测试就绪端点"""
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
response = await client.get("/ready")
# 返回200或503取决于依赖服务状态
assert response.status_code in [200, 503]
data = response.json()
assert "status" in data
assert "checks" in data
assert "database" in data["checks"]
assert "redis" in data["checks"]
@pytest.mark.asyncio
async def test_metrics(self):
"""测试Prometheus指标端点"""
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
response = await client.get("/metrics")
assert response.status_code == 200
# Prometheus指标返回文本格式
assert "text/plain" in response.headers["content-type"] or \
"text/plain" in str(response.headers.get("content-type", ""))
@pytest.mark.asyncio
async def test_detailed_health(self, async_client):
"""测试详细健康检查"""
response = await async_client.get("/health/detailed")
assert response.status_code == 200
data = response.json()
# 检查返回结构
assert "checks" in data
assert "app" in data
@pytest.mark.asyncio
async def test_readiness_probe(self, async_client):
"""测试就绪探针"""
response = await async_client.get("/health/readiness")
# 健康状态应该是200unhealthy是503
assert response.status_code in [200, 503]
data = response.json()
assert "status" in data
assert "checks" in data
class TestHealthEndpointsStructure:
"""健康检查端点结构测试"""
@pytest.mark.asyncio
async def test_health_endpoint_returns_json(self):
"""测试健康端点返回JSON格式"""
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
response = await client.get("/health")
assert response.status_code == 200
assert "application/json" in response.headers.get("content-type", "")
@pytest.mark.asyncio
async def test_detailed_health_has_required_fields(self, async_client):
"""测试详细健康检查包含必需字段"""
response = await async_client.get("/health/detailed")
assert response.status_code == 200
data = response.json()
# 检查app信息
if "app" in data:
app_info = data["app"]
assert isinstance(app_info, dict)
@pytest.mark.asyncio
async def test_ready_checks_database_and_redis(self, async_client):
"""测试就绪检查包含数据库和Redis检查"""
response = await async_client.get("/ready")
data = response.json()
checks = data.get("checks", {})
assert "database" in checks
assert "redis" in checks
class TestHealthCheckIndependence:
"""健康检查独立性测试"""
@pytest.mark.asyncio
async def test_health_no_auth_required(self):
"""测试健康检查不需要认证"""
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
# 不带token访问
response = await client.get("/health")
assert response.status_code == 200
@pytest.mark.asyncio
async def test_liveness_no_auth_required(self):
"""测试存活探针不需要认证"""
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
response = await client.get("/health/liveness")
assert response.status_code == 200
@pytest.mark.asyncio
async def test_metrics_no_auth_required(self):
"""测试指标端点不需要认证"""
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
response = await client.get("/metrics")
assert response.status_code == 200

View File

@ -0,0 +1 @@
"""内容生成Pipeline测试包"""

View File

@ -0,0 +1,207 @@
"""内容生成Pipeline测试"""
import pytest
from app.services.content.content_pipeline import ContentPipeline, PipelineResponse
class TestContentPipeline:
"""内容Pipeline测试"""
@pytest.mark.asyncio
async def test_pipeline_complete_run(self):
"""测试完整Pipeline执行"""
pipeline = ContentPipeline()
request = {
"content": "这是一篇关于华为手机评测的深度文章,"
"详细介绍了华为Mate系列的特点和性能。",
"title": "华为手机评测",
"platform": "zhihu",
"optimize_for": ["validation", "sensitive", "seo"],
"output_formats": ["html", "markdown", "plain"],
"keyword": "华为手机"
}
result = await pipeline.run(request)
# 验证返回类型
assert isinstance(result, PipelineResponse)
assert result.stages is not None
assert len(result.stages) > 0
# 验证输出
assert result.outputs is not None
assert result.outputs.html is not None or result.outputs.markdown is not None
@pytest.mark.asyncio
async def test_pipeline_validation_only(self):
"""测试仅校验模式"""
pipeline = ContentPipeline()
request = {
"content": "这是一篇测试文章内容",
"title": "测试标题",
"platform": "zhihu",
"optimize_for": ["validation"]
}
result = await pipeline.run(request)
assert result.stages is not None
# 仅校验模式不生成输出
validation_stage = next((s for s in result.stages if s.name == "validation"), None)
assert validation_stage is not None
@pytest.mark.asyncio
async def test_pipeline_sensitive_filter_only(self):
"""测试仅敏感词过滤模式"""
pipeline = ContentPipeline()
request = {
"content": "这是一篇正常内容的文章",
"title": "测试标题",
"platform": "zhihu",
"optimize_for": ["sensitive"]
}
result = await pipeline.run(request)
assert result.stages is not None
assert len(result.stages) >= 1
@pytest.mark.asyncio
async def test_pipeline_seo_only(self):
"""测试仅SEO优化模式"""
pipeline = ContentPipeline()
request = {
"content": "华为手机是知名的手机品牌",
"title": "华为手机评测",
"platform": "zhihu",
"optimize_for": ["seo"],
"keyword": "华为手机"
}
result = await pipeline.run(request)
assert result.stages is not None
@pytest.mark.asyncio
async def test_pipeline_with_validation_fail(self):
"""测试校验失败场景"""
pipeline = ContentPipeline()
request = {
"content": "内容",
"title": "这个标题太长了超过了三十个字符的限制了哈哈哈啊",
"platform": "wechat",
"optimize_for": ["validation"]
}
result = await pipeline.run(request)
# 校验失败时不应继续执行后续阶段
validation_stage = next((s for s in result.stages if s.name == "validation"), None)
assert validation_stage is not None
assert validation_stage.passed is False
@pytest.mark.asyncio
async def test_pipeline_stage_timing(self):
"""测试各阶段耗时记录"""
pipeline = ContentPipeline()
request = {
"content": "这是一篇测试文章内容",
"title": "测试标题",
"platform": "zhihu",
"optimize_for": ["validation", "sensitive", "seo"]
}
result = await pipeline.run(request)
for stage in result.stages:
assert stage.name is not None
assert hasattr(stage, 'duration')
assert stage.duration >= 0
@pytest.mark.asyncio
async def test_pipeline_multi_platform(self):
"""测试多平台适配"""
pipeline = ContentPipeline()
platforms = ["zhihu", "wechat", "xiaohongshu"]
results = []
for platform in platforms:
result = await pipeline.run({
"content": "<p>测试内容</p>",
"title": "测试标题",
"platform": platform,
"optimize_for": ["validation", "sensitive"]
})
results.append(result)
assert result.stages is not None
# 各平台结果都应成功
assert len(results) == len(platforms)
@pytest.mark.asyncio
async def test_pipeline_output_formats(self):
"""测试不同输出格式"""
pipeline = ContentPipeline()
request = {
"content": "测试内容",
"title": "测试标题",
"platform": "zhihu",
"optimize_for": [],
"output_formats": ["html", "markdown", "plain"]
}
result = await pipeline.run(request)
assert result.outputs is not None
# HTML生成是默认的
assert result.outputs.html is not None
@pytest.mark.asyncio
async def test_pipeline_empty_optimize_for(self):
"""测试空的optimize_for参数"""
pipeline = ContentPipeline()
request = {
"content": "测试内容",
"title": "测试标题",
"platform": "zhihu",
"optimize_for": [],
"output_formats": ["html"]
}
result = await pipeline.run(request)
# 无优化阶段但应有HTML生成阶段
assert result.outputs is not None
@pytest.mark.asyncio
async def test_pipeline_error_handling(self):
"""测试无效平台错误处理"""
pipeline = ContentPipeline()
try:
result = await pipeline.run({
"content": "内容",
"title": "标题",
"platform": "invalid_platform"
})
# 应该返回错误
assert result.error is not None or len(result.stages) > 0
except ValueError as e:
assert "不支持的平台" in str(e)
@pytest.mark.asyncio
async def test_pipeline_validate_only_method(self):
"""测试validate_only方法"""
pipeline = ContentPipeline()
result = await pipeline.validate_only(
content="测试内容",
title="测试标题",
platform="zhihu"
)
assert result is not None
assert hasattr(result, 'is_valid')
assert hasattr(result, 'score')
assert hasattr(result, 'issues')
assert hasattr(result, 'passed')

View File

@ -0,0 +1,265 @@
"""HTML生成器测试"""
import pytest
from app.services.content.html_generator import HTMLGenerator
class TestHTMLGenerator:
"""HTML生成器测试"""
def test_generate_basic_html(self):
"""基础HTML生成"""
generator = HTMLGenerator()
html = generator.generate(
content="<p>这是测试内容</p>",
platform="zhihu"
)
assert html is not None
assert isinstance(html, str)
def test_generate_for_zhihu(self):
"""知乎平台HTML生成"""
generator = HTMLGenerator()
html = generator.generate(
content="<h1>华为手机评测</h1><p>这是一篇关于华为手机的详细评测文章。</p>",
platform="zhihu"
)
assert html is not None
assert "华为手机评测" in html
def test_generate_for_wechat(self):
"""微信公众号HTML生成"""
generator = HTMLGenerator()
html = generator.generate(
content="<p>华为手机非常好用</p>",
platform="wechat"
)
assert html is not None
def test_generate_for_xiaohongshu(self):
"""小红书HTML生成"""
generator = HTMLGenerator()
html = generator.generate(
content="<p>种草笔记内容</p>",
platform="xiaohongshu"
)
assert html is not None
def test_filter_banned_tags(self):
"""禁用标签过滤"""
generator = HTMLGenerator()
html = generator.generate(
content="<script>alert('xss')</script><p>正常内容</p>",
platform="zhihu"
)
# script标签应被移除
assert "<script>" not in html
assert "正常内容" in html
def test_filter_banned_tags_wechat_external_links(self):
"""微信公众号外部链接过滤"""
generator = HTMLGenerator()
html = generator.generate(
content="<a href='http://baidu.com'>外部链接</a><p>内容</p>",
platform="wechat"
)
# 微信公众号应过滤外部链接
assert "http://baidu.com" not in html
def test_filter_banned_tags_wechat_preserves_internal(self):
"""微信公众号保留内部链接"""
generator = HTMLGenerator()
html = generator.generate(
content="<a href='https://mp.weixin.qq.com/s/test'>内部链接</a><p>内容</p>",
platform="wechat"
)
# 微信公众号应保留内部链接
assert html is not None
def test_to_markdown_h1_conversion(self):
"""H1标签转Markdown"""
generator = HTMLGenerator()
md = generator.to_markdown("<h1>标题</h1>")
assert "# 标题" in md
def test_to_markdown_h2_conversion(self):
"""H2标签转Markdown"""
generator = HTMLGenerator()
md = generator.to_markdown("<h2>二级标题</h2>")
assert "## 二级标题" in md
def test_to_markdown_h3_conversion(self):
"""H3标签转Markdown"""
generator = HTMLGenerator()
md = generator.to_markdown("<h3>三级标题</h3>")
assert "### 三级标题" in md
def test_to_markdown_paragraph_conversion(self):
"""段落标签转Markdown"""
generator = HTMLGenerator()
md = generator.to_markdown("<p>段落内容</p>")
assert "段落内容" in md
def test_to_markdown_br_conversion(self):
"""换行标签转Markdown"""
generator = HTMLGenerator()
md = generator.to_markdown("第一行<br>第二行")
assert "第一行" in md
assert "第二行" in md
def test_to_markdown_list_conversion(self):
"""列表标签转Markdown"""
generator = HTMLGenerator()
md = generator.to_markdown("<li>列表项</li>")
assert "- 列表项" in md
def test_to_markdown_code_inline(self):
"""行内代码转Markdown"""
generator = HTMLGenerator()
md = generator.to_markdown("<code>代码</code>")
assert "`代码`" in md
def test_to_markdown_blockquote(self):
"""引用标签转Markdown"""
generator = HTMLGenerator()
md = generator.to_markdown("<blockquote>引用内容</blockquote>")
assert "> 引用内容" in md
def test_to_markdown_pre_block(self):
"""代码块转Markdown"""
generator = HTMLGenerator()
md = generator.to_markdown("<pre>代码块</pre>")
assert "```" in md
def test_to_markdown_strips残留_tags(self):
"""Markdown转换清理残留标签"""
generator = HTMLGenerator()
md = generator.to_markdown("<div>内容</div>")
# div标签应被移除
assert "<div>" not in md
def test_to_plain_text_basic(self):
"""纯文本基本转换"""
generator = HTMLGenerator()
plain = generator.to_plain("<h1>标题</h1><p>段落</p>")
assert "标题" in plain
assert "段落" in plain
def test_to_plain_text_removes_tags(self):
"""纯文本移除所有标签"""
generator = HTMLGenerator()
plain = generator.to_plain("<script>alert(1)</script><p>内容</p>")
assert "<script>" not in plain
assert "<p>" not in plain
def test_to_plain_text_decodes_html_entities(self):
"""纯文本解码HTML实体"""
generator = HTMLGenerator()
plain = generator.to_plain("&lt;&gt;&amp;&quot;")
assert "<" in plain
assert ">" in plain
assert "&" in plain
assert '"' in plain
def test_to_plain_text_removes_extra_spaces(self):
"""纯文本清理多余空格"""
generator = HTMLGenerator()
plain = generator.to_plain("内容 多个 空格")
assert " " not in plain
def test_to_plain_text_removes_extra_newlines(self):
"""纯文本清理多余换行"""
generator = HTMLGenerator()
plain = generator.to_plain("内容\n\n\n换行")
# 不应有超过2个连续换行
assert "\n\n\n" not in plain
def test_generate_format_html(self):
"""HTML格式输出"""
generator = HTMLGenerator()
html = generator.generate(
content="<p>内容</p>",
platform="zhihu",
format="html"
)
assert html is not None
def test_generate_format_markdown(self):
"""Markdown格式输出"""
generator = HTMLGenerator()
result = generator.generate(
content="<h1>标题</h1>",
platform="zhihu",
format="markdown"
)
assert "# 标题" in result
def test_generate_format_plain(self):
"""纯文本格式输出"""
generator = HTMLGenerator()
result = generator.generate(
content="<p>内容</p>",
platform="zhihu",
format="plain"
)
assert "内容" in result
assert "<p>" not in result
def test_generate_invalid_platform(self):
"""无效平台处理"""
generator = HTMLGenerator()
html = generator.generate(
content="<p>内容</p>",
platform="invalid_platform"
)
# 无效平台应返回原内容
assert html is not None
def test_generate_with_empty_content(self):
"""空内容生成"""
generator = HTMLGenerator()
html = generator.generate(
content="",
platform="zhihu"
)
assert html == ""
def test_to_markdown_empty_content(self):
"""空内容转Markdown"""
generator = HTMLGenerator()
md = generator.to_markdown("")
assert md == ""
def test_to_plain_empty_content(self):
"""空内容转纯文本"""
generator = HTMLGenerator()
plain = generator.to_plain("")
assert plain == ""

View File

@ -0,0 +1,218 @@
"""规则校验器测试"""
import pytest
from app.services.content.rule_validator import (
RuleValidator,
ValidationIssue,
ValidationResult,
AI_Pattern
)
class TestRuleValidator:
"""规则校验器测试"""
def test_validate_title_length_pass(self):
"""标题长度符合规则时返回passed"""
validator = RuleValidator()
result = validator.validate(
content="这是一篇关于AI医疗的深度分析文章...",
title="AI医疗的发展趋势与未来展望",
platform="zhihu"
)
assert result.is_valid is True
assert any("标题长度合规" in p or "合规" in p for p in result.passed)
def test_validate_title_length_fail(self):
"""标题长度超出限制时返回issue"""
validator = RuleValidator()
result = validator.validate(
content="内容",
title="这个标题太长了超过了三十个字符的限制了哈哈哈哈哈哈",
platform="wechat"
)
assert result.is_valid is False
assert any("超过" in i.message for i in result.issues if i.severity == "high")
def test_validate_content_length_pass(self):
"""内容长度符合规则时返回passed"""
validator = RuleValidator()
result = validator.validate(
content="A" * 1500,
title="测试标题",
platform="zhihu"
)
assert result.score >= 80
def test_validate_content_length_fail(self):
"""内容超长返回issue"""
validator = RuleValidator()
result = validator.validate(
content="A" * 30000,
title="测试标题",
platform="wechat"
)
assert any("超过" in i.message for i in result.issues if i.severity == "high")
def test_validate_zhihu_specific_rules(self):
"""知乎特定规则"""
validator = RuleValidator()
result = validator.validate(
content="这是一个专业回答",
title="专业回答",
platform="zhihu"
)
assert result.score > 0
def test_validate_wechat_inducing_content(self):
"""微信公众号诱导分享检测"""
validator = RuleValidator()
result = validator.validate(
content="转发本文领取红包",
title="限时优惠",
platform="wechat"
)
# 诱导分享应该被检测
assert any("诱导" in i.message for i in result.issues)
def test_validate_wechat_marketing_words(self):
"""微信公众号营销用语检测"""
validator = RuleValidator()
result = validator.validate(
content="点击购买,限时优惠",
title="优惠信息",
platform="wechat"
)
# 营销用语应该被检测
assert any("营销" in i.message for i in result.issues)
def test_validate_xiaohongshu_cross_platform(self):
"""小红书跨平台引流检测"""
validator = RuleValidator()
result = validator.validate(
content="微信公众号搜索xxx获取更多内容",
title="种草笔记",
platform="xiaohongshu"
)
# 小红书应检测跨平台引流
assert any("引流" in i.message for i in result.issues)
def test_validate_xiaohongshu_content_length(self):
"""小红书内容长度检测"""
validator = RuleValidator()
result = validator.validate(
content="A" * 100,
title="笔记",
platform="xiaohongshu"
)
# 内容过短应该被检测
assert any("300" in i.message for i in result.issues)
def test_detect_ai_patterns_banned_words(self):
"""检测禁用词"""
validator = RuleValidator()
result = validator.detect_ai_patterns(
content="首先,其次,最后,总而言之,总之,总之",
platform="zhihu"
)
assert len(result) > 0
assert any("首先" in r.pattern or "总之" in r.pattern for r in result)
def test_detect_ai_patterns_banned_structures(self):
"""检测禁用结构"""
validator = RuleValidator()
result = validator.detect_ai_patterns(
content="第一,观点一。第二,观点二。第三,观点三。",
platform="zhihu"
)
assert len(result) > 0
def test_detect_ai_patterns_no_banned(self):
"""无禁用词时返回空列表"""
validator = RuleValidator()
result = validator.detect_ai_patterns(
content="这是一篇正常的人类写作内容",
platform="zhihu"
)
assert isinstance(result, list)
def test_detect_ai_patterns_invalid_platform(self):
"""无效平台返回空列表"""
validator = RuleValidator()
result = validator.detect_ai_patterns(
content="内容",
platform="invalid_platform"
)
assert result == []
def test_get_optimization_tips(self):
"""获取优化建议"""
validator = RuleValidator()
tips = validator.get_optimization_tips("zhihu")
assert isinstance(tips, list)
def test_get_optimization_tips_invalid_platform(self):
"""无效平台返回空列表"""
validator = RuleValidator()
tips = validator.get_optimization_tips("invalid_platform")
assert tips == []
def test_validation_result_structure(self):
"""验证结果结构完整性"""
validator = RuleValidator()
result = validator.validate(
content="测试内容",
title="测试标题",
platform="zhihu"
)
assert isinstance(result, ValidationResult)
assert isinstance(result.is_valid, bool)
assert isinstance(result.score, int)
assert isinstance(result.issues, list)
assert isinstance(result.passed, list)
for issue in result.issues:
assert isinstance(issue, ValidationIssue)
assert issue.severity in ["high", "medium", "low"]
assert isinstance(issue.message, str)
assert isinstance(issue.category, str)
def test_validation_ai_pattern_structure(self):
"""AI特征结构完整性"""
validator = RuleValidator()
results = validator.detect_ai_patterns(
content="首先,其次,最后",
platform="zhihu"
)
for pattern in results:
assert isinstance(pattern, AI_Pattern)
assert isinstance(pattern.pattern, str)
assert pattern.type in ["banned_word", "banned_structure"]
assert pattern.severity in ["medium", "high"]
def test_validate_platform_rules(self):
"""测试平台规则校验返回passed列表"""
validator = RuleValidator()
result = validator.validate(
content="华为手机非常好用",
title="华为手机评测",
platform="zhihu"
)
assert hasattr(result, 'passed')
assert isinstance(result.passed, list)
def test_validation_score_calculation(self):
"""验证分数计算逻辑"""
validator = RuleValidator()
result = validator.validate(
content="A" * 100,
title="正常标题",
platform="zhihu"
)
# 无严重问题时应该得到较高的分数
high_severity_issues = [i for i in result.issues if i.severity == "high"]
if len(high_severity_issues) == 0:
assert result.score >= 80

View File

@ -0,0 +1,192 @@
"""敏感词过滤器测试"""
import pytest
from app.services.content.sensitive_filter import (
SensitiveFilter,
FoundWord,
FilterResult
)
class TestSensitiveFilter:
"""敏感词过滤器测试"""
def test_filter_politics_words(self):
"""政治敏感词过滤"""
filter_obj = SensitiveFilter()
result = filter_obj.filter(
content="这是一个关于台湾问题的分析",
platform="zhihu"
)
assert isinstance(result, FilterResult)
# 政治敏感词应被替换
assert "台湾" not in result.filtered_content or "**" in result.filtered_content
assert len(result.found_words) > 0
assert result.found_words[0].category == "politics"
def test_filter_medical_words(self):
"""医疗敏感词过滤"""
filter_obj = SensitiveFilter()
result = filter_obj.filter(
content="这个药品效果很好",
platform="wechat"
)
# 医疗类敏感词应被检测
assert result.found_words is not None
assert isinstance(result.found_words, list)
def test_filter_finance_words(self):
"""金融敏感词过滤"""
filter_obj = SensitiveFilter()
result = filter_obj.filter(
content="年化收益率10%",
platform="zhihu"
)
# 金融敏感词检测
assert result.found_words is not None
assert isinstance(result.found_words, list)
def test_filter_no_sensitive_words(self):
"""无敏感词时返回原内容"""
filter_obj = SensitiveFilter()
result = filter_obj.filter(
content="这是一个正常的产品介绍",
platform="zhihu"
)
assert result.filtered_content == "这是一个正常的产品介绍"
assert len(result.found_words) == 0
assert len(result.replacements) == 0
def test_filter_multiple_categories(self):
"""多分类同时过滤"""
filter_obj = SensitiveFilter()
result = filter_obj.filter(
content="这是内容包含政治和医疗敏感词的内容",
platform="wechat"
)
categories = [w.category for w in result.found_words]
assert len(set(categories)) >= 1
def test_add_custom_words(self):
"""添加自定义敏感词"""
filter_obj = SensitiveFilter()
filter_obj.add_custom_words("custom", ["自定义敏感词1", "自定义敏感词2"])
result = filter_obj.filter(
content="这是一段包含自定义敏感词1的内容",
platform="zhihu"
)
# 自定义敏感词应被替换
assert "自定义敏感词1" not in result.filtered_content
def test_filter_result_structure(self):
"""过滤结果结构完整性"""
filter_obj = SensitiveFilter()
result = filter_obj.filter(
content="测试内容",
platform="zhihu"
)
assert isinstance(result, FilterResult)
assert isinstance(result.filtered_content, str)
assert isinstance(result.found_words, list)
assert isinstance(result.replacements, dict)
def test_found_word_structure(self):
"""发现词汇结构完整性"""
filter_obj = SensitiveFilter()
result = filter_obj.filter(
content="台湾是中华人民共和国不可分割的一部分",
platform="zhihu"
)
if len(result.found_words) > 0:
word = result.found_words[0]
assert isinstance(word, FoundWord)
assert isinstance(word.word, str)
assert isinstance(word.category, str)
assert isinstance(word.position, int)
assert isinstance(word.replacement, str)
def test_filter_replacement_mapping(self):
"""敏感词替换映射"""
filter_obj = SensitiveFilter()
result = filter_obj.filter(
content="这是一个关于台湾和西藏的内容",
platform="zhihu"
)
# 替换映射应该包含被替换的词
assert len(result.replacements) > 0
for original, replacement in result.replacements.items():
assert len(original) == len(replacement)
assert replacement == "*" * len(original)
def test_filter_various_platforms(self):
"""不同平台敏感词过滤"""
filter_obj = SensitiveFilter()
platforms = ["zhihu", "wechat", "xiaohongshu", "douyin"]
for platform in platforms:
result = filter_obj.filter(
content="测试内容",
platform=platform
)
assert isinstance(result, FilterResult)
assert result.filtered_content is not None
def test_filter_empty_content(self):
"""空内容过滤"""
filter_obj = SensitiveFilter()
result = filter_obj.filter(
content="",
platform="zhihu"
)
assert result.filtered_content == ""
assert len(result.found_words) == 0
def test_filter_replacement_char(self):
"""自定义替换字符"""
filter_obj = SensitiveFilter()
filter_obj.replacement_char = "#"
result = filter_obj.filter(
content="台湾是中华人民共和国的一部分",
platform="zhihu"
)
# 替换字符应该是#
if len(result.replacements) > 0:
for original, replacement in result.replacements.items():
assert replacement == "#" * len(original)
def test_filter_with_zhihu_platform(self):
"""知乎平台敏感词过滤"""
filter_obj = SensitiveFilter()
result = filter_obj.filter(
content="这是一篇关于华为手机的专业评测",
platform="zhihu"
)
# 正常内容不应有敏感词
assert result.filtered_content == "这是一篇关于华为手机的专业评测"
def test_filter_found_words_positions(self):
"""敏感词位置记录"""
filter_obj = SensitiveFilter()
result = filter_obj.filter(
content="台湾位于中国大陆东南沿海",
platform="zhihu"
)
# 检查位置是否正确记录
if len(result.found_words) > 0:
for word in result.found_words:
assert word.position >= 0
assert result.filtered_content.find(word.replacement) >= 0

View File

@ -0,0 +1,211 @@
"""SEO优化器测试"""
import pytest
from app.services.content.seo_optimizer import SEOOptimizer, OptimizationResult
class TestSEOOptimizer:
"""SEO优化器测试"""
def test_get_keyword_density_basic(self):
"""关键词密度基本计算"""
optimizer = SEOOptimizer()
content = "华为手机华为手机华为手机"
density = optimizer.get_keyword_density(content, "华为手机")
assert density > 0
assert isinstance(density, float)
def test_get_keyword_density_zero_content(self):
"""空内容密度计算"""
optimizer = SEOOptimizer()
density = optimizer.get_keyword_density("", "华为")
assert density == 0.0
def test_get_keyword_density_zero_keyword(self):
"""空关键词密度计算"""
optimizer = SEOOptimizer()
density = optimizer.get_keyword_density("内容", "")
assert density == 0.0
def test_get_keyword_density_no_match(self):
"""关键词不匹配时密度"""
optimizer = SEOOptimizer()
content = "这是一篇关于苹果的文章"
density = optimizer.get_keyword_density(content, "华为")
assert density == 0.0
def test_optimize_with_keyword(self):
"""带关键词的SEO优化"""
optimizer = SEOOptimizer()
result = optimizer.optimize(
content="华为手机是知名的手机品牌,华为手机性能出色。",
title="华为手机评测",
platform="zhihu",
keyword="华为手机"
)
assert isinstance(result, OptimizationResult)
assert isinstance(result.density, float)
assert isinstance(result.suggestions, list)
assert isinstance(result.tips, list)
def test_optimize_without_keyword(self):
"""不带关键词的SEO优化"""
optimizer = SEOOptimizer()
result = optimizer.optimize(
content="这是一篇关于手机评测的文章",
title="手机评测",
platform="zhihu"
)
assert result.optimized_content is not None
assert result.density == 0.0
def test_optimize_density_too_low(self):
"""关键词密度过低建议"""
optimizer = SEOOptimizer()
result = optimizer.optimize(
content="这是一篇简短的内容",
title="标题",
platform="zhihu",
keyword="华为手机"
)
# 密度过低应有建议
if result.density < 1.0:
assert len(result.suggestions) > 0
def test_optimize_density_too_high(self):
"""关键词密度过高建议"""
optimizer = SEOOptimizer()
result = optimizer.optimize(
content="华为华为华为华为华为华为华为华为华为华为",
title="华为",
platform="zhihu",
keyword="华为"
)
# 密度过高应有建议
if result.density > 3.0:
assert any("超过" in s for s in result.suggestions)
def test_optimize_keyword_in_title(self):
"""关键词在标题中的检查"""
optimizer = SEOOptimizer()
result = optimizer.optimize(
content="手机内容",
title="华为手机评测",
platform="zhihu",
keyword="华为手机"
)
# 关键词在标题中应有通过的建议
assert isinstance(result.suggestions, list)
def test_optimize_keyword_in_first_paragraph(self):
"""关键词在前100字的检查"""
optimizer = SEOOptimizer()
result = optimizer.optimize(
content="华为手机是本文的主题...",
title="评测",
platform="zhihu",
keyword="华为手机"
)
assert isinstance(result.suggestions, list)
def test_optimize_different_platforms(self):
"""不同平台SEO优化"""
optimizer = SEOOptimizer()
platforms = ["zhihu", "wechat", "xiaohongshu", "baijiahao"]
for platform in platforms:
result = optimizer.optimize(
content="测试内容",
title="测试标题",
platform=platform,
keyword="测试"
)
assert isinstance(result, OptimizationResult)
def test_optimization_result_structure(self):
"""优化结果结构完整性"""
optimizer = SEOOptimizer()
result = optimizer.optimize(
content="华为手机是知名品牌",
title="华为评测",
platform="zhihu",
keyword="华为"
)
assert isinstance(result, OptimizationResult)
assert isinstance(result.optimized_content, str)
assert isinstance(result.density, float)
assert isinstance(result.suggestions, list)
assert isinstance(result.tips, list)
def test_optimize_returns_suggestions(self):
"""优化建议列表返回"""
optimizer = SEOOptimizer()
result = optimizer.optimize(
content="内容过短",
title="标题",
platform="zhihu",
keyword="关键词"
)
assert isinstance(result.suggestions, list)
def test_optimize_returns_tips(self):
"""优化提示列表返回"""
optimizer = SEOOptimizer()
result = optimizer.optimize(
content="测试内容",
title="标题",
platform="zhihu"
)
assert isinstance(result.tips, list)
def test_density_calculation_accuracy(self):
"""密度计算准确性"""
optimizer = SEOOptimizer()
# "华为" 4个字符出现2次 = 8 / 总字符数 * 100
content = "华为手机华为"
total_chars = len(content) # 6
keyword = "华为"
keyword_count = content.count(keyword) # 2
expected_density = (len(keyword) * keyword_count) / total_chars * 100
actual_density = optimizer.get_keyword_density(content, keyword)
assert abs(actual_density - round(expected_density, 2)) < 0.1
def test_optimize_empty_content(self):
"""空内容优化"""
optimizer = SEOOptimizer()
result = optimizer.optimize(
content="",
title="标题",
platform="zhihu",
keyword="关键词"
)
assert result.density == 0.0
def test_optimize_invalid_platform(self):
"""无效平台优化"""
optimizer = SEOOptimizer()
result = optimizer.optimize(
content="测试内容",
title="标题",
platform="invalid_platform",
keyword="关键词"
)
# 无效平台应返回空建议和提示
assert isinstance(result.suggestions, list)
assert isinstance(result.tips, list)

View File

@ -0,0 +1,496 @@
import pytest
from unittest.mock import Mock, patch, MagicMock
from app.services.email_service import (
EmailService,
EmailMessage,
EmailSendResult,
EMAIL_TEMPLATES,
)
class TestEmailMessage:
"""邮件消息数据结构测试"""
def test_email_message_creation(self):
"""测试EmailMessage数据创建"""
msg = EmailMessage(
to="test@example.com",
subject="测试邮件",
body_html="<h1>测试</h1>",
body_text="测试邮件",
)
assert msg.to == "test@example.com"
assert msg.subject == "测试邮件"
assert msg.body_html == "<h1>测试</h1>"
assert msg.body_text == "测试邮件"
assert msg.attachments == []
assert msg.metadata == {}
def test_email_message_with_attachments(self):
"""测试带附件的邮件消息"""
msg = EmailMessage(
to="test@example.com",
subject="带附件",
body_html="<p>内容</p>",
body_text="内容",
attachments=[{"filename": "report.pdf", "content": b"data"}],
)
assert len(msg.attachments) == 1
assert msg.attachments[0]["filename"] == "report.pdf"
def test_email_message_with_metadata(self):
"""测试带元数据的邮件消息"""
msg = EmailMessage(
to="test@example.com",
subject="测试",
body_html="<p>内容</p>",
body_text="内容",
metadata={"brand_name": "test_brand", "alert_type": "error"},
)
assert msg.metadata["brand_name"] == "test_brand"
class TestEmailSendResult:
"""邮件发送结果数据结构测试"""
def test_email_send_result_success(self):
"""测试成功发送结果"""
result = EmailSendResult(
success=True,
message_id="msg_123",
error=None,
retry_count=0,
)
assert result.success is True
assert result.message_id == "msg_123"
assert result.error is None
assert result.retry_count == 0
def test_email_send_result_failure(self):
"""测试失败发送结果"""
result = EmailSendResult(
success=False,
message_id=None,
error="SMTP连接失败",
retry_count=3,
)
assert result.success is False
assert result.message_id is None
assert result.error == "SMTP连接失败"
assert result.retry_count == 3
class TestEmailTemplateRendering:
"""邮件模板渲染测试"""
@pytest.fixture
def email_service(self):
"""创建邮件服务实例"""
return EmailService()
def test_render_alert_notification_template(self, email_service):
"""测试告警通知模板渲染"""
variables = {
"alert_type": "系统错误",
"brand_name": "测试品牌",
"severity": "",
"description": "数据库连接失败",
"timestamp": "2024-01-01 12:00:00",
}
msg = email_service.render_template("alert_notification", "admin@example.com", variables)
assert msg.to == "admin@example.com"
assert "[GEO平台] 告警通知:系统错误" in msg.subject
assert "测试品牌" in msg.body_html
assert "系统错误" in msg.body_html
assert "" in msg.body_html
def test_render_quota_warning_template(self, email_service):
"""测试额度预警模板渲染"""
variables = {
"quota_type": "API调用",
"usage_percentage": 85,
"used": 850,
"limit": 1000,
"recommended_action": "请升级套餐",
}
msg = email_service.render_template("quota_warning", "user@example.com", variables)
assert msg.to == "user@example.com"
assert "[GEO平台] 额度预警API调用" in msg.subject
assert "85%" in msg.body_html
assert "850" in msg.body_html
assert "1000" in msg.body_html
def test_render_template_missing_variables(self, email_service):
"""测试模板渲染缺少变量"""
variables = {"alert_type": "系统错误"}
msg = email_service.render_template("alert_notification", "admin@example.com", variables)
assert msg is not None
assert msg.to == "admin@example.com"
def test_render_template_invalid_template(self, email_service):
"""测试无效模板名称"""
with pytest.raises(ValueError, match="模板不存在"):
email_service.render_template(
"invalid_template",
"admin@example.com",
{"key": "value"},
)
def test_render_template_variable_substitution(self, email_service):
"""测试模板变量替换"""
variables = {
"brand_name": "品牌A",
"alert_type": "告警B",
"severity": "严重",
"description": "描述C",
"timestamp": "时间D",
}
msg = email_service.render_template("alert_notification", "test@example.com", variables)
assert "品牌A" in msg.body_html
assert "告警B" in msg.body_html
assert "严重" in msg.body_html
assert "描述C" in msg.body_html
assert "时间D" in msg.body_html
class TestEmailGeneration:
"""邮件内容生成测试"""
@pytest.fixture
def email_service(self):
"""创建邮件服务实例"""
return EmailService()
def test_generate_alert_notification_email(self, email_service):
"""测试生成告警通知邮件"""
msg = email_service.generate_alert_email(
to="admin@example.com",
alert_type="数据库告警",
brand_name="测试品牌",
severity="",
description="数据库CPU使用率超过90%",
timestamp="2024-01-01 12:00:00",
)
assert msg.to == "admin@example.com"
assert "数据库告警" in msg.subject
assert "测试品牌" in msg.body_html
assert "" in msg.body_html
assert msg.body_text != ""
def test_generate_quota_warning_email(self, email_service):
"""测试生成额度预警邮件"""
msg = email_service.generate_quota_warning_email(
to="user@example.com",
quota_type="API调用",
usage_percentage=85,
used=850,
limit=1000,
recommended_action="建议升级套餐",
)
assert msg.to == "user@example.com"
assert "API调用" in msg.subject
assert "85%" in msg.body_html
assert "850" in msg.body_html
assert "1000" in msg.body_html
def test_generate_email_has_both_formats(self, email_service):
"""测试生成的邮件包含HTML和纯文本格式"""
msg = email_service.generate_alert_email(
to="test@example.com",
alert_type="测试",
brand_name="品牌",
severity="",
description="描述",
timestamp="时间",
)
assert msg.body_html != ""
assert msg.body_text != ""
assert "<" in msg.body_html
assert "<" not in msg.body_text
class TestEmailSending:
"""邮件发送测试"""
@pytest.fixture
def email_service(self):
"""创建邮件服务实例(模拟模式)"""
return EmailService(simulate_mode=True)
def test_send_email_simulate_mode(self, email_service):
"""测试模拟模式发送邮件"""
msg = EmailMessage(
to="test@example.com",
subject="测试",
body_html="<p>测试</p>",
body_text="测试",
)
result = email_service.send_email(msg)
assert result.success is True
assert result.message_id is not None
assert result.error is None
@patch("smtplib.SMTP")
def test_send_email_real_smtp(self, mock_smtp):
"""测试真实SMTP发送模拟SMTP"""
mock_server = MagicMock()
mock_smtp.return_value = mock_server
service = EmailService(
simulate_mode=False,
smtp_host="smtp.example.com",
smtp_port=587,
smtp_user="user@example.com",
smtp_password="password",
)
msg = EmailMessage(
to="recipient@example.com",
subject="测试",
body_html="<p>内容</p>",
body_text="内容",
)
result = service.send_email(msg)
assert result.success is True
mock_smtp.assert_called_once_with("smtp.example.com", 587)
mock_server.starttls.assert_called_once()
mock_server.login.assert_called_once_with("user@example.com", "password")
@patch("smtplib.SMTP")
def test_send_email_smtp_failure(self, mock_smtp):
"""测试SMTP发送失败"""
mock_smtp.side_effect = Exception("连接失败")
service = EmailService(
simulate_mode=False,
smtp_host="smtp.example.com",
smtp_port=587,
smtp_user="user@example.com",
smtp_password="password",
)
msg = EmailMessage(
to="test@example.com",
subject="测试",
body_html="<p>内容</p>",
body_text="内容",
)
result = service.send_email(msg)
assert result.success is False
assert result.error is not None
assert "连接失败" in result.error
@patch("smtplib.SMTP")
def test_send_email_with_retry(self, mock_smtp):
"""测试邮件发送重试"""
mock_server = MagicMock()
mock_smtp.return_value = mock_server
service = EmailService(
simulate_mode=False,
smtp_host="smtp.example.com",
smtp_port=587,
smtp_user="user@example.com",
smtp_password="password",
max_retries=3,
)
msg = EmailMessage(
to="test@example.com",
subject="测试",
body_html="<p>内容</p>",
body_text="内容",
)
result = service.send_email(msg)
assert result.success is True
assert result.retry_count == 0
class TestEmailQueue:
"""邮件队列测试"""
@pytest.fixture
def email_service(self):
"""创建邮件服务实例"""
return EmailService(simulate_mode=True)
def test_add_to_queue(self, email_service):
"""测试添加邮件到队列"""
msg = EmailMessage(
to="test@example.com",
subject="测试",
body_html="<p>内容</p>",
body_text="内容",
)
email_service.add_to_queue(msg)
assert len(email_service.get_queue()) == 1
assert email_service.get_queue()[0].to == "test@example.com"
def test_add_multiple_to_queue(self, email_service):
"""测试批量添加邮件到队列"""
messages = [
EmailMessage(to=f"user{i}@example.com", subject=f"测试{i}", body_html="<p>内容</p>", body_text="内容")
for i in range(5)
]
for msg in messages:
email_service.add_to_queue(msg)
assert len(email_service.get_queue()) == 5
def test_send_queue(self, email_service):
"""测试发送队列中的邮件"""
messages = [
EmailMessage(to=f"user{i}@example.com", subject=f"测试{i}", body_html="<p>内容</p>", body_text="内容")
for i in range(3)
]
for msg in messages:
email_service.add_to_queue(msg)
results = email_service.send_queue()
assert len(results) == 3
assert all(r.success for r in results)
assert len(email_service.get_queue()) == 0
def test_send_queue_empty(self, email_service):
"""测试发送空队列"""
results = email_service.send_queue()
assert results == []
def test_clear_queue(self, email_service):
"""测试清空队列"""
msg = EmailMessage(
to="test@example.com",
subject="测试",
body_html="<p>内容</p>",
body_text="内容",
)
email_service.add_to_queue(msg)
assert len(email_service.get_queue()) == 1
email_service.clear_queue()
assert len(email_service.get_queue()) == 0
class TestEmailAttachments:
"""邮件附件测试"""
@pytest.fixture
def email_service(self):
"""创建邮件服务实例"""
return EmailService(simulate_mode=True)
def test_add_attachment_to_message(self, email_service):
"""测试添加附件到邮件"""
msg = EmailMessage(
to="test@example.com",
subject="带附件",
body_html="<p>内容</p>",
body_text="内容",
)
email_service.add_attachment(msg, "report.pdf", b"PDF content")
assert len(msg.attachments) == 1
assert msg.attachments[0]["filename"] == "report.pdf"
assert msg.attachments[0]["content"] == b"PDF content"
def test_add_multiple_attachments(self, email_service):
"""测试添加多个附件"""
msg = EmailMessage(
to="test@example.com",
subject="多附件",
body_html="<p>内容</p>",
body_text="内容",
)
email_service.add_attachment(msg, "file1.pdf", b"content1")
email_service.add_attachment(msg, "file2.xlsx", b"content2")
assert len(msg.attachments) == 2
assert msg.attachments[0]["filename"] == "file1.pdf"
assert msg.attachments[1]["filename"] == "file2.xlsx"
class TestEmailValidation:
"""邮箱地址验证测试"""
@pytest.fixture
def email_service(self):
"""创建邮件服务实例"""
return EmailService()
def test_validate_valid_email(self, email_service):
"""测试有效邮箱地址"""
assert email_service.validate_email("test@example.com") is True
assert email_service.validate_email("user.name@domain.org") is True
assert email_service.validate_email("user+tag@example.com") is True
def test_validate_invalid_email(self, email_service):
"""测试无效邮箱地址"""
assert email_service.validate_email("invalid") is False
assert email_service.validate_email("invalid@") is False
assert email_service.validate_email("@example.com") is False
assert email_service.validate_email("test@") is False
assert email_service.validate_email("") is False
def test_send_to_invalid_email(self, email_service):
"""测试发送到无效邮箱"""
msg = EmailMessage(
to="invalid_email",
subject="测试",
body_html="<p>内容</p>",
body_text="内容",
)
result = email_service.send_email(msg)
assert result.success is False
assert result.error is not None
class TestEmailTemplatesConstants:
"""邮件模板常量测试"""
def test_email_templates_exist(self):
"""测试邮件模板存在"""
assert "alert_notification" in EMAIL_TEMPLATES
assert "quota_warning" in EMAIL_TEMPLATES
def test_alert_notification_template_structure(self):
"""测试告警通知模板结构"""
template = EMAIL_TEMPLATES["alert_notification"]
assert "subject" in template
assert "body_html" in template
assert "body_text" in template
assert "{alert_type}" in template["subject"]
assert "{brand_name}" in template["body_html"]
def test_quota_warning_template_structure(self):
"""测试额度预警模板结构"""
template = EMAIL_TEMPLATES["quota_warning"]
assert "subject" in template
assert "body_html" in template
assert "body_text" in template
assert "{quota_type}" in template["subject"]
assert "{usage_percentage}" in template["body_html"]

View File

@ -0,0 +1,605 @@
"""
GEO诊断服务单元测试
测试6大维度诊断逻辑评分算法推荐生成和服务类
"""
import pytest
from app.services.geo_diagnosis import (
GEODiagnosisService,
GEODiagnosisInput,
diagnose_content_extractability,
diagnose_entity_clarity,
diagnose_eeat_signals,
diagnose_schema_markup,
diagnose_topic_authority,
diagnose_citation_readiness,
generate_recommendations,
get_health_level,
get_health_level_label,
)
class TestContentExtractability:
"""内容可提取性诊断测试"""
def test_all_pass(self):
"""所有项都通过"""
result = diagnose_content_extractability(
has_direct_answer=True,
has_qa_headings=True,
has_structured_data=True,
has_internal_links=True,
has_freshness_info=True,
update_days_ago=10,
)
assert result.score == 20.0
assert result.max_score == 20.0
assert result.status == "pass"
assert result.percentage == 100.0
def test_all_fail(self):
"""所有项都失败"""
result = diagnose_content_extractability(
has_direct_answer=False,
has_qa_headings=False,
has_structured_data=False,
has_internal_links=False,
has_freshness_info=False,
)
assert result.score == 0.0
assert result.status == "warning"
def test_partial_pass(self):
"""部分通过"""
result = diagnose_content_extractability(
has_direct_answer=True,
has_qa_headings=True,
has_structured_data=False,
has_internal_links=False,
has_freshness_info=False,
)
assert result.score == 11.0 # 6 + 5
def test_freshness_recent(self):
"""内容新鲜度 - 近期更新"""
result = diagnose_content_extractability(
has_freshness_info=True,
update_days_ago=10,
)
freshness_item = [i for i in result.items if i.name == "内容新鲜度"][0]
assert freshness_item.score == 2.0
assert freshness_item.status == "pass"
def test_freshness_old(self):
"""内容新鲜度 - 过期更新"""
result = diagnose_content_extractability(
has_freshness_info=True,
update_days_ago=100,
)
freshness_item = [i for i in result.items if i.name == "内容新鲜度"][0]
assert freshness_item.score == 0.5
assert freshness_item.status == "warning"
class TestEntityClarity:
"""实体清晰度诊断测试"""
def test_all_pass(self):
"""所有项都通过"""
result = diagnose_entity_clarity(
has_brand_definition=True,
has_target_audience=True,
has_unique_value=True,
has_industry_classification=True,
)
assert result.score == 15.0
assert result.max_score == 15.0
assert result.status == "pass"
def test_all_fail(self):
"""所有项都失败"""
result = diagnose_entity_clarity(
has_brand_definition=False,
has_target_audience=False,
has_unique_value=False,
has_industry_classification=False,
)
assert result.score == 0.0
assert result.status == "warning"
def test_partial_pass(self):
"""部分通过"""
result = diagnose_entity_clarity(
has_brand_definition=True,
has_target_audience=True,
has_unique_value=False,
has_industry_classification=False,
)
assert result.score == 9.0 # 5 + 4
class TestEEATSignals:
"""E-E-A-T信号诊断测试"""
def test_all_pass(self):
"""所有项都通过"""
result = diagnose_eeat_signals(
has_author_bio=True,
author_credentials_complete=1.0,
has_certifications=True,
certification_count=5,
has_data_sources=True,
authoritative_source_ratio=1.0,
has_expert_endorsements=True,
endorsement_count=5,
)
assert result.score == 20.0
assert result.max_score == 20.0
assert result.status == "pass"
def test_all_fail(self):
"""所有项都失败"""
result = diagnose_eeat_signals(
has_author_bio=False,
has_certifications=False,
has_data_sources=False,
has_expert_endorsements=False,
)
assert result.score == 0.0
assert result.status == "warning"
def test_author_partial(self):
"""作者资质部分完成"""
result = diagnose_eeat_signals(
has_author_bio=True,
author_credentials_complete=0.5,
)
author_item = [i for i in result.items if i.name == "作者资质"][0]
assert author_item.score == 3.0 # 0.5 * 6.0
assert author_item.status == "warning"
def test_certification_tiers(self):
"""认证数量分级测试"""
# 5个以上
result = diagnose_eeat_signals(has_certifications=True, certification_count=5)
cert_item = [i for i in result.items if i.name == "专业认证"][0]
assert cert_item.score == 5.0
# 3-4个
result = diagnose_eeat_signals(has_certifications=True, certification_count=3)
cert_item = [i for i in result.items if i.name == "专业认证"][0]
assert cert_item.score == 4.0
# 1-2个
result = diagnose_eeat_signals(has_certifications=True, certification_count=1)
cert_item = [i for i in result.items if i.name == "专业认证"][0]
assert cert_item.score == 2.5
class TestSchemaMarkup:
"""Schema标记诊断测试"""
def test_all_pass(self):
"""所有项都通过"""
result = diagnose_schema_markup(
has_organization=True,
has_product=True,
has_article=True,
has_faq=True,
has_howto=True,
has_breadcrumb=True,
)
assert result.score == 15.0
assert result.max_score == 15.0
assert result.status == "pass"
def test_all_fail(self):
"""所有项都失败"""
result = diagnose_schema_markup(
has_organization=False,
has_product=False,
has_article=False,
has_faq=False,
has_howto=False,
has_breadcrumb=False,
)
assert result.score == 0.0
assert result.status == "warning"
def test_p0_only(self):
"""仅P0必须项"""
result = diagnose_schema_markup(
has_organization=True,
has_product=True,
has_article=True,
)
assert result.score == 10.0 # 4 + 3 + 3
assert result.status == "pass"
def test_schema_count(self):
"""Schema计数"""
result = diagnose_schema_markup(
has_organization=True,
has_product=True,
)
assert result.detail["schema_count"] == 2
class TestTopicAuthority:
"""主题权威诊断测试"""
def test_all_pass(self):
"""所有项都通过"""
result = diagnose_topic_authority(
content_depth_score=1.0,
topic_coverage_ratio=1.0,
entity_consistency_score=1.0,
cluster_completeness=1.0,
)
assert result.score == 15.0
assert result.max_score == 15.0
assert result.status == "pass"
def test_all_fail(self):
"""所有项都失败"""
result = diagnose_topic_authority(
content_depth_score=0.0,
topic_coverage_ratio=0.0,
entity_consistency_score=0.0,
cluster_completeness=0.0,
)
assert result.score == 0.0
assert result.status == "warning"
def test_partial_scores(self):
"""部分分数"""
result = diagnose_topic_authority(
content_depth_score=0.8,
topic_coverage_ratio=0.5,
entity_consistency_score=0.7,
cluster_completeness=0.4,
)
# 0.8*5 + 0.5*4 + 0.7*3 + 0.4*3 = 4 + 2 + 2.1 + 1.2 = 9.3
assert result.score == pytest.approx(9.3, rel=0.01)
class TestCitationReadiness:
"""引用就绪度诊断测试"""
def test_all_pass(self):
"""所有项都通过"""
result = diagnose_citation_readiness(
answer_ownership_rate=0.6,
citation_accuracy=1.0,
ai_sov=0.35,
competitor_gap=0.0,
)
assert result.score == 15.0
assert result.max_score == 15.0
assert result.status == "pass"
def test_all_fail(self):
"""所有项都失败"""
result = diagnose_citation_readiness(
answer_ownership_rate=0.0,
citation_accuracy=0.0,
ai_sov=0.0,
competitor_gap=0.6,
)
assert result.score == 0.0
assert result.status == "warning"
def test_aor_tiers(self):
"""AOR分级测试"""
# >= 50%
result = diagnose_citation_readiness(answer_ownership_rate=0.5)
aor_item = [i for i in result.items if i.name == "引用频率 (AOR)"][0]
assert aor_item.score == 5.0
# 30-49%
result = diagnose_citation_readiness(answer_ownership_rate=0.3)
aor_item = [i for i in result.items if i.name == "引用频率 (AOR)"][0]
assert aor_item.score == 3.5
# 10-29%
result = diagnose_citation_readiness(answer_ownership_rate=0.1)
aor_item = [i for i in result.items if i.name == "引用频率 (AOR)"][0]
assert aor_item.score == 2.0
def test_competitor_gap_tiers(self):
"""竞品差距分级测试"""
# <= 10pp
result = diagnose_citation_readiness(competitor_gap=0.05)
gap_item = [i for i in result.items if i.name == "竞品对比"][0]
assert gap_item.score == 3.0
# 10-20pp
result = diagnose_citation_readiness(competitor_gap=0.15)
gap_item = [i for i in result.items if i.name == "竞品对比"][0]
assert gap_item.score == 2.0
class TestRecommendations:
"""推荐生成测试"""
def test_generate_from_fail_items(self):
"""从fail项生成P0推荐"""
dimensions = [
diagnose_content_extractability(
has_direct_answer=False,
has_qa_headings=False,
),
]
recommendations = generate_recommendations(dimensions)
assert len(recommendations) >= 2
# 检查有P0推荐不一定是全部
p0_recs = [r for r in recommendations if r.priority == "P0"]
assert len(p0_recs) >= 2
def test_generate_from_warning_items(self):
"""从warning项生成P1推荐"""
dimensions = [
diagnose_content_extractability(
has_direct_answer=True,
has_qa_headings=True,
has_structured_data=True,
has_internal_links=False,
has_freshness_info=True,
update_days_ago=50,
),
]
recommendations = generate_recommendations(dimensions)
p1_recs = [r for r in recommendations if r.priority == "P1"]
assert len(p1_recs) >= 1
def test_priority_ordering(self):
"""推荐按优先级排序"""
dimensions = [
diagnose_content_extractability(
has_direct_answer=False,
has_qa_headings=False,
has_structured_data=True,
has_internal_links=False,
has_freshness_info=True,
update_days_ago=50,
),
]
recommendations = generate_recommendations(dimensions)
priorities = [r.priority for r in recommendations]
assert priorities == sorted(priorities)
def test_empty_dimensions(self):
"""空维度列表"""
recommendations = generate_recommendations([])
assert len(recommendations) == 0
class TestHealthLevel:
"""健康等级测试"""
def test_excellent(self):
assert get_health_level(85) == "excellent"
assert get_health_level(80) == "excellent"
def test_good(self):
assert get_health_level(70) == "good"
assert get_health_level(60) == "good"
def test_pass(self):
assert get_health_level(50) == "pass"
assert get_health_level(40) == "pass"
def test_danger(self):
assert get_health_level(30) == "danger"
assert get_health_level(0) == "danger"
def test_labels(self):
assert get_health_level_label("excellent") == "优秀"
assert get_health_level_label("good") == "良好"
assert get_health_level_label("pass") == "及格"
assert get_health_level_label("danger") == "危险"
class TestGEODiagnosisService:
"""GEO诊断服务类测试"""
@pytest.fixture
def service(self):
return GEODiagnosisService()
def test_full_diagnosis_all_pass(self, service):
"""完整诊断 - 所有项通过"""
input_data = GEODiagnosisInput(
# 内容可提取性
has_direct_answer=True,
has_qa_headings=True,
has_structured_data=True,
has_internal_links=True,
has_freshness_info=True,
update_days_ago=10,
# 实体清晰度
has_brand_definition=True,
has_target_audience=True,
has_unique_value=True,
has_industry_classification=True,
# E-E-A-T
has_author_bio=True,
author_credentials_complete=1.0,
has_certifications=True,
certification_count=5,
has_data_sources=True,
authoritative_source_ratio=1.0,
has_expert_endorsements=True,
endorsement_count=5,
# Schema
has_organization=True,
has_product=True,
has_article=True,
has_faq=True,
has_howto=True,
has_breadcrumb=True,
# 主题权威
content_depth_score=1.0,
topic_coverage_ratio=1.0,
entity_consistency_score=1.0,
cluster_completeness=1.0,
# 引用就绪度
answer_ownership_rate=0.6,
citation_accuracy=1.0,
ai_sov=0.35,
competitor_gap=0.0,
)
result = service.diagnose(input_data)
assert result.overall_score == 100.0
assert result.health_level == "excellent"
assert len(result.dimensions) == 6
assert len(result.recommendations) == 0
def test_full_diagnosis_all_fail(self, service):
"""完整诊断 - 所有项失败"""
input_data = GEODiagnosisInput()
result = service.diagnose(input_data)
# 由于有些项在默认情况下会得少量分数总分不一定是0
assert result.overall_score < 10.0
assert result.health_level == "danger"
assert len(result.dimensions) == 6
assert len(result.recommendations) > 0
def test_diagnose_from_dict(self, service):
"""从字典执行诊断"""
data = {
"has_direct_answer": True,
"has_brand_definition": True,
"has_author_bio": True,
"author_credentials_complete": 0.8,
"has_organization": True,
"content_depth_score": 0.8,
"answer_ownership_rate": 0.5,
}
result = service.diagnose_from_dict(data)
assert result.overall_score > 0
assert len(result.dimensions) == 6
def test_result_to_dict(self, service):
"""结果转字典"""
input_data = GEODiagnosisInput(
has_direct_answer=True,
has_brand_definition=True,
)
result = service.diagnose(input_data)
result_dict = result.to_dict()
assert "overall_score" in result_dict
assert "health_level" in result_dict
assert "health_level_label" in result_dict
assert "dimensions" in result_dict
assert "recommendations" in result_dict
assert len(result_dict["dimensions"]) == 6
def test_score_boundaries(self, service):
"""评分边界测试"""
# 最低分
result = service.diagnose(GEODiagnosisInput())
assert result.overall_score >= 0.0
# 最高分
input_data = GEODiagnosisInput(
has_direct_answer=True,
has_qa_headings=True,
has_structured_data=True,
has_internal_links=True,
has_freshness_info=True,
update_days_ago=10,
has_brand_definition=True,
has_target_audience=True,
has_unique_value=True,
has_industry_classification=True,
has_author_bio=True,
author_credentials_complete=1.0,
has_certifications=True,
certification_count=5,
has_data_sources=True,
authoritative_source_ratio=1.0,
has_expert_endorsements=True,
endorsement_count=5,
has_organization=True,
has_product=True,
has_article=True,
has_faq=True,
has_howto=True,
has_breadcrumb=True,
content_depth_score=1.0,
topic_coverage_ratio=1.0,
entity_consistency_score=1.0,
cluster_completeness=1.0,
answer_ownership_rate=0.6,
citation_accuracy=0.95,
ai_sov=0.35,
competitor_gap=0.05,
)
result = service.diagnose(input_data)
assert result.overall_score <= 100.0
def test_health_levels(self, service):
"""健康等级测试"""
# excellent (>= 80)
input_data = GEODiagnosisInput(
has_direct_answer=True,
has_qa_headings=True,
has_structured_data=True,
has_internal_links=True,
has_freshness_info=True,
update_days_ago=10,
has_brand_definition=True,
has_target_audience=True,
has_unique_value=True,
has_industry_classification=True,
has_author_bio=True,
author_credentials_complete=0.9,
has_certifications=True,
certification_count=4,
has_data_sources=True,
authoritative_source_ratio=0.9,
has_expert_endorsements=True,
endorsement_count=4,
has_organization=True,
has_product=True,
has_article=True,
has_faq=True,
has_howto=True,
has_breadcrumb=True,
content_depth_score=0.9,
topic_coverage_ratio=0.9,
entity_consistency_score=0.9,
cluster_completeness=0.8,
answer_ownership_rate=0.6,
citation_accuracy=0.95,
ai_sov=0.35,
competitor_gap=0.05,
)
result = service.diagnose(input_data)
assert result.health_level == "excellent"
def test_dimension_scores_sum(self, service):
"""维度分数求和验证"""
input_data = GEODiagnosisInput(
has_direct_answer=True,
has_brand_definition=True,
)
result = service.diagnose(input_data)
# 各维度分数求和应等于总分
total = sum(dim.score for dim in result.dimensions)
assert result.overall_score == pytest.approx(total, rel=0.01)
def test_recommendations_generated(self, service):
"""推荐生成验证"""
input_data = GEODiagnosisInput()
result = service.diagnose(input_data)
# 所有项都失败时应该有推荐
assert len(result.recommendations) > 0
assert all(r.priority in ["P0", "P1", "P2"] for r in result.recommendations)

View File

@ -0,0 +1,332 @@
import pytest
from app.services.quota_service import (
QuotaService,
QuotaUsage,
QuotaWarning,
QuotaType,
QUOTA_LIMITS,
WARNING_THRESHOLDS,
)
class TestQuotaUsage:
"""额度使用数据结构测试"""
def test_quota_usage_creation(self):
"""测试QuotaUsage数据创建"""
usage = QuotaUsage(
quota_type="api_calls",
used=800,
limit=1000,
usage_percentage=80.0,
status="warning",
remaining=200,
)
assert usage.quota_type == "api_calls"
assert usage.used == 800
assert usage.limit == 1000
assert usage.usage_percentage == 80.0
assert usage.status == "warning"
assert usage.remaining == 200
def test_quota_usage_unlimited(self):
"""测试无限额度QuotaUsage"""
usage = QuotaUsage(
quota_type="api_calls",
used=5000,
limit=-1,
usage_percentage=0.0,
status="unlimited",
remaining=-1,
)
assert usage.limit == -1
assert usage.status == "unlimited"
assert usage.remaining == -1
class TestQuotaWarning:
"""预警数据结构测试"""
def test_quota_warning_creation(self):
"""测试QuotaWarning数据创建"""
warning = QuotaWarning(
quota_type="api_calls",
status="warning",
usage_percentage=80.0,
message="API调用额度已使用80%",
recommended_action="请关注使用情况,考虑升级套餐",
)
assert warning.quota_type == "api_calls"
assert warning.status == "warning"
assert warning.usage_percentage == 80.0
assert "80%" in warning.message
class TestQuotaService:
"""额度预警服务测试"""
@pytest.fixture
def quota_service(self):
"""创建额度服务实例"""
return QuotaService()
def test_calculate_usage_percentage_basic(self, quota_service):
"""测试基础使用率计算"""
percentage = quota_service.calculate_usage_percentage(used=800, limit=1000)
assert percentage == 80.0
def test_calculate_usage_percentage_zero(self, quota_service):
"""测试0%使用率"""
percentage = quota_service.calculate_usage_percentage(used=0, limit=1000)
assert percentage == 0.0
def test_calculate_usage_percentage_full(self, quota_service):
"""测试100%使用率"""
percentage = quota_service.calculate_usage_percentage(used=1000, limit=1000)
assert percentage == 100.0
def test_calculate_usage_percentage_half(self, quota_service):
"""测试50%使用率"""
percentage = quota_service.calculate_usage_percentage(used=500, limit=1000)
assert percentage == 50.0
def test_calculate_usage_percentage_unlimited(self, quota_service):
"""测试无限额度使用率"""
percentage = quota_service.calculate_usage_percentage(used=5000, limit=-1)
assert percentage == 0.0
def test_get_quota_status_ok(self, quota_service):
"""测试正常状态低于80%"""
status = quota_service.get_quota_status(usage_percentage=50.0)
assert status == "ok"
def test_get_quota_status_warning(self, quota_service):
"""测试警告状态80%"""
status = quota_service.get_quota_status(usage_percentage=80.0)
assert status == "warning"
def test_get_quota_status_critical(self, quota_service):
"""测试严重状态90%"""
status = quota_service.get_quota_status(usage_percentage=90.0)
assert status == "critical"
def test_get_quota_status_exhausted(self, quota_service):
"""测试耗尽状态100%"""
status = quota_service.get_quota_status(usage_percentage=100.0)
assert status == "exhausted"
def test_get_quota_status_unlimited(self, quota_service):
"""测试无限状态"""
status = quota_service.get_quota_status(usage_percentage=0.0, limit=-1)
assert status == "unlimited"
def test_get_quota_limit_free_plan(self, quota_service):
"""测试免费套餐额度"""
limit = quota_service.get_quota_limit("free", QuotaType.API_CALLS)
assert limit == 1000
def test_get_quota_limit_basic_plan(self, quota_service):
"""测试基础套餐额度"""
limit = quota_service.get_quota_limit("basic", QuotaType.QUERIES)
assert limit == 500
def test_get_quota_limit_pro_plan(self, quota_service):
"""测试专业套餐额度"""
limit = quota_service.get_quota_limit("pro", QuotaType.CONTENT_GENERATION)
assert limit == 1000
def test_get_quota_limit_unlimited_plan(self, quota_service):
"""测试无限套餐额度"""
limit = quota_service.get_quota_limit("unlimited", QuotaType.API_CALLS)
assert limit == -1
def test_get_remaining_quota(self, quota_service):
"""测试剩余额度计算"""
remaining = quota_service.get_remaining(used=800, limit=1000)
assert remaining == 200
def test_get_remaining_quota_exhausted(self, quota_service):
"""测试额度耗尽"""
remaining = quota_service.get_remaining(used=1000, limit=1000)
assert remaining == 0
def test_get_remaining_quota_unlimited(self, quota_service):
"""测试无限额度"""
remaining = quota_service.get_remaining(used=5000, limit=-1)
assert remaining == -1
def test_get_remaining_quota_overuse(self, quota_service):
"""测试超额使用"""
remaining = quota_service.get_remaining(used=1200, limit=1000)
assert remaining == -200
def test_generate_warning_message_warning(self, quota_service):
"""测试生成警告消息"""
warning = quota_service.generate_warning(
quota_type=QuotaType.API_CALLS,
status="warning",
usage_percentage=80.0,
)
assert warning.quota_type == QuotaType.API_CALLS
assert warning.status == "warning"
assert warning.usage_percentage == 80.0
assert "80%" in warning.message
assert warning.recommended_action != ""
def test_generate_warning_message_critical(self, quota_service):
"""测试生成严重警告消息"""
warning = quota_service.generate_warning(
quota_type=QuotaType.QUERIES,
status="critical",
usage_percentage=90.0,
)
assert warning.status == "critical"
assert "90%" in warning.message
def test_generate_warning_message_exhausted(self, quota_service):
"""测试生成耗尽警告消息"""
warning = quota_service.generate_warning(
quota_type=QuotaType.CONTENT_GENERATION,
status="exhausted",
usage_percentage=100.0,
)
assert warning.status == "exhausted"
assert "100%" in warning.message
def test_check_quota_free_plan(self, quota_service):
"""测试检查免费套餐额度"""
usage = quota_service.check_quota(
plan="free",
quota_type=QuotaType.API_CALLS,
used=800,
)
assert usage.quota_type == QuotaType.API_CALLS
assert usage.limit == 1000
assert usage.used == 800
assert usage.remaining == 200
assert usage.usage_percentage == 80.0
assert usage.status == "warning"
def test_check_quota_pro_plan_warning(self, quota_service):
"""测试专业套餐警告状态"""
usage = quota_service.check_quota(
plan="pro",
quota_type=QuotaType.QUERIES,
used=4500,
)
assert usage.limit == 5000
assert usage.usage_percentage == 90.0
assert usage.status == "critical"
def test_check_quota_unlimited_plan(self, quota_service):
"""测试无限套餐"""
usage = quota_service.check_quota(
plan="unlimited",
quota_type=QuotaType.API_CALLS,
used=100000,
)
assert usage.limit == -1
assert usage.status == "unlimited"
assert usage.usage_percentage == 0.0
assert usage.remaining == -1
def test_check_quota_exhausted(self, quota_service):
"""测试额度耗尽"""
usage = quota_service.check_quota(
plan="free",
quota_type=QuotaType.QUERIES,
used=50,
)
assert usage.limit == 50
assert usage.used == 50
assert usage.remaining == 0
assert usage.usage_percentage == 100.0
assert usage.status == "exhausted"
def test_check_quota_ok(self, quota_service):
"""测试额度正常"""
usage = quota_service.check_quota(
plan="basic",
quota_type=QuotaType.STORAGE,
used=500,
)
assert usage.limit == 1000
assert usage.usage_percentage == 50.0
assert usage.status == "ok"
def test_reset_quota_usage(self, quota_service):
"""测试额度重置"""
usage = quota_service.check_quota(
plan="free",
quota_type=QuotaType.API_CALLS,
used=800,
)
assert usage.used == 800
reset_usage = quota_service.reset_quota(used=800)
assert reset_usage == 0
def test_reset_quota_to_specific_value(self, quota_service):
"""测试重置到特定值"""
reset_usage = quota_service.reset_quota(used=500, reset_to=100)
assert reset_usage == 100
def test_get_all_quota_usage(self, quota_service):
"""测试获取所有额度使用情况"""
all_usage = quota_service.get_all_quota_usage(
plan="free",
api_calls_used=800,
queries_used=40,
content_generation_used=8,
storage_used=50,
)
assert len(all_usage) == 4
assert any(u.quota_type == QuotaType.API_CALLS and u.used == 800 for u in all_usage)
assert any(u.quota_type == QuotaType.QUERIES and u.used == 40 for u in all_usage)
def test_get_warnings_from_usage(self, quota_service):
"""测试从使用情况生成预警"""
all_usage = quota_service.get_all_quota_usage(
plan="free",
api_calls_used=800,
queries_used=45,
content_generation_used=5,
storage_used=20,
)
warnings = quota_service.get_warnings(all_usage)
assert len(warnings) >= 1
assert any(w.status in ["warning", "critical", "exhausted"] for w in warnings)
def test_boundary_value_0_percent(self, quota_service):
"""测试边界值0%"""
usage = quota_service.check_quota(
plan="free",
quota_type=QuotaType.API_CALLS,
used=0,
)
assert usage.usage_percentage == 0.0
assert usage.status == "ok"
def test_boundary_value_100_percent(self, quota_service):
"""测试边界值100%"""
usage = quota_service.check_quota(
plan="free",
quota_type=QuotaType.API_CALLS,
used=1000,
)
assert usage.usage_percentage == 100.0
assert usage.status == "exhausted"
def test_warning_thresholds_constants(self):
"""测试预警阈值常量"""
assert WARNING_THRESHOLDS["warning"] == 0.80
assert WARNING_THRESHOLDS["critical"] == 0.90
assert WARNING_THRESHOLDS["exhausted"] == 1.00
def test_quota_limits_constants(self):
"""测试套餐额度常量"""
assert QUOTA_LIMITS["free"]["api_calls"] == 1000
assert QUOTA_LIMITS["free"]["queries"] == 50
assert QUOTA_LIMITS["basic"]["api_calls"] == 10000
assert QUOTA_LIMITS["pro"]["api_calls"] == 100000
assert QUOTA_LIMITS["unlimited"]["api_calls"] == -1

View File

@ -0,0 +1,844 @@
"""
SEO诊断服务单元测试
"""
import pytest
from app.services.seo_diagnosis import (
SEODiagnosisService,
SEODiagnosisResult,
SEODimensionScore,
DiagnosisItem,
SEORecommendation,
DiagnosisStatus,
DimensionName,
TechnicalSEOData,
OnPageSEOData,
ContentQualityData,
BacklinkData,
UserExperienceData,
diagnose_technical_seo,
diagnose_on_page_seo,
diagnose_content_quality,
diagnose_backlinks,
diagnose_user_experience,
generate_recommendations,
)
class TestDiagnosisStatus:
"""诊断状态枚举测试"""
def test_status_values(self):
"""测试状态值"""
assert DiagnosisStatus.PASS == "pass"
assert DiagnosisStatus.WARNING == "warning"
assert DiagnosisStatus.FAIL == "fail"
class TestDimensionName:
"""维度名称枚举测试"""
def test_dimension_names(self):
"""测试维度名称"""
assert DimensionName.TECHNICAL_SEO == "技术SEO"
assert DimensionName.ON_PAGE_SEO == "页面SEO"
assert DimensionName.CONTENT_QUALITY == "内容质量"
assert DimensionName.BACKLINK_ANALYSIS == "外链分析"
assert DimensionName.USER_EXPERIENCE == "用户体验"
class TestDiagnosisItem:
"""诊断项数据结构测试"""
def test_create_item(self):
"""测试创建诊断项"""
item = DiagnosisItem(
name="测试项",
status=DiagnosisStatus.PASS,
description="测试描述",
suggestion="测试建议",
score=1.0,
)
assert item.name == "测试项"
assert item.status == DiagnosisStatus.PASS
assert item.score == 1.0
def test_item_with_details(self):
"""测试带详情的诊断项"""
item = DiagnosisItem(
name="测试项",
status=DiagnosisStatus.WARNING,
description="测试描述",
suggestion="测试建议",
details={"key": "value"},
)
assert item.details == {"key": "value"}
class TestSEODimensionScore:
"""维度评分数据结构测试"""
def test_create_dimension_score(self):
"""测试创建维度评分"""
dim = SEODimensionScore(
name="测试维度",
score=20.0,
max_score=25.0,
items=[],
status=DiagnosisStatus.PASS,
)
assert dim.score == 20.0
assert dim.max_score == 25.0
assert dim.percentage == 80.0
def test_percentage_calculation(self):
"""测试得分率计算"""
dim = SEODimensionScore(
name="测试维度",
score=15.0,
max_score=25.0,
items=[],
status=DiagnosisStatus.PASS,
)
assert dim.percentage == 60.0
def test_status_calculation_all_pass(self):
"""测试全部通过时的状态"""
items = [
DiagnosisItem(name="项1", status=DiagnosisStatus.PASS, description="", suggestion=""),
DiagnosisItem(name="项2", status=DiagnosisStatus.PASS, description="", suggestion=""),
]
dim = SEODimensionScore(
name="测试维度",
score=10.0,
max_score=10.0,
items=items,
status=DiagnosisStatus.PASS,
)
assert dim.status == DiagnosisStatus.PASS
def test_status_calculation_with_warnings(self):
"""测试有警告时的状态"""
items = [
DiagnosisItem(name="项1", status=DiagnosisStatus.PASS, description="", suggestion=""),
DiagnosisItem(name="项2", status=DiagnosisStatus.WARNING, description="", suggestion=""),
DiagnosisItem(name="项3", status=DiagnosisStatus.WARNING, description="", suggestion=""),
]
dim = SEODimensionScore(
name="测试维度",
score=7.0,
max_score=10.0,
items=items,
status=DiagnosisStatus.PASS,
)
assert dim.status == DiagnosisStatus.WARNING
def test_status_calculation_with_fails(self):
"""测试有失败时的状态"""
items = [
DiagnosisItem(name="项1", status=DiagnosisStatus.FAIL, description="", suggestion=""),
DiagnosisItem(name="项2", status=DiagnosisStatus.PASS, description="", suggestion=""),
DiagnosisItem(name="项3", status=DiagnosisStatus.PASS, description="", suggestion=""),
DiagnosisItem(name="项4", status=DiagnosisStatus.PASS, description="", suggestion=""),
]
dim = SEODimensionScore(
name="测试维度",
score=7.0,
max_score=10.0,
items=items,
status=DiagnosisStatus.PASS,
)
# 1个FAIL在4个项中占25%未超过30%但有FAIL所以是WARNING
assert dim.status == DiagnosisStatus.WARNING
def test_status_calculation_many_fails(self):
"""测试大量失败时的状态"""
items = [
DiagnosisItem(name="项1", status=DiagnosisStatus.FAIL, description="", suggestion=""),
DiagnosisItem(name="项2", status=DiagnosisStatus.FAIL, description="", suggestion=""),
DiagnosisItem(name="项3", status=DiagnosisStatus.FAIL, description="", suggestion=""),
DiagnosisItem(name="项4", status=DiagnosisStatus.PASS, description="", suggestion=""),
]
dim = SEODimensionScore(
name="测试维度",
score=5.0,
max_score=10.0,
items=items,
status=DiagnosisStatus.PASS,
)
assert dim.status == DiagnosisStatus.FAIL
class TestSEODiagnosisResult:
"""诊断结果数据结构测试"""
def test_create_result(self):
"""测试创建诊断结果"""
result = SEODiagnosisResult(
overall_score=75.0,
dimensions=[],
recommendations=[],
)
assert result.overall_score == 75.0
assert result.health_level == "good"
def test_health_level_excellent(self):
"""测试优秀等级"""
result = SEODiagnosisResult(
overall_score=85.0,
dimensions=[],
recommendations=[],
)
assert result.health_level == "excellent"
def test_health_level_good(self):
"""测试良好等级"""
result = SEODiagnosisResult(
overall_score=70.0,
dimensions=[],
recommendations=[],
)
assert result.health_level == "good"
def test_health_level_pass(self):
"""测试及格等级"""
result = SEODiagnosisResult(
overall_score=50.0,
dimensions=[],
recommendations=[],
)
assert result.health_level == "pass"
def test_health_level_danger(self):
"""测试危险等级"""
result = SEODiagnosisResult(
overall_score=30.0,
dimensions=[],
recommendations=[],
)
assert result.health_level == "danger"
def test_score_clamping(self):
"""测试分数限制"""
result = SEODiagnosisResult(
overall_score=150.0,
dimensions=[],
recommendations=[],
)
assert result.overall_score == 100.0
result = SEODiagnosisResult(
overall_score=-10.0,
dimensions=[],
recommendations=[],
)
assert result.overall_score == 0.0
def test_to_dict(self):
"""测试字典转换"""
result = SEODiagnosisResult(
overall_score=75.0,
dimensions=[],
recommendations=[],
)
d = result.to_dict()
assert d["overall_score"] == 75.0
assert d["health_level"] == "good"
assert d["health_level_label"] == "良好"
assert "dimensions" in d
assert "recommendations" in d
class TestTechnicalSEODiagnosis:
"""技术SEO诊断测试"""
def test_perfect_technical_seo(self):
"""测试完美技术SEO"""
data = TechnicalSEOData(
is_indexed=True,
crawl_errors=0,
lcp_seconds=2.0,
fid_ms=50.0,
cls_score=0.05,
has_robots_txt=True,
robots_txt_blocks_important=False,
has_sitemap=True,
sitemap_valid=True,
url_structure_normalized=True,
)
result = diagnose_technical_seo(data)
assert result.score == result.max_score
assert result.status == DiagnosisStatus.PASS
def test_indexed_fail(self):
"""测试未索引情况"""
data = TechnicalSEOData(is_indexed=False)
result = diagnose_technical_seo(data)
assert any(item.status == DiagnosisStatus.FAIL for item in result.items if item.name == "索引状态")
def test_crawl_errors_warning(self):
"""测试少量爬取错误"""
data = TechnicalSEOData(crawl_errors=3)
result = diagnose_technical_seo(data)
crawl_item = next(item for item in result.items if item.name == "爬取错误")
assert crawl_item.status == DiagnosisStatus.WARNING
def test_crawl_errors_fail(self):
"""测试大量爬取错误"""
data = TechnicalSEOData(crawl_errors=10)
result = diagnose_technical_seo(data)
crawl_item = next(item for item in result.items if item.name == "爬取错误")
assert crawl_item.status == DiagnosisStatus.FAIL
def test_core_web_vitals_pass(self):
"""测试Core Web Vitals通过"""
data = TechnicalSEOData(
lcp_seconds=2.0,
fid_ms=80.0,
cls_score=0.05,
)
result = diagnose_technical_seo(data)
cwv_items = [item for item in result.items if item.name in ["LCP", "FID", "CLS"]]
assert all(item.status == DiagnosisStatus.PASS for item in cwv_items)
def test_core_web_vitals_warning(self):
"""测试Core Web Vitals警告"""
data = TechnicalSEOData(
lcp_seconds=3.0,
fid_ms=200.0,
cls_score=0.15,
)
result = diagnose_technical_seo(data)
cwv_items = [item for item in result.items if item.name in ["LCP", "FID", "CLS"]]
assert any(item.status == DiagnosisStatus.WARNING for item in cwv_items)
def test_core_web_vitals_fail(self):
"""测试Core Web Vitals失败"""
data = TechnicalSEOData(
lcp_seconds=5.0,
fid_ms=400.0,
cls_score=0.3,
)
result = diagnose_technical_seo(data)
cwv_items = [item for item in result.items if item.name in ["LCP", "FID", "CLS"]]
assert all(item.status == DiagnosisStatus.FAIL for item in cwv_items)
def test_robots_txt_blocks_important(self):
"""测试robots.txt阻止重要页面"""
data = TechnicalSEOData(
has_robots_txt=True,
robots_txt_blocks_important=True,
)
result = diagnose_technical_seo(data)
robots_item = next(item for item in result.items if item.name == "robots.txt")
assert robots_item.status == DiagnosisStatus.FAIL
def test_missing_sitemap(self):
"""测试缺少sitemap"""
data = TechnicalSEOData(has_sitemap=False)
result = diagnose_technical_seo(data)
sitemap_item = next(item for item in result.items if item.name == "sitemap")
assert sitemap_item.status == DiagnosisStatus.FAIL
class TestOnPageSEODiagnosis:
"""页面SEO诊断测试"""
def test_perfect_on_page_seo(self):
"""测试完美页面SEO"""
data = OnPageSEOData(
has_title=True,
title_length=50,
title_keyword_stuffing=False,
has_meta_description=True,
meta_description_length=140,
h1_count=1,
h_structure_valid=True,
keyword_density=2.0,
internal_links=10,
broken_internal_links=0,
images_without_alt=0,
total_images=5,
)
result = diagnose_on_page_seo(data)
assert result.score == result.max_score
assert result.status == DiagnosisStatus.PASS
def test_missing_title(self):
"""测试缺少Title"""
data = OnPageSEOData(has_title=False)
result = diagnose_on_page_seo(data)
title_item = next(item for item in result.items if item.name == "Title标签")
assert title_item.status == DiagnosisStatus.FAIL
def test_title_too_long(self):
"""测试Title过长"""
data = OnPageSEOData(title_length=80)
result = diagnose_on_page_seo(data)
title_item = next(item for item in result.items if item.name == "Title标签")
assert title_item.status == DiagnosisStatus.WARNING
def test_keyword_stuffing(self):
"""测试关键词堆砌"""
data = OnPageSEOData(title_keyword_stuffing=True)
result = diagnose_on_page_seo(data)
title_item = next(item for item in result.items if item.name == "Title标签")
assert title_item.status == DiagnosisStatus.WARNING
def test_multiple_h1(self):
"""测试多个H1"""
data = OnPageSEOData(h1_count=3)
result = diagnose_on_page_seo(data)
h_item = next(item for item in result.items if item.name == "H标签结构")
assert h_item.status == DiagnosisStatus.WARNING
def test_broken_links_warning(self):
"""测试少量死链"""
data = OnPageSEOData(broken_internal_links=2)
result = diagnose_on_page_seo(data)
link_item = next(item for item in result.items if item.name == "内链结构")
assert link_item.status == DiagnosisStatus.WARNING
def test_broken_links_fail(self):
"""测试大量死链"""
data = OnPageSEOData(broken_internal_links=10)
result = diagnose_on_page_seo(data)
link_item = next(item for item in result.items if item.name == "内链结构")
assert link_item.status == DiagnosisStatus.FAIL
def test_images_without_alt(self):
"""测试图片缺少Alt"""
data = OnPageSEOData(
images_without_alt=3,
total_images=5,
)
result = diagnose_on_page_seo(data)
alt_item = next(item for item in result.items if item.name == "图片Alt文本")
assert alt_item.status == DiagnosisStatus.FAIL
class TestContentQualityDiagnosis:
"""内容质量诊断测试"""
def test_perfect_content_quality(self):
"""测试完美内容质量"""
data = ContentQualityData(
readability_score=80.0,
word_count=2000,
topic_coverage=0.9,
has_author_info=True,
has_publication_date=True,
last_updated_days=10,
has_citations=True,
citation_authority=0.9,
duplicate_content_ratio=0.02,
has_expert_review=True,
)
result = diagnose_content_quality(data)
assert result.score == result.max_score
assert result.status == DiagnosisStatus.PASS
def test_low_readability(self):
"""测试低可读性"""
data = ContentQualityData(readability_score=40.0)
result = diagnose_content_quality(data)
readability_item = next(item for item in result.items if item.name == "可读性")
assert readability_item.status == DiagnosisStatus.FAIL
def test_shallow_content(self):
"""测试内容深度不足"""
data = ContentQualityData(
word_count=500,
topic_coverage=0.4,
)
result = diagnose_content_quality(data)
depth_item = next(item for item in result.items if item.name == "信息深度")
assert depth_item.status == DiagnosisStatus.FAIL
def test_missing_author(self):
"""测试缺少作者信息"""
data = ContentQualityData(has_author_info=False)
result = diagnose_content_quality(data)
author_item = next(item for item in result.items if item.name == "作者资质")
assert author_item.status == DiagnosisStatus.WARNING
def test_stale_content(self):
"""测试过时内容"""
data = ContentQualityData(last_updated_days=200)
result = diagnose_content_quality(data)
freshness_item = next(item for item in result.items if item.name == "内容新鲜度")
assert freshness_item.status == DiagnosisStatus.FAIL
def test_high_duplicate_ratio(self):
"""测试高重复内容比例"""
data = ContentQualityData(duplicate_content_ratio=0.5)
result = diagnose_content_quality(data)
duplicate_item = next(item for item in result.items if item.name == "重复内容")
assert duplicate_item.status == DiagnosisStatus.FAIL
class TestBacklinkDiagnosis:
"""外链分析诊断测试"""
def test_perfect_backlinks(self):
"""测试完美外链"""
data = BacklinkData(
total_backlinks=200,
referring_domains=50,
high_authority_links=20,
toxic_links=0,
nofollow_ratio=0.3,
anchor_text_diversity=0.9,
exact_match_anchor_ratio=0.1,
)
result = diagnose_backlinks(data)
assert result.score == result.max_score
assert result.status == DiagnosisStatus.PASS
def test_few_referring_domains(self):
"""测试引用域名少"""
data = BacklinkData(referring_domains=5)
result = diagnose_backlinks(data)
domain_item = next(item for item in result.items if item.name == "引用域名")
assert domain_item.status == DiagnosisStatus.FAIL
def test_toxic_links_warning(self):
"""测试少量毒性链接"""
data = BacklinkData(
total_backlinks=100,
toxic_links=3,
)
result = diagnose_backlinks(data)
toxic_item = next(item for item in result.items if item.name == "毒性链接")
assert toxic_item.status == DiagnosisStatus.WARNING
def test_toxic_links_fail(self):
"""测试大量毒性链接"""
data = BacklinkData(
total_backlinks=50,
toxic_links=10,
)
result = diagnose_backlinks(data)
toxic_item = next(item for item in result.items if item.name == "毒性链接")
assert toxic_item.status == DiagnosisStatus.FAIL
def test_low_anchor_diversity(self):
"""测试锚文本多样性低"""
data = BacklinkData(
anchor_text_diversity=0.3,
exact_match_anchor_ratio=0.6,
)
result = diagnose_backlinks(data)
anchor_item = next(item for item in result.items if item.name == "锚文本分布")
assert anchor_item.status == DiagnosisStatus.FAIL
class TestUserExperienceDiagnosis:
"""用户体验诊断测试"""
def test_perfect_ux(self):
"""测试完美用户体验"""
data = UserExperienceData(
is_mobile_friendly=True,
mobile_viewport_set=True,
page_load_time=1.5,
has_https=True,
has_breadcrumbs=True,
conversion_path_clear=True,
has_cta=True,
form_usability=0.95,
has_search=True,
)
result = diagnose_user_experience(data)
assert result.score == result.max_score
assert result.status == DiagnosisStatus.PASS
def test_not_mobile_friendly(self):
"""测试不移动友好"""
data = UserExperienceData(is_mobile_friendly=False)
result = diagnose_user_experience(data)
mobile_item = next(item for item in result.items if item.name == "移动适配")
assert mobile_item.status == DiagnosisStatus.FAIL
def test_slow_page_load(self):
"""测试页面加载慢"""
data = UserExperienceData(page_load_time=5.0)
result = diagnose_user_experience(data)
speed_item = next(item for item in result.items if item.name == "页面速度")
assert speed_item.status == DiagnosisStatus.FAIL
def test_missing_https(self):
"""测试缺少HTTPS"""
data = UserExperienceData(has_https=False)
result = diagnose_user_experience(data)
https_item = next(item for item in result.items if item.name == "HTTPS")
assert https_item.status == DiagnosisStatus.FAIL
def test_missing_cta(self):
"""测试缺少CTA"""
data = UserExperienceData(has_cta=False)
result = diagnose_user_experience(data)
cta_item = next(item for item in result.items if item.name == "CTA")
assert cta_item.status == DiagnosisStatus.WARNING
class TestRecommendations:
"""优化建议生成测试"""
def test_generate_recommendations(self):
"""测试建议生成"""
result = SEODiagnosisResult(
overall_score=60.0,
dimensions=[
SEODimensionScore(
name="测试维度",
score=10.0,
max_score=20.0,
items=[
DiagnosisItem(
name="失败项",
status=DiagnosisStatus.FAIL,
description="描述",
suggestion="修复建议",
),
DiagnosisItem(
name="警告项",
status=DiagnosisStatus.WARNING,
description="描述",
suggestion="优化建议",
),
DiagnosisItem(
name="通过项",
status=DiagnosisStatus.PASS,
description="描述",
suggestion="保持",
),
],
status=DiagnosisStatus.WARNING,
),
],
recommendations=[],
)
recommendations = generate_recommendations(result)
assert len(recommendations) == 2
assert recommendations[0].priority == "high"
assert recommendations[1].priority == "medium"
def test_recommendations_sorted_by_priority(self):
"""测试建议按优先级排序"""
result = SEODiagnosisResult(
overall_score=50.0,
dimensions=[
SEODimensionScore(
name="维度1",
score=5.0,
max_score=10.0,
items=[
DiagnosisItem(
name="警告项",
status=DiagnosisStatus.WARNING,
description="",
suggestion="",
),
],
status=DiagnosisStatus.WARNING,
),
SEODimensionScore(
name="维度2",
score=5.0,
max_score=10.0,
items=[
DiagnosisItem(
name="失败项",
status=DiagnosisStatus.FAIL,
description="",
suggestion="",
),
],
status=DiagnosisStatus.WARNING,
),
],
recommendations=[],
)
recommendations = generate_recommendations(result)
assert recommendations[0].priority == "high"
assert recommendations[1].priority == "medium"
class TestSEODiagnosisService:
"""SEO诊断服务测试"""
@pytest.fixture
def service(self):
"""创建诊断服务实例"""
return SEODiagnosisService()
def test_full_diagnosis_with_defaults(self, service):
"""测试使用默认数据的完整诊断"""
result = service.diagnose()
assert isinstance(result, SEODiagnosisResult)
assert 0 <= result.overall_score <= 100
assert len(result.dimensions) == 5
assert isinstance(result.recommendations, list)
def test_diagnosis_returns_all_dimensions(self, service):
"""测试诊断返回所有维度"""
result = service.diagnose()
dimension_names = [dim.name for dim in result.dimensions]
assert "技术SEO" in dimension_names
assert "页面SEO" in dimension_names
assert "内容质量" in dimension_names
assert "外链分析" in dimension_names
assert "用户体验" in dimension_names
def test_diagnosis_with_custom_data(self, service):
"""测试使用自定义数据的诊断"""
technical_data = TechnicalSEOData(
is_indexed=False,
crawl_errors=10,
)
result = service.diagnose(technical_data=technical_data)
assert result.overall_score < 100
def test_diagnose_technical_only(self, service):
"""测试仅技术SEO诊断"""
result = service.diagnose_technical_only()
assert isinstance(result, SEODimensionScore)
assert result.name == DimensionName.TECHNICAL_SEO
def test_diagnose_on_page_only(self, service):
"""测试仅页面SEO诊断"""
result = service.diagnose_on_page_only()
assert isinstance(result, SEODimensionScore)
assert result.name == DimensionName.ON_PAGE_SEO
def test_diagnose_content_only(self, service):
"""测试仅内容质量诊断"""
result = service.diagnose_content_only()
assert isinstance(result, SEODimensionScore)
assert result.name == DimensionName.CONTENT_QUALITY
def test_diagnose_backlinks_only(self, service):
"""测试仅外链分析"""
result = service.diagnose_backlinks_only()
assert isinstance(result, SEODimensionScore)
assert result.name == DimensionName.BACKLINK_ANALYSIS
def test_diagnose_ux_only(self, service):
"""测试仅用户体验诊断"""
result = service.diagnose_ux_only()
assert isinstance(result, SEODimensionScore)
assert result.name == DimensionName.USER_EXPERIENCE
def test_diagnosis_with_poor_data(self, service):
"""测试使用差数据的诊断"""
technical_data = TechnicalSEOData(
is_indexed=False,
crawl_errors=20,
lcp_seconds=6.0,
fid_ms=500.0,
cls_score=0.4,
has_robots_txt=False,
has_sitemap=False,
)
on_page_data = OnPageSEOData(
has_title=False,
has_meta_description=False,
h1_count=0,
broken_internal_links=10,
)
content_data = ContentQualityData(
readability_score=30.0,
word_count=200,
last_updated_days=365,
duplicate_content_ratio=0.6,
)
backlink_data = BacklinkData(
referring_domains=2,
toxic_links=20,
anchor_text_diversity=0.2,
)
ux_data = UserExperienceData(
is_mobile_friendly=False,
page_load_time=8.0,
has_https=False,
)
result = service.diagnose(
technical_data=technical_data,
on_page_data=on_page_data,
content_data=content_data,
backlink_data=backlink_data,
ux_data=ux_data,
)
assert result.overall_score < 30
assert result.health_level == "danger"
assert len(result.recommendations) > 0
assert any(rec.priority == "high" for rec in result.recommendations)
def test_diagnosis_with_excellent_data(self, service):
"""测试使用优秀数据的诊断"""
technical_data = TechnicalSEOData(
is_indexed=True,
crawl_errors=0,
lcp_seconds=1.5,
fid_ms=50.0,
cls_score=0.03,
)
on_page_data = OnPageSEOData(
title_length=45,
meta_description_length=140,
keyword_density=2.0,
)
content_data = ContentQualityData(
readability_score=85.0,
word_count=2500,
topic_coverage=0.95,
has_expert_review=True,
last_updated_days=5,
)
backlink_data = BacklinkData(
referring_domains=50,
high_authority_links=20,
toxic_links=0,
anchor_text_diversity=0.9,
)
ux_data = UserExperienceData(
page_load_time=1.2,
form_usability=0.95,
)
result = service.diagnose(
technical_data=technical_data,
on_page_data=on_page_data,
content_data=content_data,
backlink_data=backlink_data,
ux_data=ux_data,
)
assert result.overall_score >= 80
assert result.health_level == "excellent"
def test_result_to_dict_format(self, service):
"""测试结果字典格式"""
result = service.diagnose()
d = result.to_dict()
assert "overall_score" in d
assert "health_level" in d
assert "health_level_label" in d
assert "dimensions" in d
assert "recommendations" in d
assert isinstance(d["dimensions"], list)
assert isinstance(d["recommendations"], list)
if d["dimensions"]:
dim = d["dimensions"][0]
assert "name" in dim
assert "score" in dim
assert "max_score" in dim
assert "percentage" in dim
assert "status" in dim
assert "items" in dim

3
backend/uv.lock Normal file
View File

@ -0,0 +1,3 @@
version = 1
revision = 3
requires-python = ">=3.14"

View File

@ -1,156 +0,0 @@
# GEO 平台 - 项目总览
## 项目定位
**GEOGenerative Engine Optimization平台** 是一款面向企业级客户的 **AI 搜索引擎品牌曝光度优化 SaaS 平台**
随着生成式 AI 搜索引擎(如 ChatGPT、Perplexity、Google SGE、Kimi 等)的崛起,传统 SEO 已无法满足品牌在 AI 生成答案中的可见性需求。GEO 平台通过系统化的诊断、策略制定、内容生产、分发执行和监测优化,帮助企业在 AI 搜索引擎中获得更高的品牌引用率和曝光度。
## 核心业务价值
- **AI 引用检测**:自动扫描主流 AI 平台对品牌的引用情况,识别引用、未引用和竞品引用场景
- **策略智能生成**:基于诊断结果,自动生成 GEO 优化策略和内容生产计划
- **内容自动化生产**:利用 AI Agent 框架自动生成符合 GEO 优化标准的内容资产
- **多渠道分发执行**:将内容分发至目标平台,并跟踪分发效果
- **持续监测优化**:建立监测闭环,持续追踪品牌在 AI 搜索引擎中的表现变化
## 业务生命周期5 个阶段)
GEO 平台的业务运营遵循完整的 5 阶段生命周期:
| 阶段 | 名称 | 核心目标 | 关键动作 |
|------|------|----------|----------|
| Stage 1 | 诊断分析 | 了解品牌在 AI 搜索中的现状 | 查询执行、引用检测、竞品分析 |
| Stage 2 | 策略制定 | 制定 GEO 优化策略 | 策略生成、规则制定、计划排期 |
| Stage 3 | 内容生产 | 生成符合 GEO 标准的内容 | 内容生成、质量检查、素材准备 |
| Stage 4 | 分发执行 | 将内容分发至目标渠道 | 渠道管理、内容发布、效果追踪 |
| Stage 5 | 监测优化 | 持续监测并优化效果 | 性能追踪、报告生成、策略迭代 |
## 双运营模式
GEO 平台支持两种运营模式,满足不同客户的需求:
### 模式一自主订阅SaaS 自助)
- **目标客户**:中小企业、有自有运营团队的品牌方
- **核心特征**
- 客户自主注册订阅
- 通过平台自助完成全生命周期操作
- 按订阅等级Basic / Pro / Enterprise享受不同功能权限
- 支持多用户团队协作
- **收费模式**:月度/年度订阅制
### 模式二:代理运营(全托管服务)
- **目标客户**:大型企业、需要专业团队代运营的品牌方
- **核心特征**
- 由专业运营团队代替客户执行 GEO 优化
- 客户通过代理模式接入,无需直接操作系统
- 运营团队使用平台全部功能为客户服务
- 提供定制化策略和专属服务
- **收费模式**:项目制或月度服务费
### 双模式权限设计
| 功能模块 | 自主订阅 | 代理运营 |
|----------|----------|----------|
| 查询管理 | 有(按订阅等级限额) | 有(无限制) |
| 引用检测 | 有 | 有 |
| 策略生成 | 有 | 有 |
| 内容生产 | 有AI 生成限额) | 有 |
| 分发执行 | 有 | 有 |
| 监测报告 | 有 | 有 |
| 多用户管理 | Enterprise 支持 | 支持 |
| 代理客户管理 | 无 | 有 |
| 白标报告 | 无 | 有 |
## 技术栈概要
### 前端
- **框架**Next.js 15 + React 19 + TypeScript
- **样式**Tailwind CSS 4 + shadcn/ui
- **状态管理**React Server Components + SWR
- **认证**NextAuth.js v5
- **图表**Recharts
### 后端
- **框架**FastAPI + Python 3.12
- **数据库**PostgreSQL + SQLAlchemy 2.0
- **ORM**SQLAlchemy 2.0 + Alembic 迁移
- **异步**Celery + Redis 任务队列
- **认证**JWT + OAuth2
### AI Agent 框架
- **架构**:模块化 Agent 注册机制
- **通信**:基于消息队列的异步通信协议
- **核心 Agent**CitationDetector / ContentGenerator / RuleChecker / CompetitorAnalyzer / PerformanceTracker
### 基础设施
- **容器化**Docker + Docker Compose
- **部署**:云服务器 + Nginx 反向代理
- **监控**:日志系统 + 健康检查
## 快速开始指引
### 环境准备
1. 确保已安装 Docker 和 Docker Compose
2. 克隆项目仓库到本地
3. 复制 `.env.example``.env` 并配置必要的环境变量
### 启动开发环境
```bash
# 启动全部服务(前端 + 后端 + 数据库 + Redis
docker-compose up -d
# 后端数据库迁移
cd backend
alembic upgrade head
# 启动前端开发服务器
cd frontend
npm run dev
```
### 默认访问地址
- 前端应用:`http://localhost:3000`
- 后端 API`http://localhost:8000`
- API 文档:`http://localhost:8000/docs`
### 初始账号
系统初始化后,可通过注册功能创建首个管理员账号,或联系运维团队获取默认管理员凭据。
## 文档导航
| 目录 | 说明 |
|------|------|
| `00-project/` | 项目概述、架构设计、技术栈说明 |
| `01-requirements/` | 业务需求、功能清单、用户故事 |
| `02-design/` | UI 设计、数据库设计、Agent 框架设计 |
| `03-development/` | 开发规范、TDD 流程、模块指南 |
| `04-testing/` | 测试策略、测试计划、测试报告 |
| `05-deployment/` | 部署指南、Docker 配置、环境配置 |
| `06-progress/` | 迭代计划、任务看板、周报 |
## 项目里程碑
| 阶段 | 目标 | 状态 |
|------|------|------|
| Phase 0 | 基础设施搭建 + 文档体系建立 | 进行中 |
| Phase 1 | Stage 1 诊断分析(查询+检测)| 规划中 |
| Phase 2 | Stage 2 策略制定 + Stage 3 内容生产 | 规划中 |
| Phase 3 | Stage 4 分发执行 + Stage 5 监测优化 | 规划中 |
| Phase 4 | 代理运营模式 + 管理后台 | 规划中 |
| Phase 5 | 性能优化 + 稳定性提升 | 规划中 |
| Phase 6 | 生产环境部署 + 正式发布 | 规划中 |
---
*GEO 平台 - 让品牌在 AI 时代被看见。*

View File

@ -1,336 +0,0 @@
# GEO 平台 - 系统架构设计
## 整体架构
GEO 平台采用分层架构设计,由以下四层组成:
```
┌─────────────────────────────────────────────────────────────────┐
│ Presentation Layer │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────────────────┐ │
│ │ Next.js │ │ React 19 │ │ Tailwind CSS + shadcn │ │
│ │ App Router │ │ Components │ │ UI Components │ │
│ └─────────────┘ └─────────────┘ └─────────────────────────┘ │
├─────────────────────────────────────────────────────────────────┤
│ API Gateway Layer │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────────────────┐ │
│ │ FastAPI │ │ REST API │ │ JWT / OAuth2 │ │
│ │ Routers │ │ Endpoints │ │ Authentication │ │
│ └─────────────┘ └─────────────┘ └─────────────────────────┘ │
├─────────────────────────────────────────────────────────────────┤
│ Service & Agent Layer │
│ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────────────┐ │
│ │ Citation │ │ Content │ │ Rule │ │ Competitor │ │
│ │ Detector │ │Generator │ │ Checker │ │ Analyzer │ │
│ └──────────┘ └──────────┘ └──────────┘ └──────────────────┘ │
│ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────────────┐ │
│ │Performance│ │ Task │ │ Query │ │ Report │ │
│ │ Tracker │ │ Scheduler│ │ Engine │ │ Generator │ │
│ └──────────┘ └──────────┘ └──────────┘ └──────────────────┘ │
├─────────────────────────────────────────────────────────────────┤
│ Infrastructure Layer │
│ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌────────────────┐ │
│ │PostgreSQL│ │ Redis │ │ Celery │ │ Docker │ │
│ │ (Data) │ │ (Cache) │ │ (Queue) │ │ (Container) │ │
│ └──────────┘ └──────────┘ └──────────┘ └────────────────┘ │
└─────────────────────────────────────────────────────────────────┘
```
### 四层架构说明
#### 1. 表现层Presentation Layer
- **Next.js App Router**基于文件系统的路由支持服务端渲染SSR和静态生成SSG
- **React 19**:最新 React 特性,包括 Server Components、Actions、Form 状态管理
- **Tailwind CSS + shadcn/ui**:原子化 CSS + 可复用无头组件库,支持快速构建一致化 UI
**关键设计决策**
- 采用 BFFBackend for Frontend模式前端通过 API Client 与后端通信
- 认证状态使用 NextAuth.js 管理,支持 Credentials + OAuth 双模式
- 数据获取优先使用 Server Components 减少客户端 JavaScript 体积
#### 2. API 网关层API Gateway Layer
- **FastAPI**:高性能 Python Web 框架,原生支持异步和自动 API 文档
- **RESTful API**:标准的 REST 接口设计,返回 JSON 格式数据
- **认证授权**JWT Token + OAuth2 密码流,支持角色权限控制
**模块划分**
| 模块 | 职责 | 对应文件 |
|------|------|----------|
| Auth API | 用户注册、登录、Token 刷新、权限验证 | `api/auth.py` |
| Query API | 查询创建、执行、结果获取 | `api/queries.py` |
| Citation API | 引用记录管理、统计分析 | `api/citations.py` |
| Report API | 报告生成、导出、查看 | `api/reports.py` |
| Admin API | 用户管理、系统配置、订阅管理 | `api/admin.py` |
| Subscription API | 订阅计划、支付、限额管理 | `api/subscriptions.py` |
#### 3. 服务与 Agent 层Service & Agent Layer
本层是 GEO 平台的核心业务逻辑层,包含传统 Service 和 AI Agent 两大类组件。
**传统 Service**
| Service | 职责 |
|---------|------|
| QueryService | 查询生命周期管理、AI 平台调用、结果聚合 |
| CitationService | 引用检测算法执行、引用记录存储、置信度评估 |
| ReportService | 报告模板渲染、数据聚合、多种格式导出 |
| UserService | 用户 CRUD、角色权限、订阅状态管理 |
| SubscriptionService | 订阅计划管理、配额控制、升级降级 |
**AI Agent 集群**:详见 `docs/02-design/agent-framework.md`
#### 4. 基础设施层Infrastructure Layer
| 组件 | 用途 | 技术选型 |
|------|------|----------|
| PostgreSQL | 主数据库,存储业务数据 | PostgreSQL 15+ |
| Redis | 缓存、Celery Broker、会话存储 | Redis 7+ |
| Celery | 异步任务队列、定时任务调度 | Celery 5+ |
| Docker | 应用容器化部署 | Docker + Compose |
## 模块解耦原则
### 1. 领域驱动设计DDD
- **领域模型独立**核心业务模型Query、Citation、User、Subscription不依赖任何框架
- **边界上下文清晰**:每个模块维护自己的数据模型和业务规则
- **防腐层ACL**:外部 AI 平台适配器通过防腐层与核心领域隔离
### 2. 依赖倒置原则
```
API Layer → Service Layer → Repository Layer → Database
↑ ↑ ↑
依赖接口 依赖接口 依赖接口
```
- Service 层定义接口Repository 层实现接口
- 便于单元测试时 Mock 依赖
- 支持未来数据库迁移或替换
### 3. 配置与代码分离
- 所有环境相关配置通过环境变量注入
- 应用配置集中管理在 `app/config.py`
- 不同环境dev / staging / prod使用不同的 `.env` 文件
### 4. 异步解耦
- 耗时操作AI 查询执行、报告生成)通过 Celery 异步处理
- 前端通过轮询或 WebSocket 获取任务状态更新
- 任务执行状态持久化,支持断点续传和失败重试
## AI Agent 解耦通信协议设计
### 通信模型
GEO 平台的 AI Agent 之间采用 **基于消息队列的异步通信协议**,确保 Agent 之间松耦合、可独立部署和扩展。
```
┌─────────────┐ ┌─────────────┐ ┌─────────────────┐
│ Agent A │────▶│ Message │────▶│ Agent B │
│ (生产者) │ │ Queue │ │ (消费者) │
└─────────────┘ └─────────────┘ └─────────────────┘
┌─────────────┐
│ Registry │
│ (注册中心) │
└─────────────┘
```
### 消息格式
```json
{
"message_id": "uuid",
"timestamp": "2026-01-01T00:00:00Z",
"sender": "agent_name",
"recipient": "agent_name_or_broadcast",
"message_type": "task_request | task_result | status_update | heartbeat",
"payload": {
"task_id": "uuid",
"task_type": "detect_citations | generate_content | check_rules | analyze_competitors | track_performance",
"parameters": {},
"data": {},
"priority": 1
},
"correlation_id": "uuid",
"reply_to": "agent_name"
}
```
### 通信模式
| 模式 | 说明 | 适用场景 |
|------|------|----------|
| 点对点 | Agent A 直接发送任务给 Agent B | 明确的一对一任务分配 |
| 广播 | Agent 向所有订阅者发送消息 | 状态更新、配置变更通知 |
| 发布/订阅 | Agent 向特定主题发布消息 | 事件驱动架构 |
| 请求/响应 | 同步等待返回结果 | 需要即时反馈的短任务 |
### Agent 注册与发现
- **注册中心**Redis 作为轻量级注册中心
- **注册机制**Agent 启动时向注册中心注册自己的能力和状态
- **健康检查**:定期发送心跳,失效 Agent 自动从注册中心移除
- **负载均衡**:任务分发时根据 Agent 负载状态选择执行节点
### 错误处理与重试
- 消息消费失败时自动进入死信队列DLQ
- 支持配置重试次数和重试间隔(指数退避)
- 失败任务可手动重新触发或自动补偿
## 双模式权限设计
### 用户角色模型
```
┌─────────────┐
│ User │
└──────┬──────┘
┌───────────────┼───────────────┐
▼ ▼ ▼
┌────────────┐ ┌────────────┐ ┌────────────┐
│ EndUser │ │ Admin │ │ Agent │
│ (终端用户) │ │ (管理员) │ │ (运营人员) │
└─────┬──────┘ └────────────┘ └─────┬──────┘
│ │
▼ ▼
┌──────────────┐ ┌──────────────┐
│ Self-Service │ │ Proxy-Service│
│ 自主订阅 │ │ 代理运营 │
└──────────────┘ └──────────────┘
```
### 权限矩阵
| 功能 | EndUser (Basic) | EndUser (Pro) | EndUser (Enterprise) | Admin | Agent |
|------|----------------|---------------|---------------------|-------|-------|
| 创建查询 | 10/月 | 50/月 | 无限制 | 无限制 | 无限制 |
| 引用检测 | 有 | 有 | 有 | 有 | 有 |
| 策略生成 | 基础 | 高级 | 定制化 | 全部 | 全部 |
| 内容生成 | 5/月 | 20/月 | 无限制 | 无限制 | 无限制 |
| 报告导出 | PDF | PDF + Excel | 全部格式 | 全部格式 | 全部格式 |
| 团队管理 | 无 | 无 | 有 | 有 | 有 |
| 代理客户管理 | 无 | 无 | 无 | 有 | 有 |
| 白标报告 | 无 | 无 | 无 | 有 | 有 |
| 系统配置 | 无 | 无 | 无 | 有 | 无 |
| 用户管理 | 无 | 无 | 无 | 有 | 无 |
### 权限实现
- **基于角色的访问控制RBAC**:用户通过角色获得一组权限
- **基于属性的访问控制ABAC**:结合用户订阅等级、团队归属等属性进行细粒度控制
- **中间件校验**FastAPI 依赖注入实现统一的权限校验中间件
- **前端联动**:前端路由和组件根据用户权限动态渲染
### API 权限控制示例
```python
# FastAPI 依赖注入实现权限控制
async def require_admin(current_user: User = Depends(get_current_user)):
if not current_user.is_admin:
raise HTTPException(status_code=403, detail="需要管理员权限")
return current_user
async def require_subscription(min_tier: str):
async def checker(current_user: User = Depends(get_current_user)):
if not current_user.subscription.meets(min_tier):
raise HTTPException(status_code=403, detail="订阅等级不足")
return current_user
return checker
# 路由使用
@router.post("/queries")
async def create_query(
data: QueryCreate,
user: User = Depends(require_subscription("pro"))
):
...
```
## 数据流架构
### 查询执行数据流
```
用户创建查询 ──▶ API 接收 ──▶ 验证权限/配额 ──▶ 创建查询记录
提交 Celery 任务
┌─────────────────┐
│ Worker 执行查询 │
│ - 调用 AI 平台 │
│ - 聚合响应结果 │
└────────┬────────┘
┌─────────────────┐
│ CitationDetector │
│ - 解析引用内容 │
│ - 评估置信度 │
└────────┬────────┘
保存引用记录 ──▶ 更新查询状态
通知前端完成
```
### 报告生成数据流
```
用户请求报告 ──▶ API 接收 ──▶ 收集相关数据
ReportService 聚合数据
应用报告模板
生成 PDF/Excel/HTML
返回下载链接
```
## 扩展性设计
### 水平扩展
- **无状态服务**API 服务无状态设计,可通过负载均衡水平扩展
- **数据库读写分离**:未来支持主从复制,读操作分散到从库
- **缓存层**Redis 缓存热点数据,减少数据库压力
### AI 平台适配器扩展
GEO 平台需要对接多种 AI 搜索引擎平台,适配器设计遵循开闭原则:
```
┌─────────────────────────────────────────┐
│ PlatformAdapter (抽象基类) │
│ - execute_query() │
│ - parse_response() │
│ - validate_config() │
└─────────────────────────────────────────┘
│ │ │
▼ ▼ ▼
┌──────────┐ ┌──────────┐ ┌──────────┐
│ ChatGPT │ │ Kimi │ │ Perplex │
│ Adapter │ │ Adapter │ │ Adapter │
└──────────┘ └──────────┘ └──────────┘
```
新增 AI 平台只需实现 `PlatformAdapter` 接口,无需修改现有代码。
---
*本文档描述 GEO 平台的整体系统架构设计,详细模块设计请参考各子系统设计文档。*

View File

@ -1,44 +0,0 @@
# GEO 平台 - 更新日志
## 概述
本文档记录 GEO 平台的所有版本更新内容,按时间倒序排列。
> **TODO**: 本文档为占位文件,待补充完整内容。
## 版本规范
采用 [语义化版本](https://semver.org/lang/zh-CN/) 规范:`主版本号.次版本号.修订号`
- **主版本号**:不兼容的 API 修改
- **次版本号**:向下兼容的功能性新增
- **修订号**:向下兼容的问题修正
## 更新记录
### [Unreleased]
#### 新增
- [ ] 项目文档体系建立
- [ ] 基础架构搭建
#### 变更
- 无
#### 修复
- 无
### [0.1.0] - 待发布
#### 新增
- [ ] TODO: 待填充 Phase 1 功能清单
#### 变更
- 无
#### 修复
- 无
---
*本文档待补充,每次版本发布时更新。*

View File

@ -1,70 +0,0 @@
# GEO 平台 - 技术栈说明
## 概述
本文档详细说明 GEO 平台采用的技术栈及各技术选型的理由。
> **TODO**: 本文档为占位文件,待补充完整内容。
## 前端技术栈
### 待补充内容
- [ ] Next.js 15 核心特性与选型理由
- [ ] React 19 新特性应用
- [ ] TypeScript 类型系统实践
- [ ] Tailwind CSS 4 原子化样式方案
- [ ] shadcn/ui 组件库定制
- [ ] NextAuth.js v5 认证方案
- [ ] SWR 数据获取策略
- [ ] Recharts 图表实现
## 后端技术栈
### 待补充内容
- [ ] FastAPI 异步框架特性
- [ ] Python 3.12 类型提示最佳实践
- [ ] SQLAlchemy 2.0 ORM 使用
- [ ] Alembic 数据库迁移管理
- [ ] Celery 异步任务队列
- [ ] Redis 缓存与消息代理
- [ ] JWT + OAuth2 认证实现
- [ ] Pydantic 数据校验
## AI Agent 技术栈
### 待补充内容
- [ ] Agent 框架核心库选型
- [ ] LLM 模型选型与调用
- [ ] 提示工程Prompt Engineering框架
- [ ] 向量数据库(可选)
- [ ] 模型评估与监控
## 基础设施技术栈
### 待补充内容
- [ ] Docker 容器化方案
- [ ] Docker Compose 服务编排
- [ ] PostgreSQL 数据库配置
- [ ] Redis 集群配置(生产环境)
- [ ] Nginx 反向代理与负载均衡
- [ ] 日志收集方案
- [ ] 监控告警系统
## 开发工具链
### 待补充内容
- [ ] 代码格式化Black / Prettier
- [ ] 代码检查Ruff / ESLint
- [ ] 类型检查mypy
- [ ] 测试框架pytest / Jest / Playwright
- [ ] Git 工作流与 Hook
- [ ] CI/CD 工具
---
*本文档待补充,请参考 `README.md` 中的技术栈概要获取基本信息。*

View File

@ -1,318 +0,0 @@
# GEO 业务生命周期定义
## 概述
GEOGenerative Engine Optimization业务的运营遵循一个完整的 5 阶段生命周期。每个阶段都有明确的核心动作、输入输出和对应的功能模块。本生命周期同时适用于**自主订阅**和**代理运营**两种模式。
```
┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐
│ Stage 1 │───▶│ Stage 2 │───▶│ Stage 3 │───▶│ Stage 4 │───▶│ Stage 5 │
│ 诊断分析 │ │ 策略制定 │ │ 内容生产 │ │ 分发执行 │ │ 监测优化 │
└──────────┘ └──────────┘ └──────────┘ └──────────┘ └──────────┘
│ │ │ │ │
└───────────────┴───────────────┴───────────────┴───────────────┘
持续迭代优化
```
---
## Stage 1诊断分析
### 目标
全面了解品牌在主流 AI 搜索引擎中的现状,识别品牌被引用的情况、未被引用的机会点以及竞品的引用表现。
### 核心动作
| 动作 | 说明 | 执行方式 |
|------|------|----------|
| 品牌查询执行 | 向多个 AI 平台发送与品牌相关的查询 | 系统自动执行 |
| 响应结果收集 | 收集各 AI 平台返回的生成内容 | 系统自动收集 |
| 引用检测分析 | 分析生成内容中是否引用品牌、引用方式和位置 | AI Agent 自动分析 |
| 竞品对比分析 | 分析竞品在相同查询下的引用情况 | AI Agent 自动分析 |
| 诊断报告生成 | 汇总分析结果,生成诊断报告 | 系统自动生成 |
### 输入
- 品牌名称、品牌关键词
- 目标 AI 平台列表ChatGPT、Kimi、Perplexity 等)
- 核心查询模板(行业问题、产品对比、品牌认知等)
- 主要竞品列表
### 输出
- 各平台查询响应原文
- 引用检测结果(引用/未引用/竞品引用)
- 引用置信度评分
- 竞品引用对比数据
- 诊断分析报告PDF/HTML
### 对应功能模块
- **查询管理系统**:查询创建、执行、状态跟踪
- **引用检测引擎**AI 响应解析、引用识别、置信度评估
- **竞品分析模块**竞品引用数据对比、SWOT 分析
- **诊断报告模块**:数据聚合、报告模板渲染、多格式导出
### 关键指标
| 指标 | 说明 | 计算方式 |
|------|------|----------|
| 引用率 | 品牌被引用的查询占比 | 被引用查询数 / 总查询数 |
| 平均置信度 | 引用检测的平均置信度分数 | 所有引用记录置信度之和 / 引用记录数 |
| 竞品引用率 | 竞品被引用的查询占比 | 竞品被引用查询数 / 总查询数 |
| 平台覆盖率 | 已检测的 AI 平台数量占比 | 已检测平台数 / 目标平台数 |
---
## Stage 2策略制定
### 目标
基于 Stage 1 的诊断结果,制定针对性的 GEO 优化策略,明确优化方向、内容生产计划和执行时间表。
### 核心动作
| 动作 | 说明 | 执行方式 |
|------|------|----------|
| 诊断结果解读 | 分析诊断报告,识别关键问题和机会点 | AI Agent + 人工确认 |
| 策略自动生成 | 基于诊断结果生成 GEO 优化策略 | AI Agent 自动生成 |
| 规则库制定 | 制定内容生成和优化的规则约束 | 系统模板 + 人工调整 |
| 内容计划排期 | 制定内容生产的主题、数量和排期 | 系统自动排期 |
| 目标设定 | 设定可量化的优化目标 | 人工设定 + 系统建议 |
### 输入
- Stage 1 诊断分析报告
- 品牌定位与核心信息
- 目标受众画像
- 预算与资源约束
- 行业基准数据
### 输出
- GEO 优化策略文档
- 内容生产计划表
- 规则库配置
- 优化目标与 KPI
- 策略执行时间表
### 对应功能模块
- **策略生成引擎**:基于诊断数据自动生成优化策略
- **规则管理系统**:规则创建、版本管理、生效控制
- **计划排期系统**:内容主题规划、发布时间排期
- **目标管理系统**KPI 设定、目标追踪、偏差预警
### 策略类型
| 策略类型 | 适用场景 | 预期效果 |
|----------|----------|----------|
| 内容补充策略 | 品牌未被引用的查询场景 | 提升引用覆盖率 |
| 内容优化策略 | 品牌被引用但质量不高 | 提升引用质量和排名 |
| 竞品对抗策略 | 竞品引用率高于品牌 | 抢夺竞品引用份额 |
| 平台专攻策略 | 特定平台表现不佳 | 提升该平台表现 |
| 全面优化策略 | 整体表现均需提升 | 系统性提升各指标 |
---
## Stage 3内容生产
### 目标
根据 Stage 2 制定的策略和计划,生产符合 GEO 优化标准的高质量内容资产。
### 核心动作
| 动作 | 说明 | 执行方式 |
|------|------|----------|
| 内容主题生成 | 根据策略和计划生成具体的内容主题 | AI Agent 自动生成 |
| 内容草稿生成 | 生成符合 GEO 标准的文章/问答/说明 | AI Agent 自动生成 |
| 质量规则检查 | 检查内容是否符合规则库要求 | RuleChecker Agent |
| 人工审核确认 | 关键内容经人工审核后发布 | 人工操作 |
| 素材资源准备 | 准备配图、数据图表等辅助素材 | 系统生成 + 人工上传 |
### 输入
- Stage 2 策略文档和内容计划
- 规则库配置
- 品牌素材库Logo、产品介绍、案例等
- 行业知识和参考资料
### 输出
- GEO 优化内容资产(文章、问答、白皮书等)
- 内容质量评分
- 内容元数据(关键词、目标平台、适用查询等)
- 素材资源包
### 对应功能模块
- **内容生成引擎**:基于主题和规则自动生成内容
- **规则检查引擎**:内容质量校验、规则合规性检查
- **素材管理系统**:素材上传、管理、智能匹配
- **内容版本管理**:内容草稿、版本对比、发布控制
### 内容类型
| 内容类型 | 说明 | 适用平台 |
|----------|------|----------|
| 问答型内容 | 针对常见问题的结构化回答 | ChatGPT、Kimi |
| 文章型内容 | 深度行业文章或品牌故事 | Perplexity、搜索引擎 |
| 数据型内容 | 研究报告、数据统计 | 所有平台 |
| 对比型内容 | 产品与竞品的客观对比 | 所有平台 |
| 指南型内容 | 操作指南、使用教程 | 所有平台 |
### GEO 内容质量标准
- **权威性**:内容需体现品牌专业度和行业权威性
- **结构化**:使用清晰的标题层级和结构化格式
- **数据支撑**:包含统计数据、案例研究等可信来源
- **关键词优化**:自然融入目标关键词和长尾词
- **平台适配**:根据不同 AI 平台的偏好调整内容形式
---
## Stage 4分发执行
### 目标
将 Stage 3 生产的内容分发至目标渠道,执行 GEO 优化策略,并追踪分发效果。
### 核心动作
| 动作 | 说明 | 执行方式 |
|------|------|----------|
| 渠道配置管理 | 管理内容分发的目标渠道和平台 | 系统配置 |
| 内容发布执行 | 将内容发布至目标渠道 | 系统自动/半自动 |
| 分发效果追踪 | 追踪内容发布后的表现数据 | 系统自动追踪 |
| 链接建设 | 建立品牌内容与权威来源的链接关系 | 系统辅助 |
| 社媒同步 | 将内容同步至社交媒体扩大影响 | 系统对接 |
### 输入
- Stage 3 生产的内容资产
- 渠道配置信息
- 发布时间表
- 分发策略参数
### 输出
- 内容发布记录
- 分发效果数据(曝光、点击、引用)
- 渠道表现报告
- 链接建设报告
### 对应功能模块
- **渠道管理系统**:渠道配置、状态监控、接入管理
- **内容分发引擎**:发布执行、状态同步、失败重试
- **效果追踪系统**UTM 追踪、回源分析、数据收集
- **链接管理系统**:链接生成、追踪、失效检测
### 分发渠道
| 渠道类型 | 具体渠道 | 内容形式 |
|----------|----------|----------|
| 自有媒体 | 官网博客、知识库、帮助中心 | 文章、指南 |
| 第三方平台 | 知乎、百家号、今日头条 | 文章、问答 |
| 社媒平台 | 微信公众号、LinkedIn、Twitter | 短文、图文 |
| 行业平台 | 行业论坛、垂直社区 | 专业文章 |
| 权威来源 | 维基百科、行业报告库 | 数据、引用源 |
---
## Stage 5监测优化
### 目标
持续监测品牌在 AI 搜索引擎中的表现变化,评估优化效果,并基于数据反馈迭代优化策略。
### 核心动作
| 动作 | 说明 | 执行方式 |
|------|------|----------|
| 定期查询执行 | 按设定周期重新执行品牌查询 | 系统自动执行 |
| 引用变化追踪 | 对比历史数据,追踪引用变化趋势 | 系统自动追踪 |
| 性能指标计算 | 计算引用率、排名、覆盖率等 KPI | 系统自动计算 |
| 效果报告生成 | 生成周期性效果报告 | 系统自动生成 |
| 策略迭代优化 | 基于效果数据调整优化策略 | AI Agent + 人工决策 |
### 输入
- Stage 1-4 的历史数据和执行记录
- 设定的 KPI 目标和基准值
- 定期查询执行结果
- 竞品最新表现数据
### 输出
- 实时监测仪表盘
- 周期性效果报告(周报/月报/季报)
- 趋势分析图表
- 策略调整建议
- 告警通知
### 对应功能模块
- **性能追踪系统**KPI 计算、趋势分析、告警触发
- **报告生成系统**:报告模板、自动渲染、多格式导出
- **仪表盘系统**:可视化图表、实时数据、下钻分析
- **告警系统**:阈值设置、多渠道通知、告警升级
### 监测周期
| 周期 | 监测内容 | 报告类型 |
|------|----------|----------|
| 实时 | 查询执行状态、系统健康 | 系统监控 |
| 每日 | 引用变化、异常检测 | 日报 |
| 每周 | 综合表现、趋势分析 | 周报 |
| 每月 | 完整效果评估、策略建议 | 月报 |
| 每季 | 深度分析、竞品对比、策略迭代 | 季报 |
### 核心 KPI 体系
| KPI | 说明 | 目标值 |
|-----|------|--------|
| 引用率提升 | 品牌引用率的增长幅度 | 月度 ≥ 5% |
| 平均排名 | 品牌在引用中的平均位置 | 争取前 3 |
| 平台覆盖率 | 已优化平台的引用占比 | ≥ 80% |
| 内容转化率 | 内容分发后的引用转化 | ≥ 30% |
| 竞品相对优势 | 品牌引用率 vs 竞品引用率 | 保持领先 |
---
## 生命周期迭代闭环
```
┌─────────────────────────────────┐
│ Stage 5 监测优化 │
│ - 追踪效果、生成报告、策略迭代 │
└───────────────┬─────────────────┘
┌───────────────▼─────────────────┐
│ Stage 1 诊断分析 │
│ - 查询执行、引用检测、竞品分析 │
└───────────────┬─────────────────┘
┌───────────────▼─────────────────┐
│ Stage 2 策略制定 │
│ - 策略生成、规则制定、计划排期 │
└───────────────┬─────────────────┘
┌───────────────▼─────────────────┐
│ Stage 3 内容生产 │
│ - 内容生成、质量检查、素材准备 │
└───────────────┬─────────────────┘
┌───────────────▼─────────────────┐
│ Stage 4 分发执行 │
│ - 渠道管理、内容发布、效果追踪 │
└─────────────────────────────────┘
```
GEO 业务生命周期是一个持续迭代的闭环系统。每次完成 Stage 5 后诊断分析Stage 1将在新的基准上重新执行形成数据驱动的持续优化循环。
---
*本文档定义了 GEO 平台的完整业务生命周期,各阶段的详细功能设计请参考功能清单文档。*

View File

@ -1,337 +0,0 @@
# GEO 平台 - 完整功能清单
## 概述
本文档按 GEO 业务生命周期的 5 个阶段Stage 1-5+ 通用模块,列出 GEO 平台的所有功能项。每个功能项包含功能名称、功能描述、适用模式和优先级。
**优先级说明**
- **P0Critical**MVP 核心功能,必须实现
- **P1High**:重要功能,尽快实现
- **P2Medium**:增强功能,后续迭代
- **P3Low**:优化功能,长期规划
---
## Stage 1诊断分析
### 1.1 查询管理
| ID | 功能名称 | 功能描述 | 自主订阅 | 代理运营 | 优先级 |
|----|----------|----------|----------|----------|--------|
| S1-F01 | 查询模板管理 | 创建、编辑、删除品牌查询模板,支持变量替换 | 有 | 有 | P0 |
| S1-F02 | 批量查询创建 | 基于模板批量生成查询,支持 Excel/CSV 导入 | 有 | 有 | P0 |
| S1-F03 | 多平台查询执行 | 向多个 AI 平台ChatGPT/Kimi/Perplexity并行发送查询 | 有 | 有 | P0 |
| S1-F04 | 查询状态跟踪 | 实时跟踪查询执行状态(排队中/执行中/已完成/失败) | 有 | 有 | P0 |
| S1-F05 | 查询结果查看 | 查看各平台返回的原始响应内容 | 有 | 有 | P0 |
| S1-F06 | 查询历史管理 | 查看历史查询记录,支持筛选、搜索、分页 | 有 | 有 | P0 |
| S1-F07 | 查询重试机制 | 对失败的查询自动/手动重试 | 有 | 有 | P0 |
| S1-F08 | 查询调度配置 | 配置查询的执行时间和频率(即时/定时/周期) | 有 | 有 | P1 |
| S1-F09 | 查询结果导出 | 导出查询结果为 JSON/Excel/PDF 格式 | 有 | 有 | P1 |
| S1-F10 | 查询队列管理 | 查看和管理查询执行队列,支持优先级调整 | 无 | 有 | P2 |
### 1.2 引用检测
| ID | 功能名称 | 功能描述 | 自主订阅 | 代理运营 | 优先级 |
|----|----------|----------|----------|----------|--------|
| S1-F11 | 引用自动识别 | 自动识别 AI 响应中是否包含品牌引用 | 有 | 有 | P0 |
| S1-F12 | 引用内容提取 | 提取引用片段的具体内容和上下文 | 有 | 有 | P0 |
| S1-F13 | 引用位置标记 | 标记引用在响应中的具体位置 | 有 | 有 | P0 |
| S1-F14 | 引用置信度评估 | 评估引用检测结果的置信度(高/中/低) | 有 | 有 | P0 |
| S1-F15 | 引用类型分类 | 分类引用类型(直接引用/间接引用/未引用/竞品引用) | 有 | 有 | P0 |
| S1-F16 | 引用记录管理 | 查看和管理所有引用检测记录 | 有 | 有 | P0 |
| S1-F17 | 引用统计分析 | 按平台、时间、类型等维度统计引用数据 | 有 | 有 | P1 |
| S1-F18 | 引用趋势分析 | 分析引用率随时间的变化趋势 | 有 | 有 | P1 |
| S1-F19 | 引用对比分析 | 对比不同查询或不同时间点的引用差异 | 有 | 有 | P1 |
| S1-F20 | 引用详情查看 | 查看单次引用的完整上下文和关联查询 | 有 | 有 | P0 |
### 1.3 竞品分析
| ID | 功能名称 | 功能描述 | 自主订阅 | 代理运营 | 优先级 |
|----|----------|----------|----------|----------|--------|
| S1-F21 | 竞品列表管理 | 管理竞品品牌列表,支持增删改查 | 有 | 有 | P0 |
| S1-F22 | 竞品查询执行 | 对竞品执行相同的查询,获取引用数据 | 有 | 有 | P0 |
| S1-F23 | 竞品引用对比 | 对比品牌和竞品在相同查询下的引用表现 | 有 | 有 | P0 |
| S1-F24 | 竞品引用排名 | 生成品牌在竞品中的引用排名 | 有 | 有 | P1 |
| S1-F25 | 竞品 SWOT 分析 | 基于引用数据生成竞品的 SWOT 分析 | 有 | 有 | P1 |
| S1-F26 | 竞品动态监测 | 监测竞品引用表现的变化动态 | 有 | 有 | P2 |
| S1-F27 | 竞品策略推测 | 基于竞品引用模式推测其 GEO 策略 | 有 | 有 | P3 |
### 1.4 诊断报告
| ID | 功能名称 | 功能描述 | 自主订阅 | 代理运营 | 优先级 |
|----|----------|----------|----------|----------|--------|
| S1-F28 | 诊断报告生成 | 基于查询和引用数据自动生成诊断报告 | 有 | 有 | P0 |
| S1-F29 | 报告模板管理 | 管理诊断报告的模板和样式 | 有 | 有 | P1 |
| S1-F30 | 报告多格式导出 | 支持 PDF/HTML/Excel 格式导出 | 有 | 有 | P0 |
| S1-F31 | 报告在线查看 | 在线查看诊断报告的渲染效果 | 有 | 有 | P0 |
| S1-F32 | 报告分享功能 | 生成分享链接或发送邮件分享报告 | 有 | 有 | P2 |
| S1-F33 | 白标报告(代理) | 生成去除 GEO 品牌标识的白标报告 | 无 | 有 | P1 |
---
## Stage 2策略制定
### 2.1 策略生成
| ID | 功能名称 | 功能描述 | 自主订阅 | 代理运营 | 优先级 |
|----|----------|----------|----------|----------|--------|
| S2-F01 | 策略自动推荐 | 基于诊断结果自动推荐 GEO 优化策略 | 有 | 有 | P0 |
| S2-F02 | 策略编辑确认 | 编辑和确认 AI 生成的优化策略 | 有 | 有 | P0 |
| S2-F03 | 策略版本管理 | 管理策略的历史版本,支持对比和回滚 | 有 | 有 | P1 |
| S2-F04 | 策略库管理 | 管理策略模板库,支持复用和参考 | 有 | 有 | P1 |
| S2-F05 | 策略效果预测 | 基于历史数据预测策略预期效果 | 有 | 有 | P2 |
### 2.2 规则管理
| ID | 功能名称 | 功能描述 | 自主订阅 | 代理运营 | 优先级 |
|----|----------|----------|----------|----------|--------|
| S2-F06 | 规则库创建 | 创建内容生成的规则约束(关键词、格式、风格等) | 有 | 有 | P0 |
| S2-F07 | 规则库编辑 | 编辑和修改已有规则库配置 | 有 | 有 | P0 |
| S2-F08 | 规则库版本 | 管理规则库的版本历史 | 有 | 有 | P1 |
| S2-F09 | 规则库启用/禁用 | 控制规则库的生效状态 | 有 | 有 | P0 |
| S2-F10 | 规则库导入/导出 | 支持规则库配置的导入和导出 | 有 | 有 | P2 |
### 2.3 计划排期
| ID | 功能名称 | 功能描述 | 自主订阅 | 代理运营 | 优先级 |
|----|----------|----------|----------|----------|--------|
| S2-F11 | 内容主题规划 | 基于策略生成内容生产的主题列表 | 有 | 有 | P0 |
| S2-F12 | 排期日历视图 | 以日历形式展示内容发布计划 | 有 | 有 | P0 |
| S2-F13 | 排期调整 | 拖拽调整内容发布的时间和顺序 | 有 | 有 | P1 |
| S2-F14 | 排期冲突检测 | 自动检测排期中的时间和资源冲突 | 有 | 有 | P1 |
| S2-F15 | 计划与实际对比 | 对比计划排期和实际执行情况 | 有 | 有 | P2 |
### 2.4 目标管理
| ID | 功能名称 | 功能描述 | 自主订阅 | 代理运营 | 优先级 |
|----|----------|----------|----------|----------|--------|
| S2-F16 | KPI 设定 | 设定可量化的优化目标(引用率、排名等) | 有 | 有 | P0 |
| S2-F17 | 目标追踪 | 实时追踪目标达成进度 | 有 | 有 | P1 |
| S2-F18 | 偏差预警 | 当实际表现偏离目标时触发预警 | 有 | 有 | P1 |
| S2-F19 | 目标调整 | 根据实际情况调整目标值 | 有 | 有 | P2 |
---
## Stage 3内容生产
### 3.1 内容生成
| ID | 功能名称 | 功能描述 | 自主订阅 | 代理运营 | 优先级 |
|----|----------|----------|----------|----------|--------|
| S3-F01 | 内容自动生成 | 基于主题和规则自动生成 GEO 优化内容 | 有 | 有 | P0 |
| S3-F02 | 生成参数配置 | 配置内容生成的参数(长度、风格、语气等) | 有 | 有 | P0 |
| S3-F03 | 多版本生成 | 同一主题生成多个版本供选择 | 有 | 有 | P1 |
| S3-F04 | 内容编辑 | 对生成的内容进行在线编辑 | 有 | 有 | P0 |
| S3-F05 | 内容预览 | 预览内容的最终呈现效果 | 有 | 有 | P0 |
| S3-F06 | 内容保存草稿 | 将内容保存为草稿状态 | 有 | 有 | P0 |
| S3-F07 | 内容类型选择 | 选择生成内容的类型(文章/问答/指南等) | 有 | 有 | P0 |
| S3-F08 | 批量内容生成 | 批量生成多个主题的内容 | 有 | 有 | P1 |
### 3.2 质量检查
| ID | 功能名称 | 功能描述 | 自主订阅 | 代理运营 | 优先级 |
|----|----------|----------|----------|----------|--------|
| S3-F09 | 规则合规检查 | 检查内容是否符合规则库要求 | 有 | 有 | P0 |
| S3-F10 | 质量评分 | 对内容进行质量评分可读性、SEO、GEO 等维度) | 有 | 有 | P0 |
| S3-F11 | 问题高亮 | 高亮显示内容中的问题和改进建议 | 有 | 有 | P0 |
| S3-F12 | 一键修复 | 对检测到的问题提供一键修复建议 | 有 | 有 | P1 |
| S3-F13 | 人工审核标记 | 标记需要人工审核的内容 | 有 | 有 | P0 |
| S3-F14 | 审核工作流 | 定义内容审核的审批流程 | 无 | 有 | P1 |
### 3.3 素材管理
| ID | 功能名称 | 功能描述 | 自主订阅 | 代理运营 | 优先级 |
|----|----------|----------|----------|----------|--------|
| S3-F15 | 素材上传 | 上传图片、视频、文档等素材 | 有 | 有 | P0 |
| S3-F16 | 素材库管理 | 管理素材的分类、标签、搜索 | 有 | 有 | P0 |
| S3-F17 | 智能配图 | 基于内容主题推荐或生成配图 | 有 | 有 | P2 |
| S3-F18 | 素材版权检查 | 检查素材的版权和使用权 | 有 | 有 | P3 |
### 3.4 内容版本
| ID | 功能名称 | 功能描述 | 自主订阅 | 代理运营 | 优先级 |
|----|----------|----------|----------|----------|--------|
| S3-F19 | 版本历史 | 查看内容的所有历史版本 | 有 | 有 | P1 |
| S3-F20 | 版本对比 | 对比不同版本的内容差异 | 有 | 有 | P1 |
| S3-F21 | 版本回滚 | 恢复到内容的某个历史版本 | 有 | 有 | P1 |
| S3-F22 | 发布控制 | 控制内容的发布状态和可见性 | 有 | 有 | P0 |
---
## Stage 4分发执行
### 4.1 渠道管理
| ID | 功能名称 | 功能描述 | 自主订阅 | 代理运营 | 优先级 |
|----|----------|----------|----------|----------|--------|
| S4-F01 | 渠道配置 | 配置内容分发的目标渠道信息 | 有 | 有 | P0 |
| S4-F02 | 渠道状态监控 | 监控各渠道的连接和可用状态 | 有 | 有 | P0 |
| S4-F03 | 渠道分组管理 | 对渠道进行分组管理 | 有 | 有 | P1 |
| S4-F04 | 渠道授权管理 | 管理渠道 API 授权和 Token | 有 | 有 | P0 |
| S4-F05 | 渠道接入扩展 | 支持添加新的分发渠道 | 有 | 有 | P2 |
### 4.2 内容发布
| ID | 功能名称 | 功能描述 | 自主订阅 | 代理运营 | 优先级 |
|----|----------|----------|----------|----------|--------|
| S4-F06 | 单内容发布 | 将单条内容发布至指定渠道 | 有 | 有 | P0 |
| S4-F07 | 批量发布 | 批量发布多条内容至多个渠道 | 有 | 有 | P1 |
| S4-F08 | 定时发布 | 设定未来的发布时间自动执行 | 有 | 有 | P1 |
| S4-F09 | 发布预览 | 预览内容在各渠道的最终呈现 | 有 | 有 | P0 |
| S4-F10 | 发布撤回 | 撤回已发布的内容 | 有 | 有 | P1 |
| S4-F11 | 发布状态追踪 | 追踪内容发布的实时状态 | 有 | 有 | P0 |
### 4.3 效果追踪
| ID | 功能名称 | 功能描述 | 自主订阅 | 代理运营 | 优先级 |
|----|----------|----------|----------|----------|--------|
| S4-F12 | UTM 参数生成 | 为分发内容生成 UTM 追踪参数 | 有 | 有 | P1 |
| S4-F13 | 点击数据收集 | 收集内容的点击和访问数据 | 有 | 有 | P1 |
| S4-F14 | 回源分析 | 分析内容分发后的回源和转化 | 有 | 有 | P2 |
| S4-F15 | 渠道效果对比 | 对比不同渠道的分发效果 | 有 | 有 | P1 |
---
## Stage 5监测优化
### 5.1 性能监测
| ID | 功能名称 | 功能描述 | 自主订阅 | 代理运营 | 优先级 |
|----|----------|----------|----------|----------|--------|
| S5-F01 | KPI 仪表盘 | 展示核心 KPI 的实时数据仪表盘 | 有 | 有 | P0 |
| S5-F02 | 引用率趋势 | 展示品牌引用率的历史趋势图表 | 有 | 有 | P0 |
| S5-F03 | 平台表现对比 | 对比品牌在各 AI 平台的表现 | 有 | 有 | P0 |
| S5-F04 | 排名变化追踪 | 追踪品牌在引用中的排名变化 | 有 | 有 | P1 |
| S5-F05 | 竞品动态对比 | 实时对比品牌和竞品的最新表现 | 有 | 有 | P1 |
| S5-F06 | 异常检测 | 自动检测数据中的异常波动 | 有 | 有 | P1 |
| S5-F07 | 告警通知 | 数据异常或目标达成时发送通知 | 有 | 有 | P1 |
| S5-F08 | 数据下钻 | 从汇总数据下钻到详细记录 | 有 | 有 | P1 |
### 5.2 报告生成
| ID | 功能名称 | 功能描述 | 自主订阅 | 代理运营 | 优先级 |
|----|----------|----------|----------|----------|--------|
| S5-F09 | 日报生成 | 自动生成每日效果摘要报告 | 有 | 有 | P1 |
| S5-F10 | 周报生成 | 自动生成周度综合分析报告 | 有 | 有 | P0 |
| S5-F11 | 月报生成 | 自动生成月度深度分析报告 | 有 | 有 | P0 |
| S5-F12 | 季报生成 | 自动生成季度战略评估报告 | 有 | 有 | P1 |
| S5-F13 | 自定义报告 | 自定义报告的时间范围和数据维度 | 有 | 有 | P1 |
| S5-F14 | 报告自动发送 | 定时自动发送报告至指定邮箱 | 有 | 有 | P2 |
| S5-F15 | 白标报告 | 生成无 GEO 标识的品牌化报告 | 无 | 有 | P1 |
### 5.3 策略迭代
| ID | 功能名称 | 功能描述 | 自主订阅 | 代理运营 | 优先级 |
|----|----------|----------|----------|----------|--------|
| S5-F16 | 效果归因分析 | 分析优化动作与效果变化的关联 | 有 | 有 | P2 |
| S5-F17 | 策略调整建议 | 基于效果数据自动生成策略调整建议 | 有 | 有 | P1 |
| S5-F18 | A/B 测试管理 | 管理不同策略的 A/B 测试 | 有 | 有 | P2 |
| S5-F19 | 优化方案记录 | 记录每次优化的方案和结果 | 有 | 有 | P1 |
---
## 通用模块
### 6.1 用户与认证
| ID | 功能名称 | 功能描述 | 自主订阅 | 代理运营 | 优先级 |
|----|----------|----------|----------|----------|--------|
| G-F01 | 用户注册 | 邮箱/手机号注册账号 | 有 | 有 | P0 |
| G-F02 | 用户登录 | 邮箱密码/验证码登录 | 有 | 有 | P0 |
| G-F03 | 密码找回 | 通过邮箱/手机找回密码 | 有 | 有 | P0 |
| G-F04 | 个人中心 | 管理个人信息、头像、联系方式 | 有 | 有 | P0 |
| G-F05 | 修改密码 | 修改登录密码 | 有 | 有 | P0 |
| G-F06 | OAuth 登录 | 支持 Google/微信等第三方登录 | 有 | 有 | P2 |
| G-F07 | 单点登录SSO | 企业级 SSO 集成 | 无 | 有 | P3 |
### 6.2 团队管理Enterprise / 代理)
| ID | 功能名称 | 功能描述 | 自主订阅 | 代理运营 | 优先级 |
|----|----------|----------|----------|----------|--------|
| G-F08 | 团队成员管理 | 邀请、添加、移除团队成员 | Enterprise | 有 | P1 |
| G-F09 | 角色权限分配 | 为团队成员分配角色和权限 | Enterprise | 有 | P1 |
| G-F10 | 团队资源管理 | 管理团队共享的查询、内容、报告 | Enterprise | 有 | P2 |
| G-F11 | 操作日志 | 查看团队成员的操作记录 | Enterprise | 有 | P2 |
### 6.3 代理运营(代理模式专用)
| ID | 功能名称 | 功能描述 | 自主订阅 | 代理运营 | 优先级 |
|----|----------|----------|----------|----------|--------|
| G-F12 | 客户管理 | 管理代理运营的客户列表 | 无 | 有 | P0 |
| G-F13 | 客户项目分配 | 为客户分配专属的查询和优化项目 | 无 | 有 | P0 |
| G-F14 | 客户数据隔离 | 确保客户间数据完全隔离 | 无 | 有 | P0 |
| G-F15 | 白标配置 | 配置客户品牌标识Logo、颜色等 | 无 | 有 | P1 |
| G-F16 | 客户报告中心 | 为客户生成和发送专属报告 | 无 | 有 | P1 |
| G-F17 | 客户账单管理 | 管理客户的账单和结算 | 无 | 有 | P2 |
### 6.4 订阅与计费
| ID | 功能名称 | 功能描述 | 自主订阅 | 代理运营 | 优先级 |
|----|----------|----------|----------|----------|--------|
| G-F18 | 订阅计划查看 | 查看当前订阅计划的功能和限额 | 有 | 有 | P0 |
| G-F19 | 订阅升级/降级 | 升级或降级订阅等级 | 有 | 有 | P0 |
| G-F20 | 配额使用查询 | 查看当前周期的配额使用情况 | 有 | 有 | P0 |
| G-F21 | 支付管理 | 管理支付方式和账单历史 | 有 | 有 | P0 |
| G-F22 | 发票管理 | 申请和下载发票 | 有 | 有 | P1 |
| G-F23 | 代理计费 | 代理运营模式的计费结算 | 无 | 有 | P2 |
### 6.5 系统管理Admin
| ID | 功能名称 | 功能描述 | 自主订阅 | 代理运营 | 优先级 |
|----|----------|----------|----------|----------|--------|
| G-F24 | 用户管理 | 管理全平台用户(查看、禁用、删除) | 无 | 无 | P0 |
| G-F25 | 订阅计划管理 | 配置和管理订阅计划 | 无 | 无 | P1 |
| G-F26 | 系统配置 | 配置系统参数和全局设置 | 无 | 无 | P1 |
| G-F27 | 平台适配器管理 | 管理 AI 平台适配器的接入和配置 | 无 | 无 | P1 |
| G-F28 | 系统监控 | 查看系统运行状态和性能指标 | 无 | 无 | P1 |
| G-F29 | 日志查看 | 查看系统日志和操作日志 | 无 | 无 | P2 |
| G-F30 | 公告管理 | 发布全平台公告和通知 | 无 | 无 | P3 |
### 6.6 通知与消息
| ID | 功能名称 | 功能描述 | 自主订阅 | 代理运营 | 优先级 |
|----|----------|----------|----------|----------|--------|
| G-F31 | 站内消息 | 接收系统通知和任务完成提醒 | 有 | 有 | P0 |
| G-F32 | 邮件通知 | 通过邮件接收重要通知 | 有 | 有 | P1 |
| G-F33 | 通知设置 | 配置通知的类型和接收方式 | 有 | 有 | P1 |
| G-F34 | 消息中心 | 查看历史通知和消息 | 有 | 有 | P0 |
---
## 功能优先级汇总
### P0MVP 核心功能)- 必须在第一阶段实现
**Stage 1**:查询模板管理、批量查询、多平台执行、状态跟踪、结果查看、历史管理、引用识别、内容提取、位置标记、置信度评估、类型分类、记录管理、竞品列表、竞品查询、竞品对比、诊断报告生成、多格式导出、在线查看
**Stage 2**:策略推荐、策略编辑、规则库创建/编辑、主题规划、日历视图、KPI 设定
**Stage 3**:内容自动生成、参数配置、内容编辑、内容预览、保存草稿、类型选择、规则检查、质量评分、问题高亮、人工审核、素材上传/管理、发布控制
**Stage 4**:渠道配置、状态监控、授权管理、单内容发布、发布预览、状态追踪
**Stage 5**KPI 仪表盘、引用率趋势、平台对比、周报/月报生成
**通用**:用户注册/登录/找回密码、个人中心、修改密码、订阅查看、升级降级、配额查询、支付管理、站内消息、消息中心
### P1重要功能- 第二阶段实现
**Stage 1**查询调度、结果导出、队列管理、统计分析、趋势分析、对比分析、报告模板、分享功能、白标报告、竞品排名、SWOT 分析
**Stage 2**:策略版本、策略库、规则版本、启用禁用、排期调整、冲突检测、目标追踪、偏差预警
**Stage 3**:多版本生成、批量生成、一键修复、审核工作流、版本历史/对比/回滚
**Stage 4**渠道分组、批量发布、定时发布、发布撤回、UTM 生成、渠道对比
**Stage 5**:排名追踪、竞品动态、异常检测、告警通知、数据下钻、日报、季报、自定义报告、策略调整建议、优化方案记录
**通用**:团队成员管理、角色权限、白标配置、客户报告、发票管理、订阅计划管理、系统配置、适配器管理、邮件通知、通知设置
### P2/P3增强/优化功能)- 后续迭代规划
包括竞品动态监测、策略推测、效果预测、计划对比、智能配图、版权检查、渠道扩展、回源分析、报告自动发送、效果归因、A/B 测试、团队资源、操作日志、客户账单、代理计费、系统监控、日志查看、公告管理、SSO 等。
---
*本文档为 GEO 平台的完整功能清单,具体实现时可根据迭代计划分阶段落地。*

View File

@ -0,0 +1,14 @@
# 项目概览
本目录包含GEO平台的项目概述、架构设计、技术栈说明和变更日志。
## 目录内容
- [README](./README.md) - 项目简介
- [系统架构](./architecture.md) - 系统架构设计
- [技术栈](./tech-stack.md) - 技术栈说明
- [变更日志](./changelog.md) - 版本变更记录
---
*GEO平台 - 让品牌在AI时代被看见。*

View File

@ -0,0 +1,121 @@
# 系统架构
## 整体架构
```
┌─────────────────────────────────────────────────────────────┐
│ 前端 (Next.js) │
├─────────────────────────────────────────────────────────────┤
│ Dashboard │ 诊断中心 │ 内容管理 │ 知识库 │ 监控面板 │
└─────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────┐
│ API网关 (FastAPI) │
├─────────────────────────────────────────────────────────────┤
│ Auth │ Brands │ Diagnosis │ Content │ Knowledge │ Monitoring│
└─────────────────────────────────────────────────────────────┘
┌────────────────────┼────────────────────┐
▼ ▼ ▼
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
│ PostgreSQL │ │ Redis │ │ LLM Providers │
│ (数据存储) │ │ (缓存/队列) │ │ (AI服务) │
└─────────────────┘ └─────────────────┘ └─────────────────┘
```
## 诊断架构
### SEO诊断
```
网站URL输入 → 爬虫抓取 → 技术分析 → 内容分析 → 外链分析 → 生成SEO诊断报告
```
**诊断维度:**
- 技术SEO索引、爬取、Core Web Vitals
- 页面SEOTitle/Meta、H标签、内链
- 内容质量E-E-A-T、新鲜度、重复内容
- 外链分析(质量、毒性、锚文本)
- 用户体验(移动适配、页面速度)
### GEO诊断
```
品牌信息输入 → 内容可提取性检测 → 实体清晰度检测 → E-E-A-T信号检测
→ Schema标记检测 → 主题权威检测 → AI平台引用检测 → 生成GEO诊断报告
```
**诊断维度:**
- 内容可提取性(直接回答块、问答式标题、列表表格)
- 实体清晰度(品牌定义、目标受众、差异化价值)
- E-E-A-T信号作者资质、专业认证、数据来源
- Schema标记Organization、Product、FAQPage等
- 主题权威(内容深度、话题覆盖度、实体信号一致性)
- 引用就绪度引用频率、引用质量、AI声量占比
## Agent Framework
```
┌─────────────────────────────────────────────────────────────┐
│ Agent Framework │
├─────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
│ │ Dispatcher │→│ Registry │→│ Monitor │ │
│ └─────────────┘ └─────────────┘ └─────────────┘ │
│ │
│ ┌─────────────────────────────────────────────────┐ │
│ │ Agents │ │
│ ├───────────┬───────────┬───────────┬────────────┤ │
│ │ Citation │ Content │ DeAI │ GEO │ │
│ │ Detector │ Generator │ Agent │ Optimizer │ │
│ └───────────┴───────────┴───────────┴────────────┘ │
│ │
│ ┌─────────────────────────────────────────────────┐ │
│ │ Services │ │
│ ├───────────┬───────────┬───────────┬────────────┤ │
│ │ RuleValid │ SEOOptim │ Sensitive │ HTMLGen │ │
│ └───────────┴───────────┴───────────┴────────────┘ │
│ │
└─────────────────────────────────────────────────────────────┘
```
**Agent职责说明**
| Agent | 职责 | 输入 | 输出 |
|-------|------|------|------|
| CitationDetector | 引用检测 | AI平台响应、品牌名称 | 引用检测结果 |
| ContentGenerator | 内容生成 | 主题、规则库、品牌素材 | GEO优化内容 |
| DeAIAgent | 去AI化 | AI生成内容 | 自然化内容 |
| GEOOptimizer | GEO优化 | 原始内容、关键词策略 | 优化后内容 |
**注意:** 当前项目中的`SEOOptimizer`实际执行的是GEO优化内容结构化、实体优化而非传统SEO优化技术SEO。建议在后续版本中明确区分
- **SEOOptimizer** → 技术SEO优化网站技术层面
- **GEOOptimizer** → 内容实体优化AI引用层面
## 内容生成Pipeline
```
用户输入 → 母题选择 → 内容生成 → 去AI化 → GEO优化 → HTML生成 → 输出
```
**注意:** 原Pipeline中的"SEO优化"实际是GEO优化内容结构化不是传统SEO优化。
## 知识库系统
```
文档上传 → 文本分块 → 向量化 → RAG检索 → LLM增强生成
```
## 监控体系
```
API请求 → Prometheus指标 → Grafana可视化
健康检查 → 告警通知
```
## 数据库设计
核心表users, organizations, brands, competitors, queries, citations, alerts, contents, knowledge_bases, knowledge_entities, knowledge_relations

View File

@ -0,0 +1,71 @@
# 变更日志
## v2.0.0 (当前版本)
### 新增功能
- Agent Framework架构
- 内容生成Pipeline
- 知识库RAG系统
- 知识图谱
- 平台规则中心
- 监控模块
- GEO母题库
- 图片生成服务
### 核心模块
- [x] 诊断分析 (100%)
- [x] 竞品分析 (100%)
- [x] 内容生产 (100%)
- [x] 知识库 (100%)
- [x] 监控 (100%)
### Agent Framework
- [x] TaskDispatcher - 任务分发器
- [x] AgentRegistry - 注册中心
- [x] CitationDetector - 引用检测
- [x] ContentGenerator - 内容生成
- [x] DeAIAgent - 去AI化
- [x] GEOOptimizer - GEO优化
- [x] PipelineEngine - Pipeline编排引擎
### 内容Pipeline
- [x] RuleValidator - 规则校验
- [x] SensitiveFilter - 敏感词过滤
- [x] SEOOptimizer - SEO优化
- [x] HTMLGenerator - HTML生成
### 知识库
- [x] RAG服务
- [x] 文档解析器 (PDF/DOCX/Markdown/TXT)
- [x] 分块策略 (Recursive/Semantic/Fixed)
- [x] 嵌入服务
- [x] 混合检索器
- [x] 增强检索 (重排序/上下文压缩)
### 知识图谱
- [x] 实体抽取
- [x] 关系抽取
- [x] 图谱构建
- [x] 图谱查询
### 平台规则
- [x] 10个平台规则 (知乎/微信/百家号/头条/微博/小红书/B站/简书/掘金/抖音)
- [x] AI敏感度配置
- [x] 敏感词过滤
- [x] HTML规则
### 测试覆盖
- 后端单元测试: ~480个
- 前端测试: 75个
- E2E测试: 65个
---
## v1.0.0 (初始版本)
### 初始版本
- 项目初始化
- 基础架构搭建
- 基础查询功能
- 引用检测
- 基础Dashboard

View File

@ -0,0 +1,64 @@
# 技术栈
## 前端技术栈
| 组件 | 技术 | 版本 |
|------|------|------|
| 框架 | Next.js | 14+ |
| UI库 | React | 18+ |
| 语言 | TypeScript | 5.x |
| 样式 | Tailwind CSS | 4.x |
| 组件库 | shadcn/ui | - |
| 图表 | Recharts | - |
| 状态管理 | SWR | - |
| 认证 | NextAuth.js | v5 |
## 后端技术栈
| 组件 | 技术 | 版本 |
|------|------|------|
| 框架 | FastAPI | 0.109+ |
| 语言 | Python | 3.12+ |
| ORM | SQLAlchemy | 2.0+ |
| 数据库 | PostgreSQL | 15+ |
| 缓存 | Redis | 7+ |
| 任务队列 | Celery | 5+ |
| 认证 | JWT + OAuth2 | - |
## AI Agent框架
| 组件 | 技术 |
|------|------|
| Agent基础 | 自研模块化框架 |
| 消息队列 | Redis Pub/Sub |
| 注册中心 | Redis Hash |
| 任务分发 | Dispatcher + Registry |
## 基础设施
| 组件 | 技术 |
|------|------|
| 容器化 | Docker + Docker Compose |
| 反向代理 | Nginx |
| 监控 | Prometheus + Grafana |
## 项目目录结构
```
geo/
├── backend/ # FastAPI 后端
│ ├── app/
│ │ ├── api/ # API路由
│ │ ├── agent_framework/ # Agent框架
│ │ ├── models/ # 数据模型
│ │ ├── schemas/ # Pydantic模型
│ │ ├── services/ # 业务逻辑
│ │ ├── workers/ # 异步任务
│ │ └── monitoring/ # 监控模块
│ └── requirements.txt
├── frontend/ # Next.js 前端
│ ├── app/ # 页面
│ ├── components/ # 组件
│ └── lib/ # 工具函数
└── docs/ # 文档
```

View File

@ -1,502 +0,0 @@
# GEO 平台 - AI Agent 框架设计
## 概述
GEO 平台的 AI Agent 框架是系统的核心智能层,负责执行各种需要 AI 能力的业务任务。框架采用模块化、可插拔的设计,支持 Agent 的动态注册、任务分发和状态管理。
本文档定义 Agent 框架的整体架构、核心组件和通信机制。
## Agent 列表
### CitationDetector引用检测 Agent
| 属性 | 说明 |
|------|------|
| **职责** | 解析 AI 平台响应,识别其中是否包含品牌引用 |
| **输入** | AI 平台原始响应文本、品牌名称列表、竞品名称列表 |
| **输出** | 引用检测结果(引用/未引用/竞品引用)、引用片段、置信度评分 |
| **核心能力** | 自然语言理解、实体识别、引用关系判定 |
| **触发方式** | 查询任务完成后自动触发 |
| **优先级** | P0 |
**处理流程**
```
接收响应文本 ──▶ 文本预处理 ──▶ 品牌实体识别 ──▶ 引用关系判定
竞品引用检测 ──▶ 置信度评估
结果格式化 ──▶ 上报结果
```
**引用类型判定**
| 类型 | 判定标准 | 置信度权重 |
|------|----------|------------|
| `direct_quote` | 品牌名被直接提及并作为信息来源 | 高 |
| `indirect_reference` | 品牌信息被提及但未明确标注来源 | 中 |
| `no_reference` | 品牌完全未被提及 | 确定 |
| `competitor_reference` | 竞品被引用而品牌未被引用 | 高 |
### ContentGenerator内容生成 Agent
| 属性 | 说明 |
|------|------|
| **职责** | 根据策略和规则生成 GEO 优化的内容资产 |
| **输入** | 内容主题、规则库配置、品牌素材、参考数据 |
| **输出** | GEO 优化内容(文章/问答/指南等)、内容元数据 |
| **核心能力** | 创意写作、SEO/GEO 优化、多风格适配 |
| **触发方式** | 用户手动触发或按排期自动触发 |
| **优先级** | P0 |
**处理流程**
```
接收生成任务 ──▶ 需求分析 ──▶ 素材收集 ──▶ 大纲生成
内容撰写 ──▶ GEO 优化处理
自检校验 ──▶ 输出结果
```
**内容类型支持**
| 类型 | 说明 | 适用场景 |
|------|------|----------|
| `article` | 长篇文章800-2000 字) | 深度行业内容 |
| `qa` | 问答对 | 常见品牌问题 |
| `guide` | 操作指南 | 产品使用场景 |
| `comparison` | 对比分析 | 竞品对抗 |
| `data_report` | 数据报告 | 行业研究 |
### RuleChecker规则检查 Agent
| 属性 | 说明 |
|------|------|
| **职责** | 检查内容是否符合规则库中的约束条件 |
| **输入** | 待检查内容、规则库配置 |
| **输出** | 检查结果(通过/不通过)、问题列表、改进建议 |
| **核心能力** | 规则引擎、内容质量评估、语义分析 |
| **触发方式** | ContentGenerator 生成后自动触发,或手动触发 |
| **优先级** | P0 |
**检查维度**
| 维度 | 说明 | 示例规则 |
|------|------|----------|
| `keyword_coverage` | 关键词覆盖度 | 必须包含核心关键词 |
| `readability` | 可读性评分 | Flesch 阅读 ease 分数 |
| `tone_consistency` | 语气一致性 | 保持专业、客观语气 |
| `length_compliance` | 长度合规 | 符合目标平台偏好长度 |
| `fact_accuracy` | 事实准确性 | 品牌信息准确无误 |
| `geo_optimization` | GEO 优化度 | 符合 GEO 最佳实践 |
### CompetitorAnalyzer竞品分析 Agent
| 属性 | 说明 |
|------|------|
| **职责** | 分析竞品的 GEO 表现和策略 |
| **输入** | 竞品名称、查询结果、引用数据 |
| **输出** | 竞品分析报告、SWOT 分析、策略建议 |
| **核心能力** | 竞品数据对比、模式识别、策略推断 |
| **触发方式** | 诊断分析流程中自动触发 |
| **优先级** | P1 |
**分析维度**
| 维度 | 说明 |
|------|------|
| `citation_share` | 竞品引用份额占比 |
| `platform_presence` | 各平台出现频率 |
| `content_pattern` | 竞品被引用的内容模式 |
| `strength_weakness` | 竞品的优劣势分析 |
| `strategy_inference` | 推测竞品的 GEO 策略 |
### PerformanceTracker性能追踪 Agent
| 属性 | 说明 |
|------|------|
| **职责** | 持续追踪品牌 GEO 表现,计算 KPI检测异常 |
| **输入** | 历史引用数据、当前查询结果、KPI 配置 |
| **输出** | KPI 计算结果、趋势分析、异常告警、优化建议 |
| **核心能力** | 时间序列分析、趋势预测、异常检测 |
| **触发方式** | 按设定周期自动执行 |
| **优先级** | P1 |
**追踪指标**
| 指标 | 计算方式 | 更新频率 |
|------|----------|----------|
| `citation_rate` | 被引用查询数 / 总查询数 | 每日 |
| `avg_confidence` | 引用记录置信度平均值 | 每次查询 |
| `platform_coverage` | 已覆盖平台数 / 目标平台数 | 每周 |
| `ranking_position` | 品牌在引用中的平均排名 | 每周 |
| `competitor_gap` | 品牌引用率 - 竞品平均引用率 | 每周 |
## 通信协议
### Agent 间通信模型
GEO 平台 Agent 之间采用 **基于 Redis 消息队列的异步通信协议**
```
┌─────────────┐ ┌─────────────┐ ┌─────────────┐
│ Agent A │◀────▶│ Redis │◀────▶│ Agent B │
│ (Producer) │ │ Queue │ │ (Consumer) │
└─────────────┘ └──────┬──────┘ └─────────────┘
┌──────▼──────┐
│ Registry │
│ (Hash Map) │
└─────────────┘
```
### 消息格式规范
```python
class AgentMessage:
"""Agent 间通信消息标准格式"""
message_id: str # UUID消息唯一标识
timestamp: datetime # ISO 8601 格式时间戳
sender: str # 发送方 Agent 名称
recipient: str # 接收方 Agent 名称,或 "broadcast"
message_type: str # 消息类型
correlation_id: str # 关联 ID用于追踪请求-响应链路
reply_to: str # 回复目标 Agent
payload: dict # 消息负载数据
ttl: int # 消息存活时间(秒),默认 300
```
**消息类型message_type**
| 类型 | 说明 | 使用场景 |
|------|------|----------|
| `task_request` | 任务请求 | Agent A 请求 Agent B 执行任务 |
| `task_result` | 任务结果 | Agent B 返回任务执行结果 |
| `status_update` | 状态更新 | Agent 上报自身状态变化 |
| `heartbeat` | 心跳 | Agent 定期发送健康信号 |
| `config_update` | 配置更新 | 通知 Agent 配置变更 |
| `error_report` | 错误报告 | Agent 报告执行错误 |
### 消息负载payload规范
```python
class TaskPayload:
"""任务请求/结果负载"""
task_id: str # 任务唯一标识
task_type: str # 任务类型
parameters: dict # 任务参数
data: dict # 输入/输出数据
priority: int # 优先级 1-10默认 5
deadline: datetime # 截止时间(可选)
retry_count: int # 已重试次数
max_retries: int # 最大重试次数
```
### 通信模式
#### 1. 点对点模式Point-to-Point
```
Agent A ──task_request──▶ Queue ──▶ Agent B
Agent A ◀──task_result─── Queue ◀── Agent B
```
- 每个 Agent 有专属的输入队列
- 任务按轮询方式分配
- 适用于明确的任务分配场景
#### 2. 发布/订阅模式Pub/Sub
```
Agent A ──config_update──▶ Topic
┌───────────────┼───────────────┐
▼ ▼ ▼
Agent B Agent C Agent D
```
- 使用 Redis Pub/Sub 频道
- 适用于配置变更、状态广播
#### 3. 工作流模式Workflow
```
┌────────────┐ ┌────────────┐ ┌────────────┐
│ QueryTask │───▶│ Citation │───▶│ Report │
│ Completed │ │ Detector │ │ Generator │
└────────────┘ └────────────┘ └────────────┘
```
- 基于事件驱动的工作流编排
- 使用 correlation_id 串联整个工作流
- 支持并行和串行执行
## 注册机制
### Agent 注册流程
```
Agent 启动
读取配置(能力声明、资源需求、版本信息)
连接 Redis
向 Registry 注册
创建专属输入队列
启动心跳定时器
开始监听任务
```
### Registry 数据结构
```python
# Redis Hash: agent:registry:{agent_name}
{
"name": "CitationDetector",
"version": "1.0.0",
"capabilities": "[\"citation_detect\", \"entity_recognition\"]",
"status": "online", # online / busy / offline
"queue_name": "agent:citation_detector",
"registered_at": "2026-01-01T00:00:00Z",
"last_heartbeat": "2026-01-01T00:01:00Z",
"current_task": "task_123", # 当前执行任务 ID空闲时为 null
"concurrency": 5, # 最大并发数
"load": 0.6 # 当前负载 0-1
}
```
### 服务发现
```python
# 获取所有在线 Agent
agents = redis.hgetall("agent:registry:*")
online_agents = [a for a in agents if a.status == "online"]
# 按能力筛选 Agent
capable_agents = [a for a in online_agents if "citation_detect" in a.capabilities]
# 选择负载最低的 Agent
selected = min(capable_agents, key=lambda a: a.load)
```
## 任务分发
### 分发策略
| 策略 | 说明 | 适用场景 |
|------|------|----------|
| `round_robin` | 轮询分配 | 同类型 Agent 负载均衡 |
| `least_load` | 最低负载优先 | 不同 Agent 性能差异时 |
| `capability_match` | 能力匹配 | 需要特定能力的任务 |
| `priority_queue` | 优先级队列 | 任务有明确优先级差异 |
| `sticky` | 粘性分配 | 需要上下文连续性的任务 |
### 任务分发流程
```
任务提交
解析任务类型 ──▶ 确定所需能力
查询 Registry ──▶ 筛选可用 Agent
应用分发策略 ──▶ 选择目标 Agent
序列化消息 ──▶ 写入目标队列
更新任务状态 ──▶ 等待结果
```
### 任务状态机
```
┌──────────┐ 提交任务 ┌──────────┐
│ PENDING │───────────────▶│ QUEUED │
└──────────┘ └─────┬────┘
│ 入队
┌──────────┐
开始执行│ RUNNING │
┌───────┤ │
│ └─────┬────┘
│ │
执行成功 执行失败
│ │
▼ ▼
┌──────────┐ ┌──────────┐
│COMPLETED │ │ FAILED │
└──────────┘ └─────┬────┘
重试次数 < 最大重试
┌──────────┐
│ RETRY │
└────┬─────┘
│ 重新入队
└────▶ QUEUED
```
## 状态上报
### 上报内容
每个 Agent 需要定期上报以下状态信息:
| 字段 | 说明 | 上报频率 |
|------|------|----------|
| `status` | 当前状态 | 每次变化 + 心跳 |
| `current_task` | 正在执行的任务 | 每次变化 |
| `load` | 当前负载 0-1 | 心跳时 |
| `queue_depth` | 待处理任务数 | 心跳时 |
| `processed_count` | 已处理任务总数 | 心跳时 |
| `error_count` | 错误次数 | 心跳时 |
| `avg_process_time` | 平均处理时间 | 心跳时 |
### 心跳机制
```python
# 心跳间隔30 秒
HEARTBEAT_INTERVAL = 30
# 心跳超时判定90 秒3 个心跳周期)
HEARTBEAT_TIMEOUT = 90
# 心跳消息格式
heartbeat_message = {
"message_type": "heartbeat",
"sender": "CitationDetector",
"timestamp": "2026-01-01T00:01:00Z",
"payload": {
"status": "online",
"load": 0.4,
"queue_depth": 2,
"processed_count": 150,
"error_count": 3,
"avg_process_time": 2.5
}
}
```
### 状态监控
- Registry 定期检查 Agent 心跳超时
- 超时 Agent 自动标记为 `offline`
- 正在执行的任务自动重新分发
- 管理员可通过 Dashboard 查看 Agent 状态
## 配置管理
### 配置层级
| 层级 | 说明 | 优先级 |
|------|------|--------|
| `system_default` | 系统默认配置 | 最低 |
| `tenant_config` | 租户级配置(覆盖默认) | 中 |
| `agent_override` | Agent 本地配置(覆盖租户) | 最高 |
### 配置项
```python
class AgentConfig:
"""Agent 配置项"""
# 通用配置
max_concurrency: int = 5 # 最大并发数
task_timeout: int = 300 # 任务超时时间(秒)
retry_policy: str = "exponential" # 重试策略
max_retries: int = 3 # 最大重试次数
# CitationDetector 专用
confidence_threshold: float = 0.7 # 置信度阈值
enable_semantic_analysis: bool = True
# ContentGenerator 专用
default_content_length: int = 1500
content_styles: List[str] = ["professional", "casual", "technical"]
# RuleChecker 专用
strict_mode: bool = False
custom_rules: List[dict] = []
```
### 动态配置更新
```
管理员更新配置
写入 Redis 配置存储
发布 config_update 消息
各 Agent 接收消息
Agent 热重载配置(无需重启)
```
## 管理功能需求
### Agent Dashboard
管理员需要通过 Dashboard 监控和管理所有 Agent
| 功能 | 说明 | 优先级 |
|------|------|--------|
| Agent 列表 | 查看所有已注册 Agent 的状态 | P0 |
| 实时状态看板 | 展示 Agent 在线状态、负载、任务数 | P0 |
| 任务队列监控 | 查看各队列的待处理任务数 | P0 |
| 任务执行日志 | 查看任务执行的详细日志 | P1 |
| 错误告警 | Agent 异常时发送告警通知 | P1 |
| 配置管理 | 在线修改 Agent 配置 | P1 |
| Agent 启停 | 远程启动或停止 Agent | P2 |
| 版本管理 | 查看和升级 Agent 版本 | P2 |
| 性能统计 | 查看 Agent 的性能指标和趋势 | P2 |
### Agent 日志规范
```python
# 日志格式
{
"timestamp": "2026-01-01T00:00:00Z",
"level": "INFO", # DEBUG / INFO / WARNING / ERROR
"agent": "CitationDetector",
"task_id": "task_123",
"message": "开始执行引用检测",
"metadata": {
"query_id": "q_456",
"brand": "ExampleBrand"
}
}
```
### 错误处理
| 错误类型 | 处理方式 | 告警级别 |
|----------|----------|----------|
| 任务执行失败 | 自动重试,达到最大重试后标记失败 | Warning |
| Agent 心跳超时 | 标记离线,任务重新分发 | Error |
| 队列积压 | 触发弹性扩容或告警 | Warning |
| Agent 崩溃 | 自动重启(配合 supervisor | Critical |
| 配置错误 | 拒绝加载,保持上次有效配置 | Error |
---
*本文档定义 GEO 平台的 AI Agent 框架设计,各 Agent 的具体实现算法和模型选型在模块指南中详细说明。*

View File

@ -1,27 +0,0 @@
# GEO 平台 - 组件库设计
## 概述
本文档定义 GEO 平台的 UI 组件库设计方案,基于 shadcn/ui 进行扩展和定制。
> **TODO**: 本文档为占位文件,待补充完整内容。
## 待补充内容
- [ ] 组件库架构设计
- [ ] 基础组件清单Button、Input、Card 等)
- [ ] 业务组件清单QueryCard、CitationBadge、ReportPreview 等)
- [ ] 组件 Props 接口定义
- [ ] 组件使用示例
- [ ] 组件主题定制方案
- [ ] 组件开发规范
- [ ] 组件文档生成方案Storybook
## 参考
- [UI 风格指南](./ui-style-guide.md)
- [前端系统架构](../.qoder/repowiki/zh/content/前端系统架构/UI组件库.md)
---
*本文档待补充。*

View File

@ -1,31 +0,0 @@
# GEO 平台 - 数据库 Schema 设计
## 概述
本文档定义 GEO 平台的完整数据库 Schema 设计,包括表结构、字段定义、索引设计和关系图。
> **TODO**: 本文档为占位文件,待补充完整内容。
## 待补充内容
- [ ] 实体关系图ER Diagram
- [ ] 用户相关表users、user_profiles、teams 等)
- [ ] 查询相关表queries、query_tasks、query_results 等)
- [ ] 引用检测相关表citation_records、citation_analysis 等)
- [ ] 订阅相关表subscriptions、plans、payments 等)
- [ ] 内容相关表contents、content_versions、rules 等)
- [ ] 报告相关表reports、report_templates 等)
- [ ] 渠道相关表channels、channel_configs 等)
- [ ] 代理运营相关表clients、projects 等)
- [ ] 索引设计说明
- [ ] 数据迁移策略
## 参考
- [数据模型设计](../.qoder/repowiki/zh/content/后端系统架构/数据模型设计.md)
- [数据库设计](../.qoder/repowiki/zh/content/数据库设计/数据库设计.md)
- [表结构设计](../.qoder/repowiki/zh/content/数据库设计/表结构设计.md)
---
*本文档待补充。*

View File

@ -1,443 +0,0 @@
# GEO 平台 - UI 风格指南
## 概述
本文档定义 GEO 平台的前端 UI 风格规范确保全平台视觉一致性。所有前端开发人员、UI 设计师和产品经理应遵循本指南进行界面设计和开发。
GEO 平台的 UI 设计遵循以下核心原则:
- **专业可信**:体现企业级 SaaS 平台的专业感和可信度
- **简洁高效**:减少视觉噪音,让用户聚焦于数据和操作
- **层次分明**:通过清晰的信息层级引导用户注意力
- **响应式适配**:适配桌面端、平板和移动端设备
## 风格特征总结
### 设计语言
| 特征 | 描述 | 应用场景 |
|------|------|----------|
| 扁平化设计 | 去除冗余装饰,强调内容和功能 | 全局 |
| 卡片化布局 | 使用卡片容器组织信息模块 | 仪表盘、列表页 |
| 留白呼吸感 | 充足的内边距和外边距 | 全局 |
| 微妙的层次感 | 通过阴影和边框区分层级 | 卡片、弹窗、下拉 |
| 数据可视化 | 丰富的图表和数据展示 | 仪表盘、报告页 |
| 状态色彩编码 | 使用颜色直观表达状态 | 状态标签、进度条 |
### 整体氛围
- **主色调**:深蓝/靛蓝色系,传达专业、科技、可信
- **辅助色**:翠绿色表示成功/正向,琥珀色表示警告,红色表示错误
- **中性色**:丰富的灰度层级用于文字和边框
- **背景**:浅色背景为主,深色模式作为可选项
## Design Tokens
### 颜色体系
#### 主色调Primary
| Token | 色值 | 用途 |
|-------|------|------|
| `--color-primary-50` | `#EEF2FF` | 极浅背景、hover 状态 |
| `--color-primary-100` | `#E0E7FF` | 轻色背景、选中状态 |
| `--color-primary-200` | `#C7D2FE` | 边框高亮 |
| `--color-primary-300` | `#A5B4FC` | 禁用状态主色 |
| `--color-primary-400` | `#818CF8` | 次要强调 |
| `--color-primary-500` | `#6366F1` | **主色** - 按钮、链接、图标 |
| `--color-primary-600` | `#4F46E5` | 按钮 hover、深色调 |
| `--color-primary-700` | `#4338CA` | 按钮 active |
| `--color-primary-800` | `#3730A3` | 深色背景文字 |
| `--color-primary-900` | `#312E81` | 标题、深色元素 |
#### 语义色Semantic
| Token | 色值 | 用途 |
|-------|------|------|
| `--color-success-50` | `#F0FDF4` | 成功状态轻背景 |
| `--color-success-500` | `#22C55E` | **成功** - 成功提示、正向指标 |
| `--color-success-600` | `#16A34A` | 成功 hover |
| `--color-warning-50` | `#FFFBEB` | 警告状态轻背景 |
| `--color-warning-500` | `#F59E0B` | **警告** - 警告提示、待处理 |
| `--color-warning-600` | `#D97706` | 警告 hover |
| `--color-error-50` | `#FEF2F2` | 错误状态轻背景 |
| `--color-error-500` | `#EF4444` | **错误** - 错误提示、失败状态 |
| `--color-error-600` | `#DC2626` | 错误 hover |
| `--color-info-50` | `#EFF6FF` | 信息状态轻背景 |
| `--color-info-500` | `#3B82F6` | **信息** - 信息提示 |
#### 中性色Neutral
| Token | 色值 | 用途 |
|-------|------|------|
| `--color-gray-50` | `#F9FAFB` | 页面背景、hover 背景 |
| `--color-gray-100` | `#F3F4F6` | 卡片背景、分隔背景 |
| `--color-gray-200` | `#E5E7EB` | 边框、分隔线 |
| `--color-gray-300` | `#D1D5DB` | 禁用边框、占位符 |
| `--color-gray-400` | `#9CA3AF` | 次要文字、图标 |
| `--color-gray-500` | `#6B7280` | 辅助文字 |
| `--color-gray-600` | `#4B5563` | 正文文字 |
| `--color-gray-700` | `#374151` | 标题文字 |
| `--color-gray-800` | `#1F2937` | 深色标题 |
| `--color-gray-900` | `#111827` | 主标题、深色文字 |
### 圆角Border Radius
| Token | 值 | 用途 |
|-------|-----|------|
| `--radius-none` | `0` | 直角(表格、特殊场景) |
| `--radius-sm` | `4px` | 小元素(标签、徽章) |
| `--radius-md` | `6px` | 默认(按钮、输入框) |
| `--radius-lg` | `8px` | 卡片、容器 |
| `--radius-xl` | `12px` | 大卡片、模态框 |
| `--radius-2xl` | `16px` | 页面级容器 |
| `--radius-full` | `9999px` | 圆形(头像、状态点) |
### 阴影Shadow
| Token | 值 | 用途 |
|-------|-----|------|
| `--shadow-sm` | `0 1px 2px 0 rgba(0,0,0,0.05)` | 轻微提升(按钮 hover |
| `--shadow-md` | `0 4px 6px -1px rgba(0,0,0,0.1), 0 2px 4px -2px rgba(0,0,0,0.1)` | 卡片默认 |
| `--shadow-lg` | `0 10px 15px -3px rgba(0,0,0,0.1), 0 4px 6px -4px rgba(0,0,0,0.1)` | 下拉菜单、浮层 |
| `--shadow-xl` | `0 20px 25px -5px rgba(0,0,0,0.1), 0 8px 10px -6px rgba(0,0,0,0.1)` | 模态框、抽屉 |
| `--shadow-inner` | `inset 0 2px 4px 0 rgba(0,0,0,0.05)` | 内嵌效果 |
### 字体Typography
#### 字体族
| Token | 值 | 用途 |
|-------|-----|------|
| `--font-sans` | `Inter, "PingFang SC", "Microsoft YaHei", sans-serif` | 正文、UI 文字 |
| `--font-mono` | `"JetBrains Mono", "Fira Code", monospace` | 代码、数据值 |
#### 字号层级
| Token | 大小 | 字重 | 行高 | 用途 |
|-------|------|------|------|------|
| `--text-xs` | `12px` | 400 | `16px` | 辅助说明、时间戳 |
| `--text-sm` | `14px` | 400 | `20px` | 次要文字、描述 |
| `--text-base` | `16px` | 400 | `24px` | **正文默认** |
| `--text-lg` | `18px` | 500 | `28px` | 小标题、强调文字 |
| `--text-xl` | `20px` | 600 | `28px` | 卡片标题 |
| `--text-2xl` | `24px` | 600 | `32px` | 页面小标题 |
| `--text-3xl` | `30px` | 700 | `36px` | 页面标题 |
| `--text-4xl` | `36px` | 700 | `40px` | 大标题、数字展示 |
### 间距Spacing
| Token | 值 | 用途 |
|-------|-----|------|
| `--space-1` | `4px` | 极小间距 |
| `--space-2` | `8px` | 紧凑间距(图标与文字) |
| `--space-3` | `12px` | 小间距 |
| `--space-4` | `16px` | **默认间距**(内边距基础) |
| `--space-5` | `20px` | 中等间距 |
| `--space-6` | `24px` | 卡片内边距 |
| `--space-8` | `32px` | 区块间距 |
| `--space-10` | `40px` | 大区块间距 |
| `--space-12` | `48px` | 页面级间距 |
| `--space-16` | `64px` | 大页面间距 |
## 组件改造计划
GEO 平台基于 shadcn/ui 组件库进行构建,以下是对核心组件的定制化规范。
### Button 按钮
| 变体 | 背景 | 文字 | 边框 | Hover | 用途 |
|------|------|------|------|-------|------|
| `default` | `primary-500` | white | none | `primary-600` | 主要操作 |
| `secondary` | `gray-100` | `gray-900` | none | `gray-200` | 次要操作 |
| `outline` | transparent | `gray-700` | `gray-200` | `gray-50` | 辅助操作 |
| `ghost` | transparent | `gray-700` | none | `gray-100` | 极简操作 |
| `destructive` | `error-500` | white | none | `error-600` | 危险操作 |
| `link` | transparent | `primary-500` | none | underline | 链接样式 |
**尺寸规范**
- `sm``h-8 px-3 text-sm`
- `md`(默认):`h-10 px-4 text-base`
- `lg``h-12 px-6 text-lg`
### Card 卡片
```
默认样式:
- 背景white
- 圆角:`radius-lg` (8px)
- 阴影:`shadow-md`
- 内边距:`space-6` (24px)
- 边框:`1px solid gray-200`
Hover 状态:
- 阴影提升:`shadow-lg`
- 过渡:`transition-shadow duration-200`
```
**卡片变体**
| 变体 | 说明 |
|------|------|
| `default` | 标准卡片,用于信息展示 |
| `stats` | 统计卡片,大号数字 + 描述 |
| `interactive` | 可交互卡片hover 有明显反馈 |
| `flat` | 扁平卡片,无阴影,用于列表项 |
### Input 输入框
```
默认样式:
- 高度:`h-10`
- 圆角:`radius-md` (6px)
- 边框:`1px solid gray-300`
- 内边距:`px-3`
- 字体:`text-base`
- 背景white
Focus 状态:
- 边框:`primary-500`
- 阴影:`0 0 0 2px primary-100`
- 过渡:`transition-all duration-200`
错误状态:
- 边框:`error-500`
- 阴影:`0 0 0 2px error-50`
```
### Badge 标签
| 变体 | 背景 | 文字 | 用途 |
|------|------|------|------|
| `default` | `primary-100` | `primary-700` | 默认标签 |
| `secondary` | `gray-100` | `gray-700` | 次要标签 |
| `success` | `success-50` | `success-600` | 成功状态 |
| `warning` | `warning-50` | `warning-600` | 警告状态 |
| `error` | `error-50` | `error-600` | 错误状态 |
| `info` | `info-50` | `info-600` | 信息状态 |
### Table 表格
```
默认样式:
- 表头背景:`gray-50`
- 表头文字:`gray-700` font-medium text-sm
- 行高:`h-12`
- 行边框:`border-b border-gray-200`
- 行 hover`bg-gray-50`
- 单元格内边距:`px-4 py-3`
- 空状态:居中显示插图 + 提示文字
```
**特殊行样式**
| 状态 | 样式 |
|------|------|
| 选中行 | `bg-primary-50` |
| 禁用行 | `opacity-50` |
| 高亮行 | `bg-warning-50` |
### Modal / Dialog 模态框
```
默认样式:
- 遮罩:`bg-black/50`
- 容器背景white
- 圆角:`radius-xl` (12px)
- 阴影:`shadow-xl`
- 最大宽度:`max-w-lg` (512px)
- 内边距:`p-6`
- 动画fade-in + scale-in
头部:
- 标题:`text-lg font-semibold`
- 关闭按钮:右上角 ghost 按钮
底部操作区:
- 居右排列
- 主按钮在右,次按钮在左
- 间距:`gap-3`
```
### Toast 通知
| 类型 | 图标 | 背景 | 边框 |
|------|------|------|------|
| `success` | CheckCircle | white | `success-200` |
| `warning` | AlertTriangle | white | `warning-200` |
| `error` | XCircle | white | `error-200` |
| `info` | Info | white | `info-200` |
```
默认样式:
- 位置:右上角固定
- 最大宽度:`400px`
- 圆角:`radius-lg`
- 阴影:`shadow-lg`
- 内边距:`p-4`
- 自动关闭:`5秒`
- 动画slide-in-right + fade-out
```
### Navigation 导航
**侧边栏导航**
```
默认样式:
- 宽度:`w-64` (256px)
- 背景:`gray-900`
- 文字:`gray-300`
选中状态:
- 背景:`primary-600`
- 文字white
- 左侧边框指示:`3px solid primary-400`
Hover 状态:
- 背景:`gray-800`
- 文字white
```
**顶部导航**
```
默认样式:
- 高度:`h-16`
- 背景white
- 边框:`border-b border-gray-200`
- 阴影:`shadow-sm`
```
## 交互规范
### 过渡动画
| 场景 | 时长 | 缓动函数 | 属性 |
|------|------|----------|------|
| 按钮 hover | `150ms` | `ease-in-out` | `background-color, border-color, box-shadow` |
| 卡片 hover | `200ms` | `ease-out` | `box-shadow, transform` |
| 模态框出现 | `200ms` | `cubic-bezier(0.16, 1, 0.3, 1)` | `opacity, transform` |
| 下拉菜单 | `150ms` | `ease-out` | `opacity, transform` |
| Toast 出现/消失 | `300ms` | `ease-in-out` | `opacity, transform` |
| 页面切换 | `200ms` | `ease-in-out` | `opacity` |
| 数据加载骨架屏 | `pulse 2s infinite` | - | `opacity` |
### 加载状态
| 场景 | 样式 |
|------|------|
| 页面加载 | 全屏骨架屏 + Logo |
| 数据加载 | 卡片/表格骨架屏 |
| 按钮加载 | 按钮内 Spinner禁用点击 |
| 无限滚动 | 底部 Spinner |
| 文件上传 | 进度条 + 百分比 |
### 空状态
```
统一空状态样式:
- 居中布局
- 插图/图标:`primary-200` 色调
- 标题:`text-lg font-medium text-gray-700`
- 描述:`text-sm text-gray-500`
- 操作按钮(可选):`primary` 按钮
```
### 错误状态
```
统一错误状态样式:
- 错误图标:`error-500` 色的 AlertTriangle
- 标题:`text-lg font-medium text-gray-900`
- 描述:`text-sm text-gray-600`
- 重试按钮:`outline` 按钮
- 错误详情(可选):可展开的折叠面板
```
### 表单校验
```
校验提示样式:
- 错误信息:`text-xs text-error-500 mt-1`
- 错误图标:输入框右侧 `error-500` 图标
- 成功图标:输入框右侧 `success-500` Check 图标
- 实时校验:输入后 `300ms` 延迟触发
```
## 布局规范
### 页面结构
```
┌─────────────────────────────────────────────┐
│ 顶部导航栏 (h-16, fixed) │
├──────────┬──────────────────────────────────┤
│ │ 页面标题区 │
│ 侧边栏 │ ┌────────────────────────────┐ │
│ (w-64) │ │ 内容区 │ │
│ fixed │ │ ┌──────┐ ┌──────┐ ┌──────┐ │ │
│ │ │ │ 卡片 │ │ 卡片 │ │ 卡片 │ │ │
│ │ │ └──────┘ └──────┘ └──────┘ │ │
│ │ │ ┌────────────────────────┐ │ │
│ │ │ │ 表格/列表 │ │ │
│ │ │ └────────────────────────┘ │ │
│ │ └────────────────────────────┘ │
└──────────┴──────────────────────────────────┘
```
### 响应式断点
| 断点 | 宽度 | 布局调整 |
|------|------|----------|
| `sm` | `640px` | 移动端基础适配 |
| `md` | `768px` | 侧边栏可折叠 |
| `lg` | `1024px` | 完整桌面布局 |
| `xl` | `1280px` | 宽屏扩展 |
| `2xl` | `1536px` | 超宽屏 |
### 网格系统
- 采用 CSS Grid + Flexbox 混合布局
- 仪表盘卡片网格:`grid-cols-1 md:grid-cols-2 lg:grid-cols-3 xl:grid-cols-4`
- 卡片间距:`gap-6`
- 页面最大宽度:`max-w-7xl` (1280px),居中显示
## 图表规范
GEO 平台大量使用数据可视化图表,基于 Recharts 实现。
### 配色方案
```
图表主色系:
- 系列 1`#6366F1` (primary-500)
- 系列 2`#22C55E` (success-500)
- 系列 3`#F59E0B` (warning-500)
- 系列 4`#EF4444` (error-500)
- 系列 5`#8B5CF6` (violet-500)
- 系列 6`#06B6D4` (cyan-500)
```
### 图表样式
```
通用样式:
- 背景:透明
- 网格线:`gray-200`,虚线
- 坐标轴文字:`gray-500`text-xs
- 图例:底部居中
- Tooltip白色背景圆角阴影
柱状图:
- 圆角:`radius-sm` 顶部
- 间距:`barGap=4`
折线图:
- 线条宽度:`2px`
- 数据点:`r=4`hover `r=6`
- 填充区域:`fillOpacity=0.1`
饼图/环形图:
- 内半径:`60%`
- 标签:外部引导线 + 百分比
```
---
*本文档定义 GEO 平台的 UI 风格规范,所有前端实现应严格遵循以上 Design Tokens 和组件规范。*

View File

@ -0,0 +1,24 @@
# 模块说明
本目录包含GEO平台各核心模块的详细说明。
## 目录内容
- [诊断模块](./diagnosis.md) - SEO诊断和GEO诊断 ✅ 已完成
- [Agent框架](./agent-framework.md) - AI Agent框架设计
- [内容Pipeline](./content-pipeline.md) - 内容生成流水线
- [知识库](./knowledge-base.md) - RAG知识库系统
- [知识图谱](./knowledge-graph.md) - 知识图谱构建
- [平台规则](./platform-rules.md) - 平台规则中心
- [监控模块](./monitoring.md) - 系统监控
- [母题库](./topic-templates.md) - 母题模板管理
## 诊断模块API端点概览
| 端点 | 描述 | 状态 |
|------|------|------|
| `GET /api/v1/diagnosis/seo/{brand_id}` | 获取品牌的SEO诊断结果 | ✅ 已完成 |
| `GET /api/v1/diagnosis/geo/{brand_id}` | 获取品牌的GEO诊断结果 | ✅ 已完成 |
| `GET /api/v1/diagnosis/combined/{brand_id}` | 获取品牌的SEO+GEO综合诊断结果 | ✅ 已完成 |
详细说明请查看 [诊断模块文档](./diagnosis.md)。

View File

@ -0,0 +1,167 @@
# AI Agent框架
## 概述
GEO平台的AI Agent框架是系统的核心智能层采用模块化、可插拔的设计支持Agent的动态注册、任务分发和状态管理。
## 架构
```
┌─────────────────────────────────────────────────────────────┐
│ Agent Framework │
├─────────────────────────────────────────────────────────────┤
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────────┐ │
│ │ Dispatcher │→│ Registry │→│ PipelineEngine │ │
│ └─────────────┘ └─────────────┘ └─────────────────┘ │
│ │
│ ┌─────────────────────────────────────────────────┐ │
│ │ Agents │ │
│ ├───────────┬───────────┬───────────┬────────────┤ │
│ │ Citation │ Content │ DeAI │ GEO │ │
│ │ Detector │ Generator │ Agent │ Optimizer │ │
│ └───────────┴───────────┴───────────┴────────────┘ │
└─────────────────────────────────────────────────────────────┘
```
## 核心组件
### Dispatcher (调度器)
负责接收任务请求并分发到合适的Agent。
- 位置:`backend/app/agent_framework/dispatcher.py`
- 职责:任务路由、负载均衡、优先级处理、结果汇总
### Registry (注册中心)
管理所有Agent的注册信息和状态。
- 位置:`backend/app/agent_framework/registry.py`
- 存储PostgreSQL数据库 (AgentRegistryModel)
- 功能Agent注册、健康检查、服务发现、心跳更新
### PipelineEngine (编排引擎)
编排多阶段Agent任务链的执行。
- 位置:`backend/app/agent_framework/pipeline/engine.py`
- 职责DAG执行、变量传递、超时控制、条件执行
## Agent列表
### CitationDetector (引用检测Agent)
| 属性 | 说明 |
|------|------|
| 文件 | `backend/app/agent_framework/agents/citation_detector.py` |
| 职责 | 解析AI平台响应识别品牌引用情况 |
| 输入 | AI平台原始响应、品牌名称、竞品名称 |
| 输出 | 引用检测结果(引用/未引用/竞品引用) |
**引用类型**
- `direct_quote` - 品牌名直接提及
- `indirect_reference` - 品牌信息被提及
- `no_reference` - 品牌未被提及
- `competitor_reference` - 竞品被引用
### ContentGenerator (内容生成Agent)
| 属性 | 说明 |
|------|------|
| 文件 | `backend/app/agent_framework/agents/content_generator_agent.py` |
| 职责 | 生成GEO优化的内容资产 |
| 输入 | 内容主题、规则库配置、品牌素材 |
| 输出 | GEO优化内容文章/问答/指南等) |
**支持内容类型**
- `article` - 长篇文章800-2000字
- `qa` - 问答对
- `guide` - 操作指南
- `comparison` - 对比分析
- `data_report` - 数据报告
### DeAI Agent (去AI化Agent)
| 属性 | 说明 |
|------|------|
| 文件 | `backend/app/agent_framework/agents/deai_agent.py` |
| 职责 | 去除AI生成内容的痕迹 |
| 输入 | AI生成的内容 |
| 输出 | 自然化处理后的内容 |
### GEOOptimizer (GEO优化Agent)
| 属性 | 说明 |
|------|------|
| 文件 | `backend/app/agent_framework/agents/geo_optimizer_agent.py` |
| 职责 | 对内容进行GEO优化 |
| 输入 | 原始内容、关键词策略 |
| 输出 | 优化后的内容 |
## 通信协议
基于数据库+Redis Queue的异步通信
```python
class TaskMessage:
task_id: str # UUID
agent_name: str # 目标Agent名称
task_type: str # 任务类型
priority: int # 优先级(0-9)
input_data: dict # 输入参数
callback_url: str # 回调URL
created_at: datetime # 创建时间
timeout_seconds: int # 超时时间
```
### 通信流程
```
1. 外部请求 → Dispatcher.dispatch(TaskMessage) → Redis Queue
2. Redis Queue → Agent消费 → 执行任务
3. Agent执行完成 → TaskResult → Dispatcher.handle_result
4. Dispatcher更新数据库 → 触发回调(如有)
```
## 注册机制
Agent启动时向Registry注册
```python
{
"name": "CitationDetector",
"version": "1.0.0",
"status": "online", # online/busy/offline
"endpoint": "http://...",
"capabilities": {...},
"last_heartbeat": "2024-01-01T00:00:00Z"
}
```
### 心跳机制
- 心跳超时阈值90秒
- Registry.check_health() 定期检查并标记超时Agent为OFFLINE
## 任务状态机
```
PENDING → RUNNING → COMPLETED
FAILED → RETRY → RUNNING
CANCELLED
```
## Prompt模板
Agent使用统一的Prompt模板系统
| 文件 | 说明 |
|------|------|
| `prompts/base_template.py` | 基础模板结构 |
| `prompts/content_generator.py` | 内容生成模板 |
| `prompts/deai_agent.py` | 去AI化模板 |
| `prompts/geo_optimizer.py` | GEO优化模板 |
| `prompts/rule_checker.py` | 规则检查模板 |
| `prompts/topic_selector.py` | 母题选择模板 |

View File

@ -0,0 +1,109 @@
# Agent间通信协议
## 概述
本文档详细描述Agent Framework中各Agent之间的通信协议和数据结构。
## 消息类型
### TaskMessage (任务消息)
从调度器发往Agent的任务消息。
| 字段 | 类型 | 必填 | 说明 |
|------|------|------|------|
| task_id | str | 是 | UUID格式的任务ID |
| agent_name | str | 是 | 目标Agent名称 |
| task_type | str | 是 | 任务类型 |
| priority | int | 否 | 优先级0-99最高 |
| input_data | dict | 是 | 输入参数 |
| callback_url | str | 否 | 回调URL |
| created_at | datetime | 是 | 创建时间 |
| timeout_seconds | int | 否 | 超时时间默认300秒 |
### TaskResult (任务结果)
从Agent返回的结果消息。
| 字段 | 类型 | 必填 | 说明 |
|------|------|------|------|
| task_id | str | 是 | 对应的任务ID |
| agent_name | str | 是 | 执行任务的Agent名称 |
| status | str | 是 | 任务状态 (completed/failed/cancelled) |
| output_data | dict | 否 | 输出数据 |
| error_message | str | 否 | 错误信息 |
| started_at | datetime | 是 | 开始时间 |
| completed_at | datetime | 是 | 完成时间 |
| metrics | dict | 否 | 执行指标耗时、token消耗等 |
### TaskProgress (进度上报)
Agent执行过程中上报的进度信息。
| 字段 | 类型 | 必填 | 说明 |
|------|------|------|------|
| task_id | str | 是 | 任务ID |
| agent_name | str | 是 | Agent名称 |
| progress | float | 是 | 进度 (0.0-1.0) |
| message | str | 是 | 进度描述 |
| updated_at | datetime | 是 | 更新时间 |
## 状态枚举
### TaskStatus (任务状态)
```python
class TaskStatus(str, Enum):
PENDING = "pending" # 等待执行
RUNNING = "running" # 执行中
COMPLETED = "completed" # 已完成
FAILED = "failed" # 执行失败
CANCELLED = "cancelled" # 已取消
```
### AgentStatus (Agent状态)
```python
class AgentStatus(str, Enum):
ONLINE = "online" # 在线
OFFLINE = "offline" # 离线
BUSY = "busy" # 忙碌中
```
## 通信流程
```
1. 外部请求 → Dispatcher.dispatch(TaskMessage) → Redis Queue
2. Redis Queue → Agent消费 → 执行任务
3. Agent执行完成 → TaskResult → Dispatcher.handle_result
4. Dispatcher更新数据库 → 触发回调(如有)
```
## Agent类型
| 类型 | 说明 | 职责 |
|------|------|------|
| CITATION_DETECTOR | 引用检测Agent | 解析AI平台响应识别品牌引用 |
| CONTENT_GENERATOR | 内容生成Agent | 生成GEO优化内容 |
| DEAI_AGENT | 去AI化Agent | 去除AI生成痕迹 |
| GEO_OPTIMIZER | GEO优化Agent | SEO和GEO优化 |
| RULE_CHECKER | 规则检查Agent | 内容合规审核 |
| COMPETITOR_ANALYZER | 竞品分析Agent | 竞品数据收集分析 |
| PERFORMANCE_TRACKER | 性能追踪Agent | 追踪内容表现 |
## 任务类型
| Agent | task_type | 说明 |
|-------|-----------|------|
| CitationDetector | citation_detect | 检测品牌引用 |
| ContentGenerator | generate | 生成内容 |
| DeAIAgent | humanize | 去AI化处理 |
| GEOOptimizer | optimize | GEO优化 |
## 心跳机制
Agent通过定时更新心跳来维持在线状态
- 心跳超时阈值90秒
- 超时后Agent被标记为OFFLINE
- Registry.check_health() 定期检查所有Agent状态

View File

@ -0,0 +1,87 @@
# 内容生成Pipeline
## 概述
内容生成Pipeline是GEO平台的核心功能之一负责将用户输入转化为符合GEO优化标准的最终内容。
## Pipeline流程
```
用户输入 → 母题选择 → 内容生成 → 去AI化 → SEO优化 → HTML生成 → 输出
```
## 各阶段说明
### 1. 用户输入
接收用户的内容生成请求,包括:
- 品牌信息
- 目标平台
- 内容主题
- 关键词策略
### 2. 母题选择
从母题库中选择合适的模板:
- 位置:`backend/app/agent_framework/pipeline/loader.py`
- 功能:根据内容类型和行业匹配母题
### 3. 内容生成
调用ContentGenerator Agent生成初稿
- 使用LLM生成内容
- 应用品牌风格指南
- 遵守规则库约束
### 4. 去AI化
使用DeAI Agent处理内容
- 重写机械化的句式
- 增加语言多样性
- 保持语义一致性
### 5. SEO优化
使用GEOOptimizer Agent优化
- 关键词密度调整
- 语义相关性提升
- 结构化数据添加
### 6. HTML生成
将内容转换为HTML格式
- 响应式设计
- SEO友好的标签结构
- 平台适配
## Pipeline引擎
位置:`backend/app/agent_framework/pipeline/engine.py`
```python
class PipelineEngine:
"""内容生成Pipeline引擎"""
async def run(self, request: ContentRequest) -> ContentResult:
# 1. 加载母题
template = await self.loader.load(request.template_id)
# 2. 生成内容
draft = await self.generator.generate(request, template)
# 3. 去AI化
naturalized = await self.deai.process(draft)
# 4. SEO优化
optimized = await self.optimizer.optimize(naturalized)
# 5. HTML生成
html = await self.html_generator.generate(optimized)
return ContentResult(html=html, metadata={})
```

View File

@ -0,0 +1,431 @@
# 诊断模块
## 概述
诊断模块是GEO平台的核心功能提供**传统SEO诊断**和**GEO诊断**两种能力帮助用户全面了解网站和品牌在AI时代的搜索可见性。
## SEO诊断 vs GEO诊断
### 核心区别
| 维度 | SEO诊断 | GEO诊断 |
|------|---------|---------|
| **优化目标** | 网页排名 | 品牌被AI引用 |
| **诊断对象** | 网站 | 品牌实体+内容 |
| **成功指标** | 排名、流量、点击率 | 引用频率、AI声量占比 |
| **用户路径** | 点击链接访问网站 | AI直接推荐品牌 |
| **见效周期** | 3-6个月 | 2周-1个月 |
---
## SEO诊断
传统SEO诊断评估网站对搜索引擎的优化程度包含以下维度
### 1. 技术SEO诊断
| 诊断项 | 说明 | 工具/方法 |
|--------|------|----------|
| 索引状态 | 网站是否被搜索引擎正确索引 | site:domain.com、Search Console |
| 爬取错误 | 404、5xx、重定向链 | 爬虫工具 |
| Core Web Vitals | LCP<2.5sFID<100msCLS<0.1 | PageSpeed Insights |
| URL结构 | 规范化、重复URL | 爬虫分析 |
| robots.txt | 是否阻止重要页面 | 文件检查 |
| sitemap | 站点地图完整性 | XML验证 |
### 2. 页面SEO诊断
| 诊断项 | 说明 |
|--------|------|
| Title/Meta标签 | 是否完整、是否关键词堆砌 |
| H标签结构 | 层级是否清晰 |
| 关键词密度 | 是否合理分布 |
| 内链结构 | 是否有死链、锚文本是否相关 |
| 图片Alt | 是否添加描述性Alt文本 |
### 3. 内容质量诊断
| 诊断项 | 说明 |
|--------|------|
| 可读性 | 内容是否易于理解 |
| 信息深度 | 是否全面覆盖主题 |
| E-E-A-T | 经验、专业性、权威性、可信度 |
| 内容新鲜度 | 是否定期更新 |
| 重复内容 | 是否有大量重复页面 |
### 4. 外链分析
| 诊断项 | 说明 |
|--------|------|
| 反向链接质量 | 链接来源权威性 |
| 毒性信号 | 是否有垃圾链接 |
| 锚文本分布 | 是否自然多样 |
### 5. 用户体验诊断
| 诊断项 | 说明 |
|--------|------|
| 移动适配 | 移动端显示是否正常 |
| 页面速度 | 加载时间是否达标 |
| 转化路径 | 用户操作是否顺畅 |
---
## GEO诊断
GEO诊断评估品牌在AI生成式引擎中的被引用能力包含以下维度
### 1. 内容可提取性诊断
AI需要能够轻松提取和理解内容
| 诊断项 | 说明 | 优先级 |
|--------|------|--------|
| 直接回答块 | 页面首段是否有简洁明确的答案 | P0 |
| 问答式标题 | H2/H3是否采用问题形式 | P0 |
| 列表和表格 | 是否使用结构化数据展示 | P0 |
| 内链到子意图页 | 是否链接到相关深度内容 | P1 |
| 内容新鲜度 | 是否有更新日期和作者信息 | P1 |
### 2. 实体清晰度诊断
AI需要能够理解品牌是什么
| 诊断项 | 说明 | 验证标准 |
|--------|------|----------|
| 品牌定义 | 是否清晰说明品牌做什么 | AI理解准确率≥95% |
| 目标受众 | 是否明确服务谁 | 实体识别准确率≥90% |
| 差异化价值 | 为什么选择这个品牌 | 独特性评分≥80 |
| 行业分类 | 品牌属于什么行业 | 分类准确率≥95% |
### 3. E-E-A-T信号诊断
AI需要验证品牌的可信度
| 诊断项 | 说明 | 验证标准 |
|--------|------|----------|
| 作者资质 | 内容作者是否有专业背景 | 作者简介完整度≥90% |
| 专业认证 | 是否有行业认证/奖项 | 认证展示率≥80% |
| 数据来源 | 是否引用可靠数据 | 引用权威源≥70% |
| 专家背书 | 是否有行业专家认可 | 背书数量≥3 |
### 4. Schema标记诊断
结构化数据帮助AI理解内容
| Schema类型 | 适用场景 | 优先级 | 实施难度 |
|-----------|---------|--------|---------|
| Organization | 企业主页 | P0必须 | ⭐ 简单 |
| Product | 产品页 | P0必须 | ⭐⭐ 中等 |
| Article/BlogPosting | 博客文章 | P0必须 | ⭐ 简单 |
| FAQPage | 常见问题 | P1推荐 | ⭐ 简单 |
| HowTo | 操作指南 | P1推荐 | ⭐⭐ 中等 |
| BreadcrumbList | 导航结构 | P1推荐 | ⭐ 简单 |
| Review/Rating | 评价评分 | P2可选 | ⭐⭐ 中等 |
### 5. 主题权威诊断
AI需要验证品牌在特定领域的权威性
| 诊断项 | 说明 | 验证标准 |
|--------|------|----------|
| 内容深度 | 是否全面覆盖主题 | 内容质量QScore≥4.6/5 |
| 话题覆盖度 | 是否覆盖相关子话题 | 话题覆盖率≥80% |
| 实体信号一致性 | 各页面实体信号是否一致 | 一致性评分≥85% |
| 内链网络 | 是否形成主题内容集群 | 集群完整度≥70% |
### 6. 引用就绪度诊断
评估品牌在AI回答中被引用的可能性
| 诊断项 | 说明 | 验证标准 |
|--------|------|----------|
| 引用频率 | 品牌在AI回答中被提及的频率 | AORAnswer Ownership Rate≥50% |
| 引用质量 | 引用内容是否准确完整 | 引用准确率≥90% |
| AI声量占比 | 品牌在AI回答中的占比 | AI SOV≥30% |
| 竞品对比 | 与竞品在AI回答中的表现 | 差距≤10pp |
---
## 诊断流程
### SEO诊断流程
```
1. 网站爬取 → 2. 技术分析 → 3. 内容分析 → 4. 外链分析 → 5. 生成报告
```
### GEO诊断流程
```
1. 品牌信息输入 → 2. 内容可提取性检测 → 3. 实体清晰度检测 → 4. E-E-A-T信号检测
→ 5. Schema标记检测 → 6. 主题权威检测 → 7. AI平台引用检测 → 8. 生成诊断报告
```
---
## 诊断报告输出
### SEO诊断报告
包含以下内容:
- 技术SEO评分
- 页面SEO评分
- 内容质量评分
- 外链质量评分
- 用户体验评分
- 综合评分
- 优先修复建议
### GEO诊断报告
包含以下内容:
- 内容可提取性评分
- 实体清晰度评分
- E-E-A-T信号评分
- Schema标记完整性
- 主题权威评分
- AI平台引用率
- 综合评分
- 优先优化建议
---
## 技术实现
### API端点
| 端点 | 方法 | 描述 | 状态 |
|------|------|------|------|
| `GET /api/v1/diagnosis/seo/{brand_id}` | GET | 获取品牌的SEO诊断结果 | ✅ 已完成 |
| `GET /api/v1/diagnosis/geo/{brand_id}` | GET | 获取品牌的GEO诊断结果 | ✅ 已完成 |
| `GET /api/v1/diagnosis/combined/{brand_id}` | GET | 获取品牌的SEO+GEO综合诊断结果 | ✅ 已完成 |
### API响应示例
#### SEO诊断响应
```json
{
"success": true,
"data": {
"brand_id": "brand_123",
"diagnosis_type": "seo",
"overall_score": 78,
"dimensions": {
"technical_seo": {
"score": 85,
"status": "good",
"issues": []
},
"page_seo": {
"score": 72,
"status": "needs_improvement",
"issues": [
{
"type": "missing_meta_description",
"severity": "medium",
"description": "部分页面缺少Meta描述",
"affected_pages": 12
}
]
},
"content_quality": {
"score": 80,
"status": "good",
"issues": []
},
"backlinks": {
"score": 65,
"status": "needs_improvement",
"issues": []
},
"user_experience": {
"score": 88,
"status": "excellent",
"issues": []
}
},
"recommendations": [
{
"priority": "high",
"action": "补充缺失的Meta描述",
"impact": "提升页面相关性评分"
}
],
"diagnosed_at": "2024-01-15T10:30:00Z"
}
}
```
#### GEO诊断响应
```json
{
"success": true,
"data": {
"brand_id": "brand_123",
"diagnosis_type": "geo",
"overall_score": 72,
"dimensions": {
"content_extractability": {
"score": 75,
"status": "good",
"metrics": {
"direct_answer_blocks": 85,
"qa_headings": 70,
"structured_data": 80
}
},
"entity_clarity": {
"score": 80,
"status": "good",
"metrics": {
"brand_definition_accuracy": 95,
"target_audience_clarity": 85,
"differentiation_score": 75
}
},
"eeat_signals": {
"score": 68,
"status": "needs_improvement",
"metrics": {
"author_credentials": 60,
"certifications": 70,
"data_sources": 75,
"expert_endorsements": 50
}
},
"schema_markup": {
"score": 70,
"status": "good",
"coverage": {
"organization": true,
"product": true,
"article": true,
"faq": false,
"howto": false
}
},
"topic_authority": {
"score": 74,
"status": "good",
"metrics": {
"content_depth": 80,
"topic_coverage": 72,
"entity_consistency": 85,
"internal_linking": 65
}
},
"citation_readiness": {
"score": 65,
"status": "needs_improvement",
"metrics": {
"citation_frequency": 60,
"citation_accuracy": 75,
"ai_sov": 55,
"competitor_gap": -8
}
}
},
"recommendations": [
{
"priority": "high",
"action": "增强E-E-A-T信号",
"impact": "提升AI对品牌可信度的评估"
}
],
"diagnosed_at": "2024-01-15T10:30:00Z"
}
}
```
#### 综合诊断响应
```json
{
"success": true,
"data": {
"brand_id": "brand_123",
"diagnosis_type": "combined",
"seo_score": 78,
"geo_score": 72,
"combined_score": 75,
"seo_summary": {
"strengths": ["技术SEO基础良好", "用户体验优秀"],
"weaknesses": ["外链质量需提升", "部分页面Meta描述缺失"]
},
"geo_summary": {
"strengths": ["实体清晰度高", "内容结构化良好"],
"weaknesses": ["E-E-A-T信号不足", "AI引用率偏低"]
},
"priority_actions": [
{
"type": "seo",
"action": "建设高质量外链",
"expected_impact": "提升域名权威性"
},
{
"type": "geo",
"action": "添加专家背书和认证",
"expected_impact": "提升AI引用率"
}
],
"diagnosed_at": "2024-01-15T10:30:00Z"
}
}
```
### 后端实现
| 组件 | 文件 | 职责 |
|------|------|------|
| SEO诊断服务 | `backend/app/services/seo_diagnosis.py` | 执行SEO诊断分析 |
| GEO诊断服务 | `backend/app/services/geo_diagnosis.py` | 执行GEO诊断分析 |
| 引用检测Agent | `backend/app/agent_framework/agents/citation_detector.py` | 检测AI平台引用情况 |
| 诊断报告生成 | `backend/app/services/diagnosis_report.py` | 生成诊断报告 |
| 诊断路由 | `backend/app/api/routes/diagnosis.py` | API端点定义 |
### 前端实现
| 页面 | 路径 | 说明 | 状态 |
|------|------|------|------|
| SEO诊断 | `/dashboard/seo-diagnosis` | SEO诊断入口和报告展示 | ✅ 已完成 |
| GEO诊断 | `/dashboard/geo-diagnosis` | GEO诊断入口和报告展示 | ✅ 已完成 |
| 综合诊断 | `/dashboard/diagnosis` | SEO+GEO综合诊断报告 | ✅ 已完成 |
前端实现包括:
- 诊断任务触发和进度展示
- 多维度评分可视化(雷达图、进度条)
- 问题列表和修复建议
- 历史诊断记录对比
- 诊断报告导出功能
---
## 改进建议
### 当前问题
| 问题 | 说明 | 优先级 | 状态 |
|------|------|--------|------|
| 诊断定义不完整 | 仅实现引用检测缺少完整GEO诊断 | P0 | ✅ 已解决 |
| SEO诊断缺失 | 未实现传统SEO诊断能力 | P1 | ✅ 已解决 |
| 前后端断裂 | 诊断页面是占位页面 | P0 | ✅ 已解决 |
### 改进计划
| 阶段 | 目标 | 时间 | 状态 |
|------|------|------|------|
| Phase 1 | 完善GEO诊断补充6大维度 | 2周 | ✅ 已完成 |
| Phase 2 | 新增SEO诊断能力 | 2周 | ✅ 已完成 |
| Phase 3 | 整合SEO+GEO诊断报告 | 1周 | ✅ 已完成 |
| Phase 4 | 前端诊断页面实现 | 1周 | ✅ 已完成 |
### 后续优化
| 阶段 | 目标 | 预计时间 |
|------|------|----------|
| Phase 5 | 诊断报告导出PDF/Excel | 1周 |
| Phase 6 | 历史诊断趋势分析 | 2周 |
| Phase 7 | 竞品诊断对比功能 | 2周 |

View File

@ -0,0 +1,242 @@
# 知识图谱实体类型定义
## 概述
本文档描述知识图谱中使用的实体类型和关系类型。
## 实体类型 (EntityType)
位置:`backend/app/models/knowledge_graph.py`
### 类型定义
| 类型 | 枚举值 | 说明 | 示例 |
|------|--------|------|------|
| ORGANIZATION | organization | 公司/组织 | 腾讯、阿里巴巴 |
| PRODUCT | product | 产品 | 微信、王者荣耀 |
| PERSON | person | 人物 | 马化腾、张小龙 |
| LOCATION | location | 地点 | 深圳、广州 |
| TECHNOLOGY | technology | 技术 | AI、区块链 |
| BRAND | brand | 品牌 | 微信支付、腾讯云 |
| EVENT | event | 事件 | 2020腾讯年会 |
| CONCEPT | concept | 概念 | 数字化转型 |
### 属性说明
```python
class Entity:
id: str # UUID
name: str # 实体名称
entity_type: str # 实体类型
description: str # 描述
properties: dict # 自定义属性
confidence: str # 置信度 (high/medium/low)
source_chunk_id: str # 来源Chunk ID
```
## 关系类型 (RelationType)
位置:`backend/app/models/knowledge_graph.py`
### 类型定义
| 类型 | 枚举值 | 说明 | 示例 |
|------|--------|------|------|
| COMPETES_WITH | competes_with | 竞争对手 | 微信 ↔ 支付宝 |
| PARTNERS_WITH | partners_with | 合作伙伴 | 腾讯 ↔ 京东 |
| PRODUCES | produces | 生产 | 苹果 ↔ iPhone |
| USES_TECHNOLOGY | uses_technology | 使用技术 | 微信 ↔ AI |
| LOCATED_IN | located_in | 位于 | 腾讯 ↔ 深圳 |
| FOUNDED_IN | founded_in | 成立于 | 腾讯 ↔ 1998 |
| CEO_OF | ceo_of | CEO | 马化腾 ↔ 腾讯 |
| FOUNDER_OF | founder_of | 创始人 | 马化腾 ↔ 腾讯 |
| RELATED_TO | related_to | 相关 | AI ↔ 机器学习 |
| PART_OF | part_of | 属于 | iPhone ↔ 苹果 |
### 属性说明
```python
class Relation:
id: str # UUID
source_entity_id: str # 源实体ID
target_entity_id: str # 目标实体ID
relation_type: str # 关系类型
properties: dict # 自定义属性
confidence: str # 置信度 (high/medium/low)
source_chunk_id: str # 来源Chunk ID
```
## 抽取流程
位置:`backend/app/services/knowledge/entity_extractor.py`
### EntityExtractor
```python
class EntityExtractor:
ENTITY_TYPES = [
"ORGANIZATION",
"PRODUCT",
"PERSON",
"LOCATION",
"TECHNOLOGY",
"BRAND",
"EVENT",
"CONCEPT",
]
RELATION_TYPES = [
"COMPETES_WITH",
"PARTNERS_WITH",
"PRODUCES",
"USES_TECHNOLOGY",
"LOCATED_IN",
"FOUNDED_IN",
"CEO_OF",
"FOUNDER_OF",
"RELATED_TO",
"PART_OF",
]
async def extract(text: str, context: str) -> ExtractionResult:
"""
从文本中抽取实体和关系
1. 构建抽取Prompt
2. 调用LLM
3. 解析返回结果
"""
```
### 抽取Prompt示例
```
从以下文本中抽取知识图谱的实体和关系。
实体类型:
- ORGANIZATION (公司/组织)
- PRODUCT (产品)
- PERSON (人物)
...
关系类型:
- COMPETES_WITH (竞争对手)
- PARTNERS_WITH (合作伙伴)
...
文本内容:
{text}
请以JSON格式返回结果
{
"entities": [
{"name": "实体名称", "entity_type": "类型", "confidence": "high/medium/low"}
],
"relations": [
{"source_entity": "源", "target_entity": "目标", "relation_type": "类型", "confidence": "high/medium/low"}
]
}
```
## 图谱构建
位置:`backend/app/services/knowledge/graph_builder.py`
### GraphBuilder
```python
class GraphBuilder:
async def build_from_chunk(
session: AsyncSession,
chunk_id: str,
context: str = None,
) -> dict:
"""
从Chunk构建知识图谱
1. 获取Chunk内容
2. 调用EntityExtractor抽取
3. 存储到图谱
"""
```
### 构建统计
返回统计信息:
```python
{
"entities_created": 5, # 新建实体数
"entities_existing": 2, # 已存在实体数
"relations_created": 3, # 新建关系数
"relations_existing": 1, # 已存在关系数
}
```
## 图谱查询
### 实体查询
```python
# 查询实体
GET /api/v1/knowledge-graph/entities?type=brand&limit=20
# 响应
{
"entities": [
{
"id": "uuid",
"name": "微信",
"type": "brand",
"description": "...",
"properties": {...}
}
]
}
```
### 关系查询
```python
# 查询关系
GET /api/v1/knowledge-graph/relations?source_id=xxx&type=competes_with
# 响应
{
"relations": [
{
"id": "uuid",
"source_id": "xxx",
"target_id": "yyy",
"type": "competes_with",
"properties": {...}
}
]
}
```
### 语义搜索
```python
# 语义搜索
GET /api/v1/knowledge-graph/search?query=微信的竞争对手
# 响应
{
"results": [
{"entity": {...}, "score": 0.95},
{"entity": {...}, "score": 0.87}
]
}
```
## 置信度评估
| 级别 | 说明 | 使用场景 |
|------|------|----------|
| high | 高置信度 | 直接使用 |
| medium | 中置信度 | 建议人工审核 |
| low | 低置信度 | 仅供参考 |
评估因素:
- 文本中提及的明确性
- 上下文的支持程度
- LLM模型的确定性

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,50 @@
# 知识库
## 概述
GEO平台的知识库系统基于RAGRetrieval-Augmented Generation架构为内容生成提供领域知识支持。
## 系统架构
```
文档上传 → 文本分块 → 向量化 → RAG检索 → LLM增强生成
```
## 核心功能
### 文档管理
- 支持多种文档格式PDF、TXT、Markdown、DOCX
- 文档版本管理
- 分类和标签
### 文本分块
- 智能分块策略
- 重叠窗口机制
- 元数据保留
### 向量化
- 支持多种嵌入模型
- 向量数据库存储
- 相似度检索
## API接口
知识库相关API位于 `backend/app/api/knowledge.py`
### 主要端点
| 方法 | 路径 | 说明 |
|------|------|------|
| GET | /api/v1/knowledge/bases | 获取知识库列表 |
| POST | /api/v1/knowledge/bases | 创建知识库 |
| POST | /api/v1/knowledge/bases/{id}/documents | 上传文档 |
| GET | /api/v1/knowledge/search | 搜索知识 |
## 配置
环境变量:
- `KNOWLEDGE_EMBEDDING_MODEL` - 嵌入模型
- `KNOWLEDGE_VECTOR_DB` - 向量数据库类型

View File

@ -0,0 +1,58 @@
# 知识图谱
## 概述
知识图谱模块用于构建和管理品牌相关的实体关系图谱,支持语义搜索和智能推理。
## 核心功能
### 实体管理
- 品牌实体
- 产品实体
- 竞品实体
- 行业概念实体
### 关系构建
- 品牌-产品关系
- 品牌-竞品关系
- 产品-行业关系
- 概念上下位关系
## API接口
知识图谱API位于 `backend/app/api/knowledge_graph.py`
### 主要端点
| 方法 | 路径 | 说明 |
|------|------|------|
| GET | /api/v1/knowledge-graph/entities | 获取实体列表 |
| POST | /api/v1/knowledge-graph/entities | 创建实体 |
| GET | /api/v1/knowledge-graph/relations | 获取关系列表 |
| POST | /api/v1/knowledge-graph/relations | 创建关系 |
| GET | /api/v1/knowledge-graph/search | 语义搜索 |
## 数据模型
### Entity (实体)
```python
class Entity:
id: str
name: str
type: str # brand/product/competitor/concept
properties: dict
```
### Relation (关系)
```python
class Relation:
id: str
source_id: str
target_id: str
relation_type: str # produces/competes_with/belongs_to
properties: dict
```

View File

@ -0,0 +1,208 @@
# 监控指标定义
## 概述
本文档详细描述监控系统的指标定义和LLM成本追踪。
## Prometheus指标
位置:`backend/app/monitoring/metrics.py`
### API层指标
| 指标名 | 类型 | 标签 | 说明 |
|--------|------|------|------|
| geo_api_requests_total | Counter | method, endpoint, status | HTTP请求总数 |
| geo_api_request_duration_seconds | Histogram | method, endpoint | 请求延迟分布 |
| geo_api_requests_in_progress | Gauge | method, endpoint | 当前处理中的请求数 |
### Agent层指标
| 指标名 | 类型 | 标签 | 说明 |
|--------|------|------|------|
| geo_agent_executions_total | Counter | agent_name, status | Agent执行总数 |
| geo_agent_execution_duration_seconds | Histogram | agent_name | Agent执行耗时 |
| geo_agent_running_tasks | Gauge | agent_name | 当前运行的任务数 |
### LLM层指标
| 指标名 | 类型 | 标签 | 说明 |
|--------|------|------|------|
| geo_llm_requests_total | Counter | provider, model, status | LLM请求总数 |
| geo_llm_request_duration_seconds | Histogram | provider, model | LLM请求耗时 |
| geo_llm_tokens_total | Counter | provider, model, token_type | Token消耗总量 |
| geo_llm_cost_estimated | Gauge | provider, model | 预估成本(USD) |
### 业务层指标
| 指标名 | 类型 | 标签 | 说明 |
|--------|------|------|------|
| geo_brands_total | Gauge | - | 品牌总数 |
| geo_queries_total | Counter | platform, status | 查询总数 |
| geo_content_generated_total | Counter | - | 生成内容总数 |
| geo_citations_detected_total | Counter | platform | 引用检测总数 |
## LLM成本追踪
位置:`backend/app/monitoring/llm_metrics.py`
### 成本估算表
```python
LLM_COST_PER_TOKEN = {
# OpenAI
("openai", "gpt-4o"): {
"prompt": 0.000005, # $5/1M tokens
"completion": 0.000015, # $15/1M tokens
},
("openai", "gpt-4o-mini"): {
"prompt": 0.00000015, # $0.15/1M tokens
"completion": 0.0000006, # $0.60/1M tokens
},
("openai", "gpt-4-turbo"): {
"prompt": 0.00001, # $10/1M tokens
"completion": 0.00003, # $30/1M tokens
},
# DeepSeek
("deepseek", "deepseek-chat"): {
"prompt": 0.00000014, # $0.14/1M tokens
"completion": 0.00000028, # $0.28/1M tokens
},
("deepseek", "deepseek-coder"): {
"prompt": 0.00000014,
"completion": 0.00000028,
},
}
```
### 成本计算公式
```
总成本 = prompt_tokens * prompt_price + completion_tokens * completion_price
```
### LLMMetricsWrapper
```python
class LLMMetricsWrapper:
def record_request(
self,
status: str,
duration: float,
prompt_tokens: int = None,
completion_tokens: int = None,
):
"""
记录LLM请求指标
1. 记录请求数和耗时
2. 记录Token消耗
3. 估算并记录成本
"""
```
## Histogram buckets
### API延迟
```python
buckets=(0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0)
# 单位:秒
# P50/P90/P99 估算
```
### Agent执行耗时
```python
buckets=(0.1, 0.5, 1.0, 5.0, 10.0, 30.0, 60.0, 120.0)
# 单位:秒
```
### LLM请求耗时
```python
buckets=(0.1, 0.5, 1.0, 2.0, 5.0, 10.0, 30.0, 60.0)
# 单位:秒
```
## 健康检查
位置:`backend/app/services/health_checker.py`
### 检查端点
| 路径 | 说明 | 检查项 |
|------|------|--------|
| GET /health | 综合健康检查 | 所有检查项 |
| GET /health/ready | 就绪检查 | 数据库、Redis |
| GET /health/live | 存活检查 | 服务运行状态 |
### 检查项
| 检查项 | 说明 | 超时 |
|--------|------|------|
| database | 数据库连接 | 5s |
| redis | Redis连接 | 5s |
| disk | 磁盘空间 | - |
| memory | 内存使用率 | - |
## 告警规则
### 告警条件
| 规则 | 条件 | 级别 | 说明 |
|------|------|------|------|
| API响应超时 | p99 > 5s | Warning | 99分位响应时间超过5秒 |
| API错误率高 | error_rate > 5% | Error | 错误率超过5% |
| 队列积压 | queue_depth > 1000 | Warning | 队列深度超过1000 |
| Agent离线 | heartbeat_timeout | Critical | 心跳超时 |
### 告警级别
| 级别 | 说明 | 通知方式 |
|------|------|----------|
| Warning | 警告 | 日志 |
| Error | 错误 | 日志+邮件 |
| Critical | 严重 | 日志+邮件+短信 |
## 监控集成
### Prometheus端点
```
GET /metrics
```
返回Prometheus格式的指标数据。
### Grafana仪表板
推荐配置的仪表板:
1. **API Dashboard**
- QPS
- 延迟分布 (P50/P90/P99)
- 错误率
2. **Agent Dashboard**
- 各Agent执行次数
- 执行耗时
- 成功率
3. **LLM Dashboard**
- 请求数
- Token消耗
- 成本趋势
4. **业务 Dashboard**
- 品牌数量
- 内容生成量
- 引用检测量
## 指标收集流程
```
1. API请求 → 中间件记录 (method, endpoint, status, duration)
2. Agent执行 → AgentHooks记录 (agent_name, status, duration)
3. LLM调用 → LLMMetricsWrapper记录 (provider, model, tokens, cost)
4. 后台任务 → 定时汇总写入数据库
```

View File

@ -0,0 +1,64 @@
# 监控模块
## 概述
监控系统用于追踪系统运行状态、收集性能指标、检测异常并发送告警。
## 监控范围
### 系统监控
- CPU、内存、磁盘使用率
- 网络流量
- 进程状态
### 应用监控
- API响应时间
- 请求成功率
- 错误率
- 并发连接数
### 业务监控
- Agent运行状态
- 任务队列深度
- 知识库检索延迟
## 监控指标
位置:`backend/app/monitoring/`
| 指标 | 类型 | 说明 |
|------|------|------|
| http_requests_total | Counter | HTTP请求总数 |
| http_request_duration_seconds | Histogram | 请求延迟分布 |
| agent_tasks_total | Counter | Agent任务总数 |
| agent_task_duration_seconds | Histogram | 任务执行时间 |
| queue_depth | Gauge | 队列深度 |
## 健康检查
### 端点
| 路径 | 说明 |
|------|------|
| GET /health | 服务健康检查 |
| GET /health/ready | 就绪检查 |
| GET /health/live | 存活检查 |
### 检查项
- 数据库连接
- Redis连接
- 磁盘空间
- 内存使用率
## 告警规则
| 规则 | 条件 | 级别 |
|------|------|------|
| API响应超时 | p99 > 5s | Warning |
| API错误率高 | error_rate > 5% | Error |
| 队列积压 | queue_depth > 1000 | Warning |
| Agent离线 | heartbeat_timeout | Critical |

View File

@ -0,0 +1,207 @@
# 内容生成Pipeline编排逻辑
## 概述
本文档描述内容生成Pipeline的完整编排逻辑。
## Pipeline架构
```
用户请求 → PipelineEngine.execute()
├─ 1. 加载Pipeline定义 (YAML)
├─ 2. 构建执行上下文
├─ 3. 拓扑排序确定执行顺序
└─ 4. 逐阶段执行
├─ Stage 1 (无依赖)
├─ Stage 2 (依赖Stage 1)
├─ Stage 3 (依赖Stage 1, 2)
└─ ...
```
## Pipeline定义结构
Pipeline通过YAML文件定义位于 `backend/pipelines/` 目录:
```yaml
name: content_production
version: "1.0"
description: 内容生产Pipeline
variables:
brand_name: ""
keywords: ""
platform: "zhihu"
stages:
- name: topic_selection
agent: content_generator
action: select_topic
inputs:
brand: "${brand_name}"
keywords: "${keywords}"
outputs: [selected_topic, topic_score]
- name: content_generation
agent: content_generator
action: generate
depends_on: [topic_selection]
inputs:
topic: "${stages.topic_selection.outputs.selected_topic}"
brand: "${brand_name}"
outputs: [content]
- name: deai_processing
agent: deai_agent
action: humanize
depends_on: [content_generation]
inputs:
content: "${stages.content_generation.outputs.content}"
outputs: [natural_content]
- name: seo_optimization
agent: geo_optimizer
action: optimize
depends_on: [deai_processing]
inputs:
content: "${stages.deai_processing.outputs.natural_content}"
keywords: "${keywords}"
outputs: [optimized_content]
```
## 核心组件
### PipelineEngine
位置:`backend/app/agent_framework/pipeline/engine.py`
| 方法 | 说明 |
|------|------|
| execute(pipeline, context) | 执行完整Pipeline |
| _execute_stage(stage, exec_context, stages_context) | 执行单个阶段 |
| _dispatch_and_wait(stage, inputs) | 分发任务并等待结果 |
| _topological_sort(stages) | 拓扑排序 |
| _resolve_variables(template, context) | 解析变量引用 |
### PipelineLoader
位置:`backend/app/agent_framework/pipeline/loader.py`
| 方法 | 说明 |
|------|------|
| load(pipeline_name) | 从YAML加载Pipeline |
| load_from_yaml(yaml_content) | 从字符串加载 |
| validate_dag(stages) | 验证DAG无环 |
| resolve_variables(template, context) | 解析${...}变量 |
### PipelineSchema
位置:`backend/app/agent_framework/pipeline/schema.py`
| 类 | 说明 |
|----|------|
| Pipeline | Pipeline定义 |
| PipelineStage | 单个阶段定义 |
| StageResult | 阶段执行结果 |
| PipelineResult | Pipeline执行结果 |
| StageStatus | 阶段状态枚举 |
## 变量解析
支持 `${var.path}` 格式的变量引用:
| 格式 | 说明 |
|------|------|
| `${brand_name}` | 全局变量 |
| `${stages.step1.outputs.result}` | 上游阶段输出 |
| `${stages.step1.outputs.result.substring(0,10)}` | 字符串截取 |
## 条件执行
Stage支持 `condition` 字段进行条件执行:
```yaml
- name: optional_stage
agent: some_agent
action: do_something
condition: "${enable_optional} == true"
```
支持的比较操作:
- `${var}` - 变量存在且非空
- `${var} == 'value'` - 等于
- `${var} != 'value'` - 不等于
## 重试机制
每个Stage支持 `retry_count` 配置:
```yaml
- name: unstable_stage
agent: some_agent
retry_count: 3 # 失败后重试3次
```
## 错误处理
| 配置 | 说明 |
|------|------|
| continue_on_failure: false | 阶段失败则Pipeline失败默认 |
| continue_on_failure: true | 阶段失败仍继续下游 |
## 内容生成服务层Pipeline
除了Agent Pipeline还有服务层的内容处理Pipeline
位置:`backend/app/services/content/content_pipeline.py`
```
用户内容 → RuleValidator (规则校验)
├─ 高严重问题 → 中断Pipeline
└─ 通过 → SensitiveFilter (敏感词过滤)
└─ SEOOptimizer (SEO优化)
└─ HTMLGenerator (HTML生成)
└─ 输出 (html/markdown/plain)
```
### ContentPipeline类
```python
class ContentPipeline:
async def run(self, request: dict) -> PipelineResponse:
"""
request = {
"content": "原始内容",
"title": "标题",
"platform": "目标平台",
"optimize_for": ["validation", "sensitive", "seo"],
"output_formats": ["html", "markdown", "plain"]
}
"""
```
### 阶段说明
| 阶段 | 组件 | 说明 |
|------|------|------|
| validation | RuleValidator | 校验标题长度、内容长度、AI模式检测 |
| sensitive_filter | SensitiveFilter | 敏感词过滤替换 |
| seo_optimization | SEOOptimizer | 关键词密度调整、位置检查 |
| html_generation | HTMLGenerator | 生成HTML/ Markdown/纯文本 |
## Dry-Run模式
当PipelineEngine初始化时未传入dispatcher进入dry-run模式
```python
# Dry-run模式测试/开发)
engine = PipelineEngine(dispatcher=None)
result = await engine.execute(pipeline, context)
# 返回模拟结果用于测试Pipeline定义
```
生产环境若触发dry-run会记录ERROR级别日志。

View File

@ -0,0 +1,44 @@
# 平台规则中心
## 概述
平台规则中心管理各AI平台的内容规则和发布要求确保生成的内容符合各平台标准。
## 功能特性
### 规则管理
- 规则分类(关键词、长度、格式、敏感词)
- 规则版本控制
- 规则优先级
### 规则检查
- 实时检查
- 批量检查
- 检查结果反馈
## API接口
平台规则API位于 `backend/app/api/platform_rules.py`
### 主要端点
| 方法 | 路径 | 说明 |
|------|------|------|
| GET | /api/v1/platform-rules | 获取规则列表 |
| POST | /api/v1/platform-rules | 创建规则 |
| PUT | /api/v1/platform-rules/{id} | 更新规则 |
| DELETE | /api/v1/platform-rules/{id} | 删除规则 |
| POST | /api/v1/platform-rules/check | 检查内容 |
### 检查维度
| 维度 | 说明 |
|------|------|
| keyword_coverage | 关键词覆盖度 |
| readability | 可读性评分 |
| tone_consistency | 语气一致性 |
| length_compliance | 长度合规 |
| fact_accuracy | 事实准确性 |
| geo_optimization | GEO优化度 |

Some files were not shown because too many files have changed in this diff Show More