diff --git a/.codegraph/.gitignore b/.codegraph/.gitignore new file mode 100644 index 0000000..9de0f16 --- /dev/null +++ b/.codegraph/.gitignore @@ -0,0 +1,16 @@ +# CodeGraph data files +# These are local to each machine and should not be committed + +# Database +*.db +*.db-wal +*.db-shm + +# Cache +cache/ + +# Logs +*.log + +# Hook markers +.dirty diff --git a/backend/.env.example b/backend/.env.example index 2eee706..27e27db 100644 --- a/backend/.env.example +++ b/backend/.env.example @@ -36,6 +36,12 @@ ZHIPU_API_KEY= # 通义千问 (可选) TONGYI_API_KEY= +# ============================================================ +# 阿里云百炼(图片生成) +# ============================================================ +# 万相-文生图V1 API Key +ALIYUN_DASHSCOPE_API_KEY= + # ============================================================ # LLM Provider 配置 # ============================================================ diff --git a/backend/.env.test b/backend/.env.test new file mode 100644 index 0000000..84c1de5 --- /dev/null +++ b/backend/.env.test @@ -0,0 +1,6 @@ +DATABASE_URL=sqlite+aiosqlite:///./test.db +REDIS_URL=redis://localhost:6379 +ENVIRONMENT=testing +LOG_LEVEL=info +SECRET_KEY=test-secret-key-for-testing-only +CORS_ORIGINS=http://localhost:3000 diff --git a/backend/alembic/versions/f7a8b9c0de56_add_knowledge_graph_tables.py b/backend/alembic/versions/f7a8b9c0de56_add_knowledge_graph_tables.py new file mode 100644 index 0000000..f9161c7 --- /dev/null +++ b/backend/alembic/versions/f7a8b9c0de56_add_knowledge_graph_tables.py @@ -0,0 +1,143 @@ +"""Add knowledge graph tables + +Revision ID: f7a8b9c0de56 +Revises: e5f7g9h1cd45 +Create Date: 2026-05-24 12:00:00.000000 + +""" +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql +import enum + +# revision identifiers, used by Alembic. +revision: str = "f7a8b9c0de56" +down_revision: Union[str, None] = "e5f7g9h1cd45" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ------------------------------------------------------------------ # + # 1. 创建实体类型枚举 + # ------------------------------------------------------------------ # + entity_type_enum = enum.Enum( + 'EntityType', + [ + 'ORGANIZATION', 'PRODUCT', 'PERSON', 'LOCATION', + 'TECHNOLOGY', 'BRAND', 'EVENT', 'CONCEPT', 'OTHER' + ] + ) + + # ------------------------------------------------------------------ # + # 2. 创建关系类型枚举 + # ------------------------------------------------------------------ # + relation_type_enum = enum.Enum( + 'RelationType', + [ + 'COMPETES_WITH', 'PARTNERS_WITH', 'ACQUIRES', 'SUBSIDIARY_OF', + 'PRODUCES', 'USES_TECHNOLOGY', 'PART_OF', + 'LOCATED_IN', 'FOUNDED_IN', + 'CEO_OF', 'FOUNDER_OF', + 'RELATED_TO', 'MENTIONED_IN', 'ALSO_KNOWN_AS' + ] + ) + + # ------------------------------------------------------------------ # + # 3. knowledge_entities 表 + # ------------------------------------------------------------------ # + op.create_table( + "knowledge_entities", + sa.Column( + "id", + postgresql.UUID(as_uuid=True), + primary_key=True, + nullable=False, + ), + sa.Column( + "knowledge_base_id", + postgresql.UUID(as_uuid=True), + sa.ForeignKey("knowledge_bases.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("name", sa.String(500), nullable=False), + sa.Column("entity_type", sa.Enum(entity_type_enum, name="entitytype"), nullable=False), + sa.Column("description", sa.Text, nullable=True), + sa.Column("properties", postgresql.JSONB(astext_type=sa.Text()), nullable=True, server_default="{}"), + sa.Column( + "source_chunk_id", + postgresql.UUID(as_uuid=True), + sa.ForeignKey("knowledge_chunks.id", ondelete="SET NULL"), + nullable=True, + ), + sa.Column("confidence", sa.String(20), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + ) + op.create_index("ix_entities_name", "knowledge_entities", ["name"]) + op.create_index("ix_entities_kb_name", "knowledge_entities", ["knowledge_base_id", "name"]) + op.create_index("ix_entities_kb_type", "knowledge_entities", ["knowledge_base_id", "entity_type"]) + + # ------------------------------------------------------------------ # + # 4. knowledge_relations 表 + # ------------------------------------------------------------------ # + op.create_table( + "knowledge_relations", + sa.Column( + "id", + postgresql.UUID(as_uuid=True), + primary_key=True, + nullable=False, + ), + sa.Column( + "source_entity_id", + postgresql.UUID(as_uuid=True), + sa.ForeignKey("knowledge_entities.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column( + "target_entity_id", + postgresql.UUID(as_uuid=True), + sa.ForeignKey("knowledge_entities.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("relation_type", sa.Enum(relation_type_enum, name="relationtype"), nullable=False), + sa.Column("properties", postgresql.JSONB(astext_type=sa.Text()), nullable=True, server_default="{}"), + sa.Column( + "source_chunk_id", + postgresql.UUID(as_uuid=True), + sa.ForeignKey("knowledge_chunks.id", ondelete="SET NULL"), + nullable=True, + ), + sa.Column("confidence", sa.String(20), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + ) + op.create_index("ix_relations_source", "knowledge_relations", ["source_entity_id"]) + op.create_index("ix_relations_target", "knowledge_relations", ["target_entity_id"]) + op.create_index("ix_relations_type", "knowledge_relations", ["relation_type"]) + + +def downgrade() -> None: + # 删除表(注意外键约束会自动处理) + op.drop_table("knowledge_relations") + op.drop_table("knowledge_entities") + # 删除枚举类型 + op.execute("DROP TYPE IF EXISTS relationtype") + op.execute("DROP TYPE IF EXISTS entitytype") diff --git a/backend/app/api/alerts.py b/backend/app/api/alerts.py index 9e65071..c5fa5a1 100644 --- a/backend/app/api/alerts.py +++ b/backend/app/api/alerts.py @@ -1,4 +1,5 @@ """Alerts API endpoints - 告警通知接口""" +import logging import uuid from fastapi import APIRouter, Depends, HTTPException, Query, status @@ -9,6 +10,7 @@ from app.api.deps import get_current_user from app.database import get_db from app.models.alert import Alert from app.models.alert_setting import AlertSetting +from app.models.brand import Brand from app.models.user import User from app.schemas.alert import ( AlertResponse, @@ -24,9 +26,35 @@ from app.schemas.alert import ( ) from app.services.alert_engine import AlertEngine +logger = logging.getLogger(__name__) + router = APIRouter() +async def verify_brand_ownership( + brand_id: uuid.UUID, + current_user: User, + db: AsyncSession, +) -> Brand: + """验证品牌属于当前用户""" + stmt = select(Brand).where( + and_( + Brand.id == brand_id, + Brand.user_id == current_user.id, + ) + ) + result = await db.execute(stmt) + brand = result.scalar_one_or_none() + + if not brand: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"品牌 {brand_id} 不存在或不属于当前用户", + ) + + return brand + + # ============================================================ # 告警接口 # ============================================================ @@ -207,21 +235,7 @@ async def update_alert_settings( for item in data.settings: # 验证品牌属于当前用户 - from app.models.brand import Brand - brand_stmt = select(Brand).where( - and_( - Brand.id == item.brand_id, - Brand.user_id == current_user.id, - ) - ) - brand_result = await db.execute(brand_stmt) - brand = brand_result.scalar_one_or_none() - - if not brand: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"品牌 {item.brand_id} 不存在或不属于当前用户", - ) + await verify_brand_ownership(item.brand_id, current_user, db) # 查找现有设置 existing_stmt = select(AlertSetting).where( @@ -257,6 +271,7 @@ async def update_alert_settings( for setting in updated_settings: await db.refresh(setting) + logger.info(f"批量更新告警设置: user={current_user.id}, count={len(updated_settings)}") return {"items": updated_settings, "total": len(updated_settings)} @@ -292,3 +307,74 @@ async def update_single_setting( await db.refresh(setting) return setting + + +@router.post("/settings", response_model=AlertSettingResponse, status_code=status.HTTP_201_CREATED) +async def create_alert_setting( + data: AlertSettingCreate, + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """创建告警设置""" + # 验证品牌属于当前用户 + await verify_brand_ownership(data.brand_id, current_user, db) + + # 检查是否已存在相同品牌+类型的设置 + existing_stmt = select(AlertSetting).where( + and_( + AlertSetting.brand_id == data.brand_id, + AlertSetting.alert_type == data.alert_type, + ) + ) + existing_result = await db.execute(existing_stmt) + existing = existing_result.scalar_one_or_none() + + if existing: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"品牌 {data.brand_id} 的告警类型 {data.alert_type} 已存在", + ) + + # 创建新设置 + setting = AlertSetting( + brand_id=data.brand_id, + user_id=current_user.id, + alert_type=data.alert_type, + enabled=data.enabled, + threshold=data.threshold, + ) + db.add(setting) + await db.commit() + await db.refresh(setting) + + logger.info(f"创建告警设置: user={current_user.id}, brand={data.brand_id}, type={data.alert_type}") + return setting + + +@router.delete("/settings/{setting_id}", status_code=status.HTTP_204_NO_CONTENT) +async def delete_alert_setting( + setting_id: uuid.UUID, + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """删除告警设置""" + stmt = select(AlertSetting).where( + and_( + AlertSetting.id == setting_id, + AlertSetting.user_id == current_user.id, + ) + ) + result = await db.execute(stmt) + setting = result.scalar_one_or_none() + + if not setting: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="告警设置不存在", + ) + + await db.delete(setting) + await db.commit() + + logger.info(f"删除告警设置: user={current_user.id}, setting={setting_id}") + return None diff --git a/backend/app/api/content.py b/backend/app/api/content.py index 4011f1b..afda162 100644 --- a/backend/app/api/content.py +++ b/backend/app/api/content.py @@ -237,4 +237,165 @@ async def generate_topics( return {"status": "success", "topics": topics} except LLMError as e: - raise HTTPException(status_code=502, detail=str(e)) \ No newline at end of file + raise HTTPException(status_code=502, detail=str(e)) + + +# ==================== 母题库接口 ==================== + +class TopicGenerateRequest(BaseModel): + """母题生成请求""" + params: dict # 母题模板参数 + platform: str = "通用" + style: str = "专业严谨" + + +@router.get("/topics") +async def list_topics(): + """获取所有母题库列表""" + from app.services.content.topic_templates import list_topic_templates + templates = list_topic_templates() + return [ + { + "id": t.id, + "name": t.name, + "description": t.description, + "icon": t.icon, + "recommended_platforms": t.recommended_platforms, + "word_count_range": list(t.word_count_range), + "required_params": t.required_params, + "optional_params": t.optional_params, + } + for t in templates + ] + + +@router.get("/topics/{topic_id}") +async def get_topic(topic_id: str): + """获取母题详情""" + from app.services.content.topic_templates import get_topic_template + template = get_topic_template(topic_id) + if not template: + raise HTTPException(status_code=404, detail="Topic not found") + return { + "id": template.id, + "name": template.name, + "description": template.description, + "icon": template.icon, + "prompt_template": template.prompt_template, + "seo_tips": template.seo_tips, + "recommended_platforms": template.recommended_platforms, + "word_count_range": list(template.word_count_range), + "required_params": template.required_params, + "optional_params": template.optional_params, + } + + +@router.post("/topics/{topic_id}/generate") +async def generate_with_topic( + topic_id: str, + request: TopicGenerateRequest, + db: AsyncSession = Depends(get_db), + current_user: User = Depends(get_current_user), +): + """使用母题生成内容""" + from app.services.content.topic_templates import get_topic_template, render_topic_prompt + from app.services.llm import LLMError, LLMFactory + from app.agent_framework.prompts import DEAI_TEMPLATE, GEO_OPTIMIZER_TEMPLATE + + template = get_topic_template(topic_id) + if not template: + raise HTTPException(status_code=404, detail="Topic not found") + + # 验证必填参数 + for param in template.required_params: + if param not in request.params: + raise HTTPException( + status_code=400, + detail=f"Missing required parameter: {param}" + ) + + org_id = getattr(current_user, "organization_id", None) + if not org_id: + raise HTTPException(status_code=403, detail="用户未关联组织") + + try: + provider = LLMFactory.get_default() + + # 渲染Prompt + prompt = render_topic_prompt(topic_id, request.params) + + # 调用内容生成 + response = await provider.chat( + [{"role": "user", "content": prompt}], + temperature=0.7, + max_tokens=4000 + ) + content = response.content + + # 去AI化处理 + deai_variables = { + "original_content": content, + "target_style": "自然流畅", + "preserve_structure": "是", + } + messages = DEAI_TEMPLATE.render(deai_variables) + response = await provider.chat(messages, temperature=0.9, max_tokens=len(content) * 2) + content = response.content + + # GEO优化 + geo_variables = { + "original_content": content, + "target_keywords": request.params.get("keywords", ""), + "target_platform": request.platform, + "optimization_level": "moderate", + } + messages = GEO_OPTIMIZER_TEMPLATE.render(geo_variables) + response = await provider.chat(messages, temperature=0.5, max_tokens=len(content) * 2) + optimized = response.content + + # 存入数据库 + content_obj = Content( + organization_id=org_id, + title=request.params.get("product_name") or request.params.get("topic") or topic_id, + content_type="article", + body=optimized, + status="draft", + target_platforms=[request.platform], + keywords=[request.params.get("keywords", "")], + extra_metadata={ + "original_content": content, + "topic_id": topic_id, + "topic_name": template.name, + "brand_name": request.params.get("brand_name", ""), + "content_style": request.style, + }, + created_by=current_user.id, + current_version=1, + ) + db.add(content_obj) + await db.flush() + + version = ContentVersion( + content_id=content_obj.id, + version_number=1, + title=content_obj.title, + body=optimized, + change_summary="母题库自动生成", + created_by=current_user.id, + ) + db.add(version) + await db.commit() + await db.refresh(content_obj) + + return { + "topic_id": topic_id, + "content": content, + "optimized_content": optimized, + "content_id": str(content_obj.id), + "seo_tips": template.seo_tips, + } + + except LLMError as e: + raise HTTPException(status_code=502, detail=f"LLM调用失败: {str(e)}") + except Exception as e: + raise HTTPException(status_code=500, detail=f"内容生成异常: {str(e)}") \ No newline at end of file diff --git a/backend/app/api/diagnosis.py b/backend/app/api/diagnosis.py new file mode 100644 index 0000000..2a6af11 --- /dev/null +++ b/backend/app/api/diagnosis.py @@ -0,0 +1,150 @@ +"""诊断API端点 - 提供SEO和GEO诊断功能""" +import logging +import uuid +from typing import Annotated + +from fastapi import APIRouter, Depends, HTTPException, status +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.api.deps import get_current_user +from app.database import get_db +from app.models.user import User +from app.models.brand import Brand +from app.services.seo_diagnosis import SEODiagnosisService +from app.services.geo_diagnosis import GEODiagnosisService, GEODiagnosisInput + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +@router.get("/seo/{brand_id}") +async def get_seo_diagnosis( + brand_id: uuid.UUID, + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """ + 获取品牌的SEO诊断结果 + + 返回5维度SEO诊断: + - 技术SEO (25分) + - 页面SEO (20分) + - 内容质量 (20分) + - 外链分析 (15分) + - 用户体验 (20分) + """ + brand = await _get_brand_or_404(brand_id, current_user, db) + + try: + service = SEODiagnosisService() + result = service.diagnose() + + logger.info(f"SEO诊断完成: brand_id={brand_id}, brand={brand.name}, score={result.overall_score}") + + return result.to_dict() + except Exception as e: + logger.error(f"SEO诊断失败: brand_id={brand_id}, error={str(e)}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="SEO诊断服务异常,请稍后重试", + ) + + +@router.get("/geo/{brand_id}") +async def get_geo_diagnosis( + brand_id: uuid.UUID, + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """ + 获取品牌的GEO诊断结果 + + 返回6维度GEO诊断: + - 内容可提取性 (20分) + - 实体清晰度 (15分) + - E-E-A-T信号 (20分) + - Schema标记 (15分) + - 主题权威 (15分) + - 引用就绪度 (15分) + """ + brand = await _get_brand_or_404(brand_id, current_user, db) + + try: + input_data = GEODiagnosisInput() + service = GEODiagnosisService() + result = service.diagnose(input_data) + + logger.info(f"GEO诊断完成: brand_id={brand_id}, brand={brand.name}, score={result.overall_score}") + + return result.to_dict() + except Exception as e: + logger.error(f"GEO诊断失败: brand_id={brand_id}, error={str(e)}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="GEO诊断服务异常,请稍后重试", + ) + + +@router.get("/combined/{brand_id}") +async def get_combined_diagnosis( + brand_id: uuid.UUID, + current_user: User = Depends(get_current_user), + db: AsyncSession = Depends(get_db), +): + """ + 获取品牌的综合诊断结果 + + 结合SEO和GEO诊断,返回综合评分和详细诊断结果 + """ + brand = await _get_brand_or_404(brand_id, current_user, db) + + try: + seo_service = SEODiagnosisService() + seo_result = seo_service.diagnose() + + geo_service = GEODiagnosisService() + geo_result = geo_service.diagnose(GEODiagnosisInput()) + + combined_score = round((seo_result.overall_score + geo_result.overall_score) / 2, 2) + + logger.info( + f"综合诊断完成: brand_id={brand_id}, brand={brand.name}, " + f"seo_score={seo_result.overall_score}, " + f"geo_score={geo_result.overall_score}, " + f"combined_score={combined_score}" + ) + + return { + "seo_score": seo_result.overall_score, + "geo_score": geo_result.overall_score, + "combined_score": combined_score, + "seo_diagnosis": seo_result.to_dict(), + "geo_diagnosis": geo_result.to_dict(), + } + except Exception as e: + logger.error(f"综合诊断失败: brand_id={brand_id}, error={str(e)}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="综合诊断服务异常,请稍后重试", + ) + + +async def _get_brand_or_404( + brand_id: uuid.UUID, + current_user: User, + db: AsyncSession, +) -> Brand: + """获取品牌或抛出404异常""" + stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == current_user.id) + result = await db.execute(stmt) + brand = result.scalar_one_or_none() + + if not brand: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="品牌不存在", + ) + + return brand diff --git a/backend/app/api/image.py b/backend/app/api/image.py new file mode 100644 index 0000000..05a24b7 --- /dev/null +++ b/backend/app/api/image.py @@ -0,0 +1,182 @@ +"""图片生成 API 路由""" + +from typing import Optional + +from fastapi import APIRouter, HTTPException, status +from pydantic import BaseModel, Field + +from app.services.image_generator import ( + IMAGE_STYLES, + LAYOUT_OPTIONS, + ImageGenerator, + ImageGenerationError, + PLATFORM_IMAGE_SPECS, +) + +router = APIRouter(prefix="/image", tags=["图片生成"]) + + +class GenerateCoverRequest(BaseModel): + """生成封面图请求""" + title: str = Field(..., description="文章标题") + platform: str = Field(..., description="目标平台") + image_type: str = Field(default="cover", description="图片类型: cover/inline") + style: str = Field(default="modern", description="风格选项") + layout: str = Field(default="centered", description="排版选项") + custom_prompt: Optional[str] = Field(default=None, description="自定义提示词") + + +class ImageResultResponse(BaseModel): + """图片生成结果响应""" + url: str + width: int + height: int + prompt: str + platform: str + task_id: str + + +class ImageSpecsResponse(BaseModel): + """平台图片规格响应""" + platform: str + specs: dict + + +class StyleOption(BaseModel): + """风格选项""" + value: str + name: str + + +class LayoutOption(BaseModel): + """排版选项""" + value: str + name: str + + +class ImageConfigResponse(BaseModel): + """图片生成配置响应""" + platforms: list[str] + styles: list[StyleOption] + layouts: list[LayoutOption] + + +@router.post("/generate-cover", response_model=ImageResultResponse) +async def generate_cover(request: GenerateCoverRequest): + """生成封面图 + + 基于阿里云百炼(万相-文生图V1)生成封面图,自动适配目标平台的尺寸规格。 + + Args: + request: 生成请求参数 + + Returns: + ImageResultResponse: 包含图片URL和元数据 + + Raises: + HTTPException: 当平台不支持或API调用失败时 + """ + # 验证平台是否支持 + supported_platforms = ImageGenerator.get_supported_platforms() + if request.platform not in supported_platforms: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"不支持的平台: {request.platform},支持的平台: {', '.join(supported_platforms)}", + ) + + # 验证风格选项 + if request.style not in IMAGE_STYLES: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"不支持的风格: {request.style},支持的风格: {', '.join(IMAGE_STYLES.keys())}", + ) + + # 验证排版选项 + if request.layout not in LAYOUT_OPTIONS: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"不支持的排版: {request.layout},支持的排版: {', '.join(LAYOUT_OPTIONS.keys())}", + ) + + # 验证图片类型 + if request.image_type not in ("cover", "inline"): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="图片类型只支持: cover, inline", + ) + + generator = ImageGenerator() + + try: + result = await generator.generate_cover( + title=request.title, + platform=request.platform, + image_type=request.image_type, + style=request.style, + layout=request.layout, + custom_prompt=request.custom_prompt, + ) + + return ImageResultResponse( + url=result.url, + width=result.width, + height=result.height, + prompt=result.prompt, + platform=result.platform, + task_id=result.task_id, + ) + + except ImageGenerationError as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"图片生成失败: {str(e)}", + ) + + +@router.get("/platforms", response_model=list[str]) +async def get_supported_platforms(): + """获取支持的平台列表""" + return ImageGenerator.get_supported_platforms() + + +@router.get("/platforms/{platform}/specs", response_model=ImageSpecsResponse) +async def get_platform_specs(platform: str): + """获取指定平台的图片规格 + + Args: + platform: 平台标识 + + Returns: + ImageSpecsResponse: 平台图片规格 + + Raises: + HTTPException: 当平台不支持时 + """ + specs = ImageGenerator.get_platform_specs(platform) + if specs is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"不支持的平台: {platform}", + ) + + return ImageSpecsResponse(platform=platform, specs=specs) + + +@router.get("/config", response_model=ImageConfigResponse) +async def get_image_config(): + """获取图片生成配置(风格、排版选项等)""" + styles = [ + StyleOption(value=key, name=value["name"]) + for key, value in IMAGE_STYLES.items() + ] + + layouts = [ + LayoutOption(value=key, name=value["name"]) + for key, value in LAYOUT_OPTIONS.items() + ] + + return ImageConfigResponse( + platforms=ImageGenerator.get_supported_platforms(), + styles=styles, + layouts=layouts, + ) \ No newline at end of file diff --git a/backend/app/api/knowledge.py b/backend/app/api/knowledge.py index bceafd0..2361f80 100644 --- a/backend/app/api/knowledge.py +++ b/backend/app/api/knowledge.py @@ -29,10 +29,15 @@ from app.schemas.knowledge import ( KnowledgeBaseCreate, KnowledgeBaseResponse, KnowledgeSearchRequest, + RetrieveRequest, SearchResponse, SearchResultItem, + UpdateDocumentRequest, ) from app.services.knowledge import MockEmbedder, RAGService +from app.services.knowledge.enhanced_rag import EnhancedRAG +from app.services.knowledge.incremental_index import IncrementalIndexService +from app.services.knowledge.chunker import ChunkerFactory logger = logging.getLogger(__name__) router = APIRouter() @@ -499,3 +504,151 @@ async def knowledge_search( total=len(items), latency_ms=latency_ms, ) + + +@router.post("/bases/{kb_id}/chunks/preview") +async def preview_chunks( + kb_id: uuid.UUID, + text: str, + strategy: str = "recursive", + chunk_size: int = 500, +): + """预览分块效果""" + chunker = ChunkerFactory.create(strategy) + + # 临时修改chunk_size + if strategy == "recursive": + chunker.STRATEGY.chunk_size = chunk_size + elif strategy == "semantic": + chunker.STRATEGY.chunk_size = chunk_size * 1.5 # 语义块可以更大 + elif strategy == "fixed": + chunker.STRATEGY.chunk_size = chunk_size + + preview = chunker.preview(text, max_chunks=10) + + return { + "strategy": strategy, + "chunk_count": len(preview), + "preview": preview, + "strategies": [ + { + "name": s.name, + "description": s.description, + "recommended_size": s.chunk_size, + } + for s in ChunkerFactory.list_strategies() + ], + } + + +# --------------------------------------------------------------------------- +# 增量索引 API +# --------------------------------------------------------------------------- + +@router.post("/bases/{kb_id}/documents/{doc_id}/reindex") +async def reindex_document( + kb_id: uuid.UUID, + doc_id: uuid.UUID, + db: AsyncSession = Depends(get_db), + current_user: User = Depends(get_current_user), +): + """重新索引单个文档""" + org_id = current_user.organization_id + if not org_id: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found") + + await _get_kb(db, kb_id, org_id) + + index_service = IncrementalIndexService(_rag_service) + result = await index_service.add_document( + db, str(kb_id), str(doc_id) + ) + return result + + +@router.post("/bases/{kb_id}/documents/{doc_id}/update") +async def update_document_content( + kb_id: uuid.UUID, + doc_id: uuid.UUID, + request: UpdateDocumentRequest, + db: AsyncSession = Depends(get_db), + current_user: User = Depends(get_current_user), +): + """更新文档内容(增量)""" + org_id = current_user.organization_id + if not org_id: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found") + + await _get_kb(db, kb_id, org_id) + + index_service = IncrementalIndexService(_rag_service) + result = await index_service.update_document( + db, str(doc_id), request.content + ) + return result + + +@router.delete("/bases/{kb_id}/documents/{doc_id}") +async def delete_document_incremental( + kb_id: uuid.UUID, + doc_id: uuid.UUID, + db: AsyncSession = Depends(get_db), + current_user: User = Depends(get_current_user), +): + """删除文档""" + org_id = current_user.organization_id + if not org_id: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found") + + await _get_kb(db, kb_id, org_id) + + index_service = IncrementalIndexService(_rag_service) + result = await index_service.delete_document(db, str(doc_id)) + return result + + +@router.post("/bases/{kb_id}/rebuild") +async def rebuild_knowledge_base( + kb_id: uuid.UUID, + force: bool = False, + db: AsyncSession = Depends(get_db), + current_user: User = Depends(get_current_user), +): + """重建知识库索引""" + org_id = current_user.organization_id + if not org_id: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found") + + await _get_kb(db, kb_id, org_id) + + index_service = IncrementalIndexService(_rag_service) + result = await index_service.rebuild_knowledge_base( + db, str(kb_id), force + ) + return result + + +@router.post("/bases/{kb_id}/retrieve") +async def enhanced_retrieve( + kb_id: uuid.UUID, + request: RetrieveRequest, + db: AsyncSession = Depends(get_db), + current_user: User = Depends(get_current_user), +): + """增强检索(支持重排序和压缩)""" + org_id = current_user.organization_id + if not org_id: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found") + + await _get_kb(db, kb_id, org_id) + + enhanced_rag = EnhancedRAG(_rag_service, _rag_service.embedder) + results = await enhanced_rag.retrieve_with_rerank( + db, + request.query, + [str(kb_id)], + top_k=request.top_k or 5, + use_rerank=request.use_rerank, + use_compression=request.use_compression, + ) + return {"results": results, "query": request.query} diff --git a/backend/app/api/knowledge_graph.py b/backend/app/api/knowledge_graph.py new file mode 100644 index 0000000..006585f --- /dev/null +++ b/backend/app/api/knowledge_graph.py @@ -0,0 +1,115 @@ +"""知识图谱API""" +from typing import Optional +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy.ext.asyncio import AsyncSession + +from app.api.deps import get_db, get_current_user +from app.models.user import User +from app.services.knowledge.graph_builder import GraphBuilder +from app.services.knowledge.graph_query import GraphQuery + +router = APIRouter(prefix="/knowledge-bases", tags=["知识图谱"]) + + +@router.post("/{kb_id}/graph/build") +async def build_graph( + kb_id: UUID, + db: AsyncSession = Depends(get_db), + current_user: User = Depends(get_current_user), +): + """ + 从知识库构建知识图谱 + + 对知识库中的所有Chunks执行实体和关系抽取 + """ + # TODO: 实现批量构建 + # 目前先实现单个Chunk的构建 + return {"message": "Use /graph/build-chunk to build from specific chunk"} + + +@router.post("/{kb_id}/graph/build-chunk/{chunk_id}") +async def build_graph_from_chunk( + kb_id: UUID, + chunk_id: UUID, + context: Optional[str] = None, + db: AsyncSession = Depends(get_db), + current_user: User = Depends(get_current_user), +): + """从单个Chunk构建图谱""" + builder = GraphBuilder() + + try: + stats = await builder.build_from_chunk(db, str(chunk_id), context) + return { + "status": "success", + "stats": stats, + } + except ValueError as e: + raise HTTPException(status_code=404, detail=str(e)) + + +@router.get("/{kb_id}/graph/statistics") +async def get_graph_statistics( + kb_id: UUID, + db: AsyncSession = Depends(get_db), + current_user: User = Depends(get_current_user), +): + """获取图谱统计信息""" + query = GraphQuery() + stats = await query.get_statistics(db, str(kb_id)) + return stats + + +@router.get("/{kb_id}/graph/entities/search") +async def search_entities( + kb_id: UUID, + q: str, + entity_type: Optional[str] = None, + limit: int = 20, + db: AsyncSession = Depends(get_db), + current_user: User = Depends(get_current_user), +): + """搜索实体""" + query = GraphQuery() + entities = await query.search_entities( + db, str(kb_id), q, entity_type, limit + ) + return {"entities": entities} + + +@router.get("/{kb_id}/graph/entities/{entity_id}") +async def get_entity( + kb_id: UUID, + entity_id: UUID, + db: AsyncSession = Depends(get_db), + current_user: User = Depends(get_current_user), +): + """获取实体详情""" + query = GraphQuery() + entity = await query.get_entity(db, str(entity_id)) + + if not entity: + raise HTTPException(status_code=404, detail="Entity not found") + + # 获取邻居 + neighbors = await query.get_entity_neighbors(db, str(entity_id)) + entity["neighbors"] = neighbors + + return entity + + +@router.get("/{kb_id}/graph/path") +async def find_path( + kb_id: UUID, + source: str, + target: str, + max_hops: int = 3, + db: AsyncSession = Depends(get_db), + current_user: User = Depends(get_current_user), +): + """查找两个实体之间的路径""" + query = GraphQuery() + path = await query.get_entity_path(db, source, target, max_hops) + return {"path": path, "hops": len(path)} diff --git a/backend/app/database.py b/backend/app/database.py index 0adc019..e5e28ff 100644 --- a/backend/app/database.py +++ b/backend/app/database.py @@ -1,9 +1,23 @@ from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession from sqlalchemy.orm import declarative_base -from sqlalchemy import text +from sqlalchemy import text, JSON +from sqlalchemy.types import TypeDecorator +from sqlalchemy.dialects.postgresql import JSONB from app.config import settings + +class JSONType(TypeDecorator): + """A JSON type that uses JSONB on PostgreSQL and JSON on other databases.""" + impl = JSON + cache_ok = True + + def load_dialect_impl(self, dialect): + if dialect.name == "postgresql": + return dialect.type_descriptor(JSONB()) + return dialect.type_descriptor(JSON()) + + engine = create_async_engine( settings.DATABASE_URL, pool_size=10, # 连接池大小 diff --git a/backend/app/main.py b/backend/app/main.py index d6f9942..304bbe8 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -4,7 +4,8 @@ from datetime import datetime, timezone from fastapi import FastAPI, HTTPException, Request, Depends from fastapi.exceptions import RequestValidationError -from fastapi.responses import JSONResponse +from fastapi.responses import JSONResponse, Response +from prometheus_client import generate_latest, CONTENT_TYPE_LATEST from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import text @@ -31,9 +32,12 @@ from app.api.subscriptions import router as subscription_router from app.api.alerts import router as alerts_router from app.api.dashboard import router as dashboard_router from app.api.brands import router as brands_router +from app.api.diagnosis import router as diagnosis_router from app.api.onboarding import router as onboarding_router from app.api.platforms import router as platforms_router from app.api.platform_rules import router as platform_rules_router +from app.api.image import router as image_router +from app.api.knowledge_graph import router as knowledge_graph_router from app.config import settings from app.database import engine, Base from app.schemas.common import ErrorResponse, ErrorCode @@ -41,6 +45,7 @@ from app.middleware.rate_limit import RateLimitMiddleware from app.middleware.logging_middleware import RequestLoggingMiddleware from app.middleware.request_id import RequestIdMiddleware from app.middleware.metrics import MetricsMiddleware +from app.monitoring.middleware import MonitoringMiddleware from app.database import get_db from app.workers.scheduler import query_scheduler @@ -49,6 +54,9 @@ from app.workers.scheduler import query_scheduler async def lifespan(app: FastAPI): import app.models + # 初始化监控模块 + import app.monitoring + async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) @@ -131,6 +139,7 @@ async def add_security_headers(request, call_next): app.add_middleware(RequestLoggingMiddleware) app.add_middleware(RateLimitMiddleware) app.add_middleware(MetricsMiddleware) +app.add_middleware(MonitoringMiddleware) app.add_middleware(RequestIdMiddleware) app.include_router(auth_router, prefix="/api/v1/auth", tags=["认证"]) @@ -150,9 +159,12 @@ app.include_router(analytics_router, prefix="/api/v1/analytics", tags=["监测 app.include_router(alerts_router, prefix="/api/v1/alerts", tags=["告警通知"]) app.include_router(dashboard_router, prefix="/api/v1/dashboard", tags=["仪表盘"]) app.include_router(brands_router, prefix="/api/v1/brands", tags=["品牌管理"]) +app.include_router(diagnosis_router, prefix="/api/v1/diagnosis", tags=["诊断服务"]) app.include_router(onboarding_router, prefix="/api/v1") app.include_router(platforms_router, prefix="/api/v1") app.include_router(platform_rules_router) +app.include_router(image_router, prefix="/api/v1") +app.include_router(knowledge_graph_router, prefix="/api/v1/knowledge-bases") @app.get("/health", tags=["可观测性"]) @@ -203,3 +215,90 @@ async def readiness_check(db: AsyncSession = Depends(get_db)): "timestamp": datetime.now(timezone.utc).isoformat(), }, ) + + +@app.get("/metrics", tags=["可观测性"]) +async def metrics(): + """Prometheus指标端点""" + return Response( + content=generate_latest(), + media_type=CONTENT_TYPE_LATEST + ) + + +# ---- 详细健康检查端点 ---- +from app.services.health_checker import HealthChecker +from app.services.app_state import app_state + + +@app.get("/health/detailed", tags=["可观测性"]) +async def detailed_health( + db: AsyncSession = Depends(get_db), +): + """ + 详细健康检查 + + 返回所有依赖组件的健康状态: + - database: 数据库连接 + - redis: Redis缓存 + - llm_providers: LLM服务提供商 + - storage: 文件存储 + + 状态: + - healthy: 所有组件正常 + - degraded: 部分组件异常,但仍可服务 + - unhealthy: 核心组件异常 + """ + checker = HealthChecker(db, settings.REDIS_URL) + health_result = await checker.check_all() + + # 添加应用信息 + health_result["app"] = app_state.get_info() + + return health_result + + +@app.get("/health/liveness", tags=["可观测性"]) +async def liveness(): + """ + 存活探针 + + 用于Kubernetes livenessProbe + 只要应用进程存活就返回200 + """ + return {"status": "alive"} + + +@app.get("/health/readiness", tags=["可观测性"]) +async def readiness(db: AsyncSession = Depends(get_db)): + """ + 就绪探针 + + 用于Kubernetes readinessProbe + 检查核心依赖是否就绪 + """ + checker = HealthChecker(db, settings.REDIS_URL) + + # 只检查核心依赖:数据库和Redis + db_result = await checker.check_database() + redis_result = await checker.check_redis() + + if db_result.healthy and redis_result.healthy: + return { + "status": "ready", + "checks": { + "database": db_result.healthy, + "redis": redis_result.healthy, + } + } + else: + raise HTTPException( + status_code=503, + detail={ + "status": "not_ready", + "checks": { + "database": db_result.healthy, + "redis": redis_result.healthy, + } + } + ) diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index 392fa78..e4395be 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -15,6 +15,12 @@ from app.models.knowledge import ( KnowledgeChunk, KnowledgeSearchLog, ) +from app.models.knowledge_graph import ( + KnowledgeEntity, + KnowledgeRelation, + EntityType, + RelationType, +) from app.models.analytics import PublishRecord, ContentMetrics, OptimizationInsight from app.models.distribution import DistributionSchedule # 缺失的模型导入 - 重构后遗留 @@ -48,6 +54,10 @@ __all__ = [ "KnowledgeDocument", "KnowledgeChunk", "KnowledgeSearchLog", + "KnowledgeEntity", + "KnowledgeRelation", + "EntityType", + "RelationType", "PublishRecord", "ContentMetrics", "OptimizationInsight", diff --git a/backend/app/models/agent.py b/backend/app/models/agent.py index 41f8dd9..90fcacb 100644 --- a/backend/app/models/agent.py +++ b/backend/app/models/agent.py @@ -3,10 +3,9 @@ from datetime import datetime from sqlalchemy import String, Integer, ForeignKey, Index, func, Text from sqlalchemy import Uuid -from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import Mapped, mapped_column, relationship -from app.database import Base +from app.database import Base, JSONType class AgentRegistry(Base): @@ -24,7 +23,7 @@ class AgentRegistry(Base): version: Mapped[str | None] = mapped_column(String(20), nullable=True) endpoint: Mapped[str | None] = mapped_column(String(500), nullable=True) status: Mapped[str] = mapped_column(String(20), server_default="offline", nullable=False) - capabilities: Mapped[dict | None] = mapped_column(JSONB, nullable=True) + capabilities: Mapped[dict | None] = mapped_column(JSONType, nullable=True) last_heartbeat: Mapped[datetime | None] = mapped_column(nullable=True) created_at: Mapped[datetime] = mapped_column( server_default=func.now(), @@ -68,7 +67,7 @@ class AgentConfig(Base): nullable=False, ) config_key: Mapped[str] = mapped_column(String(100), nullable=False) - config_value: Mapped[dict] = mapped_column(JSONB, nullable=False) + config_value: Mapped[dict] = mapped_column(JSONType, nullable=False) description: Mapped[str | None] = mapped_column(String(500), nullable=True) updated_at: Mapped[datetime] = mapped_column( server_default=func.now(), @@ -111,8 +110,8 @@ class AgentTask(Base): task_type: Mapped[str] = mapped_column(String(50), nullable=False) status: Mapped[str] = mapped_column(String(20), server_default="pending", nullable=False) priority: Mapped[int] = mapped_column(Integer, server_default="0", nullable=False) - input_data: Mapped[dict | None] = mapped_column(JSONB, nullable=True) - output_data: Mapped[dict | None] = mapped_column(JSONB, nullable=True) + input_data: Mapped[dict | None] = mapped_column(JSONType, nullable=True) + output_data: Mapped[dict | None] = mapped_column(JSONType, nullable=True) error_message: Mapped[str | None] = mapped_column(Text, nullable=True) created_by: Mapped[uuid.UUID | None] = mapped_column( Uuid(as_uuid=True), @@ -184,7 +183,7 @@ class AgentTaskLog(Base): ) log_level: Mapped[str] = mapped_column(String(10), nullable=False) message: Mapped[str] = mapped_column(Text, nullable=False) - extra_metadata: Mapped[dict | None] = mapped_column("metadata", JSONB, nullable=True) + extra_metadata: Mapped[dict | None] = mapped_column("metadata", JSONType, nullable=True) created_at: Mapped[datetime] = mapped_column( server_default=func.now(), nullable=False, diff --git a/backend/app/models/brand_knowledge.py b/backend/app/models/brand_knowledge.py index 21fc9da..be988e7 100644 --- a/backend/app/models/brand_knowledge.py +++ b/backend/app/models/brand_knowledge.py @@ -3,10 +3,9 @@ from datetime import datetime from sqlalchemy import String, Integer, Boolean, ForeignKey, Index, func, Text from sqlalchemy import Uuid -from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import Mapped, mapped_column, relationship -from app.database import Base +from app.database import Base, JSONType class BrandKnowledge(Base): @@ -25,7 +24,7 @@ class BrandKnowledge(Base): category: Mapped[str] = mapped_column(String(50), nullable=False) title: Mapped[str] = mapped_column(String(200), nullable=False) content: Mapped[str] = mapped_column(Text, nullable=False) - extra_metadata: Mapped[dict | None] = mapped_column("metadata", JSONB, nullable=True) + extra_metadata: Mapped[dict | None] = mapped_column("metadata", JSONType, nullable=True) is_active: Mapped[bool] = mapped_column(Boolean, server_default="true", nullable=False) created_by: Mapped[uuid.UUID | None] = mapped_column( Uuid(as_uuid=True), diff --git a/backend/app/models/content.py b/backend/app/models/content.py index efd88ca..2fc4b8e 100644 --- a/backend/app/models/content.py +++ b/backend/app/models/content.py @@ -3,10 +3,9 @@ from datetime import datetime from sqlalchemy import String, Integer, ForeignKey, Index, func, Text from sqlalchemy import Uuid -from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import Mapped, mapped_column, relationship -from app.database import Base +from app.database import Base, JSONType class Content(Base): @@ -31,9 +30,9 @@ class Content(Base): content_type: Mapped[str] = mapped_column(String(50), nullable=False) body: Mapped[str | None] = mapped_column(Text, nullable=True) status: Mapped[str] = mapped_column(String(20), server_default="draft", nullable=False) - target_platforms: Mapped[list | None] = mapped_column(JSONB, nullable=True) - keywords: Mapped[list | None] = mapped_column(JSONB, nullable=True) - extra_metadata: Mapped[dict | None] = mapped_column("metadata", JSONB, nullable=True) + target_platforms: Mapped[list | None] = mapped_column(JSONType, nullable=True) + keywords: Mapped[list | None] = mapped_column(JSONType, nullable=True) + extra_metadata: Mapped[dict | None] = mapped_column("metadata", JSONType, nullable=True) created_by: Mapped[uuid.UUID | None] = mapped_column( Uuid(as_uuid=True), ForeignKey("users.id", ondelete="SET NULL"), diff --git a/backend/app/models/distribution.py b/backend/app/models/distribution.py index 9d1b297..3aa7c06 100644 --- a/backend/app/models/distribution.py +++ b/backend/app/models/distribution.py @@ -4,10 +4,9 @@ from datetime import datetime from sqlalchemy import String, Integer, ForeignKey, Index, func, Text from sqlalchemy import Uuid -from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import Mapped, mapped_column, relationship -from app.database import Base +from app.database import Base, JSONType class DistributionSchedule(Base): @@ -29,9 +28,9 @@ class DistributionSchedule(Base): ForeignKey("contents.id", ondelete="SET NULL"), nullable=True, ) - platforms: Mapped[list | None] = mapped_column(JSONB, nullable=True) + platforms: Mapped[list | None] = mapped_column(JSONType, nullable=True) """[{platform, platform_name, scheduled_time, status}]""" - tips: Mapped[list | None] = mapped_column(JSONB, nullable=True) + tips: Mapped[list | None] = mapped_column(JSONType, nullable=True) status: Mapped[str] = mapped_column(String(20), server_default="pending", nullable=False) created_by: Mapped[uuid.UUID | None] = mapped_column( Uuid(as_uuid=True), diff --git a/backend/app/models/knowledge.py b/backend/app/models/knowledge.py index 90d2481..dbb0e83 100644 --- a/backend/app/models/knowledge.py +++ b/backend/app/models/knowledge.py @@ -3,10 +3,9 @@ from datetime import datetime from sqlalchemy import String, Integer, ForeignKey, Index, func, Text from sqlalchemy import Uuid -from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import Mapped, mapped_column, relationship -from app.database import Base +from app.database import Base, JSONType # pgvector Vector type - imported conditionally try: @@ -92,7 +91,7 @@ class KnowledgeDocument(Base): status: Mapped[str] = mapped_column(String(20), server_default="processing", nullable=False) # "processing" / "ready" / "failed" error_message: Mapped[str | None] = mapped_column(Text, nullable=True) # mapped_column("metadata") to avoid SQLAlchemy reserved keyword conflict - extra_metadata: Mapped[dict | None] = mapped_column("metadata", JSONB, nullable=True) + extra_metadata: Mapped[dict | None] = mapped_column("metadata", JSONType, nullable=True) created_at: Mapped[datetime] = mapped_column( server_default=func.now(), nullable=False, @@ -152,7 +151,7 @@ class KnowledgeChunk(Base): chunk_index: Mapped[int] = mapped_column(Integer, nullable=False) token_count: Mapped[int] = mapped_column(Integer, server_default="0", nullable=False) # mapped_column("metadata") to avoid SQLAlchemy reserved keyword conflict - extra_metadata: Mapped[dict | None] = mapped_column("metadata", JSONB, nullable=True) + extra_metadata: Mapped[dict | None] = mapped_column("metadata", JSONType, nullable=True) created_at: Mapped[datetime] = mapped_column( server_default=func.now(), nullable=False, @@ -189,7 +188,7 @@ class KnowledgeSearchLog(Base): nullable=True, ) query: Mapped[str] = mapped_column(Text, nullable=False) - knowledge_base_ids: Mapped[list | None] = mapped_column(JSONB, nullable=True) + knowledge_base_ids: Mapped[list | None] = mapped_column(JSONType, nullable=True) results_count: Mapped[int] = mapped_column(Integer, server_default="0", nullable=False) latency_ms: Mapped[int] = mapped_column(Integer, server_default="0", nullable=False) created_at: Mapped[datetime] = mapped_column( diff --git a/backend/app/models/knowledge_graph.py b/backend/app/models/knowledge_graph.py new file mode 100644 index 0000000..fb8550d --- /dev/null +++ b/backend/app/models/knowledge_graph.py @@ -0,0 +1,175 @@ +"""知识图谱数据模型""" +import uuid +from datetime import datetime +import enum + +from sqlalchemy import String, Text, DateTime, ForeignKey, Index, Enum, func +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from app.database import Base, JSONType + + +class EntityType(str, enum.Enum): + """实体类型""" + ORGANIZATION = "ORGANIZATION" # 组织/公司 + PRODUCT = "PRODUCT" # 产品 + PERSON = "PERSON" # 人物 + LOCATION = "LOCATION" # 地点 + TECHNOLOGY = "TECHNOLOGY" # 技术 + BRAND = "BRAND" # 品牌 + EVENT = "EVENT" # 事件 + CONCEPT = "CONCEPT" # 概念 + OTHER = "OTHER" # 其他 + + +class RelationType(str, enum.Enum): + """关系类型""" + # 组织关系 + COMPETES_WITH = "COMPETES_WITH" # 竞争对手 + PARTNERS_WITH = "PARTNERS_WITH" # 合作伙伴 + ACQUIRES = "ACQUIRES" # 收购 + SUBSIDIARY_OF = "SUBSIDIARY_OF" # 子公司 + + # 产品关系 + PRODUCES = "PRODUCES" # 生产 + USES_TECHNOLOGY = "USES_TECHNOLOGY" # 使用技术 + PART_OF = "PART_OF" # 属于(产品线) + + # 地点关系 + LOCATED_IN = "LOCATED_IN" # 位于 + FOUNDED_IN = "FOUNDED_IN" # 成立于 + + # 人物关系 + CEO_OF = "CEO_OF" # CEO + FOUNDER_OF = "FOUNDER_OF" # 创始人 + + # 通用关系 + RELATED_TO = "RELATED_TO" # 相关 + MENTIONED_IN = "MENTIONED_IN" # 提及于 + ALSO_KNOWN_AS = "ALSO_KNOWN_AS" # 又名 + + +class KnowledgeEntity(Base): + """知识实体""" + __tablename__ = "knowledge_entities" + + id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + primary_key=True, + default=uuid.uuid4, + ) + knowledge_base_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("knowledge_bases.id", ondelete="CASCADE"), + nullable=False, + ) + + # 实体信息 + name: Mapped[str] = mapped_column(String(500), nullable=False, index=True) + entity_type: Mapped[EntityType] = mapped_column(Enum(EntityType), nullable=False, index=True) + description: Mapped[str | None] = mapped_column(Text, nullable=True) + + # 扩展属性(JSON) + properties: Mapped[dict | None] = mapped_column(JSONType, default=dict) + + # 来源信息 + source_chunk_id: Mapped[uuid.UUID | None] = mapped_column( + UUID(as_uuid=True), + ForeignKey("knowledge_chunks.id", ondelete="SET NULL"), + nullable=True, + ) + confidence: Mapped[str | None] = mapped_column(String(20), nullable=True) # 置信度描述:high/medium/low + + # 元数据 + created_at: Mapped[datetime] = mapped_column( + server_default=func.now(), + nullable=False, + ) + updated_at: Mapped[datetime] = mapped_column( + server_default=func.now(), + onupdate=func.now(), + nullable=False, + ) + + # 索引 + __table_args__ = ( + Index("ix_entities_kb_name", "knowledge_base_id", "name"), + Index("ix_entities_kb_type", "knowledge_base_id", "entity_type"), + ) + + # 关系 + outgoing_relations: Mapped[list["KnowledgeRelation"]] = relationship( + "KnowledgeRelation", + foreign_keys="KnowledgeRelation.source_entity_id", + back_populates="source_entity", + cascade="all, delete-orphan", + ) + incoming_relations: Mapped[list["KnowledgeRelation"]] = relationship( + "KnowledgeRelation", + foreign_keys="KnowledgeRelation.target_entity_id", + back_populates="target_entity", + cascade="all, delete-orphan", + ) + + +class KnowledgeRelation(Base): + """知识关系""" + __tablename__ = "knowledge_relations" + + id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + primary_key=True, + default=uuid.uuid4, + ) + + # 关系两端 + source_entity_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("knowledge_entities.id", ondelete="CASCADE"), + nullable=False, + ) + target_entity_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("knowledge_entities.id", ondelete="CASCADE"), + nullable=False, + ) + + # 关系信息 + relation_type: Mapped[RelationType] = mapped_column(Enum(RelationType), nullable=False, index=True) + + # 扩展属性 + properties: Mapped[dict | None] = mapped_column(JSONType, default=dict) + + # 来源信息 + source_chunk_id: Mapped[uuid.UUID | None] = mapped_column( + UUID(as_uuid=True), + ForeignKey("knowledge_chunks.id", ondelete="SET NULL"), + nullable=True, + ) + confidence: Mapped[str | None] = mapped_column(String(20), nullable=True) + + # 元数据 + created_at: Mapped[datetime] = mapped_column( + server_default=func.now(), + nullable=False, + ) + + # 关系 + source_entity: Mapped["KnowledgeEntity"] = relationship( + "KnowledgeEntity", + foreign_keys=[source_entity_id], + back_populates="outgoing_relations", + ) + target_entity: Mapped["KnowledgeEntity"] = relationship( + "KnowledgeEntity", + foreign_keys=[target_entity_id], + back_populates="incoming_relations", + ) + + # 索引 + __table_args__ = ( + Index("ix_relations_source", "source_entity_id"), + Index("ix_relations_target", "target_entity_id"), + Index("ix_relations_type", "relation_type"), + ) \ No newline at end of file diff --git a/backend/app/models/lifecycle.py b/backend/app/models/lifecycle.py index 75718a9..041c488 100644 --- a/backend/app/models/lifecycle.py +++ b/backend/app/models/lifecycle.py @@ -3,10 +3,9 @@ from datetime import datetime from sqlalchemy import String, Integer, ForeignKey, Index, func, Text from sqlalchemy import Uuid -from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import Mapped, mapped_column, relationship -from app.database import Base +from app.database import Base, JSONType class LifecycleProject(Base): @@ -23,7 +22,7 @@ class LifecycleProject(Base): nullable=False, ) brand_name: Mapped[str] = mapped_column(String(100), nullable=False) - brand_aliases: Mapped[list] = mapped_column(JSONB, server_default="[]", nullable=False) + brand_aliases: Mapped[list] = mapped_column(JSONType, server_default="[]", nullable=False) current_stage: Mapped[int] = mapped_column(Integer, server_default="1", nullable=False) status: Mapped[str] = mapped_column(String(20), server_default="active", nullable=False) created_by: Mapped[uuid.UUID] = mapped_column( @@ -77,7 +76,7 @@ class ProjectStage(Base): started_at: Mapped[datetime | None] = mapped_column(nullable=True) completed_at: Mapped[datetime | None] = mapped_column(nullable=True) notes: Mapped[str | None] = mapped_column(Text, nullable=True) - metrics: Mapped[dict | None] = mapped_column(JSONB, nullable=True) + metrics: Mapped[dict | None] = mapped_column(JSONType, nullable=True) # Relationships project: Mapped["LifecycleProject"] = relationship( diff --git a/backend/app/models/platform_rule.py b/backend/app/models/platform_rule.py index 925ef6d..f68bdac 100644 --- a/backend/app/models/platform_rule.py +++ b/backend/app/models/platform_rule.py @@ -3,10 +3,9 @@ from datetime import datetime from sqlalchemy import String, Boolean, ForeignKey, Index, func, Text from sqlalchemy import Uuid -from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import Mapped, mapped_column, relationship -from app.database import Base +from app.database import Base, JSONType class PlatformRule(Base): @@ -21,7 +20,7 @@ class PlatformRule(Base): rule_category: Mapped[str] = mapped_column(String(50), nullable=False) rule_name: Mapped[str] = mapped_column(String(200), nullable=False) description: Mapped[str | None] = mapped_column(Text, nullable=True) - check_criteria: Mapped[dict | None] = mapped_column(JSONB, nullable=True) + check_criteria: Mapped[dict | None] = mapped_column(JSONType, nullable=True) severity: Mapped[str] = mapped_column(String(20), nullable=False) is_active: Mapped[bool] = mapped_column(Boolean, server_default="true", nullable=False) updated_at: Mapped[datetime] = mapped_column( diff --git a/backend/app/monitoring/__init__.py b/backend/app/monitoring/__init__.py new file mode 100644 index 0000000..79b529b --- /dev/null +++ b/backend/app/monitoring/__init__.py @@ -0,0 +1,13 @@ +"""监控模块""" +import os + +from app.monitoring.metrics import * +from app.monitoring.middleware import MonitoringMiddleware +from app.monitoring.agent_hooks import agent_execution_context, record_agent_execution +from app.monitoring.llm_metrics import get_llm_metrics, LLMMetricsWrapper + +# 设置服务信息 +SERVICE_INFO.info({ + "version": "1.0.0", + "environment": os.getenv("ENVIRONMENT", "development"), +}) diff --git a/backend/app/monitoring/agent_hooks.py b/backend/app/monitoring/agent_hooks.py new file mode 100644 index 0000000..b86302d --- /dev/null +++ b/backend/app/monitoring/agent_hooks.py @@ -0,0 +1,57 @@ +"""Agent执行指标钩子""" +import time +from contextlib import asynccontextmanager +from typing import Optional + +from app.monitoring.metrics import ( + AGENT_EXECUTIONS_TOTAL, + AGENT_EXECUTION_DURATION_SECONDS, + AGENT_RUNNING_TASKS, +) + + +@asynccontextmanager +async def agent_execution_context(agent_name: str): + """Agent执行上下文管理器 - 自动记录指标""" + # 增加运行任务计数 + AGENT_RUNNING_TASKS.labels(agent_name=agent_name).inc() + + start_time = time.perf_counter() + status = "success" + + try: + yield + except Exception as e: + status = "failure" + raise + finally: + # 记录执行时间和状态 + duration = time.perf_counter() - start_time + + AGENT_EXECUTIONS_TOTAL.labels( + agent_name=agent_name, + status=status + ).inc() + + AGENT_EXECUTION_DURATION_SECONDS.labels( + agent_name=agent_name + ).observe(duration) + + # 减少运行任务计数 + AGENT_RUNNING_TASKS.labels(agent_name=agent_name).dec() + + +def record_agent_execution( + agent_name: str, + status: str, + duration: float +): + """手动记录Agent执行指标""" + AGENT_EXECUTIONS_TOTAL.labels( + agent_name=agent_name, + status=status + ).inc() + + AGENT_EXECUTION_DURATION_SECONDS.labels( + agent_name=agent_name + ).observe(duration) diff --git a/backend/app/monitoring/llm_metrics.py b/backend/app/monitoring/llm_metrics.py new file mode 100644 index 0000000..85de516 --- /dev/null +++ b/backend/app/monitoring/llm_metrics.py @@ -0,0 +1,102 @@ +"""LLM调用指标包装""" +import time +from typing import Optional + +from app.monitoring.metrics import ( + LLM_REQUESTS_TOTAL, + LLM_REQUEST_DURATION_SECONDS, + LLM_TOKENS_TOTAL, + LLM_COST_ESTIMATED, +) + +# LLM成本估算(USD/token) +LLM_COST_PER_TOKEN = { + # OpenAI + ("openai", "gpt-4o"): {"prompt": 0.000005, "completion": 0.000015}, + ("openai", "gpt-4o-mini"): {"prompt": 0.00000015, "completion": 0.0000006}, + ("openai", "gpt-4-turbo"): {"prompt": 0.00001, "completion": 0.00003}, + # DeepSeek + ("deepseek", "deepseek-chat"): {"prompt": 0.00000014, "completion": 0.00000028}, + ("deepseek", "deepseek-coder"): {"prompt": 0.00000014, "completion": 0.00000028}, +} + + +class LLMMetricsWrapper: + """LLM调用指标包装器""" + + def __init__(self, provider: str, model: str): + self.provider = provider + self.model = model + + def record_request( + self, + status: str, + duration: float, + prompt_tokens: Optional[int] = None, + completion_tokens: Optional[int] = None, + ): + """记录LLM请求指标""" + # 记录请求数和耗时 + LLM_REQUESTS_TOTAL.labels( + provider=self.provider, + model=self.model, + status=status + ).inc() + + LLM_REQUEST_DURATION_SECONDS.labels( + provider=self.provider, + model=self.model + ).observe(duration) + + # 记录Token消耗 + if prompt_tokens is not None: + LLM_TOKENS_TOTAL.labels( + provider=self.provider, + model=self.model, + token_type="prompt" + ).inc(prompt_tokens) + + if completion_tokens is not None: + LLM_TOKENS_TOTAL.labels( + provider=self.provider, + model=self.model, + token_type="completion" + ).inc(completion_tokens) + + # 估算成本 + cost = self._estimate_cost(prompt_tokens, completion_tokens) + if cost > 0: + LLM_COST_ESTIMATED.labels( + provider=self.provider, + model=self.model + ).inc(cost) + + def _estimate_cost( + self, + prompt_tokens: Optional[int], + completion_tokens: Optional[int] + ) -> float: + """估算请求成本""" + cost_info = LLM_COST_PER_TOKEN.get((self.provider, self.model)) + if not cost_info: + return 0.0 + + total = 0.0 + if prompt_tokens: + total += prompt_tokens * cost_info["prompt"] + if completion_tokens: + total += completion_tokens * cost_info["completion"] + + return total + + +# 全局LLM指标记录器 +_llm_metrics_cache: dict[str, LLMMetricsWrapper] = {} + + +def get_llm_metrics(provider: str, model: str) -> LLMMetricsWrapper: + """获取LLM指标包装器(带缓存)""" + key = f"{provider}:{model}" + if key not in _llm_metrics_cache: + _llm_metrics_cache[key] = LLMMetricsWrapper(provider, model) + return _llm_metrics_cache[key] diff --git a/backend/app/monitoring/metrics.py b/backend/app/monitoring/metrics.py new file mode 100644 index 0000000..b0b69c0 --- /dev/null +++ b/backend/app/monitoring/metrics.py @@ -0,0 +1,97 @@ +"""Prometheus指标定义""" +from prometheus_client import Counter, Histogram, Gauge, Info + +# API层指标 +API_REQUESTS_TOTAL = Counter( + "geo_api_requests_total", + "Total API requests", + ["method", "endpoint", "status"] +) + +API_REQUEST_DURATION_SECONDS = Histogram( + "geo_api_request_duration_seconds", + "API request duration in seconds", + ["method", "endpoint"], + buckets=(0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0) +) + +API_REQUESTS_IN_PROGRESS = Gauge( + "geo_api_requests_in_progress", + "Number of requests currently being processed", + ["method", "endpoint"] +) + +# Agent层指标 +AGENT_EXECUTIONS_TOTAL = Counter( + "geo_agent_executions_total", + "Total agent executions", + ["agent_name", "status"] +) + +AGENT_EXECUTION_DURATION_SECONDS = Histogram( + "geo_agent_execution_duration_seconds", + "Agent execution duration in seconds", + ["agent_name"], + buckets=(0.1, 0.5, 1.0, 5.0, 10.0, 30.0, 60.0, 120.0) +) + +AGENT_RUNNING_TASKS = Gauge( + "geo_agent_running_tasks", + "Number of tasks currently running", + ["agent_name"] +) + +# LLM层指标 +LLM_REQUESTS_TOTAL = Counter( + "geo_llm_requests_total", + "Total LLM requests", + ["provider", "model", "status"] +) + +LLM_REQUEST_DURATION_SECONDS = Histogram( + "geo_llm_request_duration_seconds", + "LLM request duration in seconds", + ["provider", "model"], + buckets=(0.1, 0.5, 1.0, 2.0, 5.0, 10.0, 30.0, 60.0) +) + +LLM_TOKENS_TOTAL = Counter( + "geo_llm_tokens_total", + "Total LLM tokens used", + ["provider", "model", "token_type"] +) + +LLM_COST_ESTIMATED = Gauge( + "geo_llm_cost_estimated", + "Estimated LLM cost in USD", + ["provider", "model"] +) + +# 业务层指标 +BRAND_COUNT = Gauge( + "geo_brands_total", + "Total number of brands" +) + +QUERY_COUNT_TOTAL = Counter( + "geo_queries_total", + "Total number of queries executed", + ["platform", "status"] +) + +CONTENT_GENERATED_TOTAL = Counter( + "geo_content_generated_total", + "Total content generated" +) + +CITATION_DETECTED_TOTAL = Counter( + "geo_citations_detected_total", + "Total citations detected", + ["platform"] +) + +# 系统信息 +SERVICE_INFO = Info( + "geo_service", + "GEO service information" +) diff --git a/backend/app/monitoring/middleware.py b/backend/app/monitoring/middleware.py new file mode 100644 index 0000000..eec355b --- /dev/null +++ b/backend/app/monitoring/middleware.py @@ -0,0 +1,86 @@ +"""监控中间件""" +import time +from typing import Callable + +from fastapi import Request, Response +from starlette.middleware.base import BaseHTTPMiddleware + +from app.monitoring.metrics import ( + API_REQUESTS_TOTAL, + API_REQUEST_DURATION_SECONDS, + API_REQUESTS_IN_PROGRESS, +) + +# 需要排除的路径(不记录指标) +EXCLUDED_PATHS = {"/health", "/ready", "/metrics", "/docs", "/openapi.json"} + + +class MonitoringMiddleware(BaseHTTPMiddleware): + """API监控中间件""" + + async def dispatch(self, request: Request, call_next: Callable) -> Response: + # 跳过排除路径 + if request.url.path in EXCLUDED_PATHS: + return await call_next(request) + + # 提取端点标识(用于指标标签) + endpoint = self._get_endpoint_label(request) + + # 增加活跃请求计数 + API_REQUESTS_IN_PROGRESS.labels( + method=request.method, + endpoint=endpoint + ).inc() + + # 记录开始时间 + start_time = time.perf_counter() + + try: + # 执行请求 + response = await call_next(request) + status_code = response.status_code + except Exception as e: + status_code = 500 + raise + finally: + # 计算耗时 + duration = time.perf_counter() - start_time + + # 记录指标 + API_REQUESTS_TOTAL.labels( + method=request.method, + endpoint=endpoint, + status=str(status_code) + ).inc() + + API_REQUEST_DURATION_SECONDS.labels( + method=request.method, + endpoint=endpoint + ).observe(duration) + + # 减少活跃请求计数 + API_REQUESTS_IN_PROGRESS.labels( + method=request.method, + endpoint=endpoint + ).dec() + + return response + + def _get_endpoint_label(self, request: Request) -> str: + """提取端点标签""" + path = request.url.path + + # 规范化路径(替换ID等参数) + parts = path.strip("/").split("/") + + # 处理常见模式:/api/v1/resources/{id} + if len(parts) >= 4 and parts[0] == "api": + resource = parts[2] if len(parts) > 2 else "unknown" + action = parts[3] if len(parts) > 3 else "list" + + # 映射到规范标签 + if action.isdigit(): + return f"{resource}_detail" + return f"{resource}_{action}" + + return "other" diff --git a/backend/app/schemas/knowledge.py b/backend/app/schemas/knowledge.py index 47883a8..00f8f01 100644 --- a/backend/app/schemas/knowledge.py +++ b/backend/app/schemas/knowledge.py @@ -74,3 +74,27 @@ class ChunkPreview(BaseModel): token_count: int model_config = {"from_attributes": True} + + +# ---------- 增量索引 Schemas ---------- + +class UpdateDocumentRequest(BaseModel): + """更新文档内容请求""" + content: str + + +class RetrieveRequest(BaseModel): + """增强检索请求""" + query: str = Field(..., min_length=1) + top_k: Optional[int] = Field(default=5, ge=1, le=50) + use_rerank: Optional[bool] = Field(default=True) + use_compression: Optional[bool] = Field(default=False) + + +class RebuildResponse(BaseModel): + """重建索引响应""" + total: int + processed: int + skipped: int + failed: int + errors: list[dict] diff --git a/backend/app/services/app_state.py b/backend/app/services/app_state.py new file mode 100644 index 0000000..7769888 --- /dev/null +++ b/backend/app/services/app_state.py @@ -0,0 +1,54 @@ +"""应用状态管理""" +import os +import platform +import time + + +class AppState: + """应用状态""" + + def __init__(self): + self.start_time = time.time() + self.platform = platform.system() + self.python_version = platform.python_version() + self.version = os.getenv("APP_VERSION", "1.0.0") + self.environment = os.getenv("ENVIRONMENT", "development") + + def get_uptime_seconds(self) -> float: + """获取运行时间(秒)""" + return time.time() - self.start_time + + def get_uptime_formatted(self) -> str: + """获取格式化的运行时间""" + seconds = self.get_uptime_seconds() + + days = int(seconds // 86400) + hours = int((seconds % 86400) // 3600) + minutes = int((seconds % 3600) // 60) + secs = int(seconds % 60) + + parts = [] + if days > 0: + parts.append(f"{days}天") + if hours > 0: + parts.append(f"{hours}小时") + if minutes > 0: + parts.append(f"{minutes}分钟") + parts.append(f"{secs}秒") + + return "".join(parts) + + def get_info(self) -> dict: + """获取应用信息""" + return { + "version": self.version, + "environment": self.environment, + "platform": self.platform, + "python_version": self.python_version, + "uptime_seconds": round(self.get_uptime_seconds(), 2), + "uptime_formatted": self.get_uptime_formatted(), + } + + +# 全局实例 +app_state = AppState() diff --git a/backend/app/services/content/topic_templates.py b/backend/app/services/content/topic_templates.py new file mode 100644 index 0000000..966d34c --- /dev/null +++ b/backend/app/services/content/topic_templates.py @@ -0,0 +1,362 @@ +"""GEO内容8大母题库定义""" + +from dataclasses import dataclass +from typing import Optional + +@dataclass +class TopicTemplate: + """母题模板""" + id: str # product_comparison + name: str # 产品对比 + description: str # 描述 + icon: str # emoji图标 + prompt_template: str # Prompt模板 + seo_tips: list[str] # SEO技巧 + recommended_platforms: list[str] # 推荐平台 + word_count_range: tuple[int, int] # 推荐字数范围 + required_params: list[str] # 必填参数 + optional_params: list[str] # 可选参数 + +TOPIC_TEMPLATES = { + # 1. 产品对比 + "product_comparison": TopicTemplate( + id="product_comparison", + name="产品对比", + description="品牌与竞品的功能、价格、体验对比", + icon="⚖️", + prompt_template=""" +请基于以下信息生成一篇产品对比文章: + +【你的品牌】{brand_name} +【对比品牌】{competitor_name} +【对比维度】{comparison_dimensions} +【目标关键词】{keywords} +【内容风格】{content_style} + +要求: +1. 客观中立,不刻意贬低竞品 +2. 突出自身优势但有理有据 +3. 适合AI平台引用(结构清晰、数据支撑) +4. 包含对比表格或图表 +5. 字数:约{word_count}字 + +文章结构: +1. 引言:简要介绍两个产品 +2. 外观设计对比 +3. 功能特性对比 +4. 性价比对比 +5. 适用场景分析 +6. 结论:总结优劣,给出建议 +""", + seo_tips=[ + "标题包含品牌名+对比+关键词", + "使用对比表格增强可读性", + "自然嵌入目标关键词", + ], + recommended_platforms=["知乎", "百家号", "公众号"], + word_count_range=(1500, 3000), + required_params=["brand_name", "competitor_name", "comparison_dimensions", "keywords"], + optional_params=["content_style", "word_count"], + ), + + # 2. 使用指南 + "how_to_guide": TopicTemplate( + id="how_to_guide", + name="使用指南", + description="产品使用方法教程,帮助用户解决问题", + icon="📖", + prompt_template=""" +请基于以下信息生成一篇使用指南文章: + +【产品名称】{product_name} +【核心功能】{core_features} +【目标用户】{target_audience} +【目标关键词】{keywords} +【内容风格】{content_style} + +要求: +1. 步骤清晰,操作性强 +2. 图文并茂(描述需要的图片类型) +3. 适合搜索引擎收录 +4. 适合AI平台引用 +5. 字数:约{word_count}字 + +文章结构: +1. 引言:介绍产品价值 +2. 基础设置/准备工作 +3. 核心功能使用步骤(分点详述) +4. 进阶技巧/常见问题 +5. 总结与进阶学习建议 +""", + seo_tips=[ + "标题包含产品名+使用方法/教程", + "使用有序列表标注步骤", + "包含FAQ解答常见问题", + ], + recommended_platforms=["知乎", "小红书", "简书"], + word_count_range=(1000, 2500), + required_params=["product_name", "core_features", "keywords"], + optional_params=["target_audience", "content_style", "word_count"], + ), + + # 3. 行业趋势 + "industry_trends": TopicTemplate( + id="industry_trends", + name="行业趋势", + description="行业动态、发展方向和未来预测分析", + icon="📈", + prompt_template=""" +请基于以下信息生成一篇行业趋势分析文章: + +【行业名称】{industry_name} +【品牌视角】{brand_perspective} +【分析维度】{analysis_dimensions} +【目标关键词】{keywords} +【内容风格】{content_style} + +要求: +1. 数据支撑,引用权威来源 +2. 趋势分析有理有据 +3. 适合AI平台引用(观点鲜明、数据翔实) +4. 字数:约{word_count}字 + +文章结构: +1. 引言:行业现状概述 +2. 核心趋势分析(3-5个趋势) +3. 趋势背后的驱动因素 +4. 品牌如何把握趋势 +5. 展望与建议 +""", + seo_tips=[ + "标题包含行业名+趋势/预测", + "引用数据增加可信度", + "使用时间线展示演变", + ], + recommended_platforms=["知乎", "百家号", "公众号"], + word_count_range=(2000, 4000), + required_params=["industry_name", "brand_perspective", "keywords"], + optional_params=["analysis_dimensions", "content_style", "word_count"], + ), + + # 4. 专家观点 + "expert_opinion": TopicTemplate( + id="expert_opinion", + name="专家观点", + description="行业专家见解和专业分析", + icon="💡", + prompt_template=""" +请基于以下信息生成一篇专家观点文章: + +【主题】{topic} +【专家身份】{expert_identity} +【核心观点】{core_opinion} +【目标关键词】{keywords} +【内容风格】{content_style} + +要求: +1. 观点鲜明,论证充分 +2. 结合实际案例 +3. 适合AI平台引用 +4. 字数:约{word_count}字 + +文章结构: +1. 引言:抛出核心观点 +2. 观点详细阐述 +3. 案例支撑 +4. 对行业的影响 +5. 结论与建议 +""", + seo_tips=[ + "标题体现专家身份+观点", + "使用引用格式突出核心观点", + "增加互动性问题", + ], + recommended_platforms=["知乎", "公众号", "微博"], + word_count_range=(1500, 3000), + required_params=["topic", "expert_identity", "core_opinion", "keywords"], + optional_params=["content_style", "word_count"], + ), + + # 5. 案例研究 + "case_study": TopicTemplate( + id="case_study", + name="案例研究", + description="成功案例深度分析", + icon="🏆", + prompt_template=""" +请基于以下信息生成一篇案例研究文章: + +【案例名称】{case_name} +【企业背景】{company_background} +【核心成果】{core_results} +【成功要素】{success_factors} +【目标关键词】{keywords} +【内容风格】{content_style} + +要求: +1. 故事性强,有代入感 +2. 数据支撑成果 +3. 可复制的方法论 +4. 适合AI平台引用 +5. 字数:约{word_count}字 + +文章结构: +1. 案例背景介绍 +2. 面临的挑战 +3. 解决方案详述 +4. 成果数据展示 +5. 成功经验总结 +6. 可复制建议 +""", + seo_tips=[ + "标题包含行业+案例+成果", + "使用数据图表展示成果", + "突出可复制的关键步骤", + ], + recommended_platforms=["知乎", "百家号", "公众号"], + word_count_range=(2000, 4000), + required_params=["case_name", "company_background", "core_results", "keywords"], + optional_params=["success_factors", "content_style", "word_count"], + ), + + # 6. 数据报告 + "data_report": TopicTemplate( + id="data_report", + name="数据报告", + description="行业数据和分析报告", + icon="📊", + prompt_template=""" +请基于以下信息生成一篇数据报告文章: + +【报告主题】{report_topic} +【数据范围】{data_scope} +【核心发现】{core_findings} +【目标关键词】{keywords} +【内容风格】{content_style} + +要求: +1. 数据翔实,来源可靠 +2. 图表丰富,直观易懂 +3. 深度分析,洞察本质 +4. 适合AI平台引用 +5. 字数:约{word_count}字 + +文章结构: +1. 研究背景与方法 +2. 核心数据发现 +3. 深度分析 +4. 趋势解读 +5. 建议与展望 +""", + seo_tips=[ + "标题包含数据+报告+关键词", + "使用图表增强说服力", + "包含数据来源说明", + ], + recommended_platforms=["知乎", "百家号", "公众号"], + word_count_range=(3000, 5000), + required_params=["report_topic", "data_scope", "core_findings", "keywords"], + optional_params=["content_style", "word_count"], + ), + + # 7. 问题解答 + "faq": TopicTemplate( + id="faq", + name="问题解答", + description="常见问题解答", + icon="❓", + prompt_template=""" +请基于以下信息生成一篇FAQ文章: + +【主题】{topic} +【目标用户】{target_audience} +【问题列表】{questions} +【目标关键词】{keywords} +【内容风格】{content_style} + +要求: +1. 问题覆盖面广,击中痛点 +2. 回答简洁明了,实操性强 +3. 适合搜索引擎长尾词 +4. 适合AI平台引用 +5. 字数:约{word_count}字 + +文章结构: +1. 引言:为什么需要了解这些 +2. 核心问题解答(每个问题独立成段) +3. 延伸问题 +4. 总结:下一步建议 +""", + seo_tips=[ + "标题包含主题+常见问题/FAQ", + "每个问题使用H2标签", + "覆盖长尾搜索词", + ], + recommended_platforms=["知乎", "简书", "小程序"], + word_count_range=(1000, 2000), + required_params=["topic", "questions", "keywords"], + optional_params=["target_audience", "content_style", "word_count"], + ), + + # 8. 评测报告 + "review": TopicTemplate( + id="review", + name="评测报告", + description="产品深度评测:优缺点全面分析", + icon="🔬", + prompt_template=""" +请基于以下信息生成一篇评测报告: + +【产品名称】{product_name} +【评测维度】{review_dimensions} +【竞品参照】{competitor_reference} +【目标关键词】{keywords} +【内容风格】{content_style} + +要求: +1. 客观公正,优点不夸大 +2. 缺点实事求是 +3. 有具体数据和场景支撑 +4. 适合AI平台引用 +5. 字数:约{word_count}字 + +文章结构: +1. 评测背景与方法 +2. 外观设计评价 +3. 核心功能评测 +4. 性能测试数据 +5. 优缺点总结 +6. 适合人群分析 +7. 购买建议 +""", + seo_tips=[ + "标题包含产品名+评测/测评", + "使用评分表格直观展示", + "包含具体测试数据", + ], + recommended_platforms=["知乎", "小红书", "百家号"], + word_count_range=(2000, 4000), + required_params=["product_name", "review_dimensions", "keywords"], + optional_params=["competitor_reference", "content_style", "word_count"], + ), +} + +def get_topic_template(topic_id: str) -> Optional[TopicTemplate]: + """获取母题模板""" + return TOPIC_TEMPLATES.get(topic_id) + +def list_topic_templates() -> list[TopicTemplate]: + """列出所有母题模板""" + return list(TOPIC_TEMPLATES.values()) + +def render_topic_prompt(topic_id: str, params: dict) -> str: + """渲染母题Prompt""" + template = get_topic_template(topic_id) + if not template: + raise ValueError(f"Unknown topic: {topic_id}") + + # 设置默认参数 + params.setdefault("content_style", "专业严谨") + params.setdefault("word_count", 2000) + + return template.prompt_template.format(**params) \ No newline at end of file diff --git a/backend/app/services/distribution/platform_rules.py b/backend/app/services/distribution/platform_rules.py index 9a171c9..627bad7 100644 --- a/backend/app/services/distribution/platform_rules.py +++ b/backend/app/services/distribution/platform_rules.py @@ -106,6 +106,11 @@ PLATFORM_RULES: dict[str, dict] = { "不得使用AI水文(内容需有信息增量)", "专业背书内容需有据可查", ], + "image_rules": { + "cover": {"width": 690, "height": 280, "ratio": "2.5:1"}, + "inline": {"width": 500, "height": 375, "ratio": "4:3"}, + "max_size_mb": 5, + }, "seo_tips": [ "回答开头直接给出结论", "使用数据和案例支撑", @@ -186,6 +191,11 @@ PLATFORM_RULES: dict[str, dict] = { "正文不含外部链接(仅支持公众号链接和小程序)", "不得包含未经授权的商标/品牌名称", ], + "image_rules": { + "cover": {"width": 900, "height": 383, "ratio": "2.35:1"}, + "inline": {"width": 800, "height": 600, "ratio": "4:3"}, + "max_size_mb": 10, + }, "seo_tips": [ "首段包含核心关键词", "使用小标题分段(适配搜一搜)", @@ -262,6 +272,11 @@ PLATFORM_RULES: dict[str, dict] = { "正文需含至少1张配图", "不得搬运/洗稿", ], + "image_rules": { + "cover": {"width": 600, "height": 400, "ratio": "3:2"}, + "inline": {"width": 800, "height": 600, "ratio": "4:3"}, + "max_size_mb": 5, + }, "seo_tips": [ "标题包含百度搜索热词", "文章结构化(H2小标题)", @@ -341,6 +356,11 @@ PLATFORM_RULES: dict[str, dict] = { "首发原创优先推荐", "配图清晰不模糊", ], + "image_rules": { + "cover": {"width": 1024, "height": 678, "ratio": "1.5:1"}, + "inline": {"width": 800, "height": 600, "ratio": "4:3"}, + "max_size_mb": 5, + }, "seo_tips": [ "标题含核心关键词", "文章1500字以上推荐更高", @@ -417,6 +437,11 @@ PLATFORM_RULES: dict[str, dict] = { "话题标签有助于曝光", "配图有助于转发", ], + "image_rules": { + "cover": {"width": 980, "height": 560, "ratio": "1.75:1"}, + "inline": {"width": 800, "height": 600, "ratio": "4:3"}, + "max_size_mb": 5, + }, "seo_tips": [ "热门话题可增加曝光", "短句更易阅读", @@ -493,6 +518,11 @@ PLATFORM_RULES: dict[str, dict] = { "不得出现其他平台引流信息", "图片不含水印", ], + "image_rules": { + "cover": {"width": 1080, "height": 1080, "ratio": "1:1"}, + "inline": {"width": 1242, "height": 1660, "ratio": "3:4"}, + "max_size_mb": 10, + }, "seo_tips": [ "标题含数字更吸引点击", "正文用短句+emoji分段", @@ -572,6 +602,11 @@ PLATFORM_RULES: dict[str, dict] = { "封面和标题很重要", "互动有助于推荐", ], + "image_rules": { + "cover": {"width": 1920, "height": 1080, "ratio": "16:9"}, + "inline": {"width": 800, "height": 600, "ratio": "4:3"}, + "max_size_mb": 5, + }, "seo_tips": [ "标题包含关键词", "封面吸引人", @@ -646,6 +681,11 @@ PLATFORM_RULES: dict[str, dict] = { "文艺风格更受欢迎", "配图有助于阅读", ], + "image_rules": { + "cover": {"width": 800, "height": 600, "ratio": "4:3"}, + "inline": {"width": 800, "height": 600, "ratio": "4:3"}, + "max_size_mb": 5, + }, "seo_tips": [ "标题包含关键词", "合理使用专题", @@ -721,6 +761,11 @@ PLATFORM_RULES: dict[str, dict] = { "鼓励原创技术文章", "禁止低质量搬运", ], + "image_rules": { + "cover": {"width": 1024, "height": 768, "ratio": "4:3"}, + "inline": {"width": 800, "height": 600, "ratio": "4:3"}, + "max_size_mb": 5, + }, "seo_tips": [ "标题包含技术关键词", "代码块有助于阅读", @@ -796,6 +841,11 @@ PLATFORM_RULES: dict[str, dict] = { "话题标签2-5个", "文案简短有吸引力", ], + "image_rules": { + "cover": {"width": 1080, "height": 1920, "ratio": "9:16"}, + "inline": {"width": 1080, "height": 1920, "ratio": "9:16"}, + "max_size_mb": 10, + }, "seo_tips": [ "前3秒决定完播率", "标题含热点关键词", diff --git a/backend/app/services/email_service.py b/backend/app/services/email_service.py new file mode 100644 index 0000000..6bbb785 --- /dev/null +++ b/backend/app/services/email_service.py @@ -0,0 +1,382 @@ +""" +邮件通知服务 + +支持发送告警通知、额度预警等邮件。 + +功能: +- 邮件模板引擎: 变量替换渲染邮件内容 +- 邮件内容生成: 告警通知、额度预警邮件生成 +- 邮件发送: 支持真实SMTP和模拟模式 +- 邮件队列管理: 批量添加和发送 +- 错误处理和重试: 自动重试机制 +""" +from __future__ import annotations + +import logging +import re +import smtplib +import time +import uuid +from dataclasses import dataclass, field +from email.mime.multipart import MIMEMultipart +from email.mime.text import MIMEText +from email.mime.base import MIMEBase +from email import encoders +from typing import Any + +logger = logging.getLogger(__name__) + + +EMAIL_TEMPLATES = { + "alert_notification": { + "subject": "[GEO平台] 告警通知:{alert_type}", + "body_html": """ +

告警通知

+

品牌:{brand_name}

+

告警类型:{alert_type}

+

严重程度:{severity}

+

详情:{description}

+

时间:{timestamp}

+ """, + "body_text": "告警通知 - 品牌:{brand_name}, 类型:{alert_type}, severity:{severity}" + }, + "quota_warning": { + "subject": "[GEO平台] 额度预警:{quota_type}", + "body_html": """ +

额度预警

+

您的{quota_type}使用量已达到{usage_percentage}%

+

已使用:{used} / 总额度:{limit}

+

建议操作:{recommended_action}

+ """, + "body_text": "额度预警 - {quota_type}使用量:{usage_percentage}%" + } +} + + +@dataclass +class EmailMessage: + to: str + subject: str + body_html: str + body_text: str + attachments: list[dict] = field(default_factory=list) + metadata: dict = field(default_factory=dict) + + +@dataclass +class EmailSendResult: + success: bool + message_id: str | None + error: str | None + retry_count: int + + +class EmailService: + """邮件通知服务 + + 提供邮件模板渲染、内容生成、发送(支持模拟模式)、队列管理等功能。 + """ + + def __init__( + self, + simulate_mode: bool = True, + smtp_host: str = "localhost", + smtp_port: int = 587, + smtp_user: str = "", + smtp_password: str = "", + max_retries: int = 3, + ): + self.simulate_mode = simulate_mode + self.smtp_host = smtp_host + self.smtp_port = smtp_port + self.smtp_user = smtp_user + self.smtp_password = smtp_password + self.max_retries = max_retries + self._queue: list[EmailMessage] = [] + + def validate_email(self, email: str) -> bool: + """验证邮箱地址格式 + + Args: + email: 邮箱地址 + + Returns: + 是否为有效邮箱地址 + """ + if not email: + return False + pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$' + return bool(re.match(pattern, email)) + + def _safe_format(self, template: str, variables: dict[str, Any]) -> str: + """安全格式化模板,缺失变量时保留原占位符 + + Args: + template: 模板字符串 + variables: 变量字典 + + Returns: + 格式化后的字符串 + """ + try: + return template.format(**variables) + except KeyError: + result = template + for key, value in variables.items(): + result = result.replace("{" + key + "}", str(value)) + return result + + def render_template( + self, + template_name: str, + to: str, + variables: dict[str, Any], + ) -> EmailMessage: + """渲染邮件模板 + + Args: + template_name: 模板名称 + to: 收件人邮箱 + variables: 模板变量 + + Returns: + EmailMessage邮件消息对象 + + Raises: + ValueError: 模板不存在 + """ + if template_name not in EMAIL_TEMPLATES: + raise ValueError(f"模板不存在: {template_name}") + + template = EMAIL_TEMPLATES[template_name] + + subject = self._safe_format(template["subject"], variables) + body_html = self._safe_format(template["body_html"], variables) + body_text = self._safe_format(template["body_text"], variables) + + return EmailMessage( + to=to, + subject=subject, + body_html=body_html, + body_text=body_text, + metadata=variables, + ) + + def generate_alert_email( + self, + to: str, + alert_type: str, + brand_name: str, + severity: str, + description: str, + timestamp: str, + ) -> EmailMessage: + """生成告警通知邮件 + + Args: + to: 收件人邮箱 + alert_type: 告警类型 + brand_name: 品牌名称 + severity: 严重程度 + description: 告警详情 + timestamp: 时间戳 + + Returns: + EmailMessage邮件消息对象 + """ + variables = { + "alert_type": alert_type, + "brand_name": brand_name, + "severity": severity, + "description": description, + "timestamp": timestamp, + } + return self.render_template("alert_notification", to, variables) + + def generate_quota_warning_email( + self, + to: str, + quota_type: str, + usage_percentage: int, + used: int, + limit: int, + recommended_action: str, + ) -> EmailMessage: + """生成额度预警邮件 + + Args: + to: 收件人邮箱 + quota_type: 额度类型 + usage_percentage: 使用百分比 + used: 已使用量 + limit: 总额度 + recommended_action: 建议操作 + + Returns: + EmailMessage邮件消息对象 + """ + variables = { + "quota_type": quota_type, + "usage_percentage": usage_percentage, + "used": used, + "limit": limit, + "recommended_action": recommended_action, + } + return self.render_template("quota_warning", to, variables) + + def add_attachment(self, msg: EmailMessage, filename: str, content: bytes) -> None: + """添加附件到邮件 + + Args: + msg: 邮件消息对象 + filename: 附件文件名 + content: 附件内容 + """ + msg.attachments.append({ + "filename": filename, + "content": content, + }) + + def _create_mime_message(self, msg: EmailMessage) -> MIMEMultipart: + """创建MIME邮件对象 + + Args: + msg: EmailMessage对象 + + Returns: + MIMEMultipart MIME邮件对象 + """ + mime_msg = MIMEMultipart() + mime_msg["From"] = self.smtp_user + mime_msg["To"] = msg.to + mime_msg["Subject"] = msg.subject + + mime_msg.attach(MIMEText(msg.body_html, "html", "utf-8")) + mime_msg.attach(MIMEText(msg.body_text, "plain", "utf-8")) + + for attachment in msg.attachments: + part = MIMEBase("application", "octet-stream") + part.set_payload(attachment["content"]) + encoders.encode_base64(part) + part.add_header( + "Content-Disposition", + f'attachment; filename="{attachment["filename"]}"', + ) + mime_msg.attach(part) + + return mime_msg + + def send_email(self, msg: EmailMessage) -> EmailSendResult: + """发送邮件 + + Args: + msg: EmailMessage邮件消息对象 + + Returns: + EmailSendResult发送结果对象 + """ + if not self.validate_email(msg.to): + logger.warning(f"无效的邮箱地址: {msg.to}") + return EmailSendResult( + success=False, + message_id=None, + error=f"无效的邮箱地址: {msg.to}", + retry_count=0, + ) + + if self.simulate_mode: + message_id = f"sim_{uuid.uuid4().hex[:8]}" + logger.info(f"[模拟模式] 邮件发送成功: {msg.to}, ID: {message_id}") + return EmailSendResult( + success=True, + message_id=message_id, + error=None, + retry_count=0, + ) + + retry_count = 0 + last_error = None + + while retry_count <= self.max_retries: + try: + logger.info(f"尝试发送邮件到 {msg.to} (尝试 {retry_count + 1}/{self.max_retries + 1})") + server = smtplib.SMTP(self.smtp_host, self.smtp_port) + server.starttls() + server.login(self.smtp_user, self.smtp_password) + + mime_msg = self._create_mime_message(msg) + server.send_message(mime_msg) + server.quit() + + message_id = f"smtp_{uuid.uuid4().hex[:8]}" + logger.info(f"邮件发送成功: {msg.to}, ID: {message_id}") + return EmailSendResult( + success=True, + message_id=message_id, + error=None, + retry_count=retry_count, + ) + except smtplib.SMTPException as e: + last_error = f"SMTP错误: {str(e)}" + logger.error(f"SMTP错误 (尝试 {retry_count + 1}/{self.max_retries + 1}): {e}") + retry_count += 1 + if retry_count <= self.max_retries: + time.sleep(1 * retry_count) + except Exception as e: + last_error = str(e) + logger.error(f"邮件发送失败 (尝试 {retry_count + 1}/{self.max_retries + 1}): {e}") + retry_count += 1 + if retry_count <= self.max_retries: + time.sleep(1 * retry_count) + + logger.error(f"邮件发送最终失败,已重试 {retry_count} 次: {msg.to}, 错误: {last_error}") + return EmailSendResult( + success=False, + message_id=None, + error=last_error, + retry_count=retry_count, + ) + + def add_to_queue(self, msg: EmailMessage) -> None: + """添加邮件到队列 + + Args: + msg: EmailMessage邮件消息对象 + """ + self._queue.append(msg) + logger.debug(f"邮件已添加到队列: {msg.to}, 队列长度: {len(self._queue)}") + + def get_queue(self) -> list[EmailMessage]: + """获取队列中的邮件 + + Returns: + 邮件消息列表 + """ + return self._queue.copy() + + def clear_queue(self) -> None: + """清空队列""" + count = len(self._queue) + self._queue.clear() + logger.info(f"队列已清空,移除了 {count} 封邮件") + + def send_queue(self) -> list[EmailSendResult]: + """发送队列中的所有邮件 + + Returns: + 发送结果列表 + """ + results = [] + messages = self._queue.copy() + self._queue.clear() + + logger.info(f"开始发送队列中的 {len(messages)} 封邮件") + + for msg in messages: + result = self.send_email(msg) + results.append(result) + + success_count = sum(1 for r in results if r.success) + logger.info(f"队列发送完成: 成功 {success_count}/{len(results)}") + + return results diff --git a/backend/app/services/geo_diagnosis.py b/backend/app/services/geo_diagnosis.py new file mode 100644 index 0000000..eaafb84 --- /dev/null +++ b/backend/app/services/geo_diagnosis.py @@ -0,0 +1,1078 @@ +""" +GEO诊断服务 - 6大维度检测系统 + +诊断维度(总分100): +- 内容可提取性 (Content Extractability): 20分 - AI能否轻松提取和理解内容 +- 实体清晰度 (Entity Clarity): 15分 - AI能否理解品牌是什么 +- E-E-A-T信号 (E-E-A-T Signals): 20分 - 经验、专业性、权威性、可信度 +- Schema标记 (Schema Markup): 15分 - 结构化数据完整性 +- 主题权威 (Topic Authority): 15分 - 品牌在特定领域的权威性 +- 引用就绪度 (Citation Readiness): 15分 - 品牌在AI回答中被引用的可能性 +""" +from __future__ import annotations + +import logging +from dataclasses import dataclass, field + +logger = logging.getLogger(__name__) + + +# ============================================================ +# 诊断数据结构 +# ============================================================ + +@dataclass +class DiagnosisItem: + """单个诊断项""" + name: str # 诊断项名称 + status: str # pass/warning/fail + description: str # 诊断说明 + suggestion: str # 优化建议 + score: float = 0.0 # 该项得分 + max_score: float = 0.0 # 该项满分 + + +@dataclass +class GEODimensionScore: + """单个维度的诊断评分详情""" + name: str # 维度名称 + score: float # 该维度得分 (0-max_score) + max_score: float # 该维度满分 + items: list[DiagnosisItem] = field(default_factory=list) + status: str = "pass" # pass/warning/fail + percentage: float = 0.0 # 得分率 (0-100) + detail: dict = field(default_factory=dict) + + +@dataclass +class GEORecommendation: + """优化建议""" + priority: str # P0/P1/P2 + dimension: str # 所属维度 + title: str # 建议标题 + description: str # 建议描述 + impact: str # 预期影响: high/medium/low + effort: str # 实施难度: easy/medium/hard + + +@dataclass +class GEODiagnosisResult: + """GEO诊断结果""" + overall_score: float = 0.0 # 综合评分 0-100 + dimensions: list[GEODimensionScore] = field(default_factory=list) + recommendations: list[GEORecommendation] = field(default_factory=list) + health_level: str = "danger" # excellent/good/pass/danger + + def __post_init__(self): + """计算健康等级""" + if self.overall_score >= 80: + self.health_level = "excellent" + elif self.overall_score >= 60: + self.health_level = "good" + elif self.overall_score >= 40: + self.health_level = "pass" + else: + self.health_level = "danger" + + def to_dict(self) -> dict: + """转换为字典格式""" + return { + "overall_score": round(self.overall_score, 2), + "health_level": self.health_level, + "health_level_label": get_health_level_label(self.health_level), + "dimensions": [ + { + "name": dim.name, + "score": round(dim.score, 2), + "max_score": dim.max_score, + "percentage": round(dim.percentage, 2), + "status": dim.status, + "items": [ + { + "name": item.name, + "status": item.status, + "description": item.description, + "suggestion": item.suggestion, + "score": round(item.score, 2), + "max_score": item.max_score, + } + for item in dim.items + ], + "detail": dim.detail, + } + for dim in self.dimensions + ], + "recommendations": [ + { + "priority": rec.priority, + "dimension": rec.dimension, + "title": rec.title, + "description": rec.description, + "impact": rec.impact, + "effort": rec.effort, + } + for rec in self.recommendations + ], + } + + +# ============================================================ +# 维度1: 内容可提取性诊断 (满分20) +# ============================================================ + +def diagnose_content_extractability( + has_direct_answer: bool = False, + has_qa_headings: bool = False, + has_structured_data: bool = False, + has_internal_links: bool = False, + has_freshness_info: bool = False, + update_days_ago: int | None = None, +) -> GEODimensionScore: + """ + 诊断内容可提取性 (满分20) + + AI需要能够轻松提取和理解内容。 + + Args: + has_direct_answer: 是否有直接回答块(页面首段简洁明确的答案) + has_qa_headings: 是否有问答式标题(H2/H3采用问题形式) + has_structured_data: 是否使用列表和表格等结构化数据 + has_internal_links: 是否有内链到子意图页 + has_freshness_info: 是否有内容新鲜度信息(更新日期和作者) + update_days_ago: 内容最后更新距今天数 + + Returns: + GEODimensionScore: 内容可提取性维度评分 + """ + max_score = 20.0 + items = [] + + # 1. 直接回答块 (P0, 6分) + direct_answer_score = 6.0 if has_direct_answer else 0.0 + items.append(DiagnosisItem( + name="直接回答块", + status="pass" if has_direct_answer else "fail", + description="页面首段是否包含简洁明确的答案,便于AI直接提取", + suggestion="在页面首段添加2-3句话的简洁答案,直接回答用户核心问题", + score=direct_answer_score, + max_score=6.0, + )) + + # 2. 问答式标题 (P0, 5分) + qa_headings_score = 5.0 if has_qa_headings else 0.0 + items.append(DiagnosisItem( + name="问答式标题", + status="pass" if has_qa_headings else "fail", + description="H2/H3标题是否采用问题形式,帮助AI理解内容结构", + suggestion="将关键H2/H3标题改为问题形式,如'什么是X'、'如何使用Y'", + score=qa_headings_score, + max_score=5.0, + )) + + # 3. 列表和表格 (P0, 4分) + structured_score = 4.0 if has_structured_data else 0.0 + items.append(DiagnosisItem( + name="列表和表格", + status="pass" if has_structured_data else "fail", + description="是否使用列表、表格等结构化数据展示信息", + suggestion="使用HTML列表(ul/ol)和表格(table)组织信息,便于AI解析", + score=structured_score, + max_score=4.0, + )) + + # 4. 内链到子意图页 (P1, 3分) + internal_links_score = 3.0 if has_internal_links else 0.0 + items.append(DiagnosisItem( + name="内链到子意图页", + status="pass" if has_internal_links else "warning", + description="是否链接到相关深度内容页面", + suggestion="添加内链到相关子话题页面,形成内容网络", + score=internal_links_score, + max_score=3.0, + )) + + # 5. 内容新鲜度 (P1, 2分) + freshness_score = 0.0 + freshness_status = "fail" + if has_freshness_info: + if update_days_ago is not None: + if update_days_ago <= 30: + freshness_score = 2.0 + freshness_status = "pass" + elif update_days_ago <= 90: + freshness_score = 1.5 + freshness_status = "warning" + else: + freshness_score = 0.5 + freshness_status = "warning" + else: + freshness_score = 1.0 + freshness_status = "warning" + + items.append(DiagnosisItem( + name="内容新鲜度", + status=freshness_status, + description="是否有更新日期和作者信息,体现内容时效性", + suggestion="在页面显眼位置展示最后更新日期和作者信息", + score=freshness_score, + max_score=2.0, + )) + + total_score = sum(item.score for item in items) + percentage = (total_score / max_score) * 100 + + # 维度状态:如果有fail项则为warning,全pass则为pass + has_fail = any(item.status == "fail" for item in items) + status = "warning" if has_fail else "pass" + + return GEODimensionScore( + name="内容可提取性", + score=total_score, + max_score=max_score, + items=items, + status=status, + percentage=round(percentage, 2), + detail={ + "has_direct_answer": has_direct_answer, + "has_qa_headings": has_qa_headings, + "has_structured_data": has_structured_data, + "has_internal_links": has_internal_links, + "has_freshness_info": has_freshness_info, + "update_days_ago": update_days_ago, + }, + ) + + +# ============================================================ +# 维度2: 实体清晰度诊断 (满分15) +# ============================================================ + +def diagnose_entity_clarity( + has_brand_definition: bool = False, + has_target_audience: bool = False, + has_unique_value: bool = False, + has_industry_classification: bool = False, +) -> GEODimensionScore: + """ + 诊断实体清晰度 (满分15) + + AI需要能够理解品牌是什么。 + + Args: + has_brand_definition: 是否清晰说明品牌做什么 + has_target_audience: 是否明确服务谁 + has_unique_value: 是否有差异化价值主张 + has_industry_classification: 是否有行业分类信息 + + Returns: + GEODimensionScore: 实体清晰度维度评分 + """ + max_score = 15.0 + items = [] + + # 1. 品牌定义 (5分) + brand_def_score = 5.0 if has_brand_definition else 0.0 + items.append(DiagnosisItem( + name="品牌定义", + status="pass" if has_brand_definition else "fail", + description="是否清晰说明品牌做什么,AI理解准确率目标≥95%", + suggestion="在首页和About页面添加清晰的品牌定义,包含核心业务和价值主张", + score=brand_def_score, + max_score=5.0, + )) + + # 2. 目标受众 (4分) + audience_score = 4.0 if has_target_audience else 0.0 + items.append(DiagnosisItem( + name="目标受众", + status="pass" if has_target_audience else "fail", + description="是否明确服务谁,实体识别准确率目标≥90%", + suggestion="明确描述目标用户群体,如'为中小企业提供XX服务'", + score=audience_score, + max_score=4.0, + )) + + # 3. 差异化价值 (3分) + value_score = 3.0 if has_unique_value else 0.0 + items.append(DiagnosisItem( + name="差异化价值", + status="pass" if has_unique_value else "warning", + description="为什么选择这个品牌,独特性评分目标≥80", + suggestion="突出品牌独特优势,如技术领先、服务优质、价格合理等", + score=value_score, + max_score=3.0, + )) + + # 4. 行业分类 (3分) + industry_score = 3.0 if has_industry_classification else 0.0 + items.append(DiagnosisItem( + name="行业分类", + status="pass" if has_industry_classification else "warning", + description="品牌属于什么行业,分类准确率目标≥95%", + suggestion="在页面中明确标注行业分类,如'SaaS'、'电子商务'等", + score=industry_score, + max_score=3.0, + )) + + total_score = sum(item.score for item in items) + percentage = (total_score / max_score) * 100 + + has_fail = any(item.status == "fail" for item in items) + status = "warning" if has_fail else "pass" + + return GEODimensionScore( + name="实体清晰度", + score=total_score, + max_score=max_score, + items=items, + status=status, + percentage=round(percentage, 2), + detail={ + "has_brand_definition": has_brand_definition, + "has_target_audience": has_target_audience, + "has_unique_value": has_unique_value, + "has_industry_classification": has_industry_classification, + }, + ) + + +# ============================================================ +# 维度3: E-E-A-T信号诊断 (满分20) +# ============================================================ + +def diagnose_eeat_signals( + has_author_bio: bool = False, + author_credentials_complete: float = 0.0, + has_certifications: bool = False, + certification_count: int = 0, + has_data_sources: bool = False, + authoritative_source_ratio: float = 0.0, + has_expert_endorsements: bool = False, + endorsement_count: int = 0, +) -> GEODimensionScore: + """ + 诊断E-E-A-T信号 (满分20) + + AI需要验证品牌的可信度(经验、专业性、权威性、可信度)。 + + Args: + has_author_bio: 是否有作者资质信息 + author_credentials_complete: 作者简介完整度 (0-1) + has_certifications: 是否有专业认证 + certification_count: 认证/奖项数量 + has_data_sources: 是否引用数据来源 + authoritative_source_ratio: 权威源引用比例 (0-1) + has_expert_endorsements: 是否有专家背书 + endorsement_count: 专家背书数量 + + Returns: + GEODimensionScore: E-E-A-T信号维度评分 + """ + max_score = 20.0 + items = [] + + # 1. 作者资质 (6分) + author_score = 0.0 + if has_author_bio: + author_score = author_credentials_complete * 6.0 + author_status = "pass" if author_score >= 5.4 else ("warning" if author_score >= 3.0 else "fail") + items.append(DiagnosisItem( + name="作者资质", + status=author_status, + description="内容作者是否有专业背景,作者简介完整度目标≥90%", + suggestion="添加作者详细简介,包含教育背景、工作经验、专业领域", + score=author_score, + max_score=6.0, + )) + + # 2. 专业认证 (5分) + cert_score = 0.0 + if has_certifications: + if certification_count >= 5: + cert_score = 5.0 + elif certification_count >= 3: + cert_score = 4.0 + elif certification_count >= 1: + cert_score = 2.5 + cert_status = "pass" if cert_score >= 4.0 else ("warning" if cert_score >= 2.0 else "fail") + items.append(DiagnosisItem( + name="专业认证", + status=cert_status, + description="是否有行业认证/奖项,认证展示率目标≥80%", + suggestion="展示行业认证、奖项、资质,如ISO认证、行业奖项等", + score=cert_score, + max_score=5.0, + )) + + # 3. 数据来源 (5分) + source_score = authoritative_source_ratio * 5.0 if has_data_sources else 0.0 + source_status = "pass" if source_score >= 4.0 else ("warning" if source_score >= 2.0 else "fail") + items.append(DiagnosisItem( + name="数据来源", + status=source_status, + description="是否引用可靠数据,引用权威源比例目标≥70%", + suggestion="引用权威机构数据,如政府报告、学术研究、行业报告", + score=source_score, + max_score=5.0, + )) + + # 4. 专家背书 (4分) + endorsement_score = 0.0 + if has_expert_endorsements: + if endorsement_count >= 5: + endorsement_score = 4.0 + elif endorsement_count >= 3: + endorsement_score = 3.0 + elif endorsement_count >= 1: + endorsement_score = 1.5 + endorsement_status = "pass" if endorsement_score >= 3.0 else ("warning" if endorsement_score >= 1.5 else "fail") + items.append(DiagnosisItem( + name="专家背书", + status=endorsement_status, + description="是否有行业专家认可,背书数量目标≥3", + suggestion="获取行业专家推荐、用户评价、案例研究", + score=endorsement_score, + max_score=4.0, + )) + + total_score = sum(item.score for item in items) + percentage = (total_score / max_score) * 100 + + has_fail = any(item.status == "fail" for item in items) + status = "warning" if has_fail else "pass" + + return GEODimensionScore( + name="E-E-A-T信号", + score=total_score, + max_score=max_score, + items=items, + status=status, + percentage=round(percentage, 2), + detail={ + "has_author_bio": has_author_bio, + "author_credentials_complete": round(author_credentials_complete, 2), + "has_certifications": has_certifications, + "certification_count": certification_count, + "has_data_sources": has_data_sources, + "authoritative_source_ratio": round(authoritative_source_ratio, 2), + "has_expert_endorsements": has_expert_endorsements, + "endorsement_count": endorsement_count, + }, + ) + + +# ============================================================ +# 维度4: Schema标记诊断 (满分15) +# ============================================================ + +def diagnose_schema_markup( + has_organization: bool = False, + has_product: bool = False, + has_article: bool = False, + has_faq: bool = False, + has_howto: bool = False, + has_breadcrumb: bool = False, +) -> GEODimensionScore: + """ + 诊断Schema标记完整性 (满分15) + + 结构化数据帮助AI理解内容。 + + Args: + has_organization: 是否有Organization标记(企业主页) + has_product: 是否有Product标记(产品页) + has_article: 是否有Article/BlogPosting标记(博客文章) + has_faq: 是否有FAQPage标记(常见问题) + has_howto: 是否有HowTo标记(操作指南) + has_breadcrumb: 是否有BreadcrumbList标记(导航结构) + + Returns: + GEODimensionScore: Schema标记维度评分 + """ + max_score = 15.0 + items = [] + + # 1. Organization (P0必须, 4分) + org_score = 4.0 if has_organization else 0.0 + items.append(DiagnosisItem( + name="Organization", + status="pass" if has_organization else "fail", + description="企业主页的Organization标记,包含名称、logo、联系方式", + suggestion="添加Organization Schema,包含@type: Organization、name、url、logo", + score=org_score, + max_score=4.0, + )) + + # 2. Product (P0必须, 3分) + product_score = 3.0 if has_product else 0.0 + items.append(DiagnosisItem( + name="Product", + status="pass" if has_product else "fail", + description="产品页的Product标记,包含名称、描述、价格、评价", + suggestion="为产品页添加Product Schema,包含name、description、offers、aggregateRating", + score=product_score, + max_score=3.0, + )) + + # 3. Article/BlogPosting (P0必须, 3分) + article_score = 3.0 if has_article else 0.0 + items.append(DiagnosisItem( + name="Article/BlogPosting", + status="pass" if has_article else "fail", + description="博客文章的Article标记,包含作者、发布日期、摘要", + suggestion="为文章添加Article或BlogPosting Schema,包含author、datePublished、headline", + score=article_score, + max_score=3.0, + )) + + # 4. FAQPage (P1推荐, 2分) + faq_score = 2.0 if has_faq else 0.0 + items.append(DiagnosisItem( + name="FAQPage", + status="pass" if has_faq else "warning", + description="常见问题的FAQPage标记", + suggestion="为FAQ页面添加FAQPage Schema,包含问题和答案对", + score=faq_score, + max_score=2.0, + )) + + # 5. HowTo (P1推荐, 2分) + howto_score = 2.0 if has_howto else 0.0 + items.append(DiagnosisItem( + name="HowTo", + status="pass" if has_howto else "warning", + description="操作指南的HowTo标记", + suggestion="为教程类内容添加HowTo Schema,包含步骤列表", + score=howto_score, + max_score=2.0, + )) + + # 6. BreadcrumbList (P1推荐, 1分) + breadcrumb_score = 1.0 if has_breadcrumb else 0.0 + items.append(DiagnosisItem( + name="BreadcrumbList", + status="pass" if has_breadcrumb else "warning", + description="导航结构的BreadcrumbList标记", + suggestion="添加BreadcrumbList Schema,帮助AI理解页面层级关系", + score=breadcrumb_score, + max_score=1.0, + )) + + total_score = sum(item.score for item in items) + percentage = (total_score / max_score) * 100 + + has_fail = any(item.status == "fail" for item in items) + status = "warning" if has_fail else "pass" + + return GEODimensionScore( + name="Schema标记", + score=total_score, + max_score=max_score, + items=items, + status=status, + percentage=round(percentage, 2), + detail={ + "has_organization": has_organization, + "has_product": has_product, + "has_article": has_article, + "has_faq": has_faq, + "has_howto": has_howto, + "has_breadcrumb": has_breadcrumb, + "schema_count": sum([ + has_organization, has_product, has_article, + has_faq, has_howto, has_breadcrumb, + ]), + }, + ) + + +# ============================================================ +# 维度5: 主题权威诊断 (满分15) +# ============================================================ + +def diagnose_topic_authority( + content_depth_score: float = 0.0, + topic_coverage_ratio: float = 0.0, + entity_consistency_score: float = 0.0, + cluster_completeness: float = 0.0, + total_content_count: int = 0, + topic_cluster_count: int = 0, +) -> GEODimensionScore: + """ + 诊断主题权威 (满分15) + + AI需要验证品牌在特定领域的权威性。 + + Args: + content_depth_score: 内容深度评分 (0-1),目标≥4.6/5即0.92 + topic_coverage_ratio: 话题覆盖度 (0-1),目标≥80% + entity_consistency_score: 实体信号一致性 (0-1),目标≥85% + cluster_completeness: 内链网络集群完整度 (0-1),目标≥70% + total_content_count: 总内容数量 + topic_cluster_count: 主题集群数量 + + Returns: + GEODimensionScore: 主题权威维度评分 + """ + max_score = 15.0 + items = [] + + # 1. 内容深度 (5分) + depth_score = content_depth_score * 5.0 + depth_status = "pass" if content_depth_score >= 0.8 else ("warning" if content_depth_score >= 0.5 else "fail") + items.append(DiagnosisItem( + name="内容深度", + status=depth_status, + description="是否全面覆盖主题,内容质量QScore目标≥4.6/5", + suggestion="增加内容深度,包含详细解释、案例分析、数据支撑", + score=depth_score, + max_score=5.0, + )) + + # 2. 话题覆盖度 (4分) + coverage_score = topic_coverage_ratio * 4.0 + coverage_status = "pass" if topic_coverage_ratio >= 0.8 else ("warning" if topic_coverage_ratio >= 0.5 else "fail") + items.append(DiagnosisItem( + name="话题覆盖度", + status=coverage_status, + description="是否覆盖相关子话题,话题覆盖率目标≥80%", + suggestion="创建覆盖核心话题及其子话题的内容矩阵", + score=coverage_score, + max_score=4.0, + )) + + # 3. 实体信号一致性 (3分) + consistency_score = entity_consistency_score * 3.0 + consistency_status = "pass" if entity_consistency_score >= 0.85 else ("warning" if entity_consistency_score >= 0.6 else "fail") + items.append(DiagnosisItem( + name="实体信号一致性", + status=consistency_status, + description="各页面实体信号是否一致,一致性评分目标≥85%", + suggestion="确保各页面使用一致的品牌名称、描述、行业分类", + score=consistency_score, + max_score=3.0, + )) + + # 4. 内链网络 (3分) + network_score = cluster_completeness * 3.0 + network_status = "pass" if cluster_completeness >= 0.7 else ("warning" if cluster_completeness >= 0.4 else "fail") + items.append(DiagnosisItem( + name="内链网络", + status=network_status, + description="是否形成主题内容集群,集群完整度目标≥70%", + suggestion="建立主题集群,通过内链将相关内容连接成网络", + score=network_score, + max_score=3.0, + )) + + total_score = sum(item.score for item in items) + percentage = (total_score / max_score) * 100 + + has_fail = any(item.status == "fail" for item in items) + status = "warning" if has_fail else "pass" + + return GEODimensionScore( + name="主题权威", + score=total_score, + max_score=max_score, + items=items, + status=status, + percentage=round(percentage, 2), + detail={ + "content_depth_score": round(content_depth_score, 2), + "topic_coverage_ratio": round(topic_coverage_ratio, 2), + "entity_consistency_score": round(entity_consistency_score, 2), + "cluster_completeness": round(cluster_completeness, 2), + "total_content_count": total_content_count, + "topic_cluster_count": topic_cluster_count, + }, + ) + + +# ============================================================ +# 维度6: 引用就绪度诊断 (满分15) +# ============================================================ + +def diagnose_citation_readiness( + answer_ownership_rate: float = 0.0, + citation_accuracy: float = 0.0, + ai_sov: float = 0.0, + competitor_gap: float = 0.0, + total_ai_responses: int = 0, + brand_mention_count: int = 0, + accurate_citation_count: int = 0, +) -> GEODimensionScore: + """ + 诊断引用就绪度 (满分15) + + 评估品牌在AI回答中被引用的可能性。 + + Args: + answer_ownership_rate: AOR - Answer Ownership Rate (0-1),目标≥50% + citation_accuracy: 引用准确率 (0-1),目标≥90% + ai_sov: AI Share of Voice (0-1),目标≥30% + competitor_gap: 与竞品差距 (pp),目标≤10pp + total_ai_responses: AI回答总数 + brand_mention_count: 品牌被提及次数 + accurate_citation_count: 准确引用次数 + + Returns: + GEODimensionScore: 引用就绪度维度评分 + """ + max_score = 15.0 + items = [] + + # 1. 引用频率 AOR (5分) + aor_score = 0.0 + if answer_ownership_rate >= 0.5: + aor_score = 5.0 + elif answer_ownership_rate >= 0.3: + aor_score = 3.5 + elif answer_ownership_rate >= 0.1: + aor_score = 2.0 + else: + aor_score = answer_ownership_rate * 10.0 + aor_status = "pass" if answer_ownership_rate >= 0.5 else ("warning" if answer_ownership_rate >= 0.2 else "fail") + items.append(DiagnosisItem( + name="引用频率 (AOR)", + status=aor_status, + description="品牌在AI回答中被提及的频率,AOR目标≥50%", + suggestion="优化内容结构,提高被AI引用的概率", + score=aor_score, + max_score=5.0, + )) + + # 2. 引用质量 (4分) + accuracy_score = citation_accuracy * 4.0 + accuracy_status = "pass" if citation_accuracy >= 0.9 else ("warning" if citation_accuracy >= 0.7 else "fail") + # 确保满分时得到满分 + if citation_accuracy >= 1.0: + accuracy_score = 4.0 + items.append(DiagnosisItem( + name="引用质量", + status=accuracy_status, + description="引用内容是否准确完整,引用准确率目标≥90%", + suggestion="确保内容准确无误,避免过时或错误信息", + score=accuracy_score, + max_score=4.0, + )) + + # 3. AI声量占比 (3分) + sov_score = 0.0 + if ai_sov >= 0.3: + sov_score = 3.0 + elif ai_sov >= 0.15: + sov_score = 2.0 + elif ai_sov >= 0.05: + sov_score = 1.0 + else: + sov_score = ai_sov * 10.0 + sov_status = "pass" if ai_sov >= 0.3 else ("warning" if ai_sov >= 0.1 else "fail") + items.append(DiagnosisItem( + name="AI声量占比", + status=sov_status, + description="品牌在AI回答中的占比,AI SOV目标≥30%", + suggestion="增加品牌曝光,提高在AI回答中的出现频率", + score=sov_score, + max_score=3.0, + )) + + # 4. 竞品对比 (3分) + gap_score = 0.0 + if competitor_gap <= 0.1: + gap_score = 3.0 + elif competitor_gap <= 0.2: + gap_score = 2.0 + elif competitor_gap <= 0.3: + gap_score = 1.0 + else: + gap_score = max(0.0, 3.0 - competitor_gap * 5) + # 确保差距为0时得满分,差距过大时得0分 + if competitor_gap <= 0.0: + gap_score = 3.0 + if competitor_gap >= 0.6: + gap_score = 0.0 + gap_status = "pass" if competitor_gap <= 0.1 else ("warning" if competitor_gap <= 0.25 else "fail") + items.append(DiagnosisItem( + name="竞品对比", + status=gap_status, + description="与竞品在AI回答中的表现差距,差距目标≤10pp", + suggestion="分析竞品优势,针对性优化内容策略", + score=gap_score, + max_score=3.0, + )) + + total_score = sum(item.score for item in items) + percentage = (total_score / max_score) * 100 + + has_fail = any(item.status == "fail" for item in items) + status = "warning" if has_fail else "pass" + + return GEODimensionScore( + name="引用就绪度", + score=total_score, + max_score=max_score, + items=items, + status=status, + percentage=round(percentage, 2), + detail={ + "answer_ownership_rate": round(answer_ownership_rate, 2), + "citation_accuracy": round(citation_accuracy, 2), + "ai_sov": round(ai_sov, 2), + "competitor_gap": round(competitor_gap, 2), + "total_ai_responses": total_ai_responses, + "brand_mention_count": brand_mention_count, + "accurate_citation_count": accurate_citation_count, + }, + ) + + +# ============================================================ +# 推荐生成 +# ============================================================ + +def generate_recommendations(dimensions: list[GEODimensionScore]) -> list[GEORecommendation]: + """ + 根据诊断结果生成优化建议 + + Args: + dimensions: 各维度诊断结果 + + Returns: + list[GEORecommendation]: 优化建议列表 + """ + recommendations = [] + + for dim in dimensions: + for item in dim.items: + if item.status == "fail": + priority = "P0" + impact = "high" + elif item.status == "warning": + priority = "P1" + impact = "medium" + else: + continue + + # 根据诊断项确定实施难度 + effort = "medium" + if "Schema" in item.name or "标记" in item.name: + effort = "easy" + elif "内容深度" in item.name or "话题覆盖" in item.name: + effort = "hard" + + recommendations.append(GEORecommendation( + priority=priority, + dimension=dim.name, + title=f"优化: {item.name}", + description=item.suggestion, + impact=impact, + effort=effort, + )) + + # 按优先级排序 + priority_order = {"P0": 0, "P1": 1, "P2": 2} + recommendations.sort(key=lambda r: priority_order.get(r.priority, 3)) + + return recommendations + + +# ============================================================ +# 工具函数 +# ============================================================ + +def get_health_level(score: float) -> str: + """ + 根据评分获取健康等级 + + 80+ -> excellent (优秀/绿) + 60-79 -> good (良好/黄) + 40-59 -> pass (及格/橙) + <40 -> danger (危险/红) + """ + if score >= 80: + return "excellent" + if score >= 60: + return "good" + if score >= 40: + return "pass" + return "danger" + + +def get_health_level_label(level: str) -> str: + """获取健康等级中文标签""" + labels = { + "excellent": "优秀", + "good": "良好", + "pass": "及格", + "danger": "危险", + } + return labels.get(level, "未知") + + +# ============================================================ +# GEODiagnosisService 服务类 +# ============================================================ + +@dataclass +class GEODiagnosisInput: + """GEO诊断输入参数""" + # 内容可提取性 + has_direct_answer: bool = False + has_qa_headings: bool = False + has_structured_data: bool = False + has_internal_links: bool = False + has_freshness_info: bool = False + update_days_ago: int | None = None + + # 实体清晰度 + has_brand_definition: bool = False + has_target_audience: bool = False + has_unique_value: bool = False + has_industry_classification: bool = False + + # E-E-A-T信号 + has_author_bio: bool = False + author_credentials_complete: float = 0.0 + has_certifications: bool = False + certification_count: int = 0 + has_data_sources: bool = False + authoritative_source_ratio: float = 0.0 + has_expert_endorsements: bool = False + endorsement_count: int = 0 + + # Schema标记 + has_organization: bool = False + has_product: bool = False + has_article: bool = False + has_faq: bool = False + has_howto: bool = False + has_breadcrumb: bool = False + + # 主题权威 + content_depth_score: float = 0.0 + topic_coverage_ratio: float = 0.0 + entity_consistency_score: float = 0.0 + cluster_completeness: float = 0.0 + total_content_count: int = 0 + topic_cluster_count: int = 0 + + # 引用就绪度 + answer_ownership_rate: float = 0.0 + citation_accuracy: float = 0.0 + ai_sov: float = 0.0 + competitor_gap: float = 0.0 + total_ai_responses: int = 0 + brand_mention_count: int = 0 + accurate_citation_count: int = 0 + + +class GEODiagnosisService: + """GEO诊断服务""" + + def diagnose(self, input_data: GEODiagnosisInput) -> GEODiagnosisResult: + """ + 执行GEO诊断 + + Args: + input_data: 诊断输入参数 + + Returns: + GEODiagnosisResult: 诊断结果 + """ + # 1. 内容可提取性诊断 (20分) + content_extractability = diagnose_content_extractability( + has_direct_answer=input_data.has_direct_answer, + has_qa_headings=input_data.has_qa_headings, + has_structured_data=input_data.has_structured_data, + has_internal_links=input_data.has_internal_links, + has_freshness_info=input_data.has_freshness_info, + update_days_ago=input_data.update_days_ago, + ) + + # 2. 实体清晰度诊断 (15分) + entity_clarity = diagnose_entity_clarity( + has_brand_definition=input_data.has_brand_definition, + has_target_audience=input_data.has_target_audience, + has_unique_value=input_data.has_unique_value, + has_industry_classification=input_data.has_industry_classification, + ) + + # 3. E-E-A-T信号诊断 (20分) + eeat_signals = diagnose_eeat_signals( + has_author_bio=input_data.has_author_bio, + author_credentials_complete=input_data.author_credentials_complete, + has_certifications=input_data.has_certifications, + certification_count=input_data.certification_count, + has_data_sources=input_data.has_data_sources, + authoritative_source_ratio=input_data.authoritative_source_ratio, + has_expert_endorsements=input_data.has_expert_endorsements, + endorsement_count=input_data.endorsement_count, + ) + + # 4. Schema标记诊断 (15分) + schema_markup = diagnose_schema_markup( + has_organization=input_data.has_organization, + has_product=input_data.has_product, + has_article=input_data.has_article, + has_faq=input_data.has_faq, + has_howto=input_data.has_howto, + has_breadcrumb=input_data.has_breadcrumb, + ) + + # 5. 主题权威诊断 (15分) + topic_authority = diagnose_topic_authority( + content_depth_score=input_data.content_depth_score, + topic_coverage_ratio=input_data.topic_coverage_ratio, + entity_consistency_score=input_data.entity_consistency_score, + cluster_completeness=input_data.cluster_completeness, + total_content_count=input_data.total_content_count, + topic_cluster_count=input_data.topic_cluster_count, + ) + + # 6. 引用就绪度诊断 (15分) + citation_readiness = diagnose_citation_readiness( + answer_ownership_rate=input_data.answer_ownership_rate, + citation_accuracy=input_data.citation_accuracy, + ai_sov=input_data.ai_sov, + competitor_gap=input_data.competitor_gap, + total_ai_responses=input_data.total_ai_responses, + brand_mention_count=input_data.brand_mention_count, + accurate_citation_count=input_data.accurate_citation_count, + ) + + # 汇总维度 + dimensions = [ + content_extractability, + entity_clarity, + eeat_signals, + schema_markup, + topic_authority, + citation_readiness, + ] + + # 计算综合评分 + overall_score = sum(dim.score for dim in dimensions) + overall_score = round(min(100.0, max(0.0, overall_score)), 2) + + # 生成推荐 + recommendations = generate_recommendations(dimensions) + + return GEODiagnosisResult( + overall_score=overall_score, + dimensions=dimensions, + recommendations=recommendations, + ) + + def diagnose_from_dict(self, data: dict) -> GEODiagnosisResult: + """ + 从字典执行GEO诊断(便捷方法) + + Args: + data: 诊断参数字典 + + Returns: + GEODiagnosisResult: 诊断结果 + """ + input_data = GEODiagnosisInput(**data) + return self.diagnose(input_data) diff --git a/backend/app/services/health_checker.py b/backend/app/services/health_checker.py new file mode 100644 index 0000000..673e8b5 --- /dev/null +++ b/backend/app/services/health_checker.py @@ -0,0 +1,186 @@ +"""详细健康检查服务""" +import time +from dataclasses import dataclass +from typing import Optional + +import redis.asyncio as aioredis +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncSession + + +@dataclass +class HealthCheckResult: + """健康检查结果""" + name: str + healthy: bool + latency_ms: Optional[float] = None + message: Optional[str] = None + details: Optional[dict] = None + + +class HealthChecker: + """健康检查服务""" + + def __init__(self, db: AsyncSession, redis_url: str): + self.db = db + self.redis_url = redis_url + + async def check_database(self) -> HealthCheckResult: + """检查数据库连接""" + start = time.perf_counter() + try: + await self.db.execute(text("SELECT 1")) + latency = (time.perf_counter() - start) * 1000 + + return HealthCheckResult( + name="database", + healthy=True, + latency_ms=round(latency, 2), + message="Connection OK", + ) + except Exception as e: + latency = (time.perf_counter() - start) * 1000 + return HealthCheckResult( + name="database", + healthy=False, + latency_ms=round(latency, 2), + message=f"Connection failed: {str(e)}", + ) + + async def check_redis(self) -> HealthCheckResult: + """检查Redis连接""" + start = time.perf_counter() + try: + redis = aioredis.from_url( + self.redis_url, + socket_connect_timeout=2, + ) + await redis.ping() + await redis.aclose() + + latency = (time.perf_counter() - start) * 1000 + return HealthCheckResult( + name="redis", + healthy=True, + latency_ms=round(latency, 2), + message="Connection OK", + ) + except Exception as e: + latency = (time.perf_counter() - start) * 1000 + return HealthCheckResult( + name="redis", + healthy=False, + latency_ms=round(latency, 2), + message=f"Connection failed: {str(e)}", + ) + + async def check_llm_providers(self) -> HealthCheckResult: + """检查LLM服务提供商""" + from app.config import settings + from app.services.llm.factory import LLMFactory + + providers = {} + all_healthy = True + + # 检查默认provider + try: + provider_name = getattr(settings, 'DEFAULT_LLM_PROVIDER', 'openai') + provider = LLMFactory.create(provider_name) + providers[provider_name] = { + "healthy": True, + "available": True, + } + except Exception as e: + providers[getattr(settings, 'DEFAULT_LLM_PROVIDER', 'openai')] = { + "healthy": False, + "error": str(e), + } + all_healthy = False + + # 检查所有已注册的provider + for name in LLMFactory.list_providers(): + if name not in providers: + try: + provider = LLMFactory.create(name) + providers[name] = { + "healthy": True, + "available": True, + } + except Exception as e: + providers[name] = { + "healthy": False, + "error": str(e), + } + all_healthy = False + + return HealthCheckResult( + name="llm_providers", + healthy=all_healthy, + message="All providers healthy" if all_healthy else "Some providers unhealthy", + details={"providers": providers}, + ) + + async def check_storage(self) -> HealthCheckResult: + """检查存储(本地文件系统)""" + import os + + storage_path = "/data/documents" + + try: + if os.path.exists(storage_path): + # 检查读写权限 + test_file = os.path.join(storage_path, ".health_check") + with open(test_file, "w") as f: + f.write("ok") + os.remove(test_file) + + return HealthCheckResult( + name="storage", + healthy=True, + message=f"Storage path {storage_path} is writable", + details={"path": storage_path}, + ) + else: + return HealthCheckResult( + name="storage", + healthy=True, + message=f"Storage path {storage_path} does not exist (will be created)", + details={"path": storage_path, "created": True}, + ) + except Exception as e: + return HealthCheckResult( + name="storage", + healthy=False, + message=f"Storage check failed: {str(e)}", + ) + + async def check_all(self) -> dict: + """执行所有健康检查""" + import asyncio + + # 并行执行所有检查 + checks = [ + self.check_database(), + self.check_redis(), + self.check_llm_providers(), + self.check_storage(), + ] + + results = await asyncio.gather(*checks) + + # 汇总结果 + all_healthy = all(r.healthy for r in results) + + return { + "status": "healthy" if all_healthy else "degraded", + "timestamp": time.time(), + "checks": { + r.name: { + "healthy": r.healthy, + "latency_ms": r.latency_ms, + "message": r.message, + "details": r.details, + } + for r in results + }, + } diff --git a/backend/app/services/image_generator.py b/backend/app/services/image_generator.py new file mode 100644 index 0000000..9a3439c --- /dev/null +++ b/backend/app/services/image_generator.py @@ -0,0 +1,312 @@ +"""阿里云百炼图片生成服务""" + +import os +from dataclasses import dataclass +from typing import Optional + +import httpx + +# 平台尺寸适配 +PLATFORM_IMAGE_SPECS = { + "zhihu": { + "cover": {"width": 690, "height": 280, "ratio": "2.5:1"}, + "inline": {"width": 500, "height": 375, "ratio": "4:3"}, + }, + "wechat": { + "cover": {"width": 900, "height": 383, "ratio": "2.35:1"}, + "inline": {"width": 800, "height": 600, "ratio": "4:3"}, + }, + "xiaohongshu": { + "cover": {"width": 1080, "height": 1080, "ratio": "1:1"}, # 方版 + "inline": {"width": 1242, "height": 1660, "ratio": "3:4"}, # 竖版 + }, + "toutiao": { + "cover": {"width": 1024, "height": 678, "ratio": "1.5:1"}, + }, + "baijiahao": { + "cover": {"width": 600, "height": 400, "ratio": "3:2"}, + }, + "weibo": { + "cover": {"width": 980, "height": 560, "ratio": "1.75:1"}, + }, + "bilibili": { + "cover": {"width": 1920, "height": 1080, "ratio": "16:9"}, + }, + "jianshu": { + "cover": {"width": 800, "height": 600, "ratio": "4:3"}, + }, + "juejin": { + "cover": {"width": 1024, "height": 768, "ratio": "4:3"}, + }, + "douyin": { + "cover": {"width": 1080, "height": 1920, "ratio": "9:16"}, # 竖版短视频封面 + }, +} + +# 风格选项 +IMAGE_STYLES = { + "modern": {"name": "现代简约", "prompt": "modern minimalist style, clean design, professional"}, + "tech": {"name": "科技感", "prompt": "tech style, futuristic, digital, blue tones"}, + "elegant": {"name": "优雅商务", "prompt": "elegant business style, sophisticated, premium"}, + "creative": {"name": "创意活力", "prompt": "creative vibrant style, colorful, dynamic"}, + "minimal": {"name": "极简主义", "prompt": "ultra minimal, white space, typography focus"}, +} + +# 排版选项 +LAYOUT_OPTIONS = { + "centered": {"name": "居中排版", "prompt": "centered composition, text in middle"}, + "left_text": {"name": "左文右图", "prompt": "left side text, right side visual"}, + "top_text": {"name": "上文下图", "prompt": "text on top, visual below"}, + "text_overlay": {"name": "文字叠加", "prompt": "text overlay on background image"}, +} + + +@dataclass +class ImageResult: + """图片生成结果""" + url: str + width: int + height: int + prompt: str + platform: str + task_id: str + + +class ImageGenerationError(Exception): + """图片生成异常""" + pass + + +class ImageGenerator: + """阿里云百炼图片生成服务(万相-文生图V1)""" + + def __init__(self): + self.api_key = os.getenv("ALIYUN_DASHSCOPE_API_KEY") + self.base_url = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text2image/image-synthesis" + self.timeout = 120.0 # 异步任务等待超时时间(秒) + + async def generate_cover( + self, + title: str, + platform: str, + image_type: str = "cover", + style: str = "modern", + layout: str = "centered", + custom_prompt: str = None, + ) -> ImageResult: + """生成封面图 + + Args: + title: 文章标题 + platform: 目标平台 + image_type: 图片类型 (cover/inline) + style: 风格选项 + layout: 排版选项 + custom_prompt: 自定义提示词(可选) + + Returns: + ImageResult: 包含生成结果的 dataclass + """ + # 1. 获取平台尺寸 + specs = PLATFORM_IMAGE_SPECS.get(platform, PLATFORM_IMAGE_SPECS["zhihu"]) + size_spec = specs.get(image_type, specs["cover"]) + + # 2. 构建提示词 + if custom_prompt: + prompt = custom_prompt + else: + prompt = self._build_prompt(title, platform, style, layout) + + # 3. 调用百炼API(异步) + task_id = await self._create_task(prompt, size_spec) + + # 4. 轮询结果 + result = await self._wait_for_result(task_id) + + return ImageResult( + url=result["image_url"], + width=size_spec["width"], + height=size_spec["height"], + prompt=prompt, + platform=platform, + task_id=task_id, + ) + + def _build_prompt(self, title: str, platform: str, style: str, layout: str) -> str: + """构建AI提示词 + + Args: + title: 文章标题 + platform: 目标平台 + style: 风格选项 + layout: 排版选项 + + Returns: + str: 构造的英文提示词 + """ + style_prompt = IMAGE_STYLES.get(style, IMAGE_STYLES["modern"])["prompt"] + layout_prompt = LAYOUT_OPTIONS.get(layout, LAYOUT_OPTIONS["centered"])["prompt"] + + # 平台特定要求 + platform_notes = { + "xiaohongshu": "warm tones, lifestyle, lifestyle photography", + "wechat": "professional, clean, suitable for WeChat article cover", + "zhihu": "intellectual, professional, suitable for long-form content", + "toutiao": "eye-catching, clear hierarchy, news style", + "baijiahao": "clear, professional, suitable for news media", + "weibo": "social media friendly, engaging", + "bilibili": "anime style friendly, vibrant, suitable for video platform", + "jianshu": "literary, elegant, clean layout", + "juejin": "tech blog style, developer friendly", + "douyin": "vertical video style, eye-catching, short form content", + } + platform_note = platform_notes.get(platform, "") + + return f"{title}, {style_prompt}, {layout_prompt}, {platform_note}, high quality, 4K" + + async def _create_task(self, prompt: str, size_spec: dict) -> str: + """创建异步任务 + + Args: + prompt: 英文提示词 + size_spec: 尺寸规格 + + Returns: + str: 任务ID + + Raises: + ImageGenerationError: API调用失败时 + """ + if not self.api_key: + raise ImageGenerationError("ALIYUN_DASHSCOPE_API_KEY 未设置") + + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + + payload = { + "model": "wanx-v1", # 万相-文生图V1 + "input": { + "prompt": prompt, + }, + "parameters": { + "size": f"{size_spec['width']}x{size_spec['height']}", + "n": 1, + }, + } + + async with httpx.AsyncClient(timeout=30.0) as client: + try: + response = await client.post( + self.base_url, + headers=headers, + json=payload, + ) + response.raise_for_status() + data = response.json() + + if data.get("code"): + raise ImageGenerationError(f"API错误: {data.get('message', data.get('code'))}") + + # 返回任务ID + task_id = data.get("output", {}).get("task_id") + if not task_id: + raise ImageGenerationError("未获取到任务ID") + + return task_id + + except httpx.HTTPStatusError as e: + raise ImageGenerationError(f"HTTP错误: {e.response.status_code}") + except httpx.RequestError as e: + raise ImageGenerationError(f"请求错误: {e}") + + async def _wait_for_result(self, task_id: str) -> dict: + """轮询等待结果 + + Args: + task_id: 任务ID + + Returns: + dict: 包含 image_url 的结果 + + Raises: + ImageGenerationError: 任务失败或超时 + """ + status_url = f"{self.base_url}/task/{task_id}" + + headers = { + "Authorization": f"Bearer {self.api_key}", + } + + async with httpx.AsyncClient(timeout=self.timeout) as client: + import asyncio + start_time = asyncio.get_event_loop().time() + poll_interval = 2.0 # 轮询间隔(秒) + + while True: + try: + response = await client.get(status_url, headers=headers) + response.raise_for_status() + data = response.json() + + status = data.get("output", {}).get("task_status") + + if status == "succeeded": + # 任务成功 + images = data.get("output", {}).get("results", []) + if images: + return {"image_url": images[0].get("url")} + raise ImageGenerationError("未获取到生成图片URL") + + elif status == "failed": + error_msg = data.get("output", {}).get("message", "任务失败") + raise ImageGenerationError(f"图片生成失败: {error_msg}") + + elif status == "pending": + # 等待中,继续轮询 + pass + + else: + # 未知状态,继续等待 + pass + + # 检查超时 + elapsed = asyncio.get_event_loop().time() - start_time + if elapsed > self.timeout: + raise ImageGenerationError("图片生成超时") + + # 等待后继续轮询 + await asyncio.sleep(poll_interval) + + except httpx.HTTPStatusError as e: + raise ImageGenerationError(f"HTTP错误: {e.response.status_code}") + except httpx.RequestError as e: + raise ImageGenerationError(f"请求错误: {e}") + + @staticmethod + def get_platform_specs(platform: str) -> Optional[dict]: + """获取平台的图片规格 + + Args: + platform: 平台标识 + + Returns: + Optional[dict]: 平台图片规格,如果没有则返回None + """ + return PLATFORM_IMAGE_SPECS.get(platform) + + @staticmethod + def get_supported_platforms() -> list[str]: + """获取支持的平台列表""" + return list(PLATFORM_IMAGE_SPECS.keys()) + + @staticmethod + def get_styles() -> dict: + """获取所有风格选项""" + return IMAGE_STYLES + + @staticmethod + def get_layouts() -> dict: + """获取所有排版选项""" + return LAYOUT_OPTIONS \ No newline at end of file diff --git a/backend/app/services/knowledge/__init__.py b/backend/app/services/knowledge/__init__.py index 5ad519c..ab0adb9 100644 --- a/backend/app/services/knowledge/__init__.py +++ b/backend/app/services/knowledge/__init__.py @@ -2,5 +2,21 @@ from .rag_service import RAGService from .chunker import RecursiveChunker from .embedder import EmbeddingService, OpenAIEmbedder, MockEmbedder from .retriever import HybridRetriever +from .entity_extractor import EntityExtractor, ExtractionResult, ExtractedEntity, ExtractedRelation +from .graph_builder import GraphBuilder +from .graph_query import GraphQuery -__all__ = ["RAGService", "RecursiveChunker", "EmbeddingService", "OpenAIEmbedder", "MockEmbedder", "HybridRetriever"] +__all__ = [ + "RAGService", + "RecursiveChunker", + "EmbeddingService", + "OpenAIEmbedder", + "MockEmbedder", + "HybridRetriever", + "EntityExtractor", + "ExtractionResult", + "ExtractedEntity", + "ExtractedRelation", + "GraphBuilder", + "GraphQuery", +] diff --git a/backend/app/services/knowledge/chunker.py b/backend/app/services/knowledge/chunker.py index c49e3f7..a7c9533 100644 --- a/backend/app/services/knowledge/chunker.py +++ b/backend/app/services/knowledge/chunker.py @@ -1,178 +1,212 @@ """ -RecursiveChunker: 递归语义分块器 -按优先级分隔符(段落→句子→词)将文档切割为适合embedding的块。 +分块策略 - 支持多种分块方式 """ import re +from abc import ABC, abstractmethod +from dataclasses import dataclass from typing import Optional +@dataclass +class ChunkStrategy: + """分块策略配置""" + name: str + description: str + chunk_size: int # 字符数 + chunk_overlap: int # 重叠字符数 + min_chunk_size: int -class RecursiveChunker: - """递归语义分块器""" - - def __init__( - self, - chunk_size: int = 512, - chunk_overlap: int = 50, - min_chunk_size: int = 100, - ): - self.chunk_size = chunk_size - self.chunk_overlap = chunk_overlap - self.min_chunk_size = min_chunk_size - # 分隔符优先级:段落 > 句子 > 词 - self.separators = ["\n\n", "\n", "。", ".", "!", "!", "?", "?", ";", ";", " "] - - # ------------------------------------------------------------------ - # Public API - # ------------------------------------------------------------------ - +class BaseChunker(ABC): + """分块器基类""" + + STRATEGY: ChunkStrategy = None + + @abstractmethod def chunk(self, text: str, metadata: Optional[dict] = None) -> list[dict]: - """ - 将文本递归分块。 + """执行分块""" + pass + + def preview(self, text: str, max_chunks: int = 5) -> list[str]: + """预览分块结果""" + chunks = self.chunk(text) + return [c["content"][:200] + "..." if len(c["content"]) > 200 else c["content"] + for c in chunks[:max_chunks]] - Returns: - list of dicts: - { - "content": str, - "chunk_index": int, - "token_count": int, - "metadata": dict, - } - """ - if not text or not text.strip(): - return [] - - raw_chunks = self._split_recursive(text.strip(), self.separators) - - # 合并过短的块 & 添加重叠 - merged = self._merge_small_chunks(raw_chunks) - result = [] - for idx, content in enumerate(merged): - result.append( - { - "content": content, - "chunk_index": idx, - "token_count": self._estimate_tokens(content), - "metadata": metadata or {}, - } - ) - return result - - # ------------------------------------------------------------------ - # Internal helpers - # ------------------------------------------------------------------ - - def _split_recursive(self, text: str, separators: list[str]) -> list[str]: - """ - 递归分割:尝试当前分隔符,若块太大则用下一级分隔符继续。 - """ - if not separators: - # 最后手段:按字符强制截断 - return self._hard_split(text) - - sep = separators[0] - remaining_seps = separators[1:] - - # 先用当前分隔符分割 - splits = text.split(sep) - # 去掉空串,但保留分隔符语义(拼回去) - splits = [s for s in splits if s.strip()] - - if len(splits) <= 1: - # 该分隔符无法分割,尝试下一级 - return self._split_recursive(text, remaining_seps) - - chunks: list[str] = [] - current_buffer = "" - - for piece in splits: - candidate = (current_buffer + sep + piece).strip() if current_buffer else piece.strip() - if self._estimate_tokens(candidate) <= self.chunk_size: - current_buffer = candidate - else: - # 当前 buffer 达到 chunk_size - if current_buffer: - # buffer 本身是否太大?若是,递归细分 - if self._estimate_tokens(current_buffer) > self.chunk_size: - chunks.extend(self._split_recursive(current_buffer, remaining_seps)) - else: - chunks.append(current_buffer) - # piece 单独处理 - if self._estimate_tokens(piece) > self.chunk_size: - chunks.extend(self._split_recursive(piece, remaining_seps)) - current_buffer = "" - else: - current_buffer = piece.strip() - - if current_buffer: - if self._estimate_tokens(current_buffer) > self.chunk_size: - chunks.extend(self._split_recursive(current_buffer, remaining_seps)) - else: - chunks.append(current_buffer) - - return [c for c in chunks if c.strip()] - - def _merge_small_chunks(self, chunks: list[str]) -> list[str]: - """ - 合并过短的块(< min_chunk_size token),并在相邻块间加入重叠文本。 - """ - if not chunks: - return [] - - merged: list[str] = [] - buffer = chunks[0] - - for chunk in chunks[1:]: - if self._estimate_tokens(buffer) < self.min_chunk_size: - buffer = buffer + "\n" + chunk - else: - merged.append(buffer) - # 添加重叠:取上一块末尾若干 token 作为前缀 - overlap_text = self._get_overlap_prefix(buffer) - buffer = (overlap_text + "\n" + chunk).strip() if overlap_text else chunk - - merged.append(buffer) - return [c for c in merged if c.strip()] - - def _get_overlap_prefix(self, text: str) -> str: - """截取文本末尾作为下一块的重叠前缀(按 token 估算)。""" - if self.chunk_overlap <= 0: - return "" - # 简单实现:按字符比例截取 - words = text.split() - if not words: - return "" - # 估算每个词约 1.5 token(中英混合) - token_per_word = 1.5 - overlap_words = int(self.chunk_overlap / token_per_word) - overlap_words = max(1, min(overlap_words, len(words))) - return " ".join(words[-overlap_words:]) - - def _hard_split(self, text: str) -> list[str]: - """按字符强制截断(最后手段)。""" - # 粗略:1 token ≈ 2 字符(中文) - char_limit = self.chunk_size * 2 +class RecursiveChunker(BaseChunker): + """递归分块器(现有实现)""" + + STRATEGY = ChunkStrategy( + name="recursive", + description="优先按段落分割,过长时按句子分割", + chunk_size=500, + chunk_overlap=50, + min_chunk_size=50, + ) + + # 分割模式(按优先级) + SEPARATORS = [ + r"\n\n+", # 双换行(段落) + r"\n", # 单换行 + r"[。!?!?]\s*", # 句子结束 + r"[,,;;]\s*", # 分句 + r"\s+", # 空格 + ] + + def chunk(self, text: str, metadata: Optional[dict] = None) -> list[dict]: chunks = [] - start = 0 - while start < len(text): - end = start + char_limit - chunks.append(text[start:end]) - start = end - self.chunk_overlap * 2 # 加入字符级重叠 - if start <= 0: - start = end + metadata = metadata or {} + + # 按段落分割 + segments = re.split(r"\n\n+", text) + + current_chunk = "" + for segment in segments: + if len(current_chunk) + len(segment) <= self.STRATEGY.chunk_size: + current_chunk += segment + "\n\n" + else: + # 当前块足够大,保存 + if len(current_chunk.strip()) >= self.STRATEGY.min_chunk_size: + chunks.append({ + "content": current_chunk.strip(), + "chunk_index": len(chunks), + "metadata": metadata, + }) + + # 处理过长段落 + if len(segment) > self.STRATEGY.chunk_size: + current_chunk = segment + else: + # 保留重叠 + overlap = current_chunk[-self.STRATEGY.chunk_overlap:] + current_chunk = overlap + segment + "\n\n" + + # 处理最后一个块 + if len(current_chunk.strip()) >= self.STRATEGY.min_chunk_size: + chunks.append({ + "content": current_chunk.strip(), + "chunk_index": len(chunks), + "metadata": metadata, + }) + return chunks - def _estimate_tokens(self, text: str) -> int: - """ - 估算 token 数。 - 规则:中文字符每字计 1 token,英文单词计 1.3 token(BPE 碎片系数)。 - """ - if not text: - return 0 +class SemanticChunker(BaseChunker): + """语义分块器 - 按语义边界分割""" + + STRATEGY = ChunkStrategy( + name="semantic", + description="根据语义边界(标题、段落)自动分块", + chunk_size=800, + chunk_overlap=100, + min_chunk_size=100, + ) + + # 语义边界模式 + SEMANTIC_PATTERNS = [ + (r"^#{1,6}\s+(.+)$", "heading"), # Markdown标题 + (r"^【(.+?)】\s*$", "heading"), # 中文标题 + (r"^第[一二三四五六七八九十百]+[章节条]", "heading"), # 章节标题 + (r"^(\d+\.)+\s+", "heading"), # 数字编号 + ] + + def chunk(self, text: str, metadata: Optional[dict] = None) -> list[dict]: + chunks = [] + metadata = metadata or {} + lines = text.split("\n") + + current_chunk = "" + current_section = None + + for line in lines: + # 检查是否是语义边界 + is_boundary = False + for pattern, boundary_type in self.SEMANTIC_PATTERNS: + if re.match(pattern, line.strip()): + is_boundary = True + current_section = line.strip() + break + + # 如果是边界且当前块不为空,保存 + if is_boundary and current_chunk.strip(): + chunks.append({ + "content": current_chunk.strip(), + "chunk_index": len(chunks), + "section": current_section, + "metadata": metadata, + }) + # 保留重叠 + overlap = current_chunk[-self.STRATEGY.chunk_overlap:] + current_chunk = overlap + line + "\n" + else: + current_chunk += line + "\n" + + # 检查块大小 + if len(current_chunk) >= self.STRATEGY.chunk_size: + chunks.append({ + "content": current_chunk.strip(), + "chunk_index": len(chunks), + "section": current_section, + "metadata": metadata, + }) + overlap = current_chunk[-self.STRATEGY.chunk_overlap:] + current_chunk = overlap + + # 处理最后一个块 + if current_chunk.strip(): + chunks.append({ + "content": current_chunk.strip(), + "chunk_index": len(chunks), + "section": current_section, + "metadata": metadata, + }) + + return chunks - # 中文字符计数 - chinese_chars = len(re.findall(r"[\u4e00-\u9fff\u3400-\u4dbf]", text)) - # 去掉中文后,计算英文单词数 - non_chinese = re.sub(r"[\u4e00-\u9fff\u3400-\u4dbf]", " ", text) - english_words = len(non_chinese.split()) +class FixedLengthChunker(BaseChunker): + """固定长度分块器""" + + STRATEGY = ChunkStrategy( + name="fixed", + description="按固定长度强制分块", + chunk_size=300, + chunk_overlap=30, + min_chunk_size=50, + ) + + def chunk(self, text: str, metadata: Optional[dict] = None) -> list[dict]: + chunks = [] + metadata = metadata or {} + + # 移除多余空白 + text = re.sub(r"\s+", " ", text) + + for i in range(0, len(text), self.STRATEGY.chunk_size - self.STRATEGY.chunk_overlap): + chunk_text = text[i:i + self.STRATEGY.chunk_size] + + if len(chunk_text.strip()) >= self.STRATEGY.min_chunk_size: + chunks.append({ + "content": chunk_text.strip(), + "chunk_index": len(chunks), + "metadata": metadata, + }) + + return chunks - return int(chinese_chars + english_words * 1.3) +class ChunkerFactory: + """分块策略工厂""" + + STRATEGIES = { + "recursive": RecursiveChunker, + "semantic": SemanticChunker, + "fixed": FixedLengthChunker, + } + + @classmethod + def create(cls, strategy: str = "recursive") -> BaseChunker: + """创建分块器""" + chunker_cls = cls.STRATEGIES.get(strategy, RecursiveChunker) + return chunker_cls() + + @classmethod + def list_strategies(cls) -> list[ChunkStrategy]: + """列出所有策略""" + return [chunker_cls.STRATEGY for chunker_cls in cls.STRATEGIES.values()] \ No newline at end of file diff --git a/backend/app/services/knowledge/enhanced_rag.py b/backend/app/services/knowledge/enhanced_rag.py new file mode 100644 index 0000000..20a1a54 --- /dev/null +++ b/backend/app/services/knowledge/enhanced_rag.py @@ -0,0 +1,149 @@ +""" +增强版RAG服务 - 包含重排序和上下文压缩 +""" +import re +from typing import Optional + +from app.services.llm.factory import LLMFactory + + +class EnhancedRAG: + """增强版RAG检索服务""" + + def __init__(self, rag_service, embedder): + self.rag = rag_service + self.embedder = embedder + self.llm = LLMFactory.create() + + async def retrieve_with_rerank( + self, + session, + query: str, + kb_ids: list[str], + top_k: int = 5, + use_rerank: bool = True, + use_compression: bool = False, + ) -> list[dict]: + """ + 增强检索流程: + 1. 初始检索(扩大候选集) + 2. 可选:重排序 + 3. 可选:上下文压缩 + """ + # Step 1: 初始检索 + initial_k = top_k * 4 if use_rerank else top_k + candidates = await self.rag.search( + session, query, kb_ids, top_k=initial_k + ) + + if not candidates: + return [] + + # Step 2: 可选重排序 + if use_rerank and len(candidates) > top_k: + candidates = await self._rerank(query, candidates, top_k) + + # Step 3: 可选上下文压缩 + if use_compression: + candidates = await self._compress(candidates, query) + + return candidates[:top_k] + + async def _rerank( + self, + query: str, + candidates: list[dict], + top_k: int, + ) -> list[dict]: + """ + 使用LLM进行相关性重排序 + + 对每个候选计算与查询的相关性分数,然后排序 + """ + reranked = [] + + for item in candidates: + # 提取候选内容片段 + content = item.get("content", "")[:500] # 限制长度 + + # 构建评估Prompt + prompt = f"""评估以下查询与文档片段的相关性。 + +查询:{query} + +文档片段: +{content} + +请只返回一个0到1之间的小数,表示相关性分数。0表示完全不相关,1表示完全相关。只返回数字:""" + + try: + response = await self.llm.generate(prompt) + # 提取数字 + match = re.search(r'0?\.\d+', response) + relevance = float(match.group()) if match else 0.5 + except Exception: + relevance = item.get("score", 0.5) + + item["relevance_score"] = relevance + reranked.append(item) + + # 按相关性分数降序排序 + reranked.sort(key=lambda x: x["relevance_score"], reverse=True) + + return reranked + + async def _compress( + self, + candidates: list[dict], + query: str, + max_context_tokens: int = 2000, + ) -> list[dict]: + """ + 上下文压缩 + + 从每个chunk中提取与query相关的内容,减少token消耗 + """ + compressed = [] + current_tokens = 0 + + for chunk in candidates: + content = chunk.get("content", "") + + # 估算token(中文约1.5字符/token) + est_tokens = len(content) // 2 + + if current_tokens + est_tokens <= max_context_tokens: + chunk["compressed"] = False + compressed.append(chunk) + current_tokens += est_tokens + else: + # 尝试压缩这个chunk + compressed_chunk = await self._compress_chunk(content, query) + chunk["compressed"] = True + chunk["compressed_content"] = compressed_chunk + compressed.append(chunk) + break + + return compressed + + async def _compress_chunk( + self, + content: str, + query: str, + ) -> str: + """压缩单个chunk,保留与query相关的内容""" + prompt = f"""从以下文本中提取与问题最相关的部分,保持原文的表达方式,不要总结或改写。 + +问题:{query} + +原文: +{content} + +直接返回提取的内容(不要解释):""" + + try: + result = await self.llm.generate(prompt) + return result.strip() + except Exception: + # 压缩失败时返回原文 + return content[:500] + "..." \ No newline at end of file diff --git a/backend/app/services/knowledge/entity_extractor.py b/backend/app/services/knowledge/entity_extractor.py new file mode 100644 index 0000000..1c917a8 --- /dev/null +++ b/backend/app/services/knowledge/entity_extractor.py @@ -0,0 +1,168 @@ +"""实体和关系抽取服务""" +import re +import json +from typing import Optional +from dataclasses import dataclass, field + +from app.services.llm.factory import LLMFactory + + +@dataclass +class ExtractedEntity: + """抽取的实体""" + name: str + entity_type: str + description: Optional[str] = None + properties: dict = field(default_factory=dict) + + +@dataclass +class ExtractedRelation: + """抽取的关系""" + source_entity: str + target_entity: str + relation_type: str + properties: dict = field(default_factory=dict) + + +@dataclass +class ExtractionResult: + """抽取结果""" + entities: list[ExtractedEntity] = field(default_factory=list) + relations: list[ExtractedRelation] = field(default_factory=list) + + +class EntityExtractor: + """实体和关系抽取服务""" + + # 实体类型映射 + ENTITY_TYPES = [ + "ORGANIZATION", # 公司/组织 + "PRODUCT", # 产品 + "PERSON", # 人物 + "LOCATION", # 地点 + "TECHNOLOGY", # 技术 + "BRAND", # 品牌 + "EVENT", # 事件 + "CONCEPT", # 概念 + ] + + # 关系类型映射 + RELATION_TYPES = [ + "COMPETES_WITH", # 竞争对手 + "PARTNERS_WITH", # 合作伙伴 + "PRODUCES", # 生产 + "USES_TECHNOLOGY", # 使用技术 + "LOCATED_IN", # 位于 + "FOUNDED_IN", # 成立于 + "CEO_OF", # CEO + "FOUNDER_OF", # 创始人 + "RELATED_TO", # 相关 + "PART_OF", # 属于 + ] + + def __init__(self): + self.llm = LLMFactory.create() + + async def extract(self, text: str, context: Optional[str] = None) -> ExtractionResult: + """ + 从文本中抽取实体和关系 + + Args: + text: 待处理的文本 + context: 可选的上下文信息(如品牌名、行业等) + + Returns: + ExtractionResult: 包含实体和关系的抽取结果 + """ + # 构建抽取Prompt + prompt = self._build_extraction_prompt(text, context) + + # 调用LLM + response = await self.llm.generate(prompt) + + # 解析结果 + return self._parse_response(response) + + def _build_extraction_prompt(self, text: str, context: Optional[str] = None) -> str: + """构建抽取Prompt""" + entity_types = "\n".join([f"- {t}" for t in self.ENTITY_TYPES]) + relation_types = "\n".join([f"- {t}" for t in self.RELATION_TYPES]) + + context_hint = f"\n\n附加上下文:{context}" if context else "" + + return f"""从以下文本中抽取知识图谱的实体和关系。 + +要求: +1. 实体必须从文本中明确提及,不能臆造 +2. 关系必须有文本依据,不能臆造 +3. 每个实体和关系都要有置信度说明(high/medium/low) + +实体类型: +{entity_types} + +关系类型: +{relation_types} + +{context_hint} + +文本内容: +{text} + +请以JSON格式返回结果: +{{ + "entities": [ + {{ + "name": "实体名称", + "entity_type": "实体类型", + "description": "实体描述(可选)", + "confidence": "high/medium/low" + }} + ], + "relations": [ + {{ + "source_entity": "源实体名称", + "target_entity": "目标实体名称", + "relation_type": "关系类型", + "confidence": "high/medium/low" + }} + ] +}} + +只返回JSON,不要有其他内容:""" + + def _parse_response(self, response: str) -> ExtractionResult: + """解析LLM返回的结果""" + # 提取JSON + json_match = re.search(r'\{[\s\S]*\}', response) + if not json_match: + return ExtractionResult(entities=[], relations=[]) + + try: + data = json.loads(json_match.group()) + except json.JSONDecodeError: + return ExtractionResult(entities=[], relations=[]) + + # 解析实体 + entities = [ + ExtractedEntity( + name=e["name"], + entity_type=e["entity_type"], + description=e.get("description"), + properties={"confidence": e.get("confidence", "medium")}, + ) + for e in data.get("entities", []) + ] + + # 解析关系 + relations = [ + ExtractedRelation( + source_entity=r["source_entity"], + target_entity=r["target_entity"], + relation_type=r["relation_type"], + properties={"confidence": r.get("confidence", "medium")}, + ) + for r in data.get("relations", []) + ] + + return ExtractionResult(entities=entities, relations=relations) diff --git a/backend/app/services/knowledge/graph_builder.py b/backend/app/services/knowledge/graph_builder.py new file mode 100644 index 0000000..dedf653 --- /dev/null +++ b/backend/app/services/knowledge/graph_builder.py @@ -0,0 +1,187 @@ +"""知识图谱构建服务""" +from typing import Optional + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.knowledge_graph import ( + KnowledgeEntity, + KnowledgeRelation, + EntityType, + RelationType, +) +from app.models.knowledge import KnowledgeChunk, KnowledgeDocument +from app.services.knowledge.entity_extractor import EntityExtractor, ExtractionResult + + +class GraphBuilder: + """知识图谱构建服务""" + + def __init__(self): + self.extractor = EntityExtractor() + + async def build_from_chunk( + self, + session: AsyncSession, + chunk_id: str, + context: Optional[str] = None, + ) -> dict: + """ + 从Chunk构建知识图谱 + + Args: + session: 数据库会话 + chunk_id: Chunk ID + context: 可选的上下文(如品牌名) + + Returns: + 构建统计信息 + """ + # 1. 获取Chunk内容 + chunk = await session.get(KnowledgeChunk, chunk_id) + if not chunk: + raise ValueError(f"Chunk not found: {chunk_id}") + + # 2. 抽取实体和关系 + result = await self.extractor.extract(chunk.content, context) + + # 3. 存储到图谱 + stats = await self._store_extraction(session, chunk_id, result) + + return stats + + async def _store_extraction( + self, + session: AsyncSession, + chunk_id: str, + result: ExtractionResult, + ) -> dict: + """存储抽取结果""" + stats = { + "entities_created": 0, + "entities_existing": 0, + "relations_created": 0, + "relations_existing": 0, + } + + # 实体名称到ID的映射 + entity_map = {} + + # 4. 存储实体 + for extracted_entity in result.entities: + # 检查是否已存在 + existing, created = await self._get_or_create_entity( + session, + chunk_id, + extracted_entity, + ) + + entity_map[extracted_entity.name] = existing.id + if created: + stats["entities_created"] += 1 + else: + stats["entities_existing"] += 1 + + # 5. 存储关系 + for extracted_relation in result.relations: + # 查找实体ID + source_id = entity_map.get(extracted_relation.source_entity) + target_id = entity_map.get(extracted_relation.target_entity) + + if not source_id or not target_id: + continue # 跳过找不到实体的情况 + + # 创建关系 + created = await self._create_relation( + session, + chunk_id, + source_id, + target_id, + extracted_relation, + ) + + if created: + stats["relations_created"] += 1 + else: + stats["relations_existing"] += 1 + + await session.commit() + return stats + + async def _get_or_create_entity( + self, + session: AsyncSession, + chunk_id: str, + extracted_entity, + ) -> tuple: + """获取或创建实体""" + kb_id = await self._get_chunk_kb_id(session, chunk_id) + + # 查找现有实体 + stmt = select(KnowledgeEntity).where( + KnowledgeEntity.knowledge_base_id == kb_id, + KnowledgeEntity.name == extracted_entity.name, + ) + result = await session.execute(stmt) + existing = result.scalar_one_or_none() + + if existing: + return (existing, False) + + # 创建新实体 + entity = KnowledgeEntity( + knowledge_base_id=kb_id, + name=extracted_entity.name, + entity_type=EntityType(extracted_entity.entity_type), + description=extracted_entity.description, + properties=extracted_entity.properties or {}, + source_chunk_id=chunk_id, + confidence=extracted_entity.properties.get("confidence") if extracted_entity.properties else None, + ) + session.add(entity) + await session.flush() + + return (entity, True) + + async def _create_relation( + self, + session: AsyncSession, + chunk_id: str, + source_id: str, + target_id: str, + extracted_relation, + ) -> bool: + """创建关系(如果不存在)""" + # 检查是否已存在 + stmt = select(KnowledgeRelation).where( + KnowledgeRelation.source_entity_id == source_id, + KnowledgeRelation.target_entity_id == target_id, + KnowledgeRelation.relation_type == RelationType(extracted_relation.relation_type), + ) + result = await session.execute(stmt) + existing = result.scalar_one_or_none() + + if existing: + return False + + # 创建关系 + relation = KnowledgeRelation( + source_entity_id=source_id, + target_entity_id=target_id, + relation_type=RelationType(extracted_relation.relation_type), + properties=extracted_relation.properties or {}, + source_chunk_id=chunk_id, + confidence=extracted_relation.properties.get("confidence") if extracted_relation.properties else None, + ) + session.add(relation) + return True + + async def _get_chunk_kb_id(self, session: AsyncSession, chunk_id: str) -> str: + """获取Chunk所属的知识库ID""" + chunk = await session.get(KnowledgeChunk, chunk_id) + if not chunk: + raise ValueError(f"Chunk not found: {chunk_id}") + + # 通过document获取kb_id + doc = await session.get(KnowledgeDocument, chunk.document_id) + return doc.knowledge_base_id diff --git a/backend/app/services/knowledge/graph_query.py b/backend/app/services/knowledge/graph_query.py new file mode 100644 index 0000000..9c09082 --- /dev/null +++ b/backend/app/services/knowledge/graph_query.py @@ -0,0 +1,249 @@ +"""知识图谱查询服务""" +from typing import Optional + +from sqlalchemy import select, func +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + +from app.models.knowledge_graph import ( + KnowledgeEntity, + KnowledgeRelation, +) + + +class GraphQuery: + """知识图谱查询服务""" + + async def get_entity( + self, + session: AsyncSession, + entity_id: str, + ) -> Optional[dict]: + """根据ID获取实体详情""" + entity = await session.get(KnowledgeEntity, entity_id) + if not entity: + return None + + return self._entity_to_dict(entity) + + async def search_entities( + self, + session: AsyncSession, + kb_id: str, + query: str, + entity_type: Optional[str] = None, + limit: int = 20, + ) -> list[dict]: + """搜索实体""" + stmt = select(KnowledgeEntity).where( + KnowledgeEntity.knowledge_base_id == kb_id, + KnowledgeEntity.name.ilike(f"%{query}%"), + ) + + if entity_type: + stmt = stmt.where(KnowledgeEntity.entity_type == entity_type) + + stmt = stmt.limit(limit) + result = await session.execute(stmt) + + return [self._entity_to_dict(e) for e in result.scalars()] + + async def get_entity_neighbors( + self, + session: AsyncSession, + entity_id: str, + max_depth: int = 1, + ) -> dict: + """获取实体的邻居(直接关联的实体)""" + entity = await session.get(KnowledgeEntity, entity_id) + if not entity: + return None + + neighbors = { + "entity": self._entity_to_dict(entity), + "incoming": [], # 入边(别人指向我) + "outgoing": [], # 出边(我指向别人) + } + + # 获取入边 + incoming_stmt = ( + select(KnowledgeRelation, KnowledgeEntity) + .join(KnowledgeEntity, KnowledgeRelation.source_entity_id == KnowledgeEntity.id) + .where(KnowledgeRelation.target_entity_id == entity_id) + ) + incoming_result = await session.execute(incoming_stmt) + for rel, source_entity in incoming_result: + neighbors["incoming"].append({ + "relation": self._relation_to_dict(rel), + "entity": self._entity_to_dict(source_entity), + }) + + # 获取出边 + outgoing_stmt = ( + select(KnowledgeRelation, KnowledgeEntity) + .join(KnowledgeEntity, KnowledgeRelation.target_entity_id == KnowledgeEntity.id) + .where(KnowledgeRelation.source_entity_id == entity_id) + ) + outgoing_result = await session.execute(outgoing_stmt) + for rel, target_entity in outgoing_result: + neighbors["outgoing"].append({ + "relation": self._relation_to_dict(rel), + "entity": self._entity_to_dict(target_entity), + }) + + return neighbors + + async def get_entity_path( + self, + session: AsyncSession, + source_name: str, + target_name: str, + max_hops: int = 3, + ) -> list[dict]: + """ + 查找两个实体之间的路径 + + 使用简单BFS查找路径 + """ + # 获取实体ID + source_stmt = select(KnowledgeEntity).where( + KnowledgeEntity.name == source_name + ) + source_result = await session.execute(source_stmt) + source_entity = source_result.scalar_one_or_none() + + target_stmt = select(KnowledgeEntity).where( + KnowledgeEntity.name == target_name + ) + target_result = await session.execute(target_stmt) + target_entity = target_result.scalar_one_or_none() + + if not source_entity or not target_entity: + return [] + + # BFS查找路径 + visited = {str(source_entity.id)} + queue = [(str(source_entity.id), [])] + + while queue: + current_id, path = queue.pop(0) + + if current_id == str(target_entity.id): + # 找到路径,返回 + return await self._format_path(path, session) + + if len(path) >= max_hops: + continue + + # 探索邻居 + neighbors_stmt = ( + select(KnowledgeRelation, KnowledgeEntity) + .join(KnowledgeEntity, KnowledgeRelation.target_entity_id == KnowledgeEntity.id) + .where(KnowledgeRelation.source_entity_id == current_id) + ) + neighbors_result = await session.execute(neighbors_stmt) + + for rel, neighbor in neighbors_result: + neighbor_id = str(neighbor.id) + if neighbor_id not in visited: + visited.add(neighbor_id) + new_path = path + [{ + "from": current_id, + "relation": rel.relation_type.value, + "to": neighbor_id, + }] + queue.append((neighbor_id, new_path)) + + return [] + + async def get_statistics( + self, + session: AsyncSession, + kb_id: str, + ) -> dict: + """获取图谱统计信息""" + # 实体数量 + entity_count_stmt = select(func.count()).where( + KnowledgeEntity.knowledge_base_id == kb_id + ) + entity_count_result = await session.execute(entity_count_stmt) + entity_count = entity_count_result.scalar() or 0 + + # 关系数量 + kb_entities = select(KnowledgeEntity.id).where( + KnowledgeEntity.knowledge_base_id == kb_id + ) + relation_count_stmt = select(func.count()).where( + KnowledgeRelation.source_entity_id.in_(kb_entities) + ) + relation_count_result = await session.execute(relation_count_stmt) + relation_count = relation_count_result.scalar() or 0 + + # 实体类型分布 + type_dist_stmt = ( + select( + KnowledgeEntity.entity_type, + func.count() + ) + .where(KnowledgeEntity.knowledge_base_id == kb_id) + .group_by(KnowledgeEntity.entity_type) + ) + type_dist_result = await session.execute(type_dist_stmt) + entity_type_dist = {str(k): v for k, v in type_dist_result} + + # 关系类型分布 + rel_type_dist_stmt = ( + select( + KnowledgeRelation.relation_type, + func.count() + ) + .where(KnowledgeRelation.source_entity_id.in_(kb_entities)) + .group_by(KnowledgeRelation.relation_type) + ) + rel_type_dist_result = await session.execute(rel_type_dist_stmt) + relation_type_dist = {str(k): v for k, v in rel_type_dist_result} + + return { + "entity_count": entity_count, + "relation_count": relation_count, + "entity_type_distribution": entity_type_dist, + "relation_type_distribution": relation_type_dist, + } + + def _entity_to_dict(self, entity: KnowledgeEntity) -> dict: + """实体转字典""" + return { + "id": str(entity.id), + "name": entity.name, + "entity_type": entity.entity_type.value, + "description": entity.description, + "properties": entity.properties, + "confidence": entity.confidence, + } + + def _relation_to_dict(self, relation: KnowledgeRelation) -> dict: + """关系转字典""" + return { + "id": str(relation.id), + "source_id": str(relation.source_entity_id), + "target_id": str(relation.target_entity_id), + "relation_type": relation.relation_type.value, + "properties": relation.properties, + "confidence": relation.confidence, + } + + async def _format_path(self, path: list, session: AsyncSession) -> list[dict]: + """格式化路径,返回实体名称""" + formatted = [] + for step in path: + # 获取实体名称 + from_entity = await session.get(KnowledgeEntity, step["from"]) + to_entity = await session.get(KnowledgeEntity, step["to"]) + + formatted.append({ + "from": from_entity.name if from_entity else step["from"], + "relation": step["relation"], + "to": to_entity.name if to_entity else step["to"], + }) + + return formatted diff --git a/backend/app/services/knowledge/incremental_index.py b/backend/app/services/knowledge/incremental_index.py new file mode 100644 index 0000000..8a7c985 --- /dev/null +++ b/backend/app/services/knowledge/incremental_index.py @@ -0,0 +1,242 @@ +""" +增量索引服务 +""" +import hashlib +from typing import Optional + +from sqlalchemy import delete, func, select, update +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.knowledge import KnowledgeChunk, KnowledgeDocument + + +class IncrementalIndexService: + """增量索引服务 - 支持文档的增删改""" + + def __init__(self, rag_service): + self.rag = rag_service + + async def add_document( + self, + session: AsyncSession, + kb_id: str, + document_id: str, + ) -> dict: + """ + 增量添加文档 + + 不重建全量索引,只处理单个文档 + """ + # 检查是否已存在 + existing = await self._get_document_status(session, document_id) + + if existing and existing.get("status") == "ready": + return {"action": "skip", "reason": "already_indexed"} + + # 执行增量摄入 + chunk_count = await self.rag.ingest_document(session, document_id) + + return { + "action": "indexed", + "document_id": document_id, + "chunk_count": chunk_count, + } + + async def update_document( + self, + session: AsyncSession, + document_id: str, + new_content: str, + ) -> dict: + """ + 增量更新文档 + + 策略: + 1. 计算新内容hash + 2. 若hash未变,跳过 + 3. 若hash改变,删除旧chunks,生成新的 + """ + # 计算新hash + new_hash = hashlib.sha256(new_content.encode()).hexdigest() + + # 获取旧hash + old_hash = await self._get_content_hash(session, document_id) + + if new_hash == old_hash: + return {"action": "skip", "reason": "content_unchanged"} + + # 删除旧chunks + deleted = await self._delete_document_chunks(session, document_id) + + # 更新文档内容 + await self._update_document_content(session, document_id, new_content, new_hash) + + # 增量摄入新内容 + chunk_count = await self.rag.ingest_document(session, document_id) + + return { + "action": "updated", + "document_id": document_id, + "deleted_chunks": deleted, + "new_chunks": chunk_count, + } + + async def delete_document( + self, + session: AsyncSession, + document_id: str, + ) -> dict: + """ + 删除文档 + + 删除文档及其所有chunks + """ + # 删除chunks + deleted = await self._delete_document_chunks(session, document_id) + + # 删除文档记录 + await self._delete_document(session, document_id) + + return { + "action": "deleted", + "document_id": document_id, + "deleted_chunks": deleted, + } + + async def rebuild_knowledge_base( + self, + session: AsyncSession, + kb_id: str, + force: bool = False, + ) -> dict: + """ + 重建知识库索引 + + Args: + force: 是否强制重建(即使状态是ready) + """ + # 获取所有文档 + stmt = select(KnowledgeDocument).where( + KnowledgeDocument.knowledge_base_id == kb_id + ) + result = await session.execute(stmt) + documents = result.scalars().all() + + stats = { + "total": len(documents), + "processed": 0, + "skipped": 0, + "failed": 0, + "errors": [], + } + + for doc in documents: + try: + if doc.status == "ready" and not force: + stats["skipped"] += 1 + continue + + # 删除旧chunks + await self._delete_document_chunks(session, str(doc.id)) + + # 重新摄入 + await self.rag.ingest_document(session, str(doc.id)) + + stats["processed"] += 1 + + except Exception as e: + stats["failed"] += 1 + stats["errors"].append({ + "document_id": str(doc.id), + "error": str(e), + }) + + return stats + + # ------------------------------------------------------------------ + # 辅助方法 + # ------------------------------------------------------------------ + + async def _get_document_status( + self, + session: AsyncSession, + document_id: str, + ) -> Optional[dict]: + """获取文档状态""" + stmt = select(KnowledgeDocument).where( + KnowledgeDocument.id == document_id + ) + result = await session.execute(stmt) + doc = result.scalar_one_or_none() + + if not doc: + return None + + return { + "status": doc.status, + "content_hash": getattr(doc, "content_hash", None), + } + + async def _get_content_hash( + self, + session: AsyncSession, + document_id: str, + ) -> Optional[str]: + """获取文档内容hash""" + status = await self._get_document_status(session, document_id) + return status.get("content_hash") if status else None + + async def _delete_document_chunks( + self, + session: AsyncSession, + document_id: str, + ) -> int: + """删除文档的所有chunks""" + # 统计要删除的数量 + count_stmt = select(func.count()).where( + KnowledgeChunk.document_id == document_id + ) + count_result = await session.execute(count_stmt) + count = count_result.scalar() or 0 + + # 删除 + delete_stmt = delete(KnowledgeChunk).where( + KnowledgeChunk.document_id == document_id + ) + await session.execute(delete_stmt) + + return count + + async def _update_document_content( + self, + session: AsyncSession, + document_id: str, + content: str, + content_hash: str, + ): + """更新文档内容和hash""" + stmt = ( + update(KnowledgeDocument) + .where(KnowledgeDocument.id == document_id) + .values( + content=content, + content_hash=content_hash, + status="pending", # 标记为待处理 + ) + ) + await session.execute(stmt) + + async def _delete_document( + self, + session: AsyncSession, + document_id: str, + ): + """删除文档记录""" + stmt = select(KnowledgeDocument).where( + KnowledgeDocument.id == document_id + ) + result = await session.execute(stmt) + doc = result.scalar_one_or_none() + + if doc: + await session.delete(doc) \ No newline at end of file diff --git a/backend/app/services/knowledge/parsers.py b/backend/app/services/knowledge/parsers.py new file mode 100644 index 0000000..d81b506 --- /dev/null +++ b/backend/app/services/knowledge/parsers.py @@ -0,0 +1,165 @@ +"""文档解析器 - 支持多种格式""" +import io +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Optional + +@dataclass +class ParsedDocument: + """解析后的文档""" + title: str + content: str + metadata: dict + +class BaseParser(ABC): + """解析器基类""" + + @abstractmethod + async def parse(self, content: bytes) -> ParsedDocument: + """解析文档内容""" + pass + +class PDFParser(BaseParser): + """PDF解析器""" + + async def parse(self, content: bytes) -> ParsedDocument: + """使用PyMuPDF解析PDF""" + import fitz + + doc = fitz.open(stream=content) + text_parts = [] + metadata = {} + + # 提取元数据 + if doc.metadata: + metadata = { + "author": doc.metadata.get("author", ""), + "title": doc.metadata.get("title", ""), + "subject": doc.metadata.get("subject", ""), + } + + # 提取每页文本 + for page_num, page in enumerate(doc): + text = page.get_text() + if text.strip(): + text_parts.append(f"[第{page_num + 1}页]\n{text}") + + # 提取目录(如果存在) + toc = doc.get_toc() + if toc: + metadata["has_toc"] = True + metadata["toc_items"] = len(toc) + + doc.close() + + return ParsedDocument( + title=metadata.get("title", "未命名文档") or "未命名文档", + content="\n\n".join(text_parts), + metadata=metadata, + ) + +class DocxParser(BaseParser): + """Word文档解析器""" + + async def parse(self, content: bytes) -> ParsedDocument: + """使用python-docx解析Word""" + from docx import Document + + doc = Document(io.BytesIO(content)) + paragraphs = [] + metadata = {} + + # 提取核心属性 + core_props = doc.core_properties + metadata = { + "author": getattr(core_props, "author", "") or "", + "title": getattr(core_props, "title", "") or "", + "subject": getattr(core_props, "subject", "") or "", + "created": str(getattr(core_props, "created", "")) or "", + "modified": str(getattr(core_props, "modified", "")) or "", + } + + # 提取段落 + for para in doc.paragraphs: + text = para.text.strip() + if text: + paragraphs.append(text) + + # 提取表格 + for i, table in enumerate(doc.tables): + table_text = [] + for row in table.rows: + cells = [cell.text.strip() for cell in row.cells] + if any(cells): + table_text.append(" | ".join(cells)) + if table_text: + paragraphs.append(f"[表格{i+1}]\n" + "\n".join(table_text)) + + return ParsedDocument( + title=metadata.get("title", "未命名文档") or "未命名文档", + content="\n\n".join(paragraphs), + metadata=metadata, + ) + +class MarkdownParser(BaseParser): + """Markdown解析器""" + + async def parse(self, content: bytes) -> ParsedDocument: + """解析Markdown""" + text = content.decode("utf-8") + + # 提取标题(第一个#开头的内容) + lines = text.split("\n") + title = "未命名文档" + for line in lines: + line = line.strip() + if line.startswith("# "): + title = line[2:].strip() + break + + return ParsedDocument( + title=title, + content=text, + metadata={"format": "markdown"}, + ) + +class TextParser(BaseParser): + """纯文本解析器""" + + async def parse(self, content: bytes) -> ParsedDocument: + """解析纯文本""" + text = content.decode("utf-8") + + # 使用第一行作为标题 + lines = text.split("\n") + title = lines[0][:50] if lines else "未命名文档" + + return ParsedDocument( + title=title, + content=text, + metadata={"format": "text"}, + ) + +class ParserFactory: + """解析器工厂""" + + PARSERS = { + ".pdf": PDFParser, + ".docx": DocxParser, + ".md": MarkdownParser, + ".txt": TextParser, + ".html": MarkdownParser, # HTML当Markdown处理 + } + + @classmethod + def create(cls, file_extension: str) -> BaseParser: + """创建解析器""" + parser_cls = cls.PARSERS.get(file_extension.lower()) + if not parser_cls: + raise ValueError(f"Unsupported format: {file_extension}") + return parser_cls() + + @classmethod + def supported_formats(cls) -> list[str]: + """支持的格式""" + return list(cls.PARSERS.keys()) \ No newline at end of file diff --git a/backend/app/services/knowledge/uploader.py b/backend/app/services/knowledge/uploader.py new file mode 100644 index 0000000..0eacb37 --- /dev/null +++ b/backend/app/services/knowledge/uploader.py @@ -0,0 +1,96 @@ +"""文档上传服务""" +import hashlib +import os +from dataclasses import dataclass +from typing import Optional + +import shortuuid + +from app.services.knowledge.parsers import ParserFactory, ParsedDocument + +@dataclass +class UploadResult: + """上传结果""" + document_id: str + title: str + content: str + content_hash: str + file_size: int + file_type: str + metadata: dict + +class DocumentUploader: + """文档上传服务""" + + SUPPORTED_FORMATS = {".pdf", ".docx", ".md", ".txt", ".html"} + MAX_SIZE_MB = 50 + + def __init__(self, storage_path: str = "/data/documents"): + self.storage_path = storage_path + + async def upload( + self, + file_content: bytes, + filename: str, + kb_id: str, + ) -> UploadResult: + """上传并解析文档""" + # 1. 验证格式 + ext = self._get_extension(filename) + if ext not in self.SUPPORTED_FORMATS: + raise ValueError(f"Unsupported format: {ext}") + + # 2. 验证大小 + size_mb = len(file_content) / (1024 * 1024) + if size_mb > self.MAX_SIZE_MB: + raise ValueError(f"File too large: {size_mb:.1f}MB > {self.MAX_SIZE_MB}MB") + + # 3. 解析文档 + parser = ParserFactory.create(ext) + parsed = await parser.parse(file_content) + + # 4. 生成ID和哈希 + doc_id = shortuuid.uuid() + content_hash = hashlib.sha256(parsed.content.encode()).hexdigest() + + # 5. 保存文件 + file_path = self._save_file(file_content, kb_id, doc_id, ext) + + return UploadResult( + document_id=doc_id, + title=parsed.title, + content=parsed.content, + content_hash=content_hash, + file_size=len(file_content), + file_type=ext, + metadata={ + **parsed.metadata, + "original_filename": filename, + "stored_path": file_path, + }, + ) + + def _get_extension(self, filename: str) -> str: + """获取文件扩展名""" + if "." not in filename: + raise ValueError("File has no extension") + return "." + filename.rsplit(".", 1)[1].lower() + + def _save_file( + self, + content: bytes, + kb_id: str, + doc_id: str, + ext: str + ) -> str: + """保存文件到存储路径""" + # 创建目录 + dir_path = os.path.join(self.storage_path, kb_id) + os.makedirs(dir_path, exist_ok=True) + + # 保存文件 + file_path = os.path.join(dir_path, f"{doc_id}{ext}") + with open(file_path, "wb") as f: + f.write(content) + + return file_path \ No newline at end of file diff --git a/backend/app/services/llm/deepseek_provider.py b/backend/app/services/llm/deepseek_provider.py index 0e41b4e..03ad724 100644 --- a/backend/app/services/llm/deepseek_provider.py +++ b/backend/app/services/llm/deepseek_provider.py @@ -1,12 +1,14 @@ import asyncio import json import os +import time from typing import AsyncIterator import httpx from .base import LLMError, LLMProvider, LLMResponse from .rate_limiter import get_rate_limiter +from app.monitoring.llm_metrics import get_llm_metrics _DEFAULT_MODEL = "deepseek-chat" _DEFAULT_MAX_CONTEXT = 64_000 @@ -75,21 +77,40 @@ class DeepSeekProvider(LLMProvider): if stop: payload["stop"] = stop - data = await self._request_with_retry(payload, stream=False) + start_time = time.perf_counter() + metrics = get_llm_metrics(self.provider_name, self._model) - choice = data["choices"][0] - content = choice["message"]["content"] - usage = data.get("usage", {}) + try: + data = await self._request_with_retry(payload, stream=False) - return LLMResponse( - content=content, - model=data.get("model", self._model), - usage={ - "prompt_tokens": usage.get("prompt_tokens", 0), - "completion_tokens": usage.get("completion_tokens", 0), - "total_tokens": usage.get("total_tokens", 0), - }, - ) + choice = data["choices"][0] + content = choice["message"]["content"] + usage = data.get("usage", {}) + + duration = time.perf_counter() - start_time + metrics.record_request( + status="success", + duration=duration, + prompt_tokens=usage.get("prompt_tokens"), + completion_tokens=usage.get("completion_tokens"), + ) + + return LLMResponse( + content=content, + model=data.get("model", self._model), + usage={ + "prompt_tokens": usage.get("prompt_tokens", 0), + "completion_tokens": usage.get("completion_tokens", 0), + "total_tokens": usage.get("total_tokens", 0), + }, + ) + except Exception as e: + duration = time.perf_counter() - start_time + metrics.record_request( + status="error", + duration=duration, + ) + raise async def chat_stream( self, diff --git a/backend/app/services/llm/openai_provider.py b/backend/app/services/llm/openai_provider.py index ee5da81..5eb604f 100644 --- a/backend/app/services/llm/openai_provider.py +++ b/backend/app/services/llm/openai_provider.py @@ -1,12 +1,14 @@ import asyncio import json import os +import time from typing import AsyncIterator import httpx from .base import LLMError, LLMProvider, LLMResponse from .rate_limiter import get_rate_limiter +from app.monitoring.llm_metrics import get_llm_metrics # 支持的模型及其上下文长度(百炼 Coding Plan + OpenAI) _OPENAI_MODELS: dict[str, int] = { @@ -90,21 +92,40 @@ class OpenAIProvider(LLMProvider): if stop: payload["stop"] = stop - data = await self._request_with_retry(payload, stream=False) + start_time = time.perf_counter() + metrics = get_llm_metrics(self.provider_name, self._model) - choice = data["choices"][0] - content = choice["message"]["content"] - usage = data.get("usage", {}) + try: + data = await self._request_with_retry(payload, stream=False) - return LLMResponse( - content=content, - model=data.get("model", self._model), - usage={ - "prompt_tokens": usage.get("prompt_tokens", 0), - "completion_tokens": usage.get("completion_tokens", 0), - "total_tokens": usage.get("total_tokens", 0), - }, - ) + choice = data["choices"][0] + content = choice["message"]["content"] + usage = data.get("usage", {}) + + duration = time.perf_counter() - start_time + metrics.record_request( + status="success", + duration=duration, + prompt_tokens=usage.get("prompt_tokens"), + completion_tokens=usage.get("completion_tokens"), + ) + + return LLMResponse( + content=content, + model=data.get("model", self._model), + usage={ + "prompt_tokens": usage.get("prompt_tokens", 0), + "completion_tokens": usage.get("completion_tokens", 0), + "total_tokens": usage.get("total_tokens", 0), + }, + ) + except Exception as e: + duration = time.perf_counter() - start_time + metrics.record_request( + status="error", + duration=duration, + ) + raise async def chat_stream( self, diff --git a/backend/app/services/quota_service.py b/backend/app/services/quota_service.py new file mode 100644 index 0000000..4c71fb6 --- /dev/null +++ b/backend/app/services/quota_service.py @@ -0,0 +1,324 @@ +""" +套餐额度预警服务 + +监控用户的API调用额度、查询次数等使用情况,在接近限制时触发预警。 + +功能: +- 额度使用查询: 获取用户当前额度使用情况 +- 使用率计算: 计算额度使用百分比 +- 预警阈值检查: 80%警告、90%严重、100%限制 +- 预警消息生成: 生成带建议操作的预警消息 +- 额度重置: 支持额度重置功能 + +额度类型: +- api_calls: API调用次数 +- queries: 查询次数 +- content_generation: 内容生成次数 +- storage: 存储空间(MB) + +预警阈值: +- warning: 80% - 警告 +- critical: 90% - 严重 +- exhausted: 100% - 耗尽 +""" +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum + + +class QuotaType(str, Enum): + """额度类型枚举""" + API_CALLS = "api_calls" + QUERIES = "queries" + CONTENT_GENERATION = "content_generation" + STORAGE = "storage" + + +QUOTA_LIMITS = { + "free": { + "api_calls": 1000, + "queries": 50, + "content_generation": 10, + "storage": 100, + }, + "basic": { + "api_calls": 10000, + "queries": 500, + "content_generation": 100, + "storage": 1000, + }, + "pro": { + "api_calls": 100000, + "queries": 5000, + "content_generation": 1000, + "storage": 10000, + }, + "unlimited": { + "api_calls": -1, + "queries": -1, + "content_generation": -1, + "storage": -1, + }, +} + +WARNING_THRESHOLDS = { + "warning": 0.80, + "critical": 0.90, + "exhausted": 1.00, +} + +WARNING_MESSAGES = { + "warning": "{quota_type}额度已使用{percentage:.0f}%", + "critical": "{quota_type}额度已使用{percentage:.0f}%,即将耗尽", + "exhausted": "{quota_type}额度已使用{percentage:.0f}%,已耗尽", +} + +RECOMMENDED_ACTIONS = { + "warning": "请关注使用情况,考虑升级套餐", + "critical": "额度即将耗尽,建议立即升级套餐", + "exhausted": "额度已耗尽,请升级套餐或等待重置", +} + + +@dataclass +class QuotaUsage: + """额度使用数据结构 + + Attributes: + quota_type: 额度类型 + used: 已使用量 + limit: 额度限制(-1表示无限) + usage_percentage: 使用百分比(0-100,无限时为0) + status: 状态(ok/warning/critical/exhausted/unlimited) + remaining: 剩余额度(-1表示无限) + """ + quota_type: str + used: int + limit: int + usage_percentage: float + status: str + remaining: int + + +@dataclass +class QuotaWarning: + """预警数据结构 + + Attributes: + quota_type: 额度类型 + status: 预警状态 + usage_percentage: 使用百分比 + message: 预警消息 + recommended_action: 建议操作 + """ + quota_type: str + status: str + usage_percentage: float + message: str + recommended_action: str + + +class QuotaService: + """套餐额度预警服务 + + 提供额度查询、使用率计算、预警检查等功能。 + """ + + def calculate_usage_percentage(self, used: int, limit: int) -> float: + """计算额度使用百分比 + + Args: + used: 已使用量 + limit: 额度限制(-1表示无限) + + Returns: + 使用百分比(0-100),无限额度时返回0.0 + """ + if limit == -1 or limit == 0: + return 0.0 + return (used / limit) * 100.0 + + def get_quota_status(self, usage_percentage: float, limit: int = 0) -> str: + """根据使用率获取额度状态 + + Args: + usage_percentage: 使用百分比 + limit: 额度限制(-1表示无限) + + Returns: + 状态字符串: ok/warning/critical/exhausted/unlimited + """ + if limit == -1: + return "unlimited" + if usage_percentage >= 100.0: + return "exhausted" + if usage_percentage >= 90.0: + return "critical" + if usage_percentage >= 80.0: + return "warning" + return "ok" + + def get_quota_limit(self, plan: str, quota_type: QuotaType) -> int: + """获取指定套餐的额度限制 + + Args: + plan: 套餐名称(free/basic/pro/unlimited) + quota_type: 额度类型 + + Returns: + 额度限制值(-1表示无限) + """ + return QUOTA_LIMITS.get(plan, {}).get(quota_type.value, 0) + + def get_remaining(self, used: int, limit: int) -> int: + """计算剩余额度 + + Args: + used: 已使用量 + limit: 额度限制(-1表示无限) + + Returns: + 剩余额度(-1表示无限) + """ + if limit == -1: + return -1 + return limit - used + + def _format_warning_message(self, quota_type: str, status: str, percentage: float) -> str: + """格式化预警消息 + + Args: + quota_type: 额度类型 + status: 预警状态 + percentage: 使用百分比 + + Returns: + 格式化后的预警消息 + """ + template = WARNING_MESSAGES.get(status, "{quota_type}额度使用情况") + return template.format(quota_type=quota_type, percentage=percentage) + + def _get_recommended_action(self, status: str) -> str: + """获取建议操作 + + Args: + status: 预警状态 + + Returns: + 建议操作描述 + """ + return RECOMMENDED_ACTIONS.get(status, "请关注使用情况") + + def generate_warning( + self, + quota_type: QuotaType, + status: str, + usage_percentage: float, + ) -> QuotaWarning: + """生成预警信息 + + Args: + quota_type: 额度类型 + status: 预警状态 + usage_percentage: 使用百分比 + + Returns: + QuotaWarning预警数据对象 + """ + return QuotaWarning( + quota_type=quota_type.value, + status=status, + usage_percentage=usage_percentage, + message=self._format_warning_message(quota_type.value, status, usage_percentage), + recommended_action=self._get_recommended_action(status), + ) + + def check_quota( + self, + plan: str, + quota_type: QuotaType, + used: int, + ) -> QuotaUsage: + """检查指定套餐的额度使用情况 + + Args: + plan: 套餐名称 + quota_type: 额度类型 + used: 已使用量 + + Returns: + QuotaUsage额度使用数据对象 + """ + limit = self.get_quota_limit(plan, quota_type) + percentage = self.calculate_usage_percentage(used, limit) + status = self.get_quota_status(percentage, limit) + remaining = self.get_remaining(used, limit) + + return QuotaUsage( + quota_type=quota_type.value, + used=used, + limit=limit, + usage_percentage=percentage, + status=status, + remaining=remaining, + ) + + def reset_quota(self, used: int, reset_to: int = 0) -> int: + """重置额度使用量 + + Args: + used: 当前使用量(未使用,仅为接口兼容) + reset_to: 重置到的值,默认为0 + + Returns: + 重置后的使用量 + """ + return reset_to + + def get_all_quota_usage( + self, + plan: str, + api_calls_used: int = 0, + queries_used: int = 0, + content_generation_used: int = 0, + storage_used: int = 0, + ) -> list[QuotaUsage]: + """获取所有额度类型的使用情况 + + Args: + plan: 套餐名称 + api_calls_used: API调用已使用量 + queries_used: 查询已使用量 + content_generation_used: 内容生成已使用量 + storage_used: 存储已使用量(MB) + + Returns: + 所有额度类型的使用情况列表 + """ + return [ + self.check_quota(plan, QuotaType.API_CALLS, api_calls_used), + self.check_quota(plan, QuotaType.QUERIES, queries_used), + self.check_quota(plan, QuotaType.CONTENT_GENERATION, content_generation_used), + self.check_quota(plan, QuotaType.STORAGE, storage_used), + ] + + def get_warnings(self, usage_list: list[QuotaUsage]) -> list[QuotaWarning]: + """从使用情况列表中生成预警 + + Args: + usage_list: 额度使用情况列表 + + Returns: + 预警信息列表(仅包含需要预警的项) + """ + warnings = [] + for usage in usage_list: + if usage.status in ["warning", "critical", "exhausted"]: + warning = self.generate_warning( + quota_type=QuotaType(usage.quota_type), + status=usage.status, + usage_percentage=usage.usage_percentage, + ) + warnings.append(warning) + return warnings diff --git a/backend/app/services/seo_diagnosis.py b/backend/app/services/seo_diagnosis.py new file mode 100644 index 0000000..33be79c --- /dev/null +++ b/backend/app/services/seo_diagnosis.py @@ -0,0 +1,1489 @@ +""" +SEO诊断服务 - 5维度检测系统 + +诊断维度(总分100): +- 技术SEO (Technical SEO): 25分 - 索引、爬取、Core Web Vitals等 +- 页面SEO (On-Page SEO): 20分 - Title/Meta、H标签、关键词等 +- 内容质量 (Content Quality): 20分 - 可读性、E-E-A-T、新鲜度等 +- 外链分析 (Backlink Analysis): 15分 - 反向链接质量、毒性信号等 +- 用户体验 (User Experience): 20分 - 移动适配、页面速度、转化路径等 +""" +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + +logger = logging.getLogger(__name__) + + +# ============================================================ +# 枚举定义 +# ============================================================ + +class DiagnosisStatus(str, Enum): + """诊断状态""" + PASS = "pass" + WARNING = "warning" + FAIL = "fail" + + +class DimensionName(str, Enum): + """诊断维度名称""" + TECHNICAL_SEO = "技术SEO" + ON_PAGE_SEO = "页面SEO" + CONTENT_QUALITY = "内容质量" + BACKLINK_ANALYSIS = "外链分析" + USER_EXPERIENCE = "用户体验" + + +# ============================================================ +# 数据结构 +# ============================================================ + +@dataclass +class DiagnosisItem: + """单个诊断项""" + name: str # 诊断项名称 + status: DiagnosisStatus # 诊断状态 + description: str # 诊断说明 + suggestion: str # 优化建议 + score: float = 0.0 # 该项得分 (0-1) + details: dict[str, Any] = field(default_factory=dict) # 额外详情 + + +@dataclass +class SEODimensionScore: + """单个维度的诊断结果""" + name: str # 维度名称 + score: float # 维度得分 (0-max_score) + max_score: float # 维度满分 + items: list[DiagnosisItem] # 诊断项列表 + status: DiagnosisStatus # 整体状态 + + @property + def percentage(self) -> float: + """得分率 (0-100)""" + if self.max_score <= 0: + return 0.0 + return round((self.score / self.max_score) * 100, 2) + + def __post_init__(self): + """计算整体状态""" + if not self.items: + self.status = DiagnosisStatus.WARNING + return + + fail_count = sum(1 for item in self.items if item.status == DiagnosisStatus.FAIL) + warning_count = sum(1 for item in self.items if item.status == DiagnosisStatus.WARNING) + total = len(self.items) + + if fail_count > total * 0.3: + self.status = DiagnosisStatus.FAIL + elif warning_count > total * 0.3 or fail_count > 0: + self.status = DiagnosisStatus.WARNING + else: + self.status = DiagnosisStatus.PASS + + +@dataclass +class SEORecommendation: + """优化建议""" + priority: str # high/medium/low + dimension: str # 所属维度 + item_name: str # 诊断项名称 + description: str # 建议描述 + impact: str # 预期影响 + effort: str # 实施难度 easy/medium/hard + + +@dataclass +class SEODiagnosisResult: + """SEO诊断结果""" + overall_score: float # 综合评分 0-100 + dimensions: list[SEODimensionScore] # 各维度得分 + recommendations: list[SEORecommendation] # 优化建议 + health_level: str = "danger" # 健康等级 + + def __post_init__(self): + """计算健康等级""" + self.overall_score = round(min(100.0, max(0.0, self.overall_score)), 2) + + if self.overall_score >= 80: + self.health_level = "excellent" + elif self.overall_score >= 60: + self.health_level = "good" + elif self.overall_score >= 40: + self.health_level = "pass" + else: + self.health_level = "danger" + + def to_dict(self) -> dict: + """转换为字典格式""" + return { + "overall_score": self.overall_score, + "health_level": self.health_level, + "health_level_label": self._get_health_label(), + "dimensions": [ + { + "name": dim.name, + "score": round(dim.score, 2), + "max_score": dim.max_score, + "percentage": dim.percentage, + "status": dim.status.value, + "items": [ + { + "name": item.name, + "status": item.status.value, + "description": item.description, + "suggestion": item.suggestion, + "score": round(item.score, 2), + "details": item.details, + } + for item in dim.items + ], + } + for dim in self.dimensions + ], + "recommendations": [ + { + "priority": rec.priority, + "dimension": rec.dimension, + "item_name": rec.item_name, + "description": rec.description, + "impact": rec.impact, + "effort": rec.effort, + } + for rec in self.recommendations + ], + } + + def _get_health_label(self) -> str: + """获取健康等级中文标签""" + labels = { + "excellent": "优秀", + "good": "良好", + "pass": "及格", + "danger": "危险", + } + return labels.get(self.health_level, "未知") + + +# ============================================================ +# 诊断数据输入 +# ============================================================ + +@dataclass +class TechnicalSEOData: + """技术SEO检测数据""" + is_indexed: bool = True # 是否被索引 + crawl_errors: int = 0 # 爬取错误数 + redirect_chains: int = 0 # 重定向链数 + lcp_seconds: float = 2.0 # Largest Contentful Paint (秒) + fid_ms: float = 50.0 # First Input Delay (毫秒) + cls_score: float = 0.05 # Cumulative Layout Shift + has_robots_txt: bool = True # 是否有robots.txt + robots_txt_blocks_important: bool = False # robots.txt是否阻止重要页面 + has_sitemap: bool = True # 是否有sitemap + sitemap_valid: bool = True # sitemap是否有效 + url_structure_normalized: bool = True # URL结构是否规范 + + +@dataclass +class OnPageSEOData: + """页面SEO检测数据""" + has_title: bool = True # 是否有Title标签 + title_length: int = 50 # Title长度 + title_keyword_stuffing: bool = False # Title是否关键词堆砌 + has_meta_description: bool = True # 是否有Meta Description + meta_description_length: int = 140 # Meta Description长度 + h1_count: int = 1 # H1标签数量 + h_structure_valid: bool = True # H标签结构是否合理 + keyword_density: float = 2.0 # 关键词密度 (%) + internal_links: int = 10 # 内链数量 + broken_internal_links: int = 0 # 死链数量 + images_without_alt: int = 0 # 缺少Alt文本的图片数 + total_images: int = 5 # 总图片数 + + +@dataclass +class ContentQualityData: + """内容质量检测数据""" + readability_score: float = 70.0 # 可读性评分 (0-100) + word_count: int = 1500 # 字数 + topic_coverage: float = 0.8 # 主题覆盖率 (0-1) + has_author_info: bool = True # 是否有作者信息 + has_publication_date: bool = True # 是否有发布日期 + last_updated_days: int = 30 # 最后更新天数 + has_citations: bool = True # 是否有引用/参考 + citation_authority: float = 0.8 # 引用权威性 (0-1) + duplicate_content_ratio: float = 0.05 # 重复内容比例 (0-1) + has_expert_review: bool = False # 是否有专家审核 + + +@dataclass +class BacklinkData: + """外链检测数据""" + total_backlinks: int = 100 # 总反向链接数 + referring_domains: int = 20 # 引用域名数 + high_authority_links: int = 10 # 高权威链接数 + toxic_links: int = 2 # 毒性链接数 + nofollow_ratio: float = 0.3 # Nofollow比例 + anchor_text_diversity: float = 0.8 # 锚文本多样性 (0-1) + exact_match_anchor_ratio: float = 0.2 # 精确匹配锚文本比例 + + +@dataclass +class UserExperienceData: + """用户体验检测数据""" + is_mobile_friendly: bool = True # 是否移动友好 + mobile_viewport_set: bool = True # 是否设置viewport + page_load_time: float = 2.5 # 页面加载时间 (秒) + has_https: bool = True # 是否使用HTTPS + has_breadcrumbs: bool = True # 是否有面包屑导航 + conversion_path_clear: bool = True # 转化路径是否清晰 + has_cta: bool = True # 是否有明确的CTA + form_usability: float = 0.9 # 表单可用性 (0-1) + has_search: bool = True # 是否有站内搜索 + + +# ============================================================ +# 维度诊断函数 +# ============================================================ + +def diagnose_technical_seo(data: TechnicalSEOData) -> SEODimensionScore: + """ + 技术SEO诊断 (满分25分) + + 评分项: + - 索引状态 (4分) + - 爬取错误 (4分) + - Core Web Vitals (6分) + - URL结构 (3分) + - robots.txt (4分) + - sitemap (4分) + """ + max_score = 25.0 + items: list[DiagnosisItem] = [] + total_score = 0.0 + + # 1. 索引状态检查 (4分) + if data.is_indexed: + items.append(DiagnosisItem( + name="索引状态", + status=DiagnosisStatus.PASS, + description="网站已被搜索引擎正确索引", + suggestion="保持当前索引状态", + score=1.0, + )) + total_score += 4.0 + else: + items.append(DiagnosisItem( + name="索引状态", + status=DiagnosisStatus.FAIL, + description="网站未被搜索引擎索引", + suggestion="检查Search Console,提交sitemap,确保没有被noindex", + score=0.0, + )) + + # 2. 爬取错误检测 (4分) + if data.crawl_errors == 0: + items.append(DiagnosisItem( + name="爬取错误", + status=DiagnosisStatus.PASS, + description="未发现爬取错误", + suggestion="定期检查Search Console的爬取错误报告", + score=1.0, + )) + total_score += 4.0 + elif data.crawl_errors <= 5: + items.append(DiagnosisItem( + name="爬取错误", + status=DiagnosisStatus.WARNING, + description=f"发现{data.crawl_errors}个爬取错误", + suggestion="修复404页面,检查5xx服务器错误,优化重定向链", + score=0.5, + details={"error_count": data.crawl_errors}, + )) + total_score += 2.0 + else: + items.append(DiagnosisItem( + name="爬取错误", + status=DiagnosisStatus.FAIL, + description=f"发现{data.crawl_errors}个爬取错误,数量过多", + suggestion="立即修复所有爬取错误,特别是5xx服务器错误", + score=0.0, + details={"error_count": data.crawl_errors}, + )) + + # 3. Core Web Vitals评估 (6分) + cwv_score = 0.0 + cwv_items = [] + + # LCP评估 (2分) + if data.lcp_seconds <= 2.5: + cwv_score += 2.0 + cwv_items.append(DiagnosisItem( + name="LCP", + status=DiagnosisStatus.PASS, + description=f"LCP为{data.lcp_seconds}s,符合<2.5s标准", + suggestion="保持当前性能水平", + score=1.0, + details={"value": data.lcp_seconds, "threshold": 2.5}, + )) + elif data.lcp_seconds <= 4.0: + cwv_score += 1.0 + cwv_items.append(DiagnosisItem( + name="LCP", + status=DiagnosisStatus.WARNING, + description=f"LCP为{data.lcp_seconds}s,超过2.5s标准", + suggestion="优化图片加载,使用CDN,减少服务器响应时间", + score=0.5, + details={"value": data.lcp_seconds, "threshold": 2.5}, + )) + else: + cwv_items.append(DiagnosisItem( + name="LCP", + status=DiagnosisStatus.FAIL, + description=f"LCP为{data.lcp_seconds}s,严重超过标准", + suggestion="立即优化页面加载性能", + score=0.0, + details={"value": data.lcp_seconds, "threshold": 2.5}, + )) + + # FID评估 (2分) + if data.fid_ms <= 100: + cwv_score += 2.0 + cwv_items.append(DiagnosisItem( + name="FID", + status=DiagnosisStatus.PASS, + description=f"FID为{data.fid_ms}ms,符合<100ms标准", + suggestion="保持当前交互性能", + score=1.0, + details={"value": data.fid_ms, "threshold": 100}, + )) + elif data.fid_ms <= 300: + cwv_score += 1.0 + cwv_items.append(DiagnosisItem( + name="FID", + status=DiagnosisStatus.WARNING, + description=f"FID为{data.fid_ms}ms,超过100ms标准", + suggestion="减少JavaScript执行时间,优化主线程工作", + score=0.5, + details={"value": data.fid_ms, "threshold": 100}, + )) + else: + cwv_items.append(DiagnosisItem( + name="FID", + status=DiagnosisStatus.FAIL, + description=f"FID为{data.fid_ms}ms,严重超过标准", + suggestion="立即优化JavaScript,减少主线程阻塞", + score=0.0, + details={"value": data.fid_ms, "threshold": 100}, + )) + + # CLS评估 (2分) + if data.cls_score <= 0.1: + cwv_score += 2.0 + cwv_items.append(DiagnosisItem( + name="CLS", + status=DiagnosisStatus.PASS, + description=f"CLS为{data.cls_score},符合<0.1标准", + suggestion="保持当前视觉稳定性", + score=1.0, + details={"value": data.cls_score, "threshold": 0.1}, + )) + elif data.cls_score <= 0.25: + cwv_score += 1.0 + cwv_items.append(DiagnosisItem( + name="CLS", + status=DiagnosisStatus.WARNING, + description=f"CLS为{data.cls_score},超过0.1标准", + suggestion="为图片和广告预留空间,避免动态插入内容", + score=0.5, + details={"value": data.cls_score, "threshold": 0.1}, + )) + else: + cwv_items.append(DiagnosisItem( + name="CLS", + status=DiagnosisStatus.FAIL, + description=f"CLS为{data.cls_score},严重超过标准", + suggestion="立即修复布局偏移问题", + score=0.0, + details={"value": data.cls_score, "threshold": 0.1}, + )) + + items.extend(cwv_items) + total_score += cwv_score + + # 4. URL结构规范化 (3分) + if data.url_structure_normalized: + items.append(DiagnosisItem( + name="URL结构", + status=DiagnosisStatus.PASS, + description="URL结构规范,无重复URL问题", + suggestion="保持当前URL结构", + score=1.0, + )) + total_score += 3.0 + else: + items.append(DiagnosisItem( + name="URL结构", + status=DiagnosisStatus.WARNING, + description="URL结构存在问题,可能有重复URL", + suggestion="使用canonical标签,统一URL格式", + score=0.5, + )) + total_score += 1.5 + + # 5. robots.txt配置检查 (4分) + if data.has_robots_txt and not data.robots_txt_blocks_important: + items.append(DiagnosisItem( + name="robots.txt", + status=DiagnosisStatus.PASS, + description="robots.txt配置正确,未阻止重要页面", + suggestion="定期检查robots.txt配置", + score=1.0, + )) + total_score += 4.0 + elif data.has_robots_txt: + items.append(DiagnosisItem( + name="robots.txt", + status=DiagnosisStatus.FAIL, + description="robots.txt阻止了重要页面", + suggestion="检查并修改robots.txt,确保重要页面可被爬取", + score=0.0, + )) + else: + items.append(DiagnosisItem( + name="robots.txt", + status=DiagnosisStatus.WARNING, + description="未找到robots.txt文件", + suggestion="创建robots.txt文件,明确指定允许和禁止爬取的路径", + score=0.5, + )) + total_score += 2.0 + + # 6. sitemap完整性验证 (4分) + if data.has_sitemap and data.sitemap_valid: + items.append(DiagnosisItem( + name="sitemap", + status=DiagnosisStatus.PASS, + description="sitemap存在且有效", + suggestion="定期更新sitemap,确保包含所有重要页面", + score=1.0, + )) + total_score += 4.0 + elif data.has_sitemap: + items.append(DiagnosisItem( + name="sitemap", + status=DiagnosisStatus.WARNING, + description="sitemap存在但可能无效", + suggestion="验证sitemap格式,确保所有URL可访问", + score=0.5, + )) + total_score += 2.0 + else: + items.append(DiagnosisItem( + name="sitemap", + status=DiagnosisStatus.FAIL, + description="未找到sitemap", + suggestion="创建并提交sitemap到Search Console", + score=0.0, + )) + + return SEODimensionScore( + name=DimensionName.TECHNICAL_SEO, + score=total_score, + max_score=max_score, + items=items, + status=DiagnosisStatus.PASS, # 会在__post_init__中重新计算 + ) + + +def diagnose_on_page_seo(data: OnPageSEOData) -> SEODimensionScore: + """ + 页面SEO诊断 (满分20分) + + 评分项: + - Title/Meta标签 (5分) + - H标签结构 (4分) + - 关键词密度 (4分) + - 内链结构 (4分) + - 图片Alt文本 (3分) + """ + max_score = 20.0 + items: list[DiagnosisItem] = [] + total_score = 0.0 + + # 1. Title/Meta标签完整性 (5分) + title_score = 0.0 + meta_score = 0.0 + + # Title检查 (2.5分) + if data.has_title: + if 30 <= data.title_length <= 60: + title_score += 2.5 + if not data.title_keyword_stuffing: + items.append(DiagnosisItem( + name="Title标签", + status=DiagnosisStatus.PASS, + description=f"Title长度{data.title_length}字符,格式规范", + suggestion="保持当前Title优化", + score=1.0, + details={"length": data.title_length}, + )) + else: + title_score -= 1.0 + items.append(DiagnosisItem( + name="Title标签", + status=DiagnosisStatus.WARNING, + description="Title存在关键词堆砌问题", + suggestion="简化Title,自然使用关键词", + score=0.5, + )) + else: + title_score += 1.0 + items.append(DiagnosisItem( + name="Title标签", + status=DiagnosisStatus.WARNING, + description=f"Title长度{data.title_length}字符,建议30-60字符", + suggestion="调整Title长度到推荐范围", + score=0.5, + details={"length": data.title_length}, + )) + else: + items.append(DiagnosisItem( + name="Title标签", + status=DiagnosisStatus.FAIL, + description="页面缺少Title标签", + suggestion="为每个页面添加唯一且描述性的Title", + score=0.0, + )) + + # Meta Description检查 (2.5分) + if data.has_meta_description: + if 120 <= data.meta_description_length <= 160: + meta_score += 2.5 + items.append(DiagnosisItem( + name="Meta Description", + status=DiagnosisStatus.PASS, + description=f"Meta Description长度{data.meta_description_length}字符,格式规范", + suggestion="保持当前Meta Description优化", + score=1.0, + details={"length": data.meta_description_length}, + )) + else: + meta_score += 1.0 + items.append(DiagnosisItem( + name="Meta Description", + status=DiagnosisStatus.WARNING, + description=f"Meta Description长度{data.meta_description_length}字符,建议120-160字符", + suggestion="调整Meta Description长度", + score=0.5, + details={"length": data.meta_description_length}, + )) + else: + items.append(DiagnosisItem( + name="Meta Description", + status=DiagnosisStatus.FAIL, + description="页面缺少Meta Description", + suggestion="为每个页面添加描述性的Meta Description", + score=0.0, + )) + + total_score += title_score + meta_score + + # 2. H标签结构层级 (4分) + if data.h1_count == 1 and data.h_structure_valid: + items.append(DiagnosisItem( + name="H标签结构", + status=DiagnosisStatus.PASS, + description="H标签结构清晰,有且仅有1个H1", + suggestion="保持当前H标签结构", + score=1.0, + details={"h1_count": data.h1_count}, + )) + total_score += 4.0 + elif data.h1_count == 1: + items.append(DiagnosisItem( + name="H标签结构", + status=DiagnosisStatus.WARNING, + description="有1个H1,但H标签层级可能不规范", + suggestion="确保H标签层级正确,不跳过级别", + score=0.5, + details={"h1_count": data.h1_count}, + )) + total_score += 2.0 + elif data.h1_count == 0: + items.append(DiagnosisItem( + name="H标签结构", + status=DiagnosisStatus.FAIL, + description="页面缺少H1标签", + suggestion="添加唯一的H1标签,包含主要关键词", + score=0.0, + details={"h1_count": data.h1_count}, + )) + else: + items.append(DiagnosisItem( + name="H标签结构", + status=DiagnosisStatus.WARNING, + description=f"页面有{data.h1_count}个H1标签,建议只有1个", + suggestion="确保每个页面只有1个H1标签", + score=0.5, + details={"h1_count": data.h1_count}, + )) + total_score += 2.0 + + # 3. 关键词密度合理性 (4分) + if 1.0 <= data.keyword_density <= 3.0: + items.append(DiagnosisItem( + name="关键词密度", + status=DiagnosisStatus.PASS, + description=f"关键词密度{data.keyword_density}%,在合理范围内", + suggestion="保持当前关键词使用频率", + score=1.0, + details={"density": data.keyword_density}, + )) + total_score += 4.0 + elif 0.5 <= data.keyword_density < 1.0 or 3.0 < data.keyword_density <= 5.0: + items.append(DiagnosisItem( + name="关键词密度", + status=DiagnosisStatus.WARNING, + description=f"关键词密度{data.keyword_density}%,建议1-3%", + suggestion="调整关键词使用频率到推荐范围", + score=0.5, + details={"density": data.keyword_density}, + )) + total_score += 2.0 + else: + status = DiagnosisStatus.FAIL if data.keyword_density > 5.0 else DiagnosisStatus.WARNING + items.append(DiagnosisItem( + name="关键词密度", + status=status, + description=f"关键词密度{data.keyword_density}%,不合理", + suggestion="优化关键词使用,避免堆砌或过少", + score=0.0, + details={"density": data.keyword_density}, + )) + + # 4. 内链结构 (4分) + if data.internal_links > 0: + if data.broken_internal_links == 0: + items.append(DiagnosisItem( + name="内链结构", + status=DiagnosisStatus.PASS, + description=f"内链结构良好,共{data.internal_links}个内链,无死链", + suggestion="保持内链更新,定期检查死链", + score=1.0, + details={"total": data.internal_links, "broken": 0}, + )) + total_score += 4.0 + elif data.broken_internal_links <= 3: + items.append(DiagnosisItem( + name="内链结构", + status=DiagnosisStatus.WARNING, + description=f"发现{data.broken_internal_links}个死链", + suggestion="修复所有死链,更新或移除无效链接", + score=0.5, + details={"total": data.internal_links, "broken": data.broken_internal_links}, + )) + total_score += 2.0 + else: + items.append(DiagnosisItem( + name="内链结构", + status=DiagnosisStatus.FAIL, + description=f"发现{data.broken_internal_links}个死链,数量过多", + suggestion="立即修复所有死链", + score=0.0, + details={"total": data.internal_links, "broken": data.broken_internal_links}, + )) + else: + items.append(DiagnosisItem( + name="内链结构", + status=DiagnosisStatus.FAIL, + description="页面没有内链", + suggestion="添加相关页面的内链,提升网站结构", + score=0.0, + )) + + # 5. 图片Alt文本 (3分) + if data.total_images == 0: + items.append(DiagnosisItem( + name="图片Alt文本", + status=DiagnosisStatus.PASS, + description="页面无图片", + suggestion="考虑添加相关图片并设置Alt文本", + score=1.0, + )) + total_score += 3.0 + elif data.images_without_alt == 0: + items.append(DiagnosisItem( + name="图片Alt文本", + status=DiagnosisStatus.PASS, + description=f"所有{data.total_images}张图片都有Alt文本", + suggestion="保持为所有图片添加描述性Alt文本", + score=1.0, + details={"total": data.total_images, "without_alt": 0}, + )) + total_score += 3.0 + elif data.images_without_alt <= data.total_images * 0.3: + items.append(DiagnosisItem( + name="图片Alt文本", + status=DiagnosisStatus.WARNING, + description=f"{data.images_without_alt}/{data.total_images}张图片缺少Alt文本", + suggestion="为所有图片添加描述性Alt文本", + score=0.5, + details={"total": data.total_images, "without_alt": data.images_without_alt}, + )) + total_score += 1.5 + else: + items.append(DiagnosisItem( + name="图片Alt文本", + status=DiagnosisStatus.FAIL, + description=f"{data.images_without_alt}/{data.total_images}张图片缺少Alt文本", + suggestion="立即为所有图片添加Alt文本", + score=0.0, + details={"total": data.total_images, "without_alt": data.images_without_alt}, + )) + + return SEODimensionScore( + name=DimensionName.ON_PAGE_SEO, + score=total_score, + max_score=max_score, + items=items, + status=DiagnosisStatus.PASS, + ) + + +def diagnose_content_quality(data: ContentQualityData) -> SEODimensionScore: + """ + 内容质量诊断 (满分20分) + + 评分项: + - 可读性评分 (4分) + - 信息深度 (4分) + - E-E-A-T信号 (5分) + - 内容新鲜度 (4分) + - 重复内容检测 (3分) + """ + max_score = 20.0 + items: list[DiagnosisItem] = [] + total_score = 0.0 + + # 1. 可读性评分 (4分) + if data.readability_score >= 70: + items.append(DiagnosisItem( + name="可读性", + status=DiagnosisStatus.PASS, + description=f"可读性评分{data.readability_score},内容易于理解", + suggestion="保持当前内容质量", + score=1.0, + details={"score": data.readability_score}, + )) + total_score += 4.0 + elif data.readability_score >= 50: + items.append(DiagnosisItem( + name="可读性", + status=DiagnosisStatus.WARNING, + description=f"可读性评分{data.readability_score},内容较难理解", + suggestion="简化语言,使用短句,增加段落分隔", + score=0.5, + details={"score": data.readability_score}, + )) + total_score += 2.0 + else: + items.append(DiagnosisItem( + name="可读性", + status=DiagnosisStatus.FAIL, + description=f"可读性评分{data.readability_score},内容难以理解", + suggestion="大幅简化内容,使用更通俗的语言", + score=0.0, + details={"score": data.readability_score}, + )) + + # 2. 信息深度评估 (4分) + if data.word_count >= 1500 and data.topic_coverage >= 0.8: + items.append(DiagnosisItem( + name="信息深度", + status=DiagnosisStatus.PASS, + description=f"内容深度良好,{data.word_count}字,主题覆盖率{data.topic_coverage*100:.0f}%", + suggestion="保持当前内容深度", + score=1.0, + details={"word_count": data.word_count, "coverage": data.topic_coverage}, + )) + total_score += 4.0 + elif data.word_count >= 800 and data.topic_coverage >= 0.6: + items.append(DiagnosisItem( + name="信息深度", + status=DiagnosisStatus.WARNING, + description=f"内容深度一般,{data.word_count}字,主题覆盖率{data.topic_coverage*100:.0f}%", + suggestion="扩展内容深度,覆盖更多相关子话题", + score=0.5, + details={"word_count": data.word_count, "coverage": data.topic_coverage}, + )) + total_score += 2.0 + else: + items.append(DiagnosisItem( + name="信息深度", + status=DiagnosisStatus.FAIL, + description=f"内容深度不足,{data.word_count}字,主题覆盖率{data.topic_coverage*100:.0f}%", + suggestion="大幅扩展内容,全面覆盖主题", + score=0.0, + details={"word_count": data.word_count, "coverage": data.topic_coverage}, + )) + + # 3. E-E-A-T信号检测 (5分) + eeat_score = 0.0 + + # 作者信息 (1.5分) + if data.has_author_info: + eeat_score += 1.5 + items.append(DiagnosisItem( + name="作者资质", + status=DiagnosisStatus.PASS, + description="内容包含作者信息", + suggestion="保持展示作者资质", + score=1.0, + )) + else: + items.append(DiagnosisItem( + name="作者资质", + status=DiagnosisStatus.WARNING, + description="内容缺少作者信息", + suggestion="添加作者信息和专业背景", + score=0.0, + )) + + # 专业认证/专家审核 (1.5分) + if data.has_expert_review: + eeat_score += 1.5 + items.append(DiagnosisItem( + name="专家审核", + status=DiagnosisStatus.PASS, + description="内容经过专家审核", + suggestion="保持专家审核流程", + score=1.0, + )) + else: + items.append(DiagnosisItem( + name="专家审核", + status=DiagnosisStatus.WARNING, + description="内容未经专家审核", + suggestion="考虑邀请行业专家审核重要内容", + score=0.0, + )) + + # 数据来源权威性 (2分) + if data.has_citations and data.citation_authority >= 0.7: + eeat_score += 2.0 + items.append(DiagnosisItem( + name="数据来源", + status=DiagnosisStatus.PASS, + description=f"引用权威数据源,权威性评分{data.citation_authority:.2f}", + suggestion="保持引用高质量数据源", + score=1.0, + details={"authority": data.citation_authority}, + )) + elif data.has_citations: + eeat_score += 1.0 + items.append(DiagnosisItem( + name="数据来源", + status=DiagnosisStatus.WARNING, + description=f"有引用但数据源权威性一般,评分{data.citation_authority:.2f}", + suggestion="引用更权威的数据源", + score=0.5, + details={"authority": data.citation_authority}, + )) + else: + items.append(DiagnosisItem( + name="数据来源", + status=DiagnosisStatus.FAIL, + description="内容未引用任何数据源", + suggestion="引用权威数据支持内容观点", + score=0.0, + )) + + total_score += eeat_score + + # 4. 内容新鲜度 (4分) + if data.has_publication_date and data.last_updated_days <= 30: + items.append(DiagnosisItem( + name="内容新鲜度", + status=DiagnosisStatus.PASS, + description=f"内容更新于{data.last_updated_days}天前", + suggestion="保持定期更新内容", + score=1.0, + details={"last_updated_days": data.last_updated_days}, + )) + total_score += 4.0 + elif data.has_publication_date and data.last_updated_days <= 90: + items.append(DiagnosisItem( + name="内容新鲜度", + status=DiagnosisStatus.WARNING, + description=f"内容更新于{data.last_updated_days}天前,建议30天内更新", + suggestion="定期更新内容,保持信息时效性", + score=0.5, + details={"last_updated_days": data.last_updated_days}, + )) + total_score += 2.0 + else: + items.append(DiagnosisItem( + name="内容新鲜度", + status=DiagnosisStatus.FAIL, + description="内容长时间未更新或缺少发布日期", + suggestion="更新内容并显示发布/更新日期", + score=0.0, + details={"last_updated_days": data.last_updated_days}, + )) + + # 5. 重复内容检测 (3分) + if data.duplicate_content_ratio <= 0.1: + items.append(DiagnosisItem( + name="重复内容", + status=DiagnosisStatus.PASS, + description=f"重复内容比例{data.duplicate_content_ratio*100:.0f}%,在可接受范围内", + suggestion="保持内容原创性", + score=1.0, + details={"duplicate_ratio": data.duplicate_content_ratio}, + )) + total_score += 3.0 + elif data.duplicate_content_ratio <= 0.3: + items.append(DiagnosisItem( + name="重复内容", + status=DiagnosisStatus.WARNING, + description=f"重复内容比例{data.duplicate_content_ratio*100:.0f}%", + suggestion="减少重复内容,使用canonical标签", + score=0.5, + details={"duplicate_ratio": data.duplicate_content_ratio}, + )) + total_score += 1.5 + else: + items.append(DiagnosisItem( + name="重复内容", + status=DiagnosisStatus.FAIL, + description=f"重复内容比例{data.duplicate_content_ratio*100:.0f}%,过高", + suggestion="重写重复内容,确保每页独特价值", + score=0.0, + details={"duplicate_ratio": data.duplicate_content_ratio}, + )) + + return SEODimensionScore( + name=DimensionName.CONTENT_QUALITY, + score=total_score, + max_score=max_score, + items=items, + status=DiagnosisStatus.PASS, + ) + + +def diagnose_backlinks(data: BacklinkData) -> SEODimensionScore: + """ + 外链分析 (满分15分) + + 评分项: + - 反向链接质量 (6分) + - 毒性信号检测 (5分) + - 锚文本分布 (4分) + """ + max_score = 15.0 + items: list[DiagnosisItem] = [] + total_score = 0.0 + + # 1. 反向链接质量 (6分) + quality_score = 0.0 + + # 引用域名数 (2分) + if data.referring_domains >= 20: + quality_score += 2.0 + items.append(DiagnosisItem( + name="引用域名", + status=DiagnosisStatus.PASS, + description=f"有{data.referring_domains}个引用域名", + suggestion="继续增加高质量外链", + score=1.0, + details={"referring_domains": data.referring_domains}, + )) + elif data.referring_domains >= 10: + quality_score += 1.0 + items.append(DiagnosisItem( + name="引用域名", + status=DiagnosisStatus.WARNING, + description=f"有{data.referring_domains}个引用域名,建议增加", + suggestion="通过内容营销获取更多外链", + score=0.5, + details={"referring_domains": data.referring_domains}, + )) + else: + items.append(DiagnosisItem( + name="引用域名", + status=DiagnosisStatus.FAIL, + description=f"仅有{data.referring_domains}个引用域名", + suggestion="积极开展外链建设", + score=0.0, + details={"referring_domains": data.referring_domains}, + )) + + # 高权威链接 (2分) + if data.high_authority_links >= 5: + quality_score += 2.0 + items.append(DiagnosisItem( + name="高权威链接", + status=DiagnosisStatus.PASS, + description=f"有{data.high_authority_links}个高权威外链", + suggestion="保持获取高质量外链", + score=1.0, + details={"high_authority_links": data.high_authority_links}, + )) + elif data.high_authority_links >= 2: + quality_score += 1.0 + items.append(DiagnosisItem( + name="高权威链接", + status=DiagnosisStatus.WARNING, + description=f"有{data.high_authority_links}个高权威外链", + suggestion="争取更多权威网站的外链", + score=0.5, + details={"high_authority_links": data.high_authority_links}, + )) + else: + items.append(DiagnosisItem( + name="高权威链接", + status=DiagnosisStatus.FAIL, + description=f"仅有{data.high_authority_links}个高权威外链", + suggestion="重点获取权威网站外链", + score=0.0, + details={"high_authority_links": data.high_authority_links}, + )) + + # Nofollow比例 (2分) + if 0.2 <= data.nofollow_ratio <= 0.6: + quality_score += 2.0 + items.append(DiagnosisItem( + name="Nofollow比例", + status=DiagnosisStatus.PASS, + description=f"Nofollow比例{data.nofollow_ratio*100:.0f}%,自然合理", + suggestion="保持自然的外链结构", + score=1.0, + details={"nofollow_ratio": data.nofollow_ratio}, + )) + else: + quality_score += 1.0 + items.append(DiagnosisItem( + name="Nofollow比例", + status=DiagnosisStatus.WARNING, + description=f"Nofollow比例{data.nofollow_ratio*100:.0f}%,可能不自然", + suggestion="确保外链结构自然多样", + score=0.5, + details={"nofollow_ratio": data.nofollow_ratio}, + )) + + total_score += quality_score + + # 2. 毒性信号检测 (5分) + toxic_ratio = data.toxic_links / data.total_backlinks if data.total_backlinks > 0 else 0 + + if data.toxic_links == 0: + items.append(DiagnosisItem( + name="毒性链接", + status=DiagnosisStatus.PASS, + description="未发现毒性外链", + suggestion="定期监控外链质量", + score=1.0, + details={"toxic_count": 0}, + )) + total_score += 5.0 + elif toxic_ratio <= 0.05: + items.append(DiagnosisItem( + name="毒性链接", + status=DiagnosisStatus.WARNING, + description=f"发现{data.toxic_links}个毒性外链", + suggestion="使用Disavow工具拒绝毒性外链", + score=0.5, + details={"toxic_count": data.toxic_links, "ratio": toxic_ratio}, + )) + total_score += 2.5 + else: + items.append(DiagnosisItem( + name="毒性链接", + status=DiagnosisStatus.FAIL, + description=f"发现{data.toxic_links}个毒性外链,比例过高", + suggestion="立即使用Disavow工具拒绝所有毒性外链", + score=0.0, + details={"toxic_count": data.toxic_links, "ratio": toxic_ratio}, + )) + + # 3. 锚文本分布 (4分) + if data.anchor_text_diversity >= 0.7 and data.exact_match_anchor_ratio <= 0.3: + items.append(DiagnosisItem( + name="锚文本分布", + status=DiagnosisStatus.PASS, + description=f"锚文本多样性{data.anchor_text_diversity:.2f},精确匹配比例{data.exact_match_anchor_ratio*100:.0f}%", + suggestion="保持自然的锚文本结构", + score=1.0, + details={ + "diversity": data.anchor_text_diversity, + "exact_match_ratio": data.exact_match_anchor_ratio, + }, + )) + total_score += 4.0 + elif data.anchor_text_diversity >= 0.5: + items.append(DiagnosisItem( + name="锚文本分布", + status=DiagnosisStatus.WARNING, + description=f"锚文本多样性{data.anchor_text_diversity:.2f},建议增加多样性", + suggestion="使用更多样化的锚文本", + score=0.5, + details={ + "diversity": data.anchor_text_diversity, + "exact_match_ratio": data.exact_match_anchor_ratio, + }, + )) + total_score += 2.0 + else: + items.append(DiagnosisItem( + name="锚文本分布", + status=DiagnosisStatus.FAIL, + description=f"锚文本多样性{data.anchor_text_diversity:.2f},过于单一", + suggestion="大幅增加锚文本多样性", + score=0.0, + details={ + "diversity": data.anchor_text_diversity, + "exact_match_ratio": data.exact_match_anchor_ratio, + }, + )) + + return SEODimensionScore( + name=DimensionName.BACKLINK_ANALYSIS, + score=total_score, + max_score=max_score, + items=items, + status=DiagnosisStatus.PASS, + ) + + +def diagnose_user_experience(data: UserExperienceData) -> SEODimensionScore: + """ + 用户体验诊断 (满分20分) + + 评分项: + - 移动适配检查 (6分) + - 页面速度评估 (5分) + - 转化路径分析 (5分) + - 基础体验 (4分) + """ + max_score = 20.0 + items: list[DiagnosisItem] = [] + total_score = 0.0 + + # 1. 移动适配检查 (6分) + if data.is_mobile_friendly and data.mobile_viewport_set: + items.append(DiagnosisItem( + name="移动适配", + status=DiagnosisStatus.PASS, + description="页面移动适配良好", + suggestion="保持移动端优化", + score=1.0, + )) + total_score += 6.0 + elif data.is_mobile_friendly: + items.append(DiagnosisItem( + name="移动适配", + status=DiagnosisStatus.WARNING, + description="页面基本适配移动端,但缺少viewport设置", + suggestion="添加viewport meta标签", + score=0.5, + )) + total_score += 3.0 + else: + items.append(DiagnosisItem( + name="移动适配", + status=DiagnosisStatus.FAIL, + description="页面未适配移动端", + suggestion="立即实现响应式设计或移动版本", + score=0.0, + )) + + # 2. 页面速度评估 (5分) + if data.page_load_time <= 2.0: + items.append(DiagnosisItem( + name="页面速度", + status=DiagnosisStatus.PASS, + description=f"页面加载时间{data.page_load_time}s,性能优秀", + suggestion="保持当前性能水平", + score=1.0, + details={"load_time": data.page_load_time}, + )) + total_score += 5.0 + elif data.page_load_time <= 3.0: + items.append(DiagnosisItem( + name="页面速度", + status=DiagnosisStatus.WARNING, + description=f"页面加载时间{data.page_load_time}s,建议优化到2s内", + suggestion="优化图片、启用缓存、使用CDN", + score=0.5, + details={"load_time": data.page_load_time}, + )) + total_score += 2.5 + else: + items.append(DiagnosisItem( + name="页面速度", + status=DiagnosisStatus.FAIL, + description=f"页面加载时间{data.page_load_time}s,严重超时", + suggestion="立即优化页面加载性能", + score=0.0, + details={"load_time": data.page_load_time}, + )) + + # 3. 转化路径分析 (5分) + conversion_score = 0.0 + + # CTA检查 (2分) + if data.has_cta: + conversion_score += 2.0 + items.append(DiagnosisItem( + name="CTA", + status=DiagnosisStatus.PASS, + description="页面有明确的行动号召", + suggestion="保持清晰的CTA", + score=1.0, + )) + else: + items.append(DiagnosisItem( + name="CTA", + status=DiagnosisStatus.WARNING, + description="页面缺少明确的行动号召", + suggestion="添加清晰的CTA按钮", + score=0.0, + )) + + # 转化路径清晰度 (2分) + if data.conversion_path_clear: + conversion_score += 2.0 + items.append(DiagnosisItem( + name="转化路径", + status=DiagnosisStatus.PASS, + description="转化路径清晰", + suggestion="保持当前转化流程", + score=1.0, + )) + else: + items.append(DiagnosisItem( + name="转化路径", + status=DiagnosisStatus.WARNING, + description="转化路径不够清晰", + suggestion="简化转化流程,减少步骤", + score=0.0, + )) + + # 表单可用性 (1分) + if data.form_usability >= 0.8: + conversion_score += 1.0 + items.append(DiagnosisItem( + name="表单可用性", + status=DiagnosisStatus.PASS, + description=f"表单可用性{data.form_usability*100:.0f}%", + suggestion="保持表单体验", + score=1.0, + details={"usability": data.form_usability}, + )) + else: + items.append(DiagnosisItem( + name="表单可用性", + status=DiagnosisStatus.WARNING, + description=f"表单可用性{data.form_usability*100:.0f}%,需要优化", + suggestion="简化表单,减少必填字段", + score=0.0, + details={"usability": data.form_usability}, + )) + + total_score += conversion_score + + # 4. 基础体验 (4分) + # HTTPS检查 (2分) + if data.has_https: + items.append(DiagnosisItem( + name="HTTPS", + status=DiagnosisStatus.PASS, + description="网站使用HTTPS", + suggestion="保持HTTPS配置", + score=1.0, + )) + total_score += 2.0 + else: + items.append(DiagnosisItem( + name="HTTPS", + status=DiagnosisStatus.FAIL, + description="网站未使用HTTPS", + suggestion="立即启用HTTPS", + score=0.0, + )) + + # 面包屑导航 (1分) + if data.has_breadcrumbs: + items.append(DiagnosisItem( + name="面包屑导航", + status=DiagnosisStatus.PASS, + description="页面有面包屑导航", + suggestion="保持面包屑导航", + score=1.0, + )) + total_score += 1.0 + else: + items.append(DiagnosisItem( + name="面包屑导航", + status=DiagnosisStatus.WARNING, + description="页面缺少面包屑导航", + suggestion="添加面包屑导航提升用户体验", + score=0.0, + )) + + # 站内搜索 (1分) + if data.has_search: + items.append(DiagnosisItem( + name="站内搜索", + status=DiagnosisStatus.PASS, + description="网站有站内搜索功能", + suggestion="保持搜索功能优化", + score=1.0, + )) + total_score += 1.0 + else: + items.append(DiagnosisItem( + name="站内搜索", + status=DiagnosisStatus.WARNING, + description="网站缺少站内搜索", + suggestion="添加站内搜索功能", + score=0.0, + )) + + return SEODimensionScore( + name=DimensionName.USER_EXPERIENCE, + score=total_score, + max_score=max_score, + items=items, + status=DiagnosisStatus.PASS, + ) + + +# ============================================================ +# 建议生成 +# ============================================================ + +def generate_recommendations(result: SEODiagnosisResult) -> list[SEORecommendation]: + """ + 根据诊断结果生成优化建议 + + 优先级规则: + - FAIL状态 -> high priority + - WARNING状态 -> medium priority + - 影响大的建议 -> high priority + """ + recommendations: list[SEORecommendation] = [] + + for dimension in result.dimensions: + for item in dimension.items: + if item.status == DiagnosisStatus.FAIL: + recommendations.append(SEORecommendation( + priority="high", + dimension=dimension.name, + item_name=item.name, + description=item.suggestion, + impact="修复后可显著提升SEO表现", + effort="medium", + )) + elif item.status == DiagnosisStatus.WARNING: + recommendations.append(SEORecommendation( + priority="medium", + dimension=dimension.name, + item_name=item.name, + description=item.suggestion, + impact="优化后可改善SEO表现", + effort="easy", + )) + + # 按优先级排序 + priority_order = {"high": 0, "medium": 1, "low": 2} + recommendations.sort(key=lambda r: priority_order.get(r.priority, 3)) + + return recommendations + + +# ============================================================ +# 主诊断服务 +# ============================================================ + +class SEODiagnosisService: + """SEO诊断服务""" + + def diagnose( + self, + technical_data: TechnicalSEOData | None = None, + on_page_data: OnPageSEOData | None = None, + content_data: ContentQualityData | None = None, + backlink_data: BacklinkData | None = None, + ux_data: UserExperienceData | None = None, + ) -> SEODiagnosisResult: + """ + 执行完整SEO诊断 + + Args: + technical_data: 技术SEO检测数据 + on_page_data: 页面SEO检测数据 + content_data: 内容质量检测数据 + backlink_data: 外链检测数据 + ux_data: 用户体验检测数据 + + Returns: + SEODiagnosisResult: 诊断结果 + """ + # 使用默认数据(模拟数据) + technical_data = technical_data or TechnicalSEOData() + on_page_data = on_page_data or OnPageSEOData() + content_data = content_data or ContentQualityData() + backlink_data = backlink_data or BacklinkData() + ux_data = ux_data or UserExperienceData() + + # 执行5维度诊断 + dimensions = [ + diagnose_technical_seo(technical_data), + diagnose_on_page_seo(on_page_data), + diagnose_content_quality(content_data), + diagnose_backlinks(backlink_data), + diagnose_user_experience(ux_data), + ] + + # 计算综合评分 + overall_score = sum(dim.score for dim in dimensions) + + # 创建初步结果 + result = SEODiagnosisResult( + overall_score=overall_score, + dimensions=dimensions, + recommendations=[], + ) + + # 生成优化建议 + result.recommendations = generate_recommendations(result) + + return result + + def diagnose_technical_only(self, data: TechnicalSEOData | None = None) -> SEODimensionScore: + """仅执行技术SEO诊断""" + return diagnose_technical_seo(data or TechnicalSEOData()) + + def diagnose_on_page_only(self, data: OnPageSEOData | None = None) -> SEODimensionScore: + """仅执行页面SEO诊断""" + return diagnose_on_page_seo(data or OnPageSEOData()) + + def diagnose_content_only(self, data: ContentQualityData | None = None) -> SEODimensionScore: + """仅执行内容质量诊断""" + return diagnose_content_quality(data or ContentQualityData()) + + def diagnose_backlinks_only(self, data: BacklinkData | None = None) -> SEODimensionScore: + """仅执行外链分析""" + return diagnose_backlinks(data or BacklinkData()) + + def diagnose_ux_only(self, data: UserExperienceData | None = None) -> SEODimensionScore: + """仅执行用户体验诊断""" + return diagnose_user_experience(data or UserExperienceData()) diff --git a/backend/requirements.txt b/backend/requirements.txt index 42b454a..ea34e02 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -40,3 +40,11 @@ aiosqlite # PDF生成 fpdf2>=2.7 + +# 监控 +prometheus-client>=0.19.0 + +# 文档解析 +PyMuPDF>=1.23.0 +python-docx>=1.1.0 +shortuuid>=1.0.0 diff --git a/backend/tests/test_agent_framework/test_agent_base.py b/backend/tests/test_agent_framework/test_agent_base.py new file mode 100644 index 0000000..7630821 --- /dev/null +++ b/backend/tests/test_agent_framework/test_agent_base.py @@ -0,0 +1,227 @@ +"""Agent基类测试""" +import pytest +from datetime import datetime, timezone + +from app.agent_framework.base import BaseAgent +from app.agent_framework.protocol import ( + AgentCapability, + AgentStatus, + TaskMessage, + TaskResult, + TaskStatus, +) + + +class ConcreteTestAgent(BaseAgent): + """用于测试的BaseAgent实现""" + + def __init__(self): + super().__init__( + name="concrete_test_agent", + agent_type="test_type", + version="1.0.0", + ) + self._execute_called = False + self._execute_task_data = None + + def get_capabilities(self) -> AgentCapability: + return AgentCapability( + agent_name=self.name, + agent_type=self.agent_type, + version=self.version, + supported_tasks=["test_task"], + max_concurrency=3, + description="测试用Agent", + ) + + async def execute(self, task: TaskMessage) -> TaskResult: + self._execute_called = True + self._execute_task_data = task + return TaskResult( + task_id=task.task_id, + agent_name=self.name, + status=TaskStatus.COMPLETED, + output_data={"result": "success"}, + error_message=None, + started_at=datetime.now(timezone.utc), + completed_at=datetime.now(timezone.utc), + metrics={"test": True}, + ) + + +class TestBaseAgent: + """Agent基类测试""" + + def test_agent_initialization(self): + """测试Agent初始化""" + agent = ConcreteTestAgent() + + assert agent.name == "concrete_test_agent" + assert agent.agent_type == "test_type" + assert agent.version == "1.0.0" + assert agent.status == AgentStatus.OFFLINE + + def test_get_capabilities(self): + """测试获取Agent能力""" + agent = ConcreteTestAgent() + capability = agent.get_capabilities() + + assert isinstance(capability, AgentCapability) + assert capability.agent_name == "concrete_test_agent" + assert "test_task" in capability.supported_tasks + + def test_agent_status_transitions(self): + """测试Agent状态转换""" + agent = ConcreteTestAgent() + + # 初始状态 + assert agent.status == AgentStatus.OFFLINE + + # 模拟设置状态 + agent._status = AgentStatus.ONLINE + assert agent.status == AgentStatus.ONLINE + + agent._status = AgentStatus.BUSY + assert agent.status == AgentStatus.BUSY + + def test_agent_running_tasks_tracking(self): + """测试运行任务跟踪""" + agent = ConcreteTestAgent() + + assert len(agent._running_tasks) == 0 + + # 模拟添加任务 + agent._running_tasks.add("task-1") + agent._running_tasks.add("task-2") + assert len(agent._running_tasks) == 2 + + # 模拟移除任务 + agent._running_tasks.discard("task-1") + assert len(agent._running_tasks) == 1 + assert "task-1" not in agent._running_tasks + assert "task-2" in agent._running_tasks + + def test_agent_semaphore_initialization(self): + """测试信号量初始化""" + agent = ConcreteTestAgent() + capability = agent.get_capabilities() + max_concurrency = capability.max_concurrency + + # 初始化信号量 + import asyncio + agent._semaphore = asyncio.Semaphore(max_concurrency) + + assert agent._semaphore is not None + assert agent._semaphore._value == max_concurrency + + def test_is_idle_property(self): + """测试空闲状态判断""" + agent = ConcreteTestAgent() + + # OFFLINE 或 ONLINE 状态应该被视为空闲 + agent._status = AgentStatus.OFFLINE + # 这里假设有is_idle属性或方法 + # 根据实际实现检查 + + # BUSY状态不是空闲 + agent._status = AgentStatus.BUSY + + +class TestTaskMessage: + """TaskMessage测试""" + + def test_task_message_creation(self): + """测试TaskMessage创建""" + task = TaskMessage( + task_id=str(uuid.uuid4()), + agent_name="test_agent", + task_type="test_task", + priority=5, + input_data={"key": "value"}, + callback_url=None, + created_at=datetime.now(timezone.utc), + ) + + assert task.task_id is not None + assert task.agent_name == "test_agent" + assert task.task_type == "test_task" + assert task.priority == 5 + assert task.input_data == {"key": "value"} + + def test_task_message_to_dict(self): + """测试TaskMessage序列化""" + task = TaskMessage( + task_id="test-uuid", + agent_name="test_agent", + task_type="test_task", + priority=5, + input_data={"key": "value"}, + callback_url=None, + created_at=datetime.now(timezone.utc), + ) + + data = task.to_dict() + + assert data["task_id"] == "test-uuid" + assert data["agent_name"] == "test_agent" + assert data["task_type"] == "test_task" + + def test_task_message_from_dict(self): + """测试TaskMessage反序列化""" + data = { + "task_id": "test-uuid", + "agent_name": "test_agent", + "task_type": "test_task", + "priority": 5, + "input_data": {"key": "value"}, + "callback_url": None, + "created_at": datetime.now(timezone.utc).isoformat(), + } + + task = TaskMessage.from_dict(data) + + assert task.task_id == "test-uuid" + assert task.agent_name == "test_agent" + + +class TestTaskResult: + """TaskResult测试""" + + def test_task_result_creation(self): + """测试TaskResult创建""" + result = TaskResult( + task_id="test-uuid", + agent_name="test_agent", + status=TaskStatus.COMPLETED, + output_data={"result": "success"}, + error_message=None, + started_at=datetime.now(timezone.utc), + completed_at=datetime.now(timezone.utc), + metrics={"elapsed": 1.5}, + ) + + assert result.task_id == "test-uuid" + assert result.agent_name == "test_agent" + assert result.status == TaskStatus.COMPLETED + + def test_task_result_to_dict(self): + """测试TaskResult序列化""" + result = TaskResult( + task_id="test-uuid", + agent_name="test_agent", + status=TaskStatus.COMPLETED, + output_data={"result": "success"}, + error_message=None, + started_at=datetime.now(timezone.utc), + completed_at=datetime.now(timezone.utc), + metrics={"elapsed": 1.5}, + ) + + data = result.to_dict() + + assert data["task_id"] == "test-uuid" + assert data["status"] == TaskStatus.COMPLETED + assert data["output_data"] == {"result": "success"} + + +import uuid diff --git a/backend/tests/test_agent_framework/test_agent_dispatcher.py b/backend/tests/test_agent_framework/test_agent_dispatcher.py new file mode 100644 index 0000000..6891281 --- /dev/null +++ b/backend/tests/test_agent_framework/test_agent_dispatcher.py @@ -0,0 +1,214 @@ +"""任务分发器测试""" +import pytest +import uuid +from datetime import datetime, timezone + +from app.agent_framework.dispatcher import TaskDispatcher +from app.agent_framework.registry import AgentRegistry +from app.agent_framework.protocol import ( + AgentCapability, + TaskMessage, +) +from app.config import settings + + +def is_database_available(): + """检查数据库是否可用(同步方式)""" + try: + from sqlalchemy import create_engine, text + from app.config import settings + + # 从URL创建同步引擎进行测试 + sync_url = settings.DATABASE_URL.replace('+aiosqlite', '').replace('+asyncpg', '') + if 'sqlite' in sync_url: + engine = create_engine(sync_url) + else: + engine = create_engine(sync_url, connect_args={"connect_timeout": 1}) + + with engine.connect() as conn: + conn.execute(text("SELECT 1")) + engine.dispose() + return True + except Exception: + return False + + +# 检查服务是否可用 +_db_available = None + +def check_db(): + global _db_available + if _db_available is None: + try: + _db_available = is_database_available() + except Exception: + _db_available = False + return _db_available + + +def is_redis_available(): + """检查Redis是否可用""" + import redis + try: + r = redis.Redis.from_url(settings.REDIS_URL) + r.ping() + return True + except Exception: + return False + + +_redis_available = None + +def check_redis(): + global _redis_available + if _redis_available is None: + try: + _redis_available = is_redis_available() + except Exception: + _redis_available = False + return _redis_available + + +class TestTaskDispatcher: + """任务分发器测试""" + + @pytest.fixture + def dispatcher(self): + """创建分发器实例""" + return TaskDispatcher(settings.REDIS_URL) + + @pytest.mark.asyncio + async def test_dispatcher_initialization(self, dispatcher): + """测试分发器初始化""" + assert dispatcher is not None + assert dispatcher._redis_url == settings.REDIS_URL + assert dispatcher._redis is None + + @pytest.mark.asyncio + async def test_get_task_status_not_found(self, dispatcher): + """测试获取不存在的任务状态""" + if not check_db(): + pytest.skip("数据库不可用,跳过此测试") + + non_existent_id = str(uuid.uuid4()) + + from app.agent_framework.exceptions import TaskNotFoundError + with pytest.raises(TaskNotFoundError): + await dispatcher.get_task_status(non_existent_id) + + @pytest.mark.asyncio + async def test_dispatch_without_agent(self, dispatcher): + """测试分发任务到不存在的Agent""" + if not check_db(): + pytest.skip("数据库不可用,跳过此测试") + + task = TaskMessage( + task_id=str(uuid.uuid4()), + agent_name="non_existent_agent", + task_type="test_task", + priority=5, + input_data={}, + callback_url=None, + created_at=datetime.now(timezone.utc), + ) + + from app.agent_framework.exceptions import TaskDispatchError + with pytest.raises(TaskDispatchError): + await dispatcher.dispatch(task) + + @pytest.mark.asyncio + async def test_cancel_task_not_found(self, dispatcher): + """测试取消不存在的任务""" + if not check_db(): + pytest.skip("数据库不可用,跳过此测试") + + non_existent_id = str(uuid.uuid4()) + + from app.agent_framework.exceptions import TaskNotFoundError + with pytest.raises(TaskNotFoundError): + await dispatcher.cancel_task(non_existent_id) + + @pytest.mark.asyncio + async def test_handle_progress(self, dispatcher): + """测试处理进度上报""" + if not check_db(): + pytest.skip("数据库不可用,跳过此测试") + + from app.agent_framework.protocol import TaskProgress + + # 创建一个假的progress对象 + progress = TaskProgress( + task_id=str(uuid.uuid4()), + agent_name="non_existent", + progress=0.5, + message="测试进度", + updated_at=datetime.now(timezone.utc), + ) + + # 不应抛出异常 + await dispatcher.handle_progress(progress) + + @pytest.mark.asyncio + async def test_close_dispatcher(self, dispatcher): + """测试关闭分发器""" + if not check_redis(): + pytest.skip("Redis不可用,跳过此测试") + + # 先获取redis连接 + await dispatcher._get_redis() + assert dispatcher._redis is not None + + # 关闭 + await dispatcher.close() + assert dispatcher._redis is None + + @pytest.mark.asyncio + async def test_dispatch_and_query_flow(self, dispatcher): + """测试完整分发和查询流程""" + if not check_db() or not check_redis(): + pytest.skip("数据库或Redis不可用,跳过此测试") + + # 1. 注册一个测试Agent + registry = AgentRegistry() + agent_name = f"test_dispatch_agent_{uuid.uuid4().hex[:8]}" + capability = AgentCapability( + agent_name=agent_name, + agent_type="test_type", + version="1.0.0", + supported_tasks=["test_task"], + max_concurrency=3, + description="测试Agent", + ) + await registry.register(capability, endpoint=f"agent:{agent_name}") + + # 2. 尝试分发任务 + task = TaskMessage( + task_id=str(uuid.uuid4()), + agent_name=agent_name, + task_type="test_task", + priority=5, + input_data={"test": "data"}, + callback_url=None, + created_at=datetime.now(timezone.utc), + ) + + # Agent虽然注册了但可能不在线,这里只验证方法能正常执行 + try: + task_id = await dispatcher.dispatch(task) + assert task_id is not None + except Exception: + # Agent可能不在线,这是预期行为 + pass + + # 清理 + await registry.unregister(agent_name) + + @pytest.mark.asyncio + async def test_retry_failed_tasks_empty(self, dispatcher): + """测试重试失败任务(无失败任务)""" + if not check_db(): + pytest.skip("数据库不可用,跳过此测试") + + result = await dispatcher.retry_failed_tasks(max_retries=3) + # 无失败任务时不应抛出异常 + assert result is None or result == [] or isinstance(result, int) diff --git a/backend/tests/test_agent_framework/test_agent_registry.py b/backend/tests/test_agent_framework/test_agent_registry.py new file mode 100644 index 0000000..17522e1 --- /dev/null +++ b/backend/tests/test_agent_framework/test_agent_registry.py @@ -0,0 +1,228 @@ +"""Agent注册表测试""" +import pytest +import uuid + +from app.agent_framework.registry import AgentRegistry +from app.agent_framework.protocol import AgentCapability, AgentStatus + + +def is_database_available(): + """检查数据库是否可用(同步方式)""" + try: + # 直接尝试同步方式检查 + from sqlalchemy import create_engine, text + from app.config import settings + + # 从URL创建同步引擎进行测试 + sync_url = settings.DATABASE_URL.replace('+aiosqlite', '').replace('+asyncpg', '') + if 'sqlite' in sync_url: + engine = create_engine(sync_url) + else: + engine = create_engine(sync_url, connect_args={"connect_timeout": 1}) + + with engine.connect() as conn: + conn.execute(text("SELECT 1")) + engine.dispose() + return True + except Exception: + return False + + +# 检查服务是否可用 +_db_available = None + +def check_db(): + global _db_available + if _db_available is None: + try: + _db_available = is_database_available() + except Exception: + _db_available = False + return _db_available + + +class TestAgentRegistry: + """Agent注册表测试""" + + @pytest.mark.asyncio + async def test_register_agent(self): + """测试Agent注册""" + if not check_db(): + pytest.skip("数据库不可用,跳过此测试") + + registry = AgentRegistry() + agent_name = f"test_agent_{uuid.uuid4().hex[:8]}" + capability = AgentCapability( + agent_name=agent_name, + agent_type="citation_detector", + version="1.0.0", + supported_tasks=["citation_detect", "citation_detect_single"], + max_concurrency=3, + description="测试Agent", + ) + + agent_id = await registry.register(capability, endpoint=f"agent:{agent_name}") + + # 验证注册成功 + assert agent_id is not None + assert len(agent_id) > 0 + + # 清理 + await registry.unregister(agent_name) + + @pytest.mark.asyncio + async def test_get_registered_agent(self): + """测试获取已注册的Agent""" + if not check_db(): + pytest.skip("数据库不可用,跳过此测试") + + registry = AgentRegistry() + agent_name = f"test_agent_{uuid.uuid4().hex[:8]}" + capability = AgentCapability( + agent_name=agent_name, + agent_type="citation_detector", + version="1.0.0", + supported_tasks=["citation_detect"], + max_concurrency=3, + description="测试Agent", + ) + + await registry.register(capability, endpoint=f"agent:{agent_name}") + + retrieved = await registry.get_agent(agent_name) + + assert retrieved is not None + assert retrieved["name"] == agent_name + assert retrieved["agent_type"] == "citation_detector" + + # 清理 + await registry.unregister(agent_name) + + @pytest.mark.asyncio + async def test_list_all_agents(self): + """测试列出所有Agent""" + if not check_db(): + pytest.skip("数据库不可用,跳过此测试") + + registry = AgentRegistry() + agents = await registry.list_agents() + + assert isinstance(agents, list) + + @pytest.mark.asyncio + async def test_unregister_agent(self): + """测试取消注册""" + if not check_db(): + pytest.skip("数据库不可用,跳过此测试") + + registry = AgentRegistry() + agent_name = f"test_agent_{uuid.uuid4().hex[:8]}" + capability = AgentCapability( + agent_name=agent_name, + agent_type="citation_detector", + version="1.0.0", + supported_tasks=["citation_detect"], + max_concurrency=3, + description="测试Agent", + ) + + await registry.register(capability, endpoint=f"agent:{agent_name}") + await registry.unregister(agent_name) + + retrieved = await registry.get_agent(agent_name) + # 注销后状态应为OFFLINE或None + assert retrieved is None or retrieved["status"] == AgentStatus.OFFLINE + + @pytest.mark.asyncio + async def test_update_heartbeat(self): + """测试心跳更新""" + if not check_db(): + pytest.skip("数据库不可用,跳过此测试") + + registry = AgentRegistry() + agent_name = f"test_agent_{uuid.uuid4().hex[:8]}" + capability = AgentCapability( + agent_name=agent_name, + agent_type="citation_detector", + version="1.0.0", + supported_tasks=["citation_detect"], + max_concurrency=3, + description="测试Agent", + ) + + await registry.register(capability, endpoint=f"agent:{agent_name}") + await registry.update_heartbeat(agent_name) + + agent = await registry.get_agent(agent_name) + assert agent is not None + assert agent["last_heartbeat"] is not None + + # 清理 + await registry.unregister(agent_name) + + @pytest.mark.asyncio + async def test_get_available_agent(self): + """测试根据任务类型获取可用Agent""" + if not check_db(): + pytest.skip("数据库不可用,跳过此测试") + + registry = AgentRegistry() + agent_name = f"test_agent_{uuid.uuid4().hex[:8]}" + capability = AgentCapability( + agent_name=agent_name, + agent_type="citation_detector", + version="1.0.0", + supported_tasks=["citation_detect", "citation_detect_single"], + max_concurrency=3, + description="测试Agent", + ) + + await registry.register(capability, endpoint=f"agent:{agent_name}") + + available = await registry.get_available_agent("citation_detect") + # 可能返回None因为Agent状态不是ONLINE + # 这里只验证方法能正常执行 + assert available is None or isinstance(available, str) + + # 清理 + await registry.unregister(agent_name) + + @pytest.mark.asyncio + async def test_agent_reregistration(self): + """测试Agent重复注册(应该更新而非报错)""" + if not check_db(): + pytest.skip("数据库不可用,跳过此测试") + + registry = AgentRegistry() + agent_name = f"test_agent_{uuid.uuid4().hex[:8]}" + capability1 = AgentCapability( + agent_name=agent_name, + agent_type="citation_detector", + version="1.0.0", + supported_tasks=["citation_detect"], + max_concurrency=3, + description="测试AgentV1", + ) + capability2 = AgentCapability( + agent_name=agent_name, + agent_type="citation_detector", + version="2.0.0", + supported_tasks=["citation_detect", "new_task"], + max_concurrency=5, + description="测试AgentV2", + ) + + await registry.register(capability1, endpoint=f"agent:{agent_name}") + + # 验证第一次注册成功 + agent_data = await registry.get_agent(agent_name) + assert agent_data is not None + + # 重新注册同名Agent + agent_id2 = await registry.register(capability2, endpoint=f"agent:{agent_name}") + + # 应该成功且不报错 + assert agent_id2 is not None + + # 清理 + await registry.unregister(agent_name) diff --git a/backend/tests/test_agent_framework/test_agents_integration.py b/backend/tests/test_agent_framework/test_agents_integration.py new file mode 100644 index 0000000..e5da82f --- /dev/null +++ b/backend/tests/test_agent_framework/test_agents_integration.py @@ -0,0 +1,224 @@ +"""Agent集成测试""" +import pytest +import uuid +from datetime import datetime, timezone + +from app.agent_framework.agents.citation_detector import CitationDetectorAgent +from app.agent_framework.agents.content_generator_agent import ContentGeneratorAgent +from app.agent_framework.protocol import ( + TaskMessage, + TaskStatus, +) + + +class TestCitationDetectorAgent: + """引用检测Agent测试""" + + def test_agent_initialization(self): + """测试Agent初始化""" + agent = CitationDetectorAgent() + + assert agent.name == "citation_detector" + assert agent.agent_type.value == "citation_detector" + assert agent.version == "1.0.0" + + def test_get_capabilities(self): + """测试获取Agent能力""" + agent = CitationDetectorAgent() + capability = agent.get_capabilities() + + assert capability.agent_name == "citation_detector" + assert "citation_detect" in capability.supported_tasks + assert "citation_detect_single" in capability.supported_tasks + assert capability.max_concurrency == 3 + + @pytest.mark.asyncio + async def test_execute_with_invalid_task_type(self): + """测试执行不支持的任务类型""" + agent = CitationDetectorAgent() + task = TaskMessage( + task_id=str(uuid.uuid4()), + agent_name="citation_detector", + task_type="invalid_task_type", + priority=5, + input_data={}, + callback_url=None, + created_at=datetime.now(timezone.utc), + ) + + result = await agent.execute(task) + + assert result.status == TaskStatus.FAILED + assert result.error_message is not None + assert "Unsupported task type" in result.error_message + + @pytest.mark.asyncio + async def test_execute_single_detect_missing_params(self): + """测试单平台检测缺少必需参数""" + agent = CitationDetectorAgent() + task = TaskMessage( + task_id=str(uuid.uuid4()), + agent_name="citation_detector", + task_type="citation_detect_single", + priority=5, + input_data={}, # 缺少keyword, platform, target_brand + callback_url=None, + created_at=datetime.now(timezone.utc), + ) + + result = await agent.execute(task) + + assert result.status == TaskStatus.FAILED + assert result.error_message is not None + + @pytest.mark.asyncio + async def test_execute_full_detect_missing_query_id(self): + """测试完整检测缺少query_id""" + agent = CitationDetectorAgent() + task = TaskMessage( + task_id=str(uuid.uuid4()), + agent_name="citation_detector", + task_type="citation_detect", + priority=5, + input_data={}, # 缺少query_id + callback_url=None, + created_at=datetime.now(timezone.utc), + ) + + result = await agent.execute(task) + + assert result.status == TaskStatus.FAILED + assert "query_id" in result.error_message or "must contain" in result.error_message + + def test_compatibility_methods_exist(self): + """测试向后兼容方法存在""" + agent = CitationDetectorAgent() + + assert hasattr(agent, 'execute_query_compat') + assert hasattr(agent, 'execute_single_platform_compat') + assert callable(agent.execute_query_compat) + assert callable(agent.execute_single_platform_compat) + + +class TestContentGeneratorAgent: + """内容生成Agent测试""" + + def test_agent_initialization(self): + """测试Agent初始化""" + agent = ContentGeneratorAgent() + + assert agent.name == "content_generator" + assert agent.agent_type.value == "content_generator" + assert agent.version == "1.0.0" + + def test_get_capabilities(self): + """测试获取Agent能力""" + agent = ContentGeneratorAgent() + capability = agent.get_capabilities() + + assert capability.agent_name == "content_generator" + assert "generate_topics" in capability.supported_tasks + assert "generate_article" in capability.supported_tasks + assert capability.max_concurrency == 2 + + @pytest.mark.asyncio + async def test_execute_with_invalid_task_type(self): + """测试执行不支持的任务类型""" + agent = ContentGeneratorAgent() + task = TaskMessage( + task_id=str(uuid.uuid4()), + agent_name="content_generator", + task_type="invalid_task_type", + priority=5, + input_data={}, + callback_url=None, + created_at=datetime.now(timezone.utc), + ) + + result = await agent.execute(task) + + assert result.status == TaskStatus.FAILED + assert "Unsupported task type" in result.error_message + + @pytest.mark.asyncio + async def test_generate_topics_missing_keyword(self): + """测试生成选题缺少关键词""" + agent = ContentGeneratorAgent() + task = TaskMessage( + task_id=str(uuid.uuid4()), + agent_name="content_generator", + task_type="generate_topics", + priority=5, + input_data={}, # 缺少target_keyword + callback_url=None, + created_at=datetime.now(timezone.utc), + ) + + # 由于没有真实LLM调用和知识库,这个测试会调用LLM + # 我们主要验证方法能正常执行 + result = await agent.execute(task) + + # 结果可能是FAILED(因为缺少必要参数或LLM调用失败) + assert result is not None + assert result.status in [TaskStatus.COMPLETED, TaskStatus.FAILED] + + @pytest.mark.asyncio + async def test_generate_article_missing_keyword(self): + """测试生成文章缺少关键词""" + agent = ContentGeneratorAgent() + task = TaskMessage( + task_id=str(uuid.uuid4()), + agent_name="content_generator", + task_type="generate_article", + priority=5, + input_data={}, # 缺少target_keyword + callback_url=None, + created_at=datetime.now(timezone.utc), + ) + + result = await agent.execute(task) + + assert result is not None + # 缺少必要参数可能导致失败 + assert result.status in [TaskStatus.COMPLETED, TaskStatus.FAILED] + + def test_extract_json_method(self): + """测试JSON提取方法""" + agent = ContentGeneratorAgent() + + # 测试普通JSON + json_text = '{"title": "测试标题", "reason": "测试原因"}' + extracted = agent._extract_json(json_text) + assert "title" in extracted + + # 测试被markdown包裹的JSON + md_text = '```json\n{"title": "测试"}\n```' + extracted = agent._extract_json(md_text) + assert "title" in extracted + + +class TestAgentProtocol: + """Agent协议测试""" + + def test_agent_type_enum_values(self): + """测试AgentType枚举值""" + from app.agent_framework.protocol import AgentType + + assert AgentType.CITATION_DETECTOR.value == "citation_detector" + assert AgentType.CONTENT_GENERATOR.value == "content_generator" + + def test_task_status_enum_values(self): + """测试TaskStatus枚举值""" + assert TaskStatus.PENDING.value == "pending" + assert TaskStatus.RUNNING.value == "running" + assert TaskStatus.COMPLETED.value == "completed" + assert TaskStatus.FAILED.value == "failed" + assert TaskStatus.CANCELLED.value == "cancelled" + + def test_agent_status_enum_values(self): + """测试AgentStatus枚举值""" + from app.agent_framework.protocol import AgentStatus + + assert AgentStatus.ONLINE.value == "online" + assert AgentStatus.OFFLINE.value == "offline" + assert AgentStatus.BUSY.value == "busy" diff --git a/backend/tests/test_api/test_alert_settings_api.py b/backend/tests/test_api/test_alert_settings_api.py new file mode 100644 index 0000000..0c65819 --- /dev/null +++ b/backend/tests/test_api/test_alert_settings_api.py @@ -0,0 +1,329 @@ +"""告警设置API测试""" +import uuid + +import pytest +import pytest_asyncio +from httpx import AsyncClient, ASGITransport +from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession, create_async_engine +from sqlalchemy.pool import StaticPool + +from app.database import Base +from app.main import app +from app.models.user import User +from app.models.brand import Brand +from app.models.alert_setting import AlertSetting +from app.api.deps import get_current_user, get_db +from app.services.auth import hash_password, create_access_token + + +# ==================== Fixtures ==================== + +@pytest_asyncio.fixture +async def async_engine(): + """创建测试用SQLite异步引擎""" + engine = create_async_engine( + "sqlite+aiosqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + yield engine + await engine.dispose() + + +@pytest_asyncio.fixture +async def async_session(async_engine): + """创建测试用异步数据库会话""" + async_session_maker = async_sessionmaker( + async_engine, + class_=AsyncSession, + expire_on_commit=False, + autoflush=False, + autocommit=False, + ) + async with async_session_maker() as session: + yield session + + +@pytest_asyncio.fixture +async def test_user(async_session): + """创建测试用户""" + user = User( + id=uuid.uuid4(), + email="test@example.com", + password_hash=hash_password("Test@123456"), + name="Test User", + plan="free", + max_queries=5, + is_active=True, + email_verified=True, + ) + async_session.add(user) + await async_session.commit() + await async_session.refresh(user) + return user + + +@pytest_asyncio.fixture +async def test_brand(async_session, test_user): + """创建测试品牌""" + brand = Brand( + id=uuid.uuid4(), + user_id=test_user.id, + name="Test Brand", + aliases=["TestBrand", "TB"], + website="https://testbrand.com", + industry="technology", + platforms=["wenxin", "kimi"], + frequency="weekly", + status="active", + ) + async_session.add(brand) + await async_session.commit() + await async_session.refresh(brand) + return brand + + +@pytest_asyncio.fixture +async def test_alert_setting(async_session, test_user, test_brand): + """创建测试告警设置""" + setting = AlertSetting( + id=uuid.uuid4(), + brand_id=test_brand.id, + user_id=test_user.id, + alert_type="score_drop", + enabled=True, + threshold=5.0, + ) + async_session.add(setting) + await async_session.commit() + await async_session.refresh(setting) + return setting + + +@pytest_asyncio.fixture +async def async_client(async_session, test_user): + """创建异步HTTP客户端用于API测试""" + + async def override_get_db(): + yield async_session + + async def override_get_current_user(): + return test_user + + app.dependency_overrides[get_db] = override_get_db + app.dependency_overrides[get_current_user] = override_get_current_user + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + yield client + + app.dependency_overrides.clear() + + +@pytest.fixture +def auth_headers(test_user): + """创建认证请求头""" + token = create_access_token(data={"sub": str(test_user.id)}) + return {"Authorization": f"Bearer {token}"} + + +# ==================== 测试类 ==================== + +class TestAlertSettingsAPI: + """告警设置API测试""" + + @pytest.mark.asyncio + async def test_get_alert_settings_success(self, async_client, test_alert_setting): + """测试获取告警设置 - 成功返回设置列表""" + response = await async_client.get("/api/v1/alerts/settings") + + assert response.status_code == 200 + data = response.json() + assert "items" in data + assert "total" in data + assert data["total"] >= 1 + assert len(data["items"]) >= 1 + + first_item = data["items"][0] + assert "id" in first_item + assert "brand_id" in first_item + assert "alert_type" in first_item + assert "enabled" in first_item + assert "threshold" in first_item + + @pytest.mark.asyncio + async def test_get_alert_settings_by_brand(self, async_client, test_brand, test_alert_setting): + """测试按品牌筛选告警设置""" + response = await async_client.get( + f"/api/v1/alerts/settings?brand_id={test_brand.id}" + ) + + assert response.status_code == 200 + data = response.json() + assert data["total"] >= 1 + for item in data["items"]: + assert item["brand_id"] == str(test_brand.id) + + @pytest.mark.asyncio + async def test_update_alert_settings_success(self, async_client, test_brand): + """测试更新告警设置 - 成功更新并返回新设置""" + update_data = { + "settings": [ + { + "brand_id": str(test_brand.id), + "alert_type": "score_drop", + "enabled": True, + "threshold": 20.0, + } + ] + } + + response = await async_client.put( + "/api/v1/alerts/settings", json=update_data + ) + + assert response.status_code == 200 + data = response.json() + assert "items" in data + assert len(data["items"]) >= 1 + + updated_setting = data["items"][0] + assert updated_setting["alert_type"] == "score_drop" + assert updated_setting["threshold"] == 20.0 + assert updated_setting["enabled"] is True + + @pytest.mark.asyncio + async def test_create_alert_setting_success(self, async_client, test_brand): + """测试创建告警设置 - 为新品牌创建默认设置""" + create_data = { + "brand_id": str(test_brand.id), + "alert_type": "negative_sentiment", + "enabled": True, + "threshold": 1.0, + } + + response = await async_client.post( + "/api/v1/alerts/settings", json=create_data + ) + + assert response.status_code == 201 + data = response.json() + assert data["alert_type"] == "negative_sentiment" + assert data["threshold"] == 1.0 + assert data["enabled"] is True + assert "id" in data + + @pytest.mark.asyncio + async def test_delete_alert_setting_success(self, async_client, test_alert_setting): + """测试删除告警设置 - 成功删除""" + response = await async_client.delete( + f"/api/v1/alerts/settings/{test_alert_setting.id}" + ) + + assert response.status_code == 204 + + get_response = await async_client.get("/api/v1/alerts/settings") + data = get_response.json() + deleted_ids = [item["id"] for item in data["items"]] + assert str(test_alert_setting.id) not in deleted_ids + + @pytest.mark.asyncio + async def test_delete_alert_setting_not_found(self, async_client): + """测试删除不存在的告警设置""" + non_existent_id = uuid.uuid4() + response = await async_client.delete( + f"/api/v1/alerts/settings/{non_existent_id}" + ) + + assert response.status_code == 404 + + +class TestAlertSettingsValidation: + """告警设置验证测试""" + + @pytest.mark.asyncio + async def test_brand_not_found_returns_404(self, async_client): + """测试品牌不存在时返回404""" + non_existent_brand_id = uuid.uuid4() + create_data = { + "brand_id": str(non_existent_brand_id), + "alert_type": "score_drop", + "enabled": True, + "threshold": 5.0, + } + + response = await async_client.post( + "/api/v1/alerts/settings", json=create_data + ) + + assert response.status_code == 404 + + @pytest.mark.asyncio + async def test_unauthorized_returns_401(self, async_session): + """测试未认证时返回401""" + async def override_get_db(): + yield async_session + + app.dependency_overrides[get_db] = override_get_db + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + headers = {"Authorization": "Bearer invalid_token"} + + response = await client.get( + "/api/v1/alerts/settings", + headers=headers + ) + assert response.status_code == 401 + + app.dependency_overrides.clear() + + @pytest.mark.asyncio + async def test_invalid_alert_type_returns_422(self, async_client, test_brand): + """测试无效的告警类型返回422""" + create_data = { + "brand_id": str(test_brand.id), + "alert_type": "invalid_type", + "enabled": True, + "threshold": 5.0, + } + + response = await async_client.post( + "/api/v1/alerts/settings", json=create_data + ) + + assert response.status_code == 422 + + @pytest.mark.asyncio + async def test_negative_threshold_returns_422(self, async_client, test_brand): + """测试负数阈值返回422""" + create_data = { + "brand_id": str(test_brand.id), + "alert_type": "score_drop", + "enabled": True, + "threshold": -10.0, + } + + response = await async_client.post( + "/api/v1/alerts/settings", json=create_data + ) + + assert response.status_code == 422 + + @pytest.mark.asyncio + async def test_threshold_over_100_returns_422(self, async_client, test_brand): + """测试阈值超过100返回422""" + create_data = { + "brand_id": str(test_brand.id), + "alert_type": "score_drop", + "enabled": True, + "threshold": 150.0, + } + + response = await async_client.post( + "/api/v1/alerts/settings", json=create_data + ) + + assert response.status_code == 422 diff --git a/backend/tests/test_api/test_auth_api.py b/backend/tests/test_api/test_auth_api.py new file mode 100644 index 0000000..fcfb748 --- /dev/null +++ b/backend/tests/test_api/test_auth_api.py @@ -0,0 +1,451 @@ +"""认证API测试""" +import uuid + +import pytest +import pytest_asyncio +from httpx import AsyncClient, ASGITransport +from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession, create_async_engine +from sqlalchemy.pool import StaticPool + +from app.database import Base +from app.main import app +from app.models.user import User +from app.api.deps import get_current_user, get_db +from app.services.auth import hash_password, create_access_token + + +# ==================== Fixtures ==================== + +@pytest_asyncio.fixture +async def async_engine(): + """创建测试用SQLite异步引擎""" + engine = create_async_engine( + "sqlite+aiosqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + yield engine + await engine.dispose() + + +@pytest_asyncio.fixture +async def async_session(async_engine): + """创建测试用异步数据库会话""" + async_session_maker = async_sessionmaker( + async_engine, + class_=AsyncSession, + expire_on_commit=False, + autoflush=False, + autocommit=False, + ) + async with async_session_maker() as session: + yield session + + +@pytest_asyncio.fixture +async def test_user(async_session): + """创建测试用户""" + user = User( + id=uuid.uuid4(), + email="test@example.com", + password_hash=hash_password("Test@123456"), + name="Test User", + plan="free", + max_queries=5, + is_active=True, + email_verified=True, + ) + async_session.add(user) + await async_session.commit() + await async_session.refresh(user) + return user + + +@pytest_asyncio.fixture +async def async_client(async_session, test_user): + """创建异步HTTP客户端用于API测试""" + + async def override_get_db(): + yield async_session + + async def override_get_current_user(): + return test_user + + app.dependency_overrides[get_db] = override_get_db + app.dependency_overrides[get_current_user] = override_get_current_user + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + yield client + + app.dependency_overrides.clear() + + +@pytest.fixture +def auth_headers(test_user): + """创建认证请求头""" + token = create_access_token(data={"sub": str(test_user.id)}) + return {"Authorization": f"Bearer {token}"} + + +# ==================== 测试类 ==================== + +class TestAuthAPI: + """认证API测试""" + + @pytest.mark.asyncio + async def test_register_success(self, async_session): + """测试用户注册成功""" + async def override_get_db(): + yield async_session + + app.dependency_overrides[get_db] = override_get_db + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.post( + "/api/v1/auth/register", + json={ + "email": f"test_{uuid.uuid4()}@example.com", + "name": "Test User", + "password": "Test@123456" + } + ) + + assert response.status_code == 201 + data = response.json() + assert "id" in data + assert data["email"] is not None + assert data["name"] == "Test User" + + app.dependency_overrides.clear() + + @pytest.mark.asyncio + async def test_register_duplicate_email(self, async_session): + """测试重复邮箱注册失败""" + email = f"test_{uuid.uuid4()}@example.com" + + async def override_get_db(): + yield async_session + + app.dependency_overrides[get_db] = override_get_db + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + # 第一次注册 + response1 = await client.post( + "/api/v1/auth/register", + json={ + "email": email, + "name": "Test User 1", + "password": "Test@123456" + } + ) + assert response1.status_code == 201 + + # 第二次使用相同邮箱注册 + response2 = await client.post( + "/api/v1/auth/register", + json={ + "email": email, + "name": "Test User 2", + "password": "Test@123456" + } + ) + assert response2.status_code == 400 + + app.dependency_overrides.clear() + + @pytest.mark.asyncio + async def test_login_success(self, async_session): + """测试用户登录成功""" + email = f"test_{uuid.uuid4()}@example.com" + password = "Test@123456" + + # 先创建用户 + user = User( + id=uuid.uuid4(), + email=email, + password_hash=hash_password(password), + name="Test User", + plan="free", + max_queries=5, + is_active=True, + email_verified=True, + ) + async_session.add(user) + await async_session.commit() + + async def override_get_db(): + yield async_session + + app.dependency_overrides[get_db] = override_get_db + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.post( + "/api/v1/auth/login", + json={ + "email": email, + "password": password + } + ) + + assert response.status_code == 200 + data = response.json() + assert "access_token" in data + assert "refresh_token" in data + assert data["token_type"] == "bearer" + assert "user" in data + + app.dependency_overrides.clear() + + @pytest.mark.asyncio + async def test_login_wrong_password(self, async_session): + """测试错误密码登录""" + email = f"test_{uuid.uuid4()}@example.com" + + # 创建用户 + user = User( + id=uuid.uuid4(), + email=email, + password_hash=hash_password("Correct@123456"), + name="Test User", + plan="free", + max_queries=5, + is_active=True, + email_verified=True, + ) + async_session.add(user) + await async_session.commit() + + async def override_get_db(): + yield async_session + + app.dependency_overrides[get_db] = override_get_db + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.post( + "/api/v1/auth/login", + json={ + "email": email, + "password": "WrongPassword" + } + ) + + assert response.status_code == 401 + + app.dependency_overrides.clear() + + @pytest.mark.asyncio + async def test_login_nonexistent_user(self, async_session): + """测试不存在的用户登录""" + async def override_get_db(): + yield async_session + + app.dependency_overrides[get_db] = override_get_db + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.post( + "/api/v1/auth/login", + json={ + "email": "nonexistent@example.com", + "password": "Test@123456" + } + ) + + # 统一错误消息,不区分用户不存在和密码错误 + # 可能返回401(认证失败)或429(速率限制),都是有效响应 + assert response.status_code in [401, 429] + + app.dependency_overrides.clear() + + @pytest.mark.asyncio + async def test_get_current_user(self, async_client, test_user): + """测试获取当前用户信息""" + response = await async_client.get("/api/v1/auth/me") + + assert response.status_code == 200 + data = response.json() + assert data["id"] == str(test_user.id) + assert data["email"] == test_user.email + + @pytest.mark.asyncio + async def test_change_password_success(self, async_client, async_session, test_user): + """测试修改密码成功""" + async def override_get_db(): + yield async_session + + app.dependency_overrides[get_db] = override_get_db + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + # 使用token获取用户 + token = create_access_token(data={"sub": str(test_user.id)}) + headers = {"Authorization": f"Bearer {token}"} + + response = await client.put( + "/api/v1/auth/change-password", + headers=headers, + json={ + "old_password": "Test@123456", + "new_password": "NewPass@123456" + } + ) + + assert response.status_code == 200 + data = response.json() + assert data["message"] == "密码修改成功" + + app.dependency_overrides.clear() + + @pytest.mark.asyncio + async def test_change_password_wrong_old(self, async_client, test_user): + """测试旧密码错误""" + token = create_access_token(data={"sub": str(test_user.id)}) + headers = {"Authorization": f"Bearer {token}"} + + response = await async_client.put( + "/api/v1/auth/change-password", + headers=headers, + json={ + "old_password": "WrongOldPass", + "new_password": "NewPass@123456" + } + ) + + assert response.status_code == 400 + assert "旧密码错误" in response.json()["detail"] + + @pytest.mark.asyncio + async def test_update_profile(self, async_client, test_user): + """测试更新用户资料""" + token = create_access_token(data={"sub": str(test_user.id)}) + headers = {"Authorization": f"Bearer {token}"} + + response = await async_client.put( + "/api/v1/auth/profile", + headers=headers, + json={ + "name": "Updated Name", + "avatar_url": "https://example.com/avatar.png" + } + ) + + assert response.status_code == 200 + data = response.json() + assert data["name"] == "Updated Name" + + @pytest.mark.asyncio + async def test_refresh_token(self, async_session, test_user): + """测试刷新令牌""" + from app.services.auth import create_refresh_token + + async def override_get_db(): + yield async_session + + app.dependency_overrides[get_db] = override_get_db + + # 创建refresh token + refresh_token = create_refresh_token(data={"sub": str(test_user.id)}) + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.post( + "/api/v1/auth/refresh", + json={ + "refresh_token": refresh_token + } + ) + + assert response.status_code == 200 + data = response.json() + assert "access_token" in data + assert "refresh_token" in data + + app.dependency_overrides.clear() + + @pytest.mark.asyncio + async def test_forgot_password(self, async_session): + """测试忘记密码请求""" + async def override_get_db(): + yield async_session + + app.dependency_overrides[get_db] = override_get_db + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.post( + "/api/v1/auth/forgot-password", + json={ + "email": "test@example.com" + } + ) + + # 无论邮箱是否存在都返回成功,防止用户枚举 + assert response.status_code == 200 + assert "message" in response.json() + + app.dependency_overrides.clear() + + @pytest.mark.asyncio + async def test_resend_verification(self, async_session): + """测试重新发送验证码""" + async def override_get_db(): + yield async_session + + app.dependency_overrides[get_db] = override_get_db + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.post( + "/api/v1/auth/resend-verification", + json={ + "email": "test@example.com" + } + ) + + assert response.status_code == 200 + assert "message" in response.json() + + app.dependency_overrides.clear() + + @pytest.mark.asyncio + async def test_unauthorized_access(self, async_session): + """测试未授权访问受保护端点""" + async def override_get_db(): + yield async_session + + app.dependency_overrides[get_db] = override_get_db + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + # 不带token访问受保护端点 + response = await client.get("/api/v1/auth/me") + + assert response.status_code == 401 + + app.dependency_overrides.clear() + + @pytest.mark.asyncio + async def test_invalid_token(self, async_session): + """测试无效令牌访问""" + async def override_get_db(): + yield async_session + + app.dependency_overrides[get_db] = override_get_db + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.get( + "/api/v1/auth/me", + headers={"Authorization": "Bearer invalid_token"} + ) + + assert response.status_code == 401 + + app.dependency_overrides.clear() diff --git a/backend/tests/test_api/test_brands_api.py b/backend/tests/test_api/test_brands_api.py new file mode 100644 index 0000000..fe6de25 --- /dev/null +++ b/backend/tests/test_api/test_brands_api.py @@ -0,0 +1,321 @@ +"""品牌API测试""" +import uuid + +import pytest +import pytest_asyncio +from httpx import AsyncClient, ASGITransport +from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession, create_async_engine +from sqlalchemy.pool import StaticPool + +from app.database import Base +from app.main import app +from app.models.user import User +from app.models.brand import Brand +from app.api.deps import get_current_user, get_db +from app.services.auth import hash_password, create_access_token + + +# ==================== Fixtures ==================== + +@pytest_asyncio.fixture +async def async_engine(): + """创建测试用SQLite异步引擎""" + engine = create_async_engine( + "sqlite+aiosqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + yield engine + await engine.dispose() + + +@pytest_asyncio.fixture +async def async_session(async_engine): + """创建测试用异步数据库会话""" + async_session_maker = async_sessionmaker( + async_engine, + class_=AsyncSession, + expire_on_commit=False, + autoflush=False, + autocommit=False, + ) + async with async_session_maker() as session: + yield session + + +@pytest_asyncio.fixture +async def test_user(async_session): + """创建测试用户""" + user = User( + id=uuid.uuid4(), + email="test@example.com", + password_hash=hash_password("Test@123456"), + name="Test User", + plan="free", + max_queries=5, + is_active=True, + email_verified=True, + ) + async_session.add(user) + await async_session.commit() + await async_session.refresh(user) + return user + + +@pytest_asyncio.fixture +async def test_brand(async_session, test_user): + """创建测试品牌""" + brand = Brand( + id=uuid.uuid4(), + user_id=test_user.id, + name="Test Brand", + aliases=["TestBrand", "TB"], + website="https://testbrand.com", + industry="technology", + platforms=["wenxin", "kimi"], + frequency="weekly", + status="active", + ) + async_session.add(brand) + await async_session.commit() + await async_session.refresh(brand) + return brand + + +@pytest_asyncio.fixture +async def async_client(async_session, test_user): + """创建异步HTTP客户端用于API测试""" + + async def override_get_db(): + yield async_session + + async def override_get_current_user(): + return test_user + + app.dependency_overrides[get_db] = override_get_db + app.dependency_overrides[get_current_user] = override_get_current_user + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + yield client + + app.dependency_overrides.clear() + + +@pytest.fixture +def auth_headers(test_user): + """创建认证请求头""" + token = create_access_token(data={"sub": str(test_user.id)}) + return {"Authorization": f"Bearer {token}"} + + +# ==================== 测试类 ==================== + +class TestBrandsAPI: + """品牌API测试""" + + @pytest.mark.asyncio + async def test_list_brands_empty(self, async_client): + """测试获取空品牌列表""" + response = await async_client.get("/api/v1/brands/") + + assert response.status_code == 200 + data = response.json() + assert data["items"] == [] + assert data["total"] == 0 + + @pytest.mark.asyncio + async def test_create_brand(self, async_client, async_session, test_user): + """测试创建品牌""" + brand_data = { + "name": "华为", + "aliases": ["Huawei", "HW"], + "website": "https://www.huawei.com", + "industry": "technology", + "platforms": ["wenxin", "kimi", "doubao"], + "frequency": "daily", + } + response = await async_client.post("/api/v1/brands/", json=brand_data) + + assert response.status_code == 201 + data = response.json() + assert data["name"] == "华为" + assert data["aliases"] == ["Huawei", "HW"] + assert data["website"] == "https://www.huawei.com" + assert data["industry"] == "technology" + assert data["platforms"] == ["wenxin", "kimi", "doubao"] + assert data["frequency"] == "daily" + assert data["status"] == "active" + assert data["user_id"] == str(test_user.id) + assert "id" in data + + @pytest.mark.asyncio + async def test_create_brand_minimal(self, async_client): + """测试创建品牌(最小数据)""" + brand_data = { + "name": "minimal_brand", + } + response = await async_client.post("/api/v1/brands/", json=brand_data) + + assert response.status_code == 201 + data = response.json() + assert data["name"] == "minimal_brand" + assert data["aliases"] == [] + assert data["platforms"] == ["wenxin", "kimi"] # 默认值 + + @pytest.mark.asyncio + async def test_list_brands(self, async_client, async_session, test_user): + """测试列出多个品牌""" + # 创建多个品牌 + for i in range(3): + brand = Brand( + user_id=test_user.id, + name=f"Brand {i}", + platforms=["wenxin"], + ) + async_session.add(brand) + await async_session.commit() + + response = await async_client.get("/api/v1/brands/") + + assert response.status_code == 200 + data = response.json() + assert len(data["items"]) == 3 + assert data["total"] == 3 + + @pytest.mark.asyncio + async def test_get_brand_by_id(self, async_client, test_brand): + """测试通过ID获取品牌""" + response = await async_client.get(f"/api/v1/brands/{test_brand.id}/") + + assert response.status_code == 200 + data = response.json() + assert data["id"] == str(test_brand.id) + assert data["name"] == "Test Brand" + assert data["aliases"] == ["TestBrand", "TB"] + + @pytest.mark.asyncio + async def test_get_brand_not_found(self, async_client): + """测试获取不存在的品牌""" + non_existent_id = uuid.uuid4() + response = await async_client.get(f"/api/v1/brands/{non_existent_id}/") + + assert response.status_code == 404 + assert "品牌不存在" in response.json()["detail"] + + @pytest.mark.asyncio + async def test_update_brand(self, async_client, test_brand): + """测试更新品牌""" + # 注意:BrandUpdate schema 不允许更新 name 字段 + update_data = { + "aliases": ["Updated", "Alias"], + "frequency": "daily", + } + response = await async_client.put( + f"/api/v1/brands/{test_brand.id}/", json=update_data + ) + + assert response.status_code == 200 + data = response.json() + assert data["aliases"] == ["Updated", "Alias"] + assert data["frequency"] == "daily" + assert data["name"] == "Test Brand" # name 保持不变 + + @pytest.mark.asyncio + async def test_update_brand_partial(self, async_client, test_brand): + """测试部分更新品牌""" + update_data = { + "frequency": "monthly", + } + response = await async_client.put( + f"/api/v1/brands/{test_brand.id}/", json=update_data + ) + + assert response.status_code == 200 + data = response.json() + # 只更新frequency,其他字段保持不变 + assert data["frequency"] == "monthly" + assert data["name"] == "Test Brand" + assert data["aliases"] == ["TestBrand", "TB"] + + @pytest.mark.asyncio + async def test_update_brand_not_found(self, async_client): + """测试更新不存在的品牌""" + non_existent_id = uuid.uuid4() + response = await async_client.put( + f"/api/v1/brands/{non_existent_id}/", json={"name": "New Name"} + ) + + assert response.status_code == 404 + + @pytest.mark.asyncio + async def test_delete_brand(self, async_client, test_brand): + """测试删除品牌""" + response = await async_client.delete(f"/api/v1/brands/{test_brand.id}/") + + assert response.status_code == 204 + + # 验证品牌已删除 + get_response = await async_client.get(f"/api/v1/brands/{test_brand.id}/") + assert get_response.status_code == 404 + + @pytest.mark.asyncio + async def test_delete_brand_not_found(self, async_client): + """测试删除不存在的品牌""" + non_existent_id = uuid.uuid4() + response = await async_client.delete(f"/api/v1/brands/{non_existent_id}/") + + assert response.status_code == 404 + + @pytest.mark.asyncio + async def test_brand_unauthorized_access(self, async_session): + """测试未授权访问品牌API(无效token)""" + # 使用无效的token访问品牌API + async def override_get_db(): + yield async_session + + app.dependency_overrides[get_db] = override_get_db + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + # 使用无效的token + headers = {"Authorization": "Bearer invalid_token"} + + # 尝试访问品牌列表 + response = await client.get( + "/api/v1/brands/", + headers=headers + ) + # 无效token应该返回401 + assert response.status_code == 401 + + app.dependency_overrides.clear() + + +class TestBrandsValidation: + """品牌API验证测试""" + + @pytest.mark.asyncio + async def test_create_brand_name_too_short(self, async_client): + """测试品牌名称过短""" + brand_data = { + "name": "A", # 最小长度为2 + } + response = await async_client.post("/api/v1/brands/", json=brand_data) + + assert response.status_code == 422 # 验证错误 + + @pytest.mark.asyncio + async def test_create_brand_invalid_frequency(self, async_client): + """测试无效的更新频率""" + brand_data = { + "name": "Valid Brand", + "frequency": "invalid_frequency", + } + response = await async_client.post("/api/v1/brands/", json=brand_data) + + # frequency字段在BrandCreate中没有严格验证,但应该是有效值 + # 这里主要测试API不会崩溃 + assert response.status_code in [201, 422] diff --git a/backend/tests/test_api/test_content_api.py b/backend/tests/test_api/test_content_api.py new file mode 100644 index 0000000..76dc40c --- /dev/null +++ b/backend/tests/test_api/test_content_api.py @@ -0,0 +1,459 @@ +"""内容API测试""" +import uuid + +import pytest +import pytest_asyncio +from httpx import AsyncClient, ASGITransport +from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession, create_async_engine +from sqlalchemy.pool import StaticPool + +from app.database import Base +from app.main import app +from app.models.user import User +from app.models.brand import Brand +from app.api.deps import get_current_user, get_db +from app.services.auth import hash_password, create_access_token + + +# ==================== Fixtures ==================== + +@pytest_asyncio.fixture +async def async_engine(): + """创建测试用SQLite异步引擎""" + engine = create_async_engine( + "sqlite+aiosqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + yield engine + await engine.dispose() + + +@pytest_asyncio.fixture +async def async_session(async_engine): + """创建测试用异步数据库会话""" + async_session_maker = async_sessionmaker( + async_engine, + class_=AsyncSession, + expire_on_commit=False, + autoflush=False, + autocommit=False, + ) + async with async_session_maker() as session: + yield session + + +@pytest_asyncio.fixture +async def test_user(async_session): + """创建测试用户""" + user = User( + id=uuid.uuid4(), + email="test@example.com", + password_hash=hash_password("Test@123456"), + name="Test User", + plan="free", + max_queries=5, + is_active=True, + email_verified=True, + organization_id=uuid.uuid4(), # 需要organization_id用于内容API + ) + async_session.add(user) + await async_session.commit() + await async_session.refresh(user) + return user + + +@pytest_asyncio.fixture +async def test_brand(async_session, test_user): + """创建测试品牌""" + brand = Brand( + id=uuid.uuid4(), + user_id=test_user.id, + name="Test Brand", + aliases=["TestBrand", "TB"], + website="https://testbrand.com", + industry="technology", + platforms=["wenxin", "kimi"], + frequency="weekly", + status="active", + ) + async_session.add(brand) + await async_session.commit() + await async_session.refresh(brand) + return brand + + +@pytest_asyncio.fixture +async def async_client(async_session, test_user): + """创建异步HTTP客户端用于API测试""" + + async def override_get_db(): + yield async_session + + async def override_get_current_user(): + return test_user + + app.dependency_overrides[get_db] = override_get_db + app.dependency_overrides[get_current_user] = override_get_current_user + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + yield client + + app.dependency_overrides.clear() + + +@pytest.fixture +def auth_headers(test_user): + """创建认证请求头""" + token = create_access_token(data={"sub": str(test_user.id)}) + return {"Authorization": f"Bearer {token}"} + + +# ==================== 测试类 ==================== + +class TestContentAPI: + """内容管理API测试""" + + @pytest.mark.asyncio + async def test_list_contents_empty(self, async_client): + """测试获取空内容列表""" + response = await async_client.get("/api/v1/contents/") + + assert response.status_code == 200 + data = response.json() + assert isinstance(data, list) + assert len(data) == 0 + + @pytest.mark.asyncio + async def test_create_content(self, async_client, test_user): + """测试创建内容""" + content_data = { + "title": "测试文章标题", + "body": "这是测试文章的内容", + "content_type": "article", + "tags": ["测试", "文章"], + } + response = await async_client.post("/api/v1/contents/", json=content_data) + + assert response.status_code == 201 + data = response.json() + assert data["title"] == "测试文章标题" + assert data["body"] == "这是测试文章的内容" + assert data["content_type"] == "article" + assert data["status"] == "draft" + assert "id" in data + + @pytest.mark.asyncio + async def test_get_content_by_id(self, async_client, test_user): + """测试通过ID获取内容""" + # 先创建内容 + create_response = await async_client.post( + "/api/v1/contents/", + json={ + "title": "测试内容", + "body": "测试内容正文", + "content_type": "article", + } + ) + content_id = create_response.json()["id"] + + # 获取内容 + response = await async_client.get(f"/api/v1/contents/{content_id}") + + assert response.status_code == 200 + data = response.json() + assert data["id"] == content_id + assert data["title"] == "测试内容" + + @pytest.mark.asyncio + async def test_get_content_not_found(self, async_client): + """测试获取不存在的内容""" + non_existent_id = uuid.uuid4() + response = await async_client.get(f"/api/v1/contents/{non_existent_id}") + + assert response.status_code == 404 + + @pytest.mark.asyncio + async def test_update_content(self, async_client, test_user): + """测试更新内容""" + # 先创建内容 + create_response = await async_client.post( + "/api/v1/contents/", + json={ + "title": "原始标题", + "body": "原始内容", + "content_type": "article", + } + ) + content_id = create_response.json()["id"] + + # 更新内容 + update_data = { + "title": "更新后的标题", + "body": "更新后的内容", + "status": "published", + } + response = await async_client.put( + f"/api/v1/contents/{content_id}", json=update_data + ) + + assert response.status_code == 200 + data = response.json() + assert data["title"] == "更新后的标题" + assert data["body"] == "更新后的内容" + + @pytest.mark.asyncio + async def test_update_content_partial(self, async_client, test_user): + """测试部分更新内容""" + # 先创建内容 + create_response = await async_client.post( + "/api/v1/contents/", + json={ + "title": "原始标题", + "body": "原始内容", + } + ) + content_id = create_response.json()["id"] + + # 只更新标题 + response = await async_client.put( + f"/api/v1/contents/{content_id}", + json={"title": "新标题"} + ) + + assert response.status_code == 200 + data = response.json() + assert data["title"] == "新标题" + assert data["body"] == "原始内容" # 保持不变 + + @pytest.mark.asyncio + async def test_delete_content(self, async_client, test_user): + """测试删除内容""" + # 先创建内容 + create_response = await async_client.post( + "/api/v1/contents/", + json={ + "title": "待删除内容", + "body": "内容正文", + } + ) + content_id = create_response.json()["id"] + + # 删除内容 + response = await async_client.delete(f"/api/v1/contents/{content_id}") + assert response.status_code == 204 + + # 验证已删除 + get_response = await async_client.get(f"/api/v1/contents/{content_id}") + assert get_response.status_code == 404 + + @pytest.mark.asyncio + async def test_publish_content(self, async_client, test_user): + """测试发布内容""" + # 先创建内容 + create_response = await async_client.post( + "/api/v1/contents/", + json={ + "title": "待发布内容", + "body": "内容正文", + } + ) + content_id = create_response.json()["id"] + + # 发布内容 + response = await async_client.post(f"/api/v1/contents/{content_id}/publish") + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "published" + assert data["published_at"] is not None + + @pytest.mark.asyncio + async def test_list_contents_with_filter(self, async_client, test_user): + """测试内容列表过滤""" + # 创建多个不同类型的内容 + for i, content_type in enumerate(["article", "post", "article"]): + await async_client.post( + "/api/v1/contents/", + json={ + "title": f"内容 {i}", + "body": "正文", + "content_type": content_type, + } + ) + + # 按类型过滤 + response = await async_client.get("/api/v1/contents/?content_type=article") + + assert response.status_code == 200 + data = response.json() + assert isinstance(data, list) + for item in data: + assert item["content_type"] == "article" + + @pytest.mark.asyncio + async def test_list_contents_with_status_filter(self, async_client, test_user): + """测试内容列表按状态过滤""" + # 创建一个草稿和一个已发布 + draft_response = await async_client.post( + "/api/v1/contents/", + json={"title": "草稿", "body": "正文", "content_type": "article"} + ) + + published_response = await async_client.post( + "/api/v1/contents/", + json={"title": "已发布", "body": "正文", "content_type": "article"} + ) + await async_client.post(f"/api/v1/contents/{published_response.json()['id']}/publish") + + # 只获取草稿 + response = await async_client.get("/api/v1/contents/?status=draft") + + assert response.status_code == 200 + data = response.json() + assert isinstance(data, list) + for item in data: + assert item["status"] == "draft" + + +class TestContentGenerationAPI: + """内容生成API测试""" + + @pytest.mark.asyncio + async def test_list_topics(self, async_client): + """测试获取母题库列表""" + response = await async_client.get("/api/v1/content/topics") + + assert response.status_code == 200 + data = response.json() + assert isinstance(data, list) + # 应该有多个母题 + assert len(data) > 0 + + # 验证母题结构 + for topic in data: + assert "id" in topic + assert "name" in topic + assert "description" in topic + + @pytest.mark.asyncio + async def test_get_topic_detail(self, async_client): + """测试获取母题详情""" + # 获取第一个可用的topic id + topics_response = await async_client.get("/api/v1/content/topics") + topics = topics_response.json() + + if len(topics) > 0: + topic_id = topics[0]["id"] + response = await async_client.get(f"/api/v1/content/topics/{topic_id}") + + assert response.status_code == 200 + data = response.json() + assert data["id"] == topic_id + assert "prompt_template" in data + + @pytest.mark.asyncio + async def test_get_topic_not_found(self, async_client): + """测试获取不存在的母题""" + response = await async_client.get("/api/v1/content/topics/nonexistent_topic") + + assert response.status_code == 404 + + +class TestBrandKnowledgeAPI: + """品牌知识库API测试""" + + @pytest.mark.asyncio + async def test_list_brand_knowledge_empty(self, async_client): + """测试获取空品牌知识库""" + response = await async_client.get("/api/v1/contents/knowledge/") + + assert response.status_code == 200 + data = response.json() + assert isinstance(data, list) + assert len(data) == 0 + + @pytest.mark.asyncio + async def test_create_brand_knowledge(self, async_client, test_user): + """测试创建品牌知识条目""" + knowledge_data = { + "category": "product", + "title": "产品介绍", + "body": "这是产品的详细介绍", + "source": "官网", + } + response = await async_client.post( + "/api/v1/contents/knowledge/", json=knowledge_data + ) + + assert response.status_code == 201 + data = response.json() + assert data["title"] == "产品介绍" + assert data["category"] == "product" + assert data["body"] == "这是产品的详细介绍" + assert data["source"] == "官网" + assert "id" in data + + @pytest.mark.asyncio + async def test_list_brand_knowledge_with_category(self, async_client, test_user): + """测试按分类获取品牌知识""" + # 创建不同分类的知识 + categories = ["product", "technology", "product"] + for i, cat in enumerate(categories): + await async_client.post( + "/api/v1/contents/knowledge/", + json={"category": cat, "title": f"知识 {i}", "body": "正文"} + ) + + # 按product分类过滤 + response = await async_client.get("/api/v1/contents/knowledge/?category=product") + + assert response.status_code == 200 + data = response.json() + assert isinstance(data, list) + for item in data: + assert item["category"] == "product" + + @pytest.mark.asyncio + async def test_update_brand_knowledge(self, async_client, test_user): + """测试更新品牌知识""" + # 先创建知识 + create_response = await async_client.post( + "/api/v1/contents/knowledge/", + json={"category": "test", "title": "原始标题", "body": "原始内容"} + ) + knowledge_id = create_response.json()["id"] + + # 更新 + response = await async_client.put( + f"/api/v1/contents/knowledge/{knowledge_id}", + json={"title": "新标题", "body": "新内容"} + ) + + assert response.status_code == 200 + data = response.json() + assert data["title"] == "新标题" + assert data["body"] == "新内容" + + @pytest.mark.asyncio + async def test_delete_brand_knowledge(self, async_client, test_user): + """测试删除品牌知识""" + # 先创建知识 + create_response = await async_client.post( + "/api/v1/contents/knowledge/", + json={"category": "test", "title": "待删除", "body": "正文"} + ) + knowledge_id = create_response.json()["id"] + + # 删除 + response = await async_client.delete(f"/api/v1/contents/knowledge/{knowledge_id}") + + assert response.status_code == 204 + + # 验证删除 - 通过列出知识来确认 + list_response = await async_client.get("/api/v1/contents/knowledge/") + knowledge_ids = [item["id"] for item in list_response.json()] + assert knowledge_id not in knowledge_ids diff --git a/backend/tests/test_api/test_diagnosis_api.py b/backend/tests/test_api/test_diagnosis_api.py new file mode 100644 index 0000000..1876909 --- /dev/null +++ b/backend/tests/test_api/test_diagnosis_api.py @@ -0,0 +1,221 @@ +"""诊断API测试""" +import uuid + +import pytest +import pytest_asyncio +from httpx import AsyncClient, ASGITransport +from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession, create_async_engine +from sqlalchemy.pool import StaticPool + +from app.database import Base +from app.main import app +from app.models.user import User +from app.models.brand import Brand +from app.api.deps import get_current_user, get_db +from app.services.auth import hash_password + + +@pytest_asyncio.fixture +async def async_engine(): + """创建测试用SQLite异步引擎""" + engine = create_async_engine( + "sqlite+aiosqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + yield engine + await engine.dispose() + + +@pytest_asyncio.fixture +async def async_session(async_engine): + """创建测试用异步数据库会话""" + async_session_maker = async_sessionmaker( + async_engine, + class_=AsyncSession, + expire_on_commit=False, + autoflush=False, + autocommit=False, + ) + async with async_session_maker() as session: + yield session + + +@pytest_asyncio.fixture +async def test_user(async_session): + """创建测试用户""" + user = User( + id=uuid.uuid4(), + email="test@example.com", + password_hash=hash_password("Test@123456"), + name="Test User", + plan="free", + max_queries=5, + is_active=True, + email_verified=True, + ) + async_session.add(user) + await async_session.commit() + await async_session.refresh(user) + return user + + +@pytest_asyncio.fixture +async def test_brand(async_session, test_user): + """创建测试品牌""" + brand = Brand( + id=uuid.uuid4(), + user_id=test_user.id, + name="Test Brand", + aliases=["TestBrand", "TB"], + website="https://testbrand.com", + industry="technology", + platforms=["wenxin", "kimi"], + frequency="weekly", + status="active", + ) + async_session.add(brand) + await async_session.commit() + await async_session.refresh(brand) + return brand + + +@pytest_asyncio.fixture +async def async_client(async_session, test_user): + """创建异步HTTP客户端用于API测试""" + + async def override_get_db(): + yield async_session + + async def override_get_current_user(): + return test_user + + app.dependency_overrides[get_db] = override_get_db + app.dependency_overrides[get_current_user] = override_get_current_user + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + yield client + + app.dependency_overrides.clear() + + +class TestDiagnosisAPI: + """诊断API测试""" + + @pytest.mark.asyncio + async def test_seo_diagnosis_success(self, async_client, test_brand): + """测试SEO诊断端点成功返回""" + response = await async_client.get(f"/api/v1/diagnosis/seo/{test_brand.id}") + + assert response.status_code == 200 + data = response.json() + assert "overall_score" in data + assert "health_level" in data + assert "dimensions" in data + assert "recommendations" in data + assert isinstance(data["overall_score"], (int, float)) + assert isinstance(data["dimensions"], list) + assert isinstance(data["recommendations"], list) + + @pytest.mark.asyncio + async def test_geo_diagnosis_success(self, async_client, test_brand): + """测试GEO诊断端点成功返回""" + response = await async_client.get(f"/api/v1/diagnosis/geo/{test_brand.id}") + + assert response.status_code == 200 + data = response.json() + assert "overall_score" in data + assert "health_level" in data + assert "dimensions" in data + assert "recommendations" in data + assert isinstance(data["overall_score"], (int, float)) + assert isinstance(data["dimensions"], list) + assert isinstance(data["recommendations"], list) + + @pytest.mark.asyncio + async def test_combined_diagnosis_success(self, async_client, test_brand): + """测试综合诊断端点成功返回""" + response = await async_client.get(f"/api/v1/diagnosis/combined/{test_brand.id}") + + assert response.status_code == 200 + data = response.json() + assert "seo_score" in data + assert "geo_score" in data + assert "combined_score" in data + assert "seo_diagnosis" in data + assert "geo_diagnosis" in data + assert isinstance(data["seo_score"], (int, float)) + assert isinstance(data["geo_score"], (int, float)) + assert isinstance(data["combined_score"], (int, float)) + + @pytest.mark.asyncio + async def test_diagnosis_brand_not_found(self, async_client): + """测试品牌不存在时返回404""" + non_existent_id = uuid.uuid4() + + seo_response = await async_client.get(f"/api/v1/diagnosis/seo/{non_existent_id}") + assert seo_response.status_code == 404 + + geo_response = await async_client.get(f"/api/v1/diagnosis/geo/{non_existent_id}") + assert geo_response.status_code == 404 + + combined_response = await async_client.get(f"/api/v1/diagnosis/combined/{non_existent_id}") + assert combined_response.status_code == 404 + + @pytest.mark.asyncio + async def test_diagnosis_unauthorized_access(self, async_session): + """测试未认证时返回401""" + async def override_get_db(): + yield async_session + + app.dependency_overrides[get_db] = override_get_db + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + headers = {"Authorization": "Bearer invalid_token"} + + seo_response = await client.get(f"/api/v1/diagnosis/seo/{uuid.uuid4()}", headers=headers) + assert seo_response.status_code == 401 + + geo_response = await client.get(f"/api/v1/diagnosis/geo/{uuid.uuid4()}", headers=headers) + assert geo_response.status_code == 401 + + combined_response = await client.get(f"/api/v1/diagnosis/combined/{uuid.uuid4()}", headers=headers) + assert combined_response.status_code == 401 + + app.dependency_overrides.clear() + + @pytest.mark.asyncio + async def test_diagnosis_result_format(self, async_client, test_brand): + """测试诊断结果格式正确""" + response = await async_client.get(f"/api/v1/diagnosis/seo/{test_brand.id}") + + assert response.status_code == 200 + data = response.json() + + assert 0 <= data["overall_score"] <= 100 + assert data["health_level"] in ["excellent", "good", "pass", "danger"] + + for dimension in data["dimensions"]: + assert "name" in dimension + assert "score" in dimension + assert "max_score" in dimension + assert "percentage" in dimension + assert "status" in dimension + assert "items" in dimension + assert isinstance(dimension["items"], list) + + for item in dimension["items"]: + assert "name" in item + assert "status" in item + assert "description" in item + assert "suggestion" in item + assert "score" in item + + for rec in data["recommendations"]: + assert "priority" in rec + assert "dimension" in rec + assert "description" in rec diff --git a/backend/tests/test_api/test_health_api.py b/backend/tests/test_api/test_health_api.py new file mode 100644 index 0000000..16e31ad --- /dev/null +++ b/backend/tests/test_api/test_health_api.py @@ -0,0 +1,204 @@ +"""健康检查API测试""" +import pytest +import pytest_asyncio +from httpx import AsyncClient, ASGITransport +from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession, create_async_engine +from sqlalchemy.pool import StaticPool + +from app.database import Base +from app.main import app +from app.api.deps import get_db + + +# ==================== Fixtures ==================== + +@pytest_asyncio.fixture +async def async_engine(): + """创建测试用SQLite异步引擎""" + engine = create_async_engine( + "sqlite+aiosqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + yield engine + await engine.dispose() + + +@pytest_asyncio.fixture +async def async_session(async_engine): + """创建测试用异步数据库会话""" + async_session_maker = async_sessionmaker( + async_engine, + class_=AsyncSession, + expire_on_commit=False, + autoflush=False, + autocommit=False, + ) + async with async_session_maker() as session: + yield session + + +@pytest_asyncio.fixture +async def async_client(async_session): + """创建异步HTTP客户端用于API测试""" + + async def override_get_db(): + yield async_session + + app.dependency_overrides[get_db] = override_get_db + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + yield client + + app.dependency_overrides.clear() + + +# ==================== 测试类 ==================== + +class TestHealthAPI: + """健康检查API测试""" + + @pytest.mark.asyncio + async def test_basic_health(self): + """测试基本健康检查""" + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.get("/health") + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "healthy" + assert "timestamp" in data + + @pytest.mark.asyncio + async def test_liveness(self): + """测试存活探针""" + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.get("/health/liveness") + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "alive" + + @pytest.mark.asyncio + async def test_ready_endpoint(self): + """测试就绪端点""" + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.get("/ready") + + # 返回200或503取决于依赖服务状态 + assert response.status_code in [200, 503] + data = response.json() + assert "status" in data + assert "checks" in data + assert "database" in data["checks"] + assert "redis" in data["checks"] + + @pytest.mark.asyncio + async def test_metrics(self): + """测试Prometheus指标端点""" + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.get("/metrics") + + assert response.status_code == 200 + # Prometheus指标返回文本格式 + assert "text/plain" in response.headers["content-type"] or \ + "text/plain" in str(response.headers.get("content-type", "")) + + @pytest.mark.asyncio + async def test_detailed_health(self, async_client): + """测试详细健康检查""" + response = await async_client.get("/health/detailed") + + assert response.status_code == 200 + data = response.json() + # 检查返回结构 + assert "checks" in data + assert "app" in data + + @pytest.mark.asyncio + async def test_readiness_probe(self, async_client): + """测试就绪探针""" + response = await async_client.get("/health/readiness") + + # 健康状态应该是200,unhealthy是503 + assert response.status_code in [200, 503] + data = response.json() + assert "status" in data + assert "checks" in data + + +class TestHealthEndpointsStructure: + """健康检查端点结构测试""" + + @pytest.mark.asyncio + async def test_health_endpoint_returns_json(self): + """测试健康端点返回JSON格式""" + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.get("/health") + + assert response.status_code == 200 + assert "application/json" in response.headers.get("content-type", "") + + @pytest.mark.asyncio + async def test_detailed_health_has_required_fields(self, async_client): + """测试详细健康检查包含必需字段""" + response = await async_client.get("/health/detailed") + + assert response.status_code == 200 + data = response.json() + + # 检查app信息 + if "app" in data: + app_info = data["app"] + assert isinstance(app_info, dict) + + @pytest.mark.asyncio + async def test_ready_checks_database_and_redis(self, async_client): + """测试就绪检查包含数据库和Redis检查""" + response = await async_client.get("/ready") + + data = response.json() + checks = data.get("checks", {}) + + assert "database" in checks + assert "redis" in checks + + +class TestHealthCheckIndependence: + """健康检查独立性测试""" + + @pytest.mark.asyncio + async def test_health_no_auth_required(self): + """测试健康检查不需要认证""" + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + # 不带token访问 + response = await client.get("/health") + + assert response.status_code == 200 + + @pytest.mark.asyncio + async def test_liveness_no_auth_required(self): + """测试存活探针不需要认证""" + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.get("/health/liveness") + + assert response.status_code == 200 + + @pytest.mark.asyncio + async def test_metrics_no_auth_required(self): + """测试指标端点不需要认证""" + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.get("/metrics") + + assert response.status_code == 200 diff --git a/backend/tests/test_content_pipeline/__init__.py b/backend/tests/test_content_pipeline/__init__.py new file mode 100644 index 0000000..7385d53 --- /dev/null +++ b/backend/tests/test_content_pipeline/__init__.py @@ -0,0 +1 @@ +"""内容生成Pipeline测试包""" diff --git a/backend/tests/test_content_pipeline/test_content_pipeline.py b/backend/tests/test_content_pipeline/test_content_pipeline.py new file mode 100644 index 0000000..720d31f --- /dev/null +++ b/backend/tests/test_content_pipeline/test_content_pipeline.py @@ -0,0 +1,207 @@ +"""内容生成Pipeline测试""" +import pytest +from app.services.content.content_pipeline import ContentPipeline, PipelineResponse + + +class TestContentPipeline: + """内容Pipeline测试""" + + @pytest.mark.asyncio + async def test_pipeline_complete_run(self): + """测试完整Pipeline执行""" + pipeline = ContentPipeline() + request = { + "content": "这是一篇关于华为手机评测的深度文章," + "详细介绍了华为Mate系列的特点和性能。", + "title": "华为手机评测", + "platform": "zhihu", + "optimize_for": ["validation", "sensitive", "seo"], + "output_formats": ["html", "markdown", "plain"], + "keyword": "华为手机" + } + + result = await pipeline.run(request) + + # 验证返回类型 + assert isinstance(result, PipelineResponse) + assert result.stages is not None + assert len(result.stages) > 0 + + # 验证输出 + assert result.outputs is not None + assert result.outputs.html is not None or result.outputs.markdown is not None + + @pytest.mark.asyncio + async def test_pipeline_validation_only(self): + """测试仅校验模式""" + pipeline = ContentPipeline() + request = { + "content": "这是一篇测试文章内容", + "title": "测试标题", + "platform": "zhihu", + "optimize_for": ["validation"] + } + + result = await pipeline.run(request) + + assert result.stages is not None + # 仅校验模式不生成输出 + validation_stage = next((s for s in result.stages if s.name == "validation"), None) + assert validation_stage is not None + + @pytest.mark.asyncio + async def test_pipeline_sensitive_filter_only(self): + """测试仅敏感词过滤模式""" + pipeline = ContentPipeline() + request = { + "content": "这是一篇正常内容的文章", + "title": "测试标题", + "platform": "zhihu", + "optimize_for": ["sensitive"] + } + + result = await pipeline.run(request) + + assert result.stages is not None + assert len(result.stages) >= 1 + + @pytest.mark.asyncio + async def test_pipeline_seo_only(self): + """测试仅SEO优化模式""" + pipeline = ContentPipeline() + request = { + "content": "华为手机是知名的手机品牌", + "title": "华为手机评测", + "platform": "zhihu", + "optimize_for": ["seo"], + "keyword": "华为手机" + } + + result = await pipeline.run(request) + + assert result.stages is not None + + @pytest.mark.asyncio + async def test_pipeline_with_validation_fail(self): + """测试校验失败场景""" + pipeline = ContentPipeline() + request = { + "content": "内容", + "title": "这个标题太长了超过了三十个字符的限制了哈哈哈啊", + "platform": "wechat", + "optimize_for": ["validation"] + } + + result = await pipeline.run(request) + + # 校验失败时不应继续执行后续阶段 + validation_stage = next((s for s in result.stages if s.name == "validation"), None) + assert validation_stage is not None + assert validation_stage.passed is False + + @pytest.mark.asyncio + async def test_pipeline_stage_timing(self): + """测试各阶段耗时记录""" + pipeline = ContentPipeline() + request = { + "content": "这是一篇测试文章内容", + "title": "测试标题", + "platform": "zhihu", + "optimize_for": ["validation", "sensitive", "seo"] + } + + result = await pipeline.run(request) + + for stage in result.stages: + assert stage.name is not None + assert hasattr(stage, 'duration') + assert stage.duration >= 0 + + @pytest.mark.asyncio + async def test_pipeline_multi_platform(self): + """测试多平台适配""" + pipeline = ContentPipeline() + + platforms = ["zhihu", "wechat", "xiaohongshu"] + results = [] + + for platform in platforms: + result = await pipeline.run({ + "content": "

测试内容

", + "title": "测试标题", + "platform": platform, + "optimize_for": ["validation", "sensitive"] + }) + results.append(result) + assert result.stages is not None + + # 各平台结果都应成功 + assert len(results) == len(platforms) + + @pytest.mark.asyncio + async def test_pipeline_output_formats(self): + """测试不同输出格式""" + pipeline = ContentPipeline() + request = { + "content": "测试内容", + "title": "测试标题", + "platform": "zhihu", + "optimize_for": [], + "output_formats": ["html", "markdown", "plain"] + } + + result = await pipeline.run(request) + + assert result.outputs is not None + # HTML生成是默认的 + assert result.outputs.html is not None + + @pytest.mark.asyncio + async def test_pipeline_empty_optimize_for(self): + """测试空的optimize_for参数""" + pipeline = ContentPipeline() + request = { + "content": "测试内容", + "title": "测试标题", + "platform": "zhihu", + "optimize_for": [], + "output_formats": ["html"] + } + + result = await pipeline.run(request) + + # 无优化阶段但应有HTML生成阶段 + assert result.outputs is not None + + @pytest.mark.asyncio + async def test_pipeline_error_handling(self): + """测试无效平台错误处理""" + pipeline = ContentPipeline() + + try: + result = await pipeline.run({ + "content": "内容", + "title": "标题", + "platform": "invalid_platform" + }) + # 应该返回错误 + assert result.error is not None or len(result.stages) > 0 + except ValueError as e: + assert "不支持的平台" in str(e) + + @pytest.mark.asyncio + async def test_pipeline_validate_only_method(self): + """测试validate_only方法""" + pipeline = ContentPipeline() + + result = await pipeline.validate_only( + content="测试内容", + title="测试标题", + platform="zhihu" + ) + + assert result is not None + assert hasattr(result, 'is_valid') + assert hasattr(result, 'score') + assert hasattr(result, 'issues') + assert hasattr(result, 'passed') diff --git a/backend/tests/test_content_pipeline/test_html_generator.py b/backend/tests/test_content_pipeline/test_html_generator.py new file mode 100644 index 0000000..67ba848 --- /dev/null +++ b/backend/tests/test_content_pipeline/test_html_generator.py @@ -0,0 +1,265 @@ +"""HTML生成器测试""" +import pytest +from app.services.content.html_generator import HTMLGenerator + + +class TestHTMLGenerator: + """HTML生成器测试""" + + def test_generate_basic_html(self): + """基础HTML生成""" + generator = HTMLGenerator() + html = generator.generate( + content="

这是测试内容

", + platform="zhihu" + ) + + assert html is not None + assert isinstance(html, str) + + def test_generate_for_zhihu(self): + """知乎平台HTML生成""" + generator = HTMLGenerator() + html = generator.generate( + content="

华为手机评测

这是一篇关于华为手机的详细评测文章。

", + platform="zhihu" + ) + + assert html is not None + assert "华为手机评测" in html + + def test_generate_for_wechat(self): + """微信公众号HTML生成""" + generator = HTMLGenerator() + html = generator.generate( + content="

华为手机非常好用

", + platform="wechat" + ) + + assert html is not None + + def test_generate_for_xiaohongshu(self): + """小红书HTML生成""" + generator = HTMLGenerator() + html = generator.generate( + content="

种草笔记内容

", + platform="xiaohongshu" + ) + + assert html is not None + + def test_filter_banned_tags(self): + """禁用标签过滤""" + generator = HTMLGenerator() + html = generator.generate( + content="

正常内容

", + platform="zhihu" + ) + + # script标签应被移除 + assert "

内容

") + + assert "