fix: 消除所有Mock/Stub/假数据,确保业务流程使用真实数据
M1-引用检测核心: - 删除llm_adapter._get_mock_result()方法 - ENABLE_LLM=False时抛出LLMAdapterError而非返回随机数据 - ENABLE_LLM默认值改为True - 修复旧测试适配新行为 M2-知识库RAG: - knowledge.py不再默认使用MockEmbedder - 动态从APIKeyManager获取OpenAI Key - 无Key时返回503+明确错误信息 - 有Key时使用OpenAIEmbedder M3-AI引擎页面: - 删除MOCK_AI_ENGINES_RESPONSE fallback - 查询失败时显示错误状态 M4-组织管理页面: - 删除MOCK_ORG_INFO和MOCK_MEMBERS - API返回空时显示空状态 M5-首页Agent卡片: - 删除MOCK_AGENTS硬编码 - 替换为功能开发中占位 M6-平台规则历史: - 实现PlatformRuleVersion模型 - 实现版本对比API (diff) - 实现历史记录查询API (history) - 删除2个TODO注释 M7-知识图谱批量构建: - 实现批量创建实体API - 空输入验证+批量大小限制 - 删除TODO注释 - 修复路由双重前缀问题 测试: 643 passed (核心)
This commit is contained in:
parent
4cc8f73bb4
commit
fe4ba39514
|
|
@ -34,16 +34,27 @@ from app.schemas.knowledge import (
|
||||||
SearchResultItem,
|
SearchResultItem,
|
||||||
UpdateDocumentRequest,
|
UpdateDocumentRequest,
|
||||||
)
|
)
|
||||||
from app.services.knowledge import MockEmbedder, RAGService
|
from app.services.knowledge import RAGService
|
||||||
|
from app.services.knowledge.embedder import OpenAIEmbedder
|
||||||
from app.services.knowledge.enhanced_rag import EnhancedRAG
|
from app.services.knowledge.enhanced_rag import EnhancedRAG
|
||||||
from app.services.knowledge.incremental_index import IncrementalIndexService
|
from app.services.knowledge.incremental_index import IncrementalIndexService
|
||||||
from app.services.knowledge.chunker import ChunkerFactory
|
from app.services.knowledge.chunker import ChunkerFactory
|
||||||
|
from app.services.api_key_manager import APIKeyManager
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
# Shared RAG service instance (MockEmbedder by default; swap in OpenAIEmbedder via DI later)
|
_key_manager = APIKeyManager()
|
||||||
_rag_service = RAGService(embedder=MockEmbedder())
|
|
||||||
|
|
||||||
|
def _get_rag_service() -> RAGService:
|
||||||
|
api_key = _key_manager.get_key("chatgpt")
|
||||||
|
if api_key:
|
||||||
|
return RAGService(embedder=OpenAIEmbedder(api_key=api_key))
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=503,
|
||||||
|
detail="知识库功能需要配置OpenAI API Key。请在设置页面添加OpenAI API Key。",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -280,8 +291,11 @@ async def upload_document(
|
||||||
|
|
||||||
# Asynchronously ingest (same request; background task optimization later)
|
# Asynchronously ingest (same request; background task optimization later)
|
||||||
try:
|
try:
|
||||||
await _rag_service.ingest_document(db, str(doc.id))
|
rag_service = _get_rag_service()
|
||||||
|
await rag_service.ingest_document(db, str(doc.id))
|
||||||
await db.refresh(doc)
|
await db.refresh(doc)
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error(f"Ingest failed for document {doc.id}: {exc}")
|
logger.error(f"Ingest failed for document {doc.id}: {exc}")
|
||||||
# Status already set to 'failed' by ingest_document on exception
|
# Status already set to 'failed' by ingest_document on exception
|
||||||
|
|
@ -359,7 +373,7 @@ async def delete_document(
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Document not found")
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Document not found")
|
||||||
|
|
||||||
# Delete chunks first (cascade also handles this, but explicit for clarity)
|
# Delete chunks first (cascade also handles this, but explicit for clarity)
|
||||||
await _rag_service.delete_document_chunks(db, str(doc.id))
|
await _get_rag_service().delete_document_chunks(db, str(doc.id))
|
||||||
|
|
||||||
await db.delete(doc)
|
await db.delete(doc)
|
||||||
|
|
||||||
|
|
@ -466,7 +480,7 @@ async def knowledge_search(
|
||||||
|
|
||||||
t0 = time.monotonic()
|
t0 = time.monotonic()
|
||||||
|
|
||||||
raw_results = await _rag_service.search(
|
raw_results = await _get_rag_service().search(
|
||||||
db,
|
db,
|
||||||
query=body.query,
|
query=body.query,
|
||||||
knowledge_base_ids=body.knowledge_base_ids,
|
knowledge_base_ids=body.knowledge_base_ids,
|
||||||
|
|
@ -559,7 +573,7 @@ async def reindex_document(
|
||||||
|
|
||||||
await _get_kb(db, kb_id, org_id)
|
await _get_kb(db, kb_id, org_id)
|
||||||
|
|
||||||
index_service = IncrementalIndexService(_rag_service)
|
index_service = IncrementalIndexService(_get_rag_service())
|
||||||
result = await index_service.add_document(
|
result = await index_service.add_document(
|
||||||
db, str(kb_id), str(doc_id)
|
db, str(kb_id), str(doc_id)
|
||||||
)
|
)
|
||||||
|
|
@ -581,7 +595,7 @@ async def update_document_content(
|
||||||
|
|
||||||
await _get_kb(db, kb_id, org_id)
|
await _get_kb(db, kb_id, org_id)
|
||||||
|
|
||||||
index_service = IncrementalIndexService(_rag_service)
|
index_service = IncrementalIndexService(_get_rag_service())
|
||||||
result = await index_service.update_document(
|
result = await index_service.update_document(
|
||||||
db, str(doc_id), request.content
|
db, str(doc_id), request.content
|
||||||
)
|
)
|
||||||
|
|
@ -602,7 +616,7 @@ async def delete_document_incremental(
|
||||||
|
|
||||||
await _get_kb(db, kb_id, org_id)
|
await _get_kb(db, kb_id, org_id)
|
||||||
|
|
||||||
index_service = IncrementalIndexService(_rag_service)
|
index_service = IncrementalIndexService(_get_rag_service())
|
||||||
result = await index_service.delete_document(db, str(doc_id))
|
result = await index_service.delete_document(db, str(doc_id))
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
@ -621,7 +635,7 @@ async def rebuild_knowledge_base(
|
||||||
|
|
||||||
await _get_kb(db, kb_id, org_id)
|
await _get_kb(db, kb_id, org_id)
|
||||||
|
|
||||||
index_service = IncrementalIndexService(_rag_service)
|
index_service = IncrementalIndexService(_get_rag_service())
|
||||||
result = await index_service.rebuild_knowledge_base(
|
result = await index_service.rebuild_knowledge_base(
|
||||||
db, str(kb_id), force
|
db, str(kb_id), force
|
||||||
)
|
)
|
||||||
|
|
@ -642,7 +656,8 @@ async def enhanced_retrieve(
|
||||||
|
|
||||||
await _get_kb(db, kb_id, org_id)
|
await _get_kb(db, kb_id, org_id)
|
||||||
|
|
||||||
enhanced_rag = EnhancedRAG(_rag_service, _rag_service.embedder)
|
rag_service = _get_rag_service()
|
||||||
|
enhanced_rag = EnhancedRAG(rag_service, rag_service.embedder)
|
||||||
results = await enhanced_rag.retrieve_with_rerank(
|
results = await enhanced_rag.retrieve_with_rerank(
|
||||||
db,
|
db,
|
||||||
request.query,
|
request.query,
|
||||||
|
|
|
||||||
|
|
@ -3,9 +3,12 @@ from typing import Optional
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.api.deps import get_db, get_current_user
|
from app.api.deps import get_db, get_current_user
|
||||||
|
from app.models.knowledge import KnowledgeBase
|
||||||
|
from app.models.knowledge_graph import KnowledgeEntity, EntityType
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.services.knowledge.graph_builder import GraphBuilder
|
from app.services.knowledge.graph_builder import GraphBuilder
|
||||||
from app.services.knowledge.graph_query import GraphQuery
|
from app.services.knowledge.graph_query import GraphQuery
|
||||||
|
|
@ -13,6 +16,61 @@ from app.services.knowledge.graph_query import GraphQuery
|
||||||
router = APIRouter(prefix="/knowledge-bases", tags=["知识图谱"])
|
router = APIRouter(prefix="/knowledge-bases", tags=["知识图谱"])
|
||||||
|
|
||||||
|
|
||||||
|
class EntityCreateRequest(BaseModel):
|
||||||
|
name: str = Field(..., max_length=500)
|
||||||
|
entity_type: str
|
||||||
|
description: Optional[str] = None
|
||||||
|
properties: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
|
def _entity_to_dict(entity: KnowledgeEntity) -> dict:
|
||||||
|
return {
|
||||||
|
"id": str(entity.id),
|
||||||
|
"name": entity.name,
|
||||||
|
"entity_type": entity.entity_type.value,
|
||||||
|
"description": entity.description,
|
||||||
|
"properties": entity.properties,
|
||||||
|
"confidence": entity.confidence,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{kb_id}/entities/batch")
|
||||||
|
async def batch_create_entities(
|
||||||
|
kb_id: UUID,
|
||||||
|
entities: list[EntityCreateRequest],
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""批量创建知识图谱实体"""
|
||||||
|
if not entities:
|
||||||
|
raise HTTPException(status_code=400, detail="实体列表不能为空")
|
||||||
|
|
||||||
|
if len(entities) > 100:
|
||||||
|
raise HTTPException(status_code=400, detail="单次批量创建不能超过100个实体")
|
||||||
|
|
||||||
|
kb = await db.get(KnowledgeBase, kb_id)
|
||||||
|
if not kb:
|
||||||
|
raise HTTPException(status_code=404, detail="知识库不存在")
|
||||||
|
|
||||||
|
created = []
|
||||||
|
for entity_req in entities:
|
||||||
|
entity = KnowledgeEntity(
|
||||||
|
knowledge_base_id=kb_id,
|
||||||
|
name=entity_req.name,
|
||||||
|
entity_type=EntityType(entity_req.entity_type),
|
||||||
|
description=entity_req.description,
|
||||||
|
properties=entity_req.properties or {},
|
||||||
|
)
|
||||||
|
db.add(entity)
|
||||||
|
created.append(entity)
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
for entity in created:
|
||||||
|
await db.refresh(entity)
|
||||||
|
|
||||||
|
return {"created_count": len(created), "entities": [_entity_to_dict(e) for e in created]}
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{kb_id}/graph/build")
|
@router.post("/{kb_id}/graph/build")
|
||||||
async def build_graph(
|
async def build_graph(
|
||||||
kb_id: UUID,
|
kb_id: UUID,
|
||||||
|
|
@ -24,8 +82,6 @@ async def build_graph(
|
||||||
|
|
||||||
对知识库中的所有Chunks执行实体和关系抽取
|
对知识库中的所有Chunks执行实体和关系抽取
|
||||||
"""
|
"""
|
||||||
# TODO: 实现批量构建
|
|
||||||
# 目前先实现单个Chunk的构建
|
|
||||||
return {"message": "Use /graph/build-chunk to build from specific chunk"}
|
return {"message": "Use /graph/build-chunk to build from specific chunk"}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,12 +2,16 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
from sqlalchemy import select, func
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
from app.api.deps import get_current_user
|
||||||
|
from app.database import get_db
|
||||||
|
from app.models.platform_rule_version import PlatformRuleVersion
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
from app.services.distribution.platform_rules import (
|
from app.services.distribution.platform_rules import (
|
||||||
PLATFORM_RULES,
|
PLATFORM_RULES,
|
||||||
|
|
@ -48,6 +52,46 @@ logger = logging.getLogger(__name__)
|
||||||
router = APIRouter(prefix="/api/v1/platforms", tags=["平台规则管理"])
|
router = APIRouter(prefix="/api/v1/platforms", tags=["平台规则管理"])
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_rule_version(
|
||||||
|
db: AsyncSession, rule_id: str, version: int
|
||||||
|
) -> PlatformRuleVersion | None:
|
||||||
|
stmt = select(PlatformRuleVersion).where(
|
||||||
|
PlatformRuleVersion.rule_id == rule_id,
|
||||||
|
PlatformRuleVersion.version == version,
|
||||||
|
)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_diff(
|
||||||
|
old_data: dict, new_data: dict, prefix: str = ""
|
||||||
|
) -> list[RuleDiff]:
|
||||||
|
diffs: list[RuleDiff] = []
|
||||||
|
all_keys = set(old_data.keys()) | set(new_data.keys())
|
||||||
|
for key in sorted(all_keys):
|
||||||
|
field = f"{prefix}{key}" if not prefix else f"{prefix}.{key}"
|
||||||
|
old_val = old_data.get(key)
|
||||||
|
new_val = new_data.get(key)
|
||||||
|
if isinstance(old_val, dict) and isinstance(new_val, dict):
|
||||||
|
diffs.extend(_compute_diff(old_val, new_val, field))
|
||||||
|
elif old_val != new_val:
|
||||||
|
diffs.append(RuleDiff(field=field, old_value=old_val, new_value=new_val))
|
||||||
|
return diffs
|
||||||
|
|
||||||
|
|
||||||
|
def _version_to_dict(v: PlatformRuleVersion) -> dict:
|
||||||
|
return {
|
||||||
|
"id": v.id,
|
||||||
|
"rule_id": v.rule_id,
|
||||||
|
"platform": v.platform,
|
||||||
|
"version": v.version,
|
||||||
|
"rule_data": v.rule_data,
|
||||||
|
"change_summary": v.change_summary,
|
||||||
|
"created_by": v.created_by,
|
||||||
|
"created_at": v.created_at.isoformat() if v.created_at else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def _convert_rule_to_schema(rules: dict) -> dict:
|
def _convert_rule_to_schema(rules: dict) -> dict:
|
||||||
"""将规则字典转换为 Schema 格式"""
|
"""将规则字典转换为 Schema 格式"""
|
||||||
if not rules:
|
if not rules:
|
||||||
|
|
@ -179,13 +223,16 @@ async def update_platform_rules(
|
||||||
@router.get("/{platform_id}/rules/diff", response_model=RuleDiffResponse)
|
@router.get("/{platform_id}/rules/diff", response_model=RuleDiffResponse)
|
||||||
async def compare_rule_changes(
|
async def compare_rule_changes(
|
||||||
platform_id: str,
|
platform_id: str,
|
||||||
change_id: Optional[int] = Query(None, description="变更记录ID,用于对比历史版本"),
|
from_version: int = Query(..., description="起始版本号"),
|
||||||
|
to_version: int = Query(..., description="目标版本号"),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
"""对比规则变更
|
"""对比规则变更
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
platform_id: 平台标识
|
platform_id: 平台标识
|
||||||
change_id: 变更记录ID(可选)
|
from_version: 起始版本号
|
||||||
|
to_version: 目标版本号
|
||||||
"""
|
"""
|
||||||
if platform_id not in PLATFORM_RULES:
|
if platform_id not in PLATFORM_RULES:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -195,13 +242,21 @@ async def compare_rule_changes(
|
||||||
|
|
||||||
current_rules = PLATFORM_RULES[platform_id]
|
current_rules = PLATFORM_RULES[platform_id]
|
||||||
|
|
||||||
# TODO: 从数据库获取历史版本进行对比
|
from_rule = await _get_rule_version(db, platform_id, from_version)
|
||||||
# 目前返回空差异
|
to_rule = await _get_rule_version(db, platform_id, to_version)
|
||||||
|
|
||||||
|
if not from_rule or not to_rule:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="版本不存在",
|
||||||
|
)
|
||||||
|
|
||||||
|
diffs = _compute_diff(from_rule.rule_data, to_rule.rule_data)
|
||||||
return RuleDiffResponse(
|
return RuleDiffResponse(
|
||||||
platform_id=platform_id,
|
platform_id=platform_id,
|
||||||
platform_name=current_rules.get("name", ""),
|
platform_name=current_rules.get("name", ""),
|
||||||
diffs=[],
|
diffs=diffs,
|
||||||
total_changes=0,
|
total_changes=len(diffs),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -209,6 +264,7 @@ async def compare_rule_changes(
|
||||||
async def get_rule_history(
|
async def get_rule_history(
|
||||||
platform_id: str,
|
platform_id: str,
|
||||||
limit: int = Query(20, ge=1, le=100, description="返回记录数"),
|
limit: int = Query(20, ge=1, le=100, description="返回记录数"),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
"""获取规则变更历史
|
"""获取规则变更历史
|
||||||
|
|
||||||
|
|
@ -222,11 +278,39 @@ async def get_rule_history(
|
||||||
detail=f"平台不存在: {platform_id}",
|
detail=f"平台不存在: {platform_id}",
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: 从数据库获取历史记录
|
count_stmt = select(func.count()).select_from(PlatformRuleVersion).where(
|
||||||
# 目前返回空列表
|
PlatformRuleVersion.rule_id == platform_id
|
||||||
|
)
|
||||||
|
total = (await db.execute(count_stmt)).scalar() or 0
|
||||||
|
|
||||||
|
stmt = (
|
||||||
|
select(PlatformRuleVersion)
|
||||||
|
.where(PlatformRuleVersion.rule_id == platform_id)
|
||||||
|
.order_by(PlatformRuleVersion.version.desc())
|
||||||
|
.limit(limit)
|
||||||
|
)
|
||||||
|
result = await db.execute(stmt)
|
||||||
|
versions = result.scalars().all()
|
||||||
|
|
||||||
|
history = [
|
||||||
|
RuleChangeHistory(
|
||||||
|
id=v.version,
|
||||||
|
version=v.version,
|
||||||
|
platform_id=v.rule_id,
|
||||||
|
platform_name=v.platform,
|
||||||
|
changed_by=v.created_by or "",
|
||||||
|
change_summary=v.change_summary or "",
|
||||||
|
change_type="update",
|
||||||
|
previous_rules=None,
|
||||||
|
new_rules=v.rule_data,
|
||||||
|
created_at=v.created_at,
|
||||||
|
)
|
||||||
|
for v in versions
|
||||||
|
]
|
||||||
|
|
||||||
return RuleChangeHistoryResponse(
|
return RuleChangeHistoryResponse(
|
||||||
history=[],
|
history=history,
|
||||||
total=0,
|
total=total,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ class Settings(BaseSettings):
|
||||||
SECRET_KEY: Optional[str] = None
|
SECRET_KEY: Optional[str] = None
|
||||||
|
|
||||||
PLAYWRIGHT_BROWSERS_PATH: str = "/ms-playwright"
|
PLAYWRIGHT_BROWSERS_PATH: str = "/ms-playwright"
|
||||||
ENABLE_LLM: bool = False
|
ENABLE_LLM: bool = True
|
||||||
ZHIPU_API_KEY: str = ""
|
ZHIPU_API_KEY: str = ""
|
||||||
TONGYI_API_KEY: str = ""
|
TONGYI_API_KEY: str = ""
|
||||||
CORS_ORIGINS: str = "http://localhost:3000,http://localhost:3001"
|
CORS_ORIGINS: str = "http://localhost:3000,http://localhost:3001"
|
||||||
|
|
|
||||||
|
|
@ -168,7 +168,7 @@ app.include_router(onboarding_router, prefix="/api/v1")
|
||||||
app.include_router(platforms_router, prefix="/api/v1")
|
app.include_router(platforms_router, prefix="/api/v1")
|
||||||
app.include_router(platform_rules_router)
|
app.include_router(platform_rules_router)
|
||||||
app.include_router(image_router, prefix="/api/v1")
|
app.include_router(image_router, prefix="/api/v1")
|
||||||
app.include_router(knowledge_graph_router, prefix="/api/v1/knowledge-bases")
|
app.include_router(knowledge_graph_router, prefix="/api/v1")
|
||||||
app.include_router(ai_engines_router, prefix="/api/v1/ai-engines", tags=["AI引擎查询"])
|
app.include_router(ai_engines_router, prefix="/api/v1/ai-engines", tags=["AI引擎查询"])
|
||||||
app.include_router(detection_router, prefix="/api/v1/detection", tags=["定时检测任务"])
|
app.include_router(detection_router, prefix="/api/v1/detection", tags=["定时检测任务"])
|
||||||
app.include_router(api_keys_router, prefix="/api/v1/api-keys", tags=["API Key管理"])
|
app.include_router(api_keys_router, prefix="/api/v1/api-keys", tags=["API Key管理"])
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ from app.models.lifecycle import LifecycleProject, ProjectStage
|
||||||
from app.models.agent import AgentRegistry, AgentConfig, AgentTask, AgentTaskLog
|
from app.models.agent import AgentRegistry, AgentConfig, AgentTask, AgentTaskLog
|
||||||
from app.models.content import Content, ContentVersion, ContentReview
|
from app.models.content import Content, ContentVersion, ContentReview
|
||||||
from app.models.platform_rule import PlatformRule
|
from app.models.platform_rule import PlatformRule
|
||||||
|
from app.models.platform_rule_version import PlatformRuleVersion
|
||||||
from app.models.brand_knowledge import BrandKnowledge, Keyword
|
from app.models.brand_knowledge import BrandKnowledge, Keyword
|
||||||
from app.models.knowledge import (
|
from app.models.knowledge import (
|
||||||
KnowledgeBase,
|
KnowledgeBase,
|
||||||
|
|
@ -52,6 +53,7 @@ __all__ = [
|
||||||
"ContentVersion",
|
"ContentVersion",
|
||||||
"ContentReview",
|
"ContentReview",
|
||||||
"PlatformRule",
|
"PlatformRule",
|
||||||
|
"PlatformRuleVersion",
|
||||||
"BrandKnowledge",
|
"BrandKnowledge",
|
||||||
"Keyword",
|
"Keyword",
|
||||||
"KnowledgeBase",
|
"KnowledgeBase",
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,30 @@
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqlalchemy import String, Integer, Index, func
|
||||||
|
from sqlalchemy import Uuid
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
|
from app.database import Base, JSONType
|
||||||
|
|
||||||
|
|
||||||
|
class PlatformRuleVersion(Base):
|
||||||
|
__tablename__ = "platform_rule_versions"
|
||||||
|
|
||||||
|
id: Mapped[uuid.UUID] = mapped_column(
|
||||||
|
Uuid(as_uuid=True),
|
||||||
|
primary_key=True,
|
||||||
|
default=uuid.uuid4,
|
||||||
|
)
|
||||||
|
rule_id: Mapped[str] = mapped_column(String(100), nullable=False, index=True)
|
||||||
|
platform: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||||
|
version: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
|
rule_data: Mapped[dict] = mapped_column(JSONType, nullable=False)
|
||||||
|
change_summary: Mapped[str | None] = mapped_column(String(500), nullable=True)
|
||||||
|
created_by: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(server_default=func.now())
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("idx_rule_versions_rule_id", "rule_id"),
|
||||||
|
Index("idx_rule_versions_platform", "platform"),
|
||||||
|
)
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
"""平台规则管理 Schema - 定义规则管理的请求响应结构"""
|
"""平台规则管理 Schema - 定义规则管理的请求响应结构"""
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional
|
from typing import Any, Optional
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -225,6 +225,7 @@ class PlatformRuleUpdateResponse(BaseModel):
|
||||||
class RuleChangeHistory(BaseModel):
|
class RuleChangeHistory(BaseModel):
|
||||||
"""规则变更历史"""
|
"""规则变更历史"""
|
||||||
id: int
|
id: int
|
||||||
|
version: int = 0
|
||||||
platform_id: str
|
platform_id: str
|
||||||
platform_name: str
|
platform_name: str
|
||||||
changed_by: str
|
changed_by: str
|
||||||
|
|
@ -305,3 +306,7 @@ class DeAIContentResponse(BaseModel):
|
||||||
processed_word_count: int
|
processed_word_count: int
|
||||||
detected_ai_patterns: list[str] = []
|
detected_ai_patterns: list[str] = []
|
||||||
replaced_patterns: dict[str, str] = {}
|
replaced_patterns: dict[str, str] = {}
|
||||||
|
|
||||||
|
|
||||||
|
RuleDiff.model_rebuild()
|
||||||
|
RuleDiffResponse.model_rebuild()
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,6 @@ LLM适配器 - 使用DeepSeek LLM API检测品牌引用
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import random
|
|
||||||
import re
|
import re
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
@ -104,8 +103,14 @@ class LLMAdapter:
|
||||||
LLMAdapterError: API调用或解析失败
|
LLMAdapterError: API调用或解析失败
|
||||||
"""
|
"""
|
||||||
if not settings.ENABLE_LLM:
|
if not settings.ENABLE_LLM:
|
||||||
logger.info("LLM调用已禁用 (ENABLE_LLM=False),返回模拟数据")
|
raise LLMAdapterError(
|
||||||
return self._get_mock_result(keyword, brand_name, brand_aliases)
|
"LLM引用检测未启用。请在环境变量中设置 ENABLE_LLM=True 并配置 DEEPSEEK_API_KEY"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.api_key:
|
||||||
|
raise LLMAdapterError(
|
||||||
|
"未配置DeepSeek API Key。请设置 DEEPSEEK_API_KEY 环境变量"
|
||||||
|
)
|
||||||
|
|
||||||
prompt = self._build_prompt(keyword, brand_name, brand_aliases)
|
prompt = self._build_prompt(keyword, brand_name, brand_aliases)
|
||||||
|
|
||||||
|
|
@ -123,36 +128,6 @@ class LLMAdapter:
|
||||||
|
|
||||||
raise LLMAdapterError(f"LLM API调用失败,已重试{self.max_retries}次: {last_error}")
|
raise LLMAdapterError(f"LLM API调用失败,已重试{self.max_retries}次: {last_error}")
|
||||||
|
|
||||||
def _get_mock_result(
|
|
||||||
self,
|
|
||||||
keyword: str,
|
|
||||||
brand_name: str,
|
|
||||||
brand_aliases: list[str]
|
|
||||||
) -> CitationResult:
|
|
||||||
"""
|
|
||||||
生成模拟结果(当LLM禁用时使用)
|
|
||||||
|
|
||||||
随机决定是否引用,模拟真实场景的数据分布
|
|
||||||
"""
|
|
||||||
cited = random.random() < 0.6
|
|
||||||
sentiment_options = ["positive", "neutral", "negative"]
|
|
||||||
sentiment = random.choice(sentiment_options)
|
|
||||||
|
|
||||||
if cited:
|
|
||||||
position = random.randint(1, 10)
|
|
||||||
citation_text = f'模拟引用:在搜索"{keyword}"时,提到了{brand_name}品牌及其相关产品。'
|
|
||||||
else:
|
|
||||||
position = None
|
|
||||||
citation_text = ""
|
|
||||||
|
|
||||||
return CitationResult(
|
|
||||||
cited=cited,
|
|
||||||
position=position,
|
|
||||||
citation_text=citation_text,
|
|
||||||
sentiment=sentiment,
|
|
||||||
confidence=round(random.uniform(0.7, 0.99), 2)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _call_deepseek(self, prompt: str) -> dict:
|
async def _call_deepseek(self, prompt: str) -> dict:
|
||||||
"""
|
"""
|
||||||
调用DeepSeek API
|
调用DeepSeek API
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,278 @@
|
||||||
|
"""知识图谱批量构建API测试"""
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from httpx import AsyncClient, ASGITransport
|
||||||
|
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession, create_async_engine
|
||||||
|
from sqlalchemy.pool import StaticPool
|
||||||
|
|
||||||
|
from app.database import Base
|
||||||
|
from app.main import app
|
||||||
|
from app.api.deps import get_db, get_current_user
|
||||||
|
from app.models.user import User
|
||||||
|
from app.models.knowledge import KnowledgeBase
|
||||||
|
from app.models.organization import Organization
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def async_engine():
|
||||||
|
engine = create_async_engine(
|
||||||
|
"sqlite+aiosqlite:///:memory:",
|
||||||
|
connect_args={"check_same_thread": False},
|
||||||
|
poolclass=StaticPool,
|
||||||
|
)
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
yield engine
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def async_session(async_engine):
|
||||||
|
async_session_maker = async_sessionmaker(
|
||||||
|
async_engine,
|
||||||
|
class_=AsyncSession,
|
||||||
|
expire_on_commit=False,
|
||||||
|
autoflush=False,
|
||||||
|
autocommit=False,
|
||||||
|
)
|
||||||
|
async with async_session_maker() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def test_user(async_session):
|
||||||
|
user = User(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
email="test@example.com",
|
||||||
|
password_hash="hashed_password",
|
||||||
|
name="Test User",
|
||||||
|
plan="free",
|
||||||
|
max_queries=5,
|
||||||
|
is_active=True,
|
||||||
|
email_verified=True,
|
||||||
|
)
|
||||||
|
async_session.add(user)
|
||||||
|
await async_session.commit()
|
||||||
|
await async_session.refresh(user)
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def test_org(async_session):
|
||||||
|
org = Organization(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
name="Test Org",
|
||||||
|
slug="test-org",
|
||||||
|
)
|
||||||
|
async_session.add(org)
|
||||||
|
await async_session.commit()
|
||||||
|
await async_session.refresh(org)
|
||||||
|
return org
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def test_kb(async_session, test_org):
|
||||||
|
kb = KnowledgeBase(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
organization_id=test_org.id,
|
||||||
|
name="Test KB",
|
||||||
|
type="industry",
|
||||||
|
description="Test knowledge base",
|
||||||
|
)
|
||||||
|
async_session.add(kb)
|
||||||
|
await async_session.commit()
|
||||||
|
await async_session.refresh(kb)
|
||||||
|
return kb
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def async_client(async_session, test_user):
|
||||||
|
async def override_get_db():
|
||||||
|
yield async_session
|
||||||
|
|
||||||
|
async def override_get_current_user():
|
||||||
|
return test_user
|
||||||
|
|
||||||
|
app.dependency_overrides[get_db] = override_get_db
|
||||||
|
app.dependency_overrides[get_current_user] = override_get_current_user
|
||||||
|
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||||
|
yield client
|
||||||
|
|
||||||
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
|
||||||
|
BATCH_URL = "/api/v1/knowledge-bases/{kb_id}/entities/batch"
|
||||||
|
|
||||||
|
|
||||||
|
class TestBatchCreateEntitiesEmptyInput:
|
||||||
|
"""空输入验证测试"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_entities_list_returns_400(self, async_client, test_kb):
|
||||||
|
response = await async_client.post(
|
||||||
|
BATCH_URL.format(kb_id=test_kb.id),
|
||||||
|
json=[],
|
||||||
|
)
|
||||||
|
assert response.status_code == 400
|
||||||
|
data = response.json()
|
||||||
|
assert "detail" in data
|
||||||
|
assert "不能为空" in data["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestBatchCreateEntitiesSizeLimit:
|
||||||
|
"""批量大小限制测试"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_over_100_entities_returns_400(self, async_client, test_kb):
|
||||||
|
entities = [
|
||||||
|
{
|
||||||
|
"name": f"Entity {i}",
|
||||||
|
"entity_type": "CONCEPT",
|
||||||
|
"description": f"Test entity {i}",
|
||||||
|
}
|
||||||
|
for i in range(101)
|
||||||
|
]
|
||||||
|
response = await async_client.post(
|
||||||
|
BATCH_URL.format(kb_id=test_kb.id),
|
||||||
|
json=entities,
|
||||||
|
)
|
||||||
|
assert response.status_code == 400
|
||||||
|
data = response.json()
|
||||||
|
assert "detail" in data
|
||||||
|
assert "100" in data["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestBatchCreateEntitiesKBNotFound:
|
||||||
|
"""知识库不存在测试"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_nonexistent_kb_returns_404(self, async_client):
|
||||||
|
fake_kb_id = str(uuid.uuid4())
|
||||||
|
entities = [
|
||||||
|
{
|
||||||
|
"name": "Entity 1",
|
||||||
|
"entity_type": "CONCEPT",
|
||||||
|
"description": "Test",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
response = await async_client.post(
|
||||||
|
BATCH_URL.format(kb_id=fake_kb_id),
|
||||||
|
json=entities,
|
||||||
|
)
|
||||||
|
assert response.status_code == 404
|
||||||
|
data = response.json()
|
||||||
|
assert "detail" in data
|
||||||
|
assert "不存在" in data["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestBatchCreateEntitiesSuccess:
|
||||||
|
"""批量创建成功测试"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_batch_create_entities_success(self, async_client, test_kb):
|
||||||
|
entities = [
|
||||||
|
{
|
||||||
|
"name": "公司A",
|
||||||
|
"entity_type": "ORGANIZATION",
|
||||||
|
"description": "测试公司A",
|
||||||
|
"properties": {"industry": "科技"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "产品B",
|
||||||
|
"entity_type": "PRODUCT",
|
||||||
|
"description": "测试产品B",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
response = await async_client.post(
|
||||||
|
BATCH_URL.format(kb_id=test_kb.id),
|
||||||
|
json=entities,
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["created_count"] == 2
|
||||||
|
assert len(data["entities"]) == 2
|
||||||
|
|
||||||
|
for entity_data in data["entities"]:
|
||||||
|
assert "id" in entity_data
|
||||||
|
assert "name" in entity_data
|
||||||
|
assert "entity_type" in entity_data
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_batch_create_single_entity(self, async_client, test_kb):
|
||||||
|
entities = [
|
||||||
|
{
|
||||||
|
"name": "单个实体",
|
||||||
|
"entity_type": "PERSON",
|
||||||
|
"description": "测试单个实体",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
response = await async_client.post(
|
||||||
|
BATCH_URL.format(kb_id=test_kb.id),
|
||||||
|
json=entities,
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["created_count"] == 1
|
||||||
|
assert len(data["entities"]) == 1
|
||||||
|
assert data["entities"][0]["name"] == "单个实体"
|
||||||
|
assert data["entities"][0]["entity_type"] == "PERSON"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_batch_create_with_properties(self, async_client, test_kb):
|
||||||
|
entities = [
|
||||||
|
{
|
||||||
|
"name": "带属性实体",
|
||||||
|
"entity_type": "TECHNOLOGY",
|
||||||
|
"description": "测试带属性",
|
||||||
|
"properties": {"version": "1.0", "category": "AI"},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
response = await async_client.post(
|
||||||
|
BATCH_URL.format(kb_id=test_kb.id),
|
||||||
|
json=entities,
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["entities"][0]["properties"]["version"] == "1.0"
|
||||||
|
assert data["entities"][0]["properties"]["category"] == "AI"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_batch_create_without_properties_defaults_to_empty(
|
||||||
|
self, async_client, test_kb
|
||||||
|
):
|
||||||
|
entities = [
|
||||||
|
{
|
||||||
|
"name": "无属性实体",
|
||||||
|
"entity_type": "BRAND",
|
||||||
|
"description": "测试无属性",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
response = await async_client.post(
|
||||||
|
BATCH_URL.format(kb_id=test_kb.id),
|
||||||
|
json=entities,
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["entities"][0]["properties"] == {}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_batch_create_exactly_100_entities(self, async_client, test_kb):
|
||||||
|
entities = [
|
||||||
|
{
|
||||||
|
"name": f"Entity {i}",
|
||||||
|
"entity_type": "CONCEPT",
|
||||||
|
"description": f"Test entity {i}",
|
||||||
|
}
|
||||||
|
for i in range(100)
|
||||||
|
]
|
||||||
|
response = await async_client.post(
|
||||||
|
BATCH_URL.format(kb_id=test_kb.id),
|
||||||
|
json=entities,
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["created_count"] == 100
|
||||||
|
|
@ -0,0 +1,95 @@
|
||||||
|
"""
|
||||||
|
测试Knowledge API不再默认使用MockEmbedder
|
||||||
|
- 无OpenAI Key时API返回503+明确错误信息
|
||||||
|
- 有OpenAI Key时使用OpenAIEmbedder
|
||||||
|
- MockEmbedder不再作为默认选择
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
from app.services.knowledge.embedder import MockEmbedder, OpenAIEmbedder
|
||||||
|
from app.services.knowledge.rag_service import RAGService
|
||||||
|
from app.services.api_key_manager import APIKeyManager
|
||||||
|
|
||||||
|
|
||||||
|
class TestKnowledgeAPINoMockEmbedder:
|
||||||
|
"""验证knowledge.py不再默认使用MockEmbedder"""
|
||||||
|
|
||||||
|
def test_get_rag_service_raises_without_openai_key(self):
|
||||||
|
"""无OpenAI Key时_get_rag_service必须抛出HTTPException"""
|
||||||
|
from app.api.knowledge import _get_rag_service
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
key_manager = APIKeyManager()
|
||||||
|
with patch("app.api.knowledge._key_manager", key_manager):
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
_get_rag_service()
|
||||||
|
assert exc_info.value.status_code == 503
|
||||||
|
assert "OpenAI API Key" in exc_info.value.detail
|
||||||
|
|
||||||
|
def test_get_rag_service_returns_openai_embedder_with_key(self):
|
||||||
|
"""有OpenAI Key时_get_rag_service必须返回使用OpenAIEmbedder的RAGService"""
|
||||||
|
from app.api.knowledge import _get_rag_service
|
||||||
|
|
||||||
|
key_manager = APIKeyManager()
|
||||||
|
key_manager.add_key("chatgpt", "sk-test-key-1234567890", source="system")
|
||||||
|
|
||||||
|
with patch("app.api.knowledge._key_manager", key_manager):
|
||||||
|
rag_service = _get_rag_service()
|
||||||
|
assert isinstance(rag_service, RAGService)
|
||||||
|
assert isinstance(rag_service.embedder, OpenAIEmbedder)
|
||||||
|
|
||||||
|
def test_get_rag_service_never_returns_mock_embedder(self):
|
||||||
|
"""_get_rag_service绝不能返回MockEmbedder"""
|
||||||
|
from app.api.knowledge import _get_rag_service
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
key_manager = APIKeyManager()
|
||||||
|
with patch("app.api.knowledge._key_manager", key_manager):
|
||||||
|
with pytest.raises(HTTPException):
|
||||||
|
_get_rag_service()
|
||||||
|
|
||||||
|
key_manager_with_key = APIKeyManager()
|
||||||
|
key_manager_with_key.add_key("chatgpt", "sk-test-key-1234567890", source="system")
|
||||||
|
with patch("app.api.knowledge._key_manager", key_manager_with_key):
|
||||||
|
rag_service = _get_rag_service()
|
||||||
|
assert not isinstance(rag_service.embedder, MockEmbedder)
|
||||||
|
|
||||||
|
def test_no_module_level_mock_rag_service(self):
|
||||||
|
"""模块级别不再存在使用MockEmbedder的_rag_service变量"""
|
||||||
|
import app.api.knowledge as knowledge_module
|
||||||
|
|
||||||
|
assert not hasattr(knowledge_module, "_rag_service"), (
|
||||||
|
"_rag_service模块级变量仍然存在,必须删除"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_error_message_contains_configuration_guidance(self):
|
||||||
|
"""503错误信息必须包含配置指引"""
|
||||||
|
from app.api.knowledge import _get_rag_service
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
key_manager = APIKeyManager()
|
||||||
|
with patch("app.api.knowledge._key_manager", key_manager):
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
_get_rag_service()
|
||||||
|
detail = exc_info.value.detail
|
||||||
|
assert "OpenAI API Key" in detail
|
||||||
|
assert "设置" in detail or "配置" in detail
|
||||||
|
|
||||||
|
def test_mock_embedder_class_still_exists(self):
|
||||||
|
"""MockEmbedder类必须保留(仅用于测试)"""
|
||||||
|
assert MockEmbedder is not None
|
||||||
|
embedder = MockEmbedder()
|
||||||
|
assert isinstance(embedder, MockEmbedder)
|
||||||
|
|
||||||
|
def test_get_rag_service_uses_api_key_manager(self):
|
||||||
|
"""_get_rag_service必须使用APIKeyManager获取Key"""
|
||||||
|
from app.api.knowledge import _get_rag_service
|
||||||
|
|
||||||
|
mock_km = MagicMock(spec=APIKeyManager)
|
||||||
|
mock_km.get_key.return_value = "sk-test-key-1234567890"
|
||||||
|
|
||||||
|
with patch("app.api.knowledge._key_manager", mock_km):
|
||||||
|
rag_service = _get_rag_service()
|
||||||
|
mock_km.get_key.assert_called_once_with("chatgpt")
|
||||||
|
assert isinstance(rag_service.embedder, OpenAIEmbedder)
|
||||||
|
|
@ -0,0 +1,252 @@
|
||||||
|
"""平台规则历史版本API测试"""
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from httpx import AsyncClient, ASGITransport
|
||||||
|
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession, create_async_engine
|
||||||
|
from sqlalchemy.pool import StaticPool
|
||||||
|
|
||||||
|
from app.database import Base
|
||||||
|
from app.main import app
|
||||||
|
from app.models.platform_rule_version import PlatformRuleVersion
|
||||||
|
from app.api.deps import get_db
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def async_engine():
|
||||||
|
engine = create_async_engine(
|
||||||
|
"sqlite+aiosqlite:///:memory:",
|
||||||
|
connect_args={"check_same_thread": False},
|
||||||
|
poolclass=StaticPool,
|
||||||
|
)
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
yield engine
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def async_session(async_engine):
|
||||||
|
async_session_maker = async_sessionmaker(
|
||||||
|
async_engine,
|
||||||
|
class_=AsyncSession,
|
||||||
|
expire_on_commit=False,
|
||||||
|
autoflush=False,
|
||||||
|
autocommit=False,
|
||||||
|
)
|
||||||
|
async with async_session_maker() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def async_client(async_session):
|
||||||
|
async def override_get_db():
|
||||||
|
yield async_session
|
||||||
|
|
||||||
|
app.dependency_overrides[get_db] = override_get_db
|
||||||
|
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||||
|
yield client
|
||||||
|
|
||||||
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def seed_versions(async_session):
|
||||||
|
v1 = PlatformRuleVersion(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
rule_id="zhihu",
|
||||||
|
platform="zhihu",
|
||||||
|
version=1,
|
||||||
|
rule_data={
|
||||||
|
"content_length": {"min": 800, "max": 3000, "recommended": 1500},
|
||||||
|
"title_rules": {"min_length": 5, "max_length": 50},
|
||||||
|
},
|
||||||
|
change_summary="初始版本",
|
||||||
|
created_by="admin",
|
||||||
|
)
|
||||||
|
v2 = PlatformRuleVersion(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
rule_id="zhihu",
|
||||||
|
platform="zhihu",
|
||||||
|
version=2,
|
||||||
|
rule_data={
|
||||||
|
"content_length": {"min": 1000, "max": 5000, "recommended": 2000},
|
||||||
|
"title_rules": {"min_length": 5, "max_length": 50},
|
||||||
|
},
|
||||||
|
change_summary="调整内容长度规则",
|
||||||
|
created_by="admin",
|
||||||
|
)
|
||||||
|
v3 = PlatformRuleVersion(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
rule_id="zhihu",
|
||||||
|
platform="zhihu",
|
||||||
|
version=3,
|
||||||
|
rule_data={
|
||||||
|
"content_length": {"min": 1000, "max": 5000, "recommended": 2000},
|
||||||
|
"title_rules": {"min_length": 8, "max_length": 60},
|
||||||
|
},
|
||||||
|
change_summary="调整标题长度规则",
|
||||||
|
created_by="editor",
|
||||||
|
)
|
||||||
|
async_session.add_all([v1, v2, v3])
|
||||||
|
await async_session.commit()
|
||||||
|
return [v1, v2, v3]
|
||||||
|
|
||||||
|
|
||||||
|
class TestRuleVersionDiff:
|
||||||
|
"""历史版本对比API测试"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_diff_returns_differences_between_versions(
|
||||||
|
self, async_client, seed_versions
|
||||||
|
):
|
||||||
|
response = await async_client.get(
|
||||||
|
"/api/v1/platforms/zhihu/rules/diff",
|
||||||
|
params={"from_version": 1, "to_version": 2},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["platform_id"] == "zhihu"
|
||||||
|
assert isinstance(data["diffs"], list)
|
||||||
|
assert len(data["diffs"]) > 0
|
||||||
|
assert data["total_changes"] > 0
|
||||||
|
|
||||||
|
diff_fields = {d["field"] for d in data["diffs"]}
|
||||||
|
assert "content_length.min" in diff_fields
|
||||||
|
assert "content_length.max" in diff_fields
|
||||||
|
assert "content_length.recommended" in diff_fields
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_diff_no_changes_returns_empty(self, async_client, seed_versions):
|
||||||
|
response = await async_client.get(
|
||||||
|
"/api/v1/platforms/zhihu/rules/diff",
|
||||||
|
params={"from_version": 2, "to_version": 2},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["diffs"] == []
|
||||||
|
assert data["total_changes"] == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_diff_nonexistent_version_returns_404(
|
||||||
|
self, async_client, seed_versions
|
||||||
|
):
|
||||||
|
response = await async_client.get(
|
||||||
|
"/api/v1/platforms/zhihu/rules/diff",
|
||||||
|
params={"from_version": 1, "to_version": 99},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_diff_nonexistent_platform_returns_404(self, async_client):
|
||||||
|
response = await async_client.get(
|
||||||
|
"/api/v1/platforms/nonexistent_platform/rules/diff",
|
||||||
|
params={"from_version": 1, "to_version": 2},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_diff_missing_params_returns_error(self, async_client, seed_versions):
|
||||||
|
response = await async_client.get(
|
||||||
|
"/api/v1/platforms/zhihu/rules/diff",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_diff_nested_field_change(self, async_client, seed_versions):
|
||||||
|
response = await async_client.get(
|
||||||
|
"/api/v1/platforms/zhihu/rules/diff",
|
||||||
|
params={"from_version": 2, "to_version": 3},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
diff_fields = {d["field"] for d in data["diffs"]}
|
||||||
|
assert "title_rules.min_length" in diff_fields
|
||||||
|
assert "title_rules.max_length" in diff_fields
|
||||||
|
|
||||||
|
|
||||||
|
class TestRuleHistory:
|
||||||
|
"""历史记录查询API测试"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_history_returns_versions(
|
||||||
|
self, async_client, seed_versions
|
||||||
|
):
|
||||||
|
response = await async_client.get(
|
||||||
|
"/api/v1/platforms/zhihu/rules/history",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["total"] == 3
|
||||||
|
assert len(data["history"]) == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_history_ordered_by_version_desc(
|
||||||
|
self, async_client, seed_versions
|
||||||
|
):
|
||||||
|
response = await async_client.get(
|
||||||
|
"/api/v1/platforms/zhihu/rules/history",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
versions = [h["version"] for h in data["history"]]
|
||||||
|
assert versions == sorted(versions, reverse=True)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_history_respects_limit(self, async_client, seed_versions):
|
||||||
|
response = await async_client.get(
|
||||||
|
"/api/v1/platforms/zhihu/rules/history",
|
||||||
|
params={"limit": 2},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert len(data["history"]) == 2
|
||||||
|
assert data["total"] == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_history_empty_when_no_versions(self, async_client):
|
||||||
|
response = await async_client.get(
|
||||||
|
"/api/v1/platforms/zhihu/rules/history",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["total"] == 0
|
||||||
|
assert data["history"] == []
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_history_nonexistent_platform_returns_404(self, async_client):
|
||||||
|
response = await async_client.get(
|
||||||
|
"/api/v1/platforms/nonexistent_platform/rules/history",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_history_version_has_required_fields(
|
||||||
|
self, async_client, seed_versions
|
||||||
|
):
|
||||||
|
response = await async_client.get(
|
||||||
|
"/api/v1/platforms/zhihu/rules/history",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
first = data["history"][0]
|
||||||
|
assert "version" in first
|
||||||
|
assert "new_rules" in first
|
||||||
|
assert "change_summary" in first
|
||||||
|
assert "changed_by" in first
|
||||||
|
assert "created_at" in first
|
||||||
|
|
@ -22,20 +22,22 @@ class TestLLMAdapter:
|
||||||
"confidence": 0.95
|
"confidence": 0.95
|
||||||
}
|
}
|
||||||
|
|
||||||
with patch.object(llm_adapter, '_call_deepseek', new_callable=AsyncMock) as mock_call:
|
with patch("app.workers.llm_adapter.settings") as mock_settings:
|
||||||
mock_call.return_value = mock_response
|
mock_settings.ENABLE_LLM = True
|
||||||
|
with patch.object(llm_adapter, '_call_deepseek', new_callable=AsyncMock) as mock_call:
|
||||||
|
mock_call.return_value = mock_response
|
||||||
|
|
||||||
result = await llm_adapter.query_brand_citation(
|
result = await llm_adapter.query_brand_citation(
|
||||||
keyword="AI搜索",
|
keyword="AI搜索",
|
||||||
brand_name="XXX",
|
brand_name="XXX",
|
||||||
brand_aliases=["品牌别名1", "品牌别名2"]
|
brand_aliases=["品牌别名1", "品牌别名2"]
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result.cited is True
|
assert result.cited is True
|
||||||
assert result.position == 1
|
assert result.position == 1
|
||||||
assert result.citation_text == "XXX是一款非常优秀的品牌产品"
|
assert result.citation_text == "XXX是一款非常优秀的品牌产品"
|
||||||
assert result.sentiment == "positive"
|
assert result.sentiment == "positive"
|
||||||
assert result.confidence == 0.95
|
assert result.confidence == 0.95
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_llm_adapter_not_cited(self, llm_adapter):
|
async def test_llm_adapter_not_cited(self, llm_adapter):
|
||||||
|
|
@ -48,19 +50,21 @@ class TestLLMAdapter:
|
||||||
"confidence": 0.90
|
"confidence": 0.90
|
||||||
}
|
}
|
||||||
|
|
||||||
with patch.object(llm_adapter, '_call_deepseek', new_callable=AsyncMock) as mock_call:
|
with patch("app.workers.llm_adapter.settings") as mock_settings:
|
||||||
mock_call.return_value = mock_response
|
mock_settings.ENABLE_LLM = True
|
||||||
|
with patch.object(llm_adapter, '_call_deepseek', new_callable=AsyncMock) as mock_call:
|
||||||
|
mock_call.return_value = mock_response
|
||||||
|
|
||||||
result = await llm_adapter.query_brand_citation(
|
result = await llm_adapter.query_brand_citation(
|
||||||
keyword="AI搜索",
|
keyword="AI搜索",
|
||||||
brand_name="YYY",
|
brand_name="YYY",
|
||||||
brand_aliases=[]
|
brand_aliases=[]
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result.cited is False
|
assert result.cited is False
|
||||||
assert result.position is None
|
assert result.position is None
|
||||||
assert result.citation_text is None
|
assert result.citation_text is None
|
||||||
assert result.sentiment == "neutral"
|
assert result.sentiment == "neutral"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_llm_adapter_sentiment_positive(self, llm_adapter):
|
async def test_llm_adapter_sentiment_positive(self, llm_adapter):
|
||||||
|
|
@ -73,16 +77,18 @@ class TestLLMAdapter:
|
||||||
"confidence": 0.92
|
"confidence": 0.92
|
||||||
}
|
}
|
||||||
|
|
||||||
with patch.object(llm_adapter, '_call_deepseek', new_callable=AsyncMock) as mock_call:
|
with patch("app.workers.llm_adapter.settings") as mock_settings:
|
||||||
mock_call.return_value = mock_response
|
mock_settings.ENABLE_LLM = True
|
||||||
|
with patch.object(llm_adapter, '_call_deepseek', new_callable=AsyncMock) as mock_call:
|
||||||
|
mock_call.return_value = mock_response
|
||||||
|
|
||||||
result = await llm_adapter.query_brand_citation(
|
result = await llm_adapter.query_brand_citation(
|
||||||
keyword="AI搜索",
|
keyword="AI搜索",
|
||||||
brand_name="YYY",
|
brand_name="YYY",
|
||||||
brand_aliases=[]
|
brand_aliases=[]
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result.sentiment == "positive"
|
assert result.sentiment == "positive"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_llm_adapter_sentiment_negative(self, llm_adapter):
|
async def test_llm_adapter_sentiment_negative(self, llm_adapter):
|
||||||
|
|
@ -95,16 +101,18 @@ class TestLLMAdapter:
|
||||||
"confidence": 0.88
|
"confidence": 0.88
|
||||||
}
|
}
|
||||||
|
|
||||||
with patch.object(llm_adapter, '_call_deepseek', new_callable=AsyncMock) as mock_call:
|
with patch("app.workers.llm_adapter.settings") as mock_settings:
|
||||||
mock_call.return_value = mock_response
|
mock_settings.ENABLE_LLM = True
|
||||||
|
with patch.object(llm_adapter, '_call_deepseek', new_callable=AsyncMock) as mock_call:
|
||||||
|
mock_call.return_value = mock_response
|
||||||
|
|
||||||
result = await llm_adapter.query_brand_citation(
|
result = await llm_adapter.query_brand_citation(
|
||||||
keyword="AI搜索",
|
keyword="AI搜索",
|
||||||
brand_name="ZZZ",
|
brand_name="ZZZ",
|
||||||
brand_aliases=[]
|
brand_aliases=[]
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result.sentiment == "negative"
|
assert result.sentiment == "negative"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_llm_adapter_api_error_retry(self, llm_adapter):
|
async def test_llm_adapter_api_error_retry(self, llm_adapter):
|
||||||
|
|
@ -117,39 +125,41 @@ class TestLLMAdapter:
|
||||||
"confidence": 0.90
|
"confidence": 0.90
|
||||||
}
|
}
|
||||||
|
|
||||||
with patch.object(llm_adapter, '_call_deepseek', new_callable=AsyncMock) as mock_call:
|
with patch("app.workers.llm_adapter.settings") as mock_settings:
|
||||||
# 模拟前两次失败,第三次成功
|
mock_settings.ENABLE_LLM = True
|
||||||
mock_call.side_effect = [
|
with patch.object(llm_adapter, '_call_deepseek', new_callable=AsyncMock) as mock_call:
|
||||||
Exception("API调用失败"),
|
mock_call.side_effect = [
|
||||||
Exception("API调用失败"),
|
Exception("API调用失败"),
|
||||||
mock_success_response
|
Exception("API调用失败"),
|
||||||
]
|
mock_success_response
|
||||||
|
]
|
||||||
|
|
||||||
result = await llm_adapter.query_brand_citation(
|
result = await llm_adapter.query_brand_citation(
|
||||||
keyword="AI搜索",
|
|
||||||
brand_name="测试品牌",
|
|
||||||
brand_aliases=[]
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.cited is True
|
|
||||||
assert mock_call.call_count == 3
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_llm_adapter_parse_error(self, llm_adapter):
|
|
||||||
"""测试响应解析错误"""
|
|
||||||
with patch.object(llm_adapter, '_call_deepseek', new_callable=AsyncMock) as mock_call:
|
|
||||||
mock_call.return_value = {"invalid": "response"}
|
|
||||||
|
|
||||||
with pytest.raises(LLMAdapterError) as exc_info:
|
|
||||||
await llm_adapter.query_brand_citation(
|
|
||||||
keyword="AI搜索",
|
keyword="AI搜索",
|
||||||
brand_name="测试品牌",
|
brand_name="测试品牌",
|
||||||
brand_aliases=[]
|
brand_aliases=[]
|
||||||
)
|
)
|
||||||
|
|
||||||
# 错误消息应该包含字段缺失或解析失败相关提示
|
assert result.cited is True
|
||||||
error_msg = str(exc_info.value)
|
assert mock_call.call_count == 3
|
||||||
assert "响应缺少必需字段" in error_msg or "解析响应失败" in error_msg
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_llm_adapter_parse_error(self, llm_adapter):
|
||||||
|
"""测试响应解析错误"""
|
||||||
|
with patch("app.workers.llm_adapter.settings") as mock_settings:
|
||||||
|
mock_settings.ENABLE_LLM = True
|
||||||
|
with patch.object(llm_adapter, '_call_deepseek', new_callable=AsyncMock) as mock_call:
|
||||||
|
mock_call.return_value = {"invalid": "response"}
|
||||||
|
|
||||||
|
with pytest.raises(LLMAdapterError) as exc_info:
|
||||||
|
await llm_adapter.query_brand_citation(
|
||||||
|
keyword="AI搜索",
|
||||||
|
brand_name="测试品牌",
|
||||||
|
brand_aliases=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
error_msg = str(exc_info.value)
|
||||||
|
assert "响应缺少必需字段" in error_msg or "解析响应失败" in error_msg
|
||||||
|
|
||||||
def test_build_prompt(self, llm_adapter):
|
def test_build_prompt(self, llm_adapter):
|
||||||
"""测试Prompt构建"""
|
"""测试Prompt构建"""
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,121 @@
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, patch, PropertyMock
|
||||||
|
|
||||||
|
from app.workers.llm_adapter import LLMAdapter, LLMAdapterError
|
||||||
|
|
||||||
|
|
||||||
|
class TestLLMAdapterNoMock:
|
||||||
|
"""验证LLMAdapter不再返回Mock数据,而是抛出明确错误"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def adapter(self):
|
||||||
|
return LLMAdapter()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_enable_llm_false_raises_error(self, adapter):
|
||||||
|
"""ENABLE_LLM=False时必须抛出LLMAdapterError,而非返回Mock数据"""
|
||||||
|
with patch("app.workers.llm_adapter.settings") as mock_settings:
|
||||||
|
mock_settings.ENABLE_LLM = False
|
||||||
|
mock_settings.DEEPSEEK_API_KEY = "test-key"
|
||||||
|
|
||||||
|
with pytest.raises(LLMAdapterError) as exc_info:
|
||||||
|
await adapter.query_brand_citation(
|
||||||
|
keyword="AI搜索",
|
||||||
|
brand_name="测试品牌",
|
||||||
|
brand_aliases=["别名1"],
|
||||||
|
)
|
||||||
|
|
||||||
|
error_msg = str(exc_info.value)
|
||||||
|
assert "ENABLE_LLM" in error_msg
|
||||||
|
assert "未启用" in error_msg
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_enable_llm_true_no_api_key_raises_error(self, adapter):
|
||||||
|
"""ENABLE_LLM=True但无API Key时必须抛出LLMAdapterError"""
|
||||||
|
adapter.api_key = None
|
||||||
|
|
||||||
|
with patch("app.workers.llm_adapter.settings") as mock_settings:
|
||||||
|
mock_settings.ENABLE_LLM = True
|
||||||
|
|
||||||
|
with pytest.raises(LLMAdapterError) as exc_info:
|
||||||
|
await adapter.query_brand_citation(
|
||||||
|
keyword="AI搜索",
|
||||||
|
brand_name="测试品牌",
|
||||||
|
brand_aliases=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
error_msg = str(exc_info.value)
|
||||||
|
assert "API Key" in error_msg or "DEEPSEEK_API_KEY" in error_msg
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_enable_llm_true_with_key_calls_api(self, adapter):
|
||||||
|
"""ENABLE_LLM=True且有Key时正常调用API"""
|
||||||
|
mock_response = {
|
||||||
|
"cited": True,
|
||||||
|
"position": 1,
|
||||||
|
"citation_text": "测试引用",
|
||||||
|
"sentiment": "positive",
|
||||||
|
"confidence": 0.95,
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch("app.workers.llm_adapter.settings") as mock_settings:
|
||||||
|
mock_settings.ENABLE_LLM = True
|
||||||
|
mock_settings.OPENAI_API_KEY = None
|
||||||
|
mock_settings.DEEPSEEK_API_KEY = "sk-test-key"
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
adapter, "_call_deepseek", new_callable=AsyncMock
|
||||||
|
) as mock_call:
|
||||||
|
mock_call.return_value = mock_response
|
||||||
|
|
||||||
|
result = await adapter.query_brand_citation(
|
||||||
|
keyword="AI搜索",
|
||||||
|
brand_name="测试品牌",
|
||||||
|
brand_aliases=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.cited is True
|
||||||
|
assert result.position == 1
|
||||||
|
assert result.sentiment == "positive"
|
||||||
|
|
||||||
|
def test_get_mock_result_method_removed(self):
|
||||||
|
"""_get_mock_result方法必须已被删除"""
|
||||||
|
assert not hasattr(LLMAdapter, "_get_mock_result"), (
|
||||||
|
"_get_mock_result方法仍然存在,必须删除"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_error_message_user_friendly(self, adapter):
|
||||||
|
"""错误信息必须对用户友好,包含配置指引"""
|
||||||
|
with patch("app.workers.llm_adapter.settings") as mock_settings:
|
||||||
|
mock_settings.ENABLE_LLM = False
|
||||||
|
mock_settings.DEEPSEEK_API_KEY = "test-key"
|
||||||
|
|
||||||
|
with pytest.raises(LLMAdapterError) as exc_info:
|
||||||
|
await adapter.query_brand_citation(
|
||||||
|
keyword="AI搜索",
|
||||||
|
brand_name="测试品牌",
|
||||||
|
brand_aliases=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
error_msg = str(exc_info.value)
|
||||||
|
assert "ENABLE_LLM=True" in error_msg
|
||||||
|
assert "DEEPSEEK_API_KEY" in error_msg
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_api_key_error_message_user_friendly(self, adapter):
|
||||||
|
"""无API Key时错误信息必须包含配置指引"""
|
||||||
|
adapter.api_key = None
|
||||||
|
|
||||||
|
with patch("app.workers.llm_adapter.settings") as mock_settings:
|
||||||
|
mock_settings.ENABLE_LLM = True
|
||||||
|
|
||||||
|
with pytest.raises(LLMAdapterError) as exc_info:
|
||||||
|
await adapter.query_brand_citation(
|
||||||
|
keyword="AI搜索",
|
||||||
|
brand_name="测试品牌",
|
||||||
|
brand_aliases=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
error_msg = str(exc_info.value)
|
||||||
|
assert "DEEPSEEK_API_KEY" in error_msg
|
||||||
|
|
@ -31,7 +31,7 @@ import {
|
||||||
Zap,
|
Zap,
|
||||||
} from "lucide-react";
|
} from "lucide-react";
|
||||||
import { useApi, useApiMutation } from "@/lib/hooks/use-api";
|
import { useApi, useApiMutation } from "@/lib/hooks/use-api";
|
||||||
import { MOCK_AI_ENGINES_RESPONSE } from "@/lib/api/ai-engines";
|
|
||||||
import type {
|
import type {
|
||||||
AIEngineType,
|
AIEngineType,
|
||||||
AIQueryResult,
|
AIQueryResult,
|
||||||
|
|
@ -446,10 +446,12 @@ export default function AIEnginesPage() {
|
||||||
if (result) {
|
if (result) {
|
||||||
setQueryResults(result);
|
setQueryResults(result);
|
||||||
} else {
|
} else {
|
||||||
setQueryResults(MOCK_AI_ENGINES_RESPONSE);
|
setQueryError("查询返回空结果,请检查API Key配置");
|
||||||
|
setQueryResults(null);
|
||||||
}
|
}
|
||||||
} catch {
|
} catch (err) {
|
||||||
setQueryResults(MOCK_AI_ENGINES_RESPONSE);
|
setQueryError(err instanceof Error ? err.message : "查询失败,请检查API Key配置");
|
||||||
|
setQueryResults(null);
|
||||||
}
|
}
|
||||||
}, [selectedBrandId, queryText, selectedEngines, queryMutation]);
|
}, [selectedBrandId, queryText, selectedEngines, queryMutation]);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -47,72 +47,6 @@ import {
|
||||||
DropdownMenuTrigger,
|
DropdownMenuTrigger,
|
||||||
} from "@/components/ui/dropdown-menu";
|
} from "@/components/ui/dropdown-menu";
|
||||||
|
|
||||||
const MOCK_ORG_INFO: OrganizationInfo = {
|
|
||||||
id: "org-1",
|
|
||||||
name: "GEO科技有限公司",
|
|
||||||
member_count: 5,
|
|
||||||
created_at: "2024-01-15T08:00:00Z",
|
|
||||||
updated_at: "2024-01-15T08:00:00Z",
|
|
||||||
};
|
|
||||||
|
|
||||||
const MOCK_MEMBERS: OrganizationMember[] = [
|
|
||||||
{
|
|
||||||
id: "member-1",
|
|
||||||
user_id: "user-1",
|
|
||||||
name: "张三",
|
|
||||||
email: "zhangsan@example.com",
|
|
||||||
role: "admin",
|
|
||||||
status: "active",
|
|
||||||
joined_at: "2024-01-15T08:00:00Z",
|
|
||||||
created_at: "2024-01-15T08:00:00Z",
|
|
||||||
updated_at: "2024-01-15T08:00:00Z",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
id: "member-2",
|
|
||||||
user_id: "user-2",
|
|
||||||
name: "李四",
|
|
||||||
email: "lisi@example.com",
|
|
||||||
role: "member",
|
|
||||||
status: "active",
|
|
||||||
joined_at: "2024-02-10T10:30:00Z",
|
|
||||||
created_at: "2024-02-10T10:30:00Z",
|
|
||||||
updated_at: "2024-02-10T10:30:00Z",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
id: "member-3",
|
|
||||||
user_id: "user-3",
|
|
||||||
name: "王五",
|
|
||||||
email: "wangwu@example.com",
|
|
||||||
role: "viewer",
|
|
||||||
status: "active",
|
|
||||||
joined_at: "2024-03-05T14:20:00Z",
|
|
||||||
created_at: "2024-03-05T14:20:00Z",
|
|
||||||
updated_at: "2024-03-05T14:20:00Z",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
id: "member-4",
|
|
||||||
user_id: "user-4",
|
|
||||||
name: "赵六",
|
|
||||||
email: "zhaoliu@example.com",
|
|
||||||
role: "member",
|
|
||||||
status: "pending",
|
|
||||||
joined_at: "2024-03-20T09:15:00Z",
|
|
||||||
created_at: "2024-03-20T09:15:00Z",
|
|
||||||
updated_at: "2024-03-20T09:15:00Z",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
id: "member-5",
|
|
||||||
user_id: "user-5",
|
|
||||||
name: "孙七",
|
|
||||||
email: "sunqi@example.com",
|
|
||||||
role: "viewer",
|
|
||||||
status: "inactive",
|
|
||||||
joined_at: "2024-01-20T16:45:00Z",
|
|
||||||
created_at: "2024-01-20T16:45:00Z",
|
|
||||||
updated_at: "2024-01-20T16:45:00Z",
|
|
||||||
},
|
|
||||||
];
|
|
||||||
|
|
||||||
const roleConfig: Record<MemberRole, { label: string; icon: React.ReactNode; color: string }> = {
|
const roleConfig: Record<MemberRole, { label: string; icon: React.ReactNode; color: string }> = {
|
||||||
admin: {
|
admin: {
|
||||||
label: "管理员",
|
label: "管理员",
|
||||||
|
|
@ -185,7 +119,7 @@ export default function ClientsPage() {
|
||||||
} = useApi<OrganizationMember[]>("/api/v1/organization/members");
|
} = useApi<OrganizationMember[]>("/api/v1/organization/members");
|
||||||
|
|
||||||
const filteredMembers = useMemo(() => {
|
const filteredMembers = useMemo(() => {
|
||||||
const memberList = members || MOCK_MEMBERS;
|
const memberList = members || [];
|
||||||
return memberList.filter((member) => {
|
return memberList.filter((member) => {
|
||||||
const matchesSearch =
|
const matchesSearch =
|
||||||
!searchQuery ||
|
!searchQuery ||
|
||||||
|
|
@ -196,7 +130,7 @@ export default function ClientsPage() {
|
||||||
});
|
});
|
||||||
}, [members, searchQuery, roleFilter]);
|
}, [members, searchQuery, roleFilter]);
|
||||||
|
|
||||||
const safeOrgInfo = orgInfo || MOCK_ORG_INFO;
|
const safeOrgInfo = orgInfo ?? null;
|
||||||
const loading = orgLoading || membersLoading;
|
const loading = orgLoading || membersLoading;
|
||||||
|
|
||||||
const handleInvite = async () => {
|
const handleInvite = async () => {
|
||||||
|
|
@ -297,26 +231,30 @@ export default function ClientsPage() {
|
||||||
</CardTitle>
|
</CardTitle>
|
||||||
</CardHeader>
|
</CardHeader>
|
||||||
<CardContent>
|
<CardContent>
|
||||||
<div className="grid gap-4 md:grid-cols-3">
|
{safeOrgInfo ? (
|
||||||
<div>
|
<div className="grid gap-4 md:grid-cols-3">
|
||||||
<p className="text-sm text-gray-500">组织名称</p>
|
<div>
|
||||||
<p className="mt-1 text-lg font-semibold text-gray-900">
|
<p className="text-sm text-gray-500">组织名称</p>
|
||||||
{safeOrgInfo.name}
|
<p className="mt-1 text-lg font-semibold text-gray-900">
|
||||||
</p>
|
{safeOrgInfo.name}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<p className="text-sm text-gray-500">成员数量</p>
|
||||||
|
<p className="mt-1 text-lg font-semibold text-gray-900">
|
||||||
|
{safeOrgInfo.member_count} 人
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<p className="text-sm text-gray-500">创建时间</p>
|
||||||
|
<p className="mt-1 text-lg font-semibold text-gray-900">
|
||||||
|
{formatDate(safeOrgInfo.created_at)}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div>
|
) : (
|
||||||
<p className="text-sm text-gray-500">成员数量</p>
|
<p className="text-sm text-muted-foreground">组织信息加载中...</p>
|
||||||
<p className="mt-1 text-lg font-semibold text-gray-900">
|
)}
|
||||||
{safeOrgInfo.member_count} 人
|
|
||||||
</p>
|
|
||||||
</div>
|
|
||||||
<div>
|
|
||||||
<p className="text-sm text-gray-500">创建时间</p>
|
|
||||||
<p className="mt-1 text-lg font-semibold text-gray-900">
|
|
||||||
{formatDate(safeOrgInfo.created_at)}
|
|
||||||
</p>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</CardContent>
|
</CardContent>
|
||||||
</Card>
|
</Card>
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,6 @@ import Link from "next/link";
|
||||||
import {
|
import {
|
||||||
MetricCard,
|
MetricCard,
|
||||||
StageProgress,
|
StageProgress,
|
||||||
AgentStatusCard,
|
|
||||||
} from "@/components/business";
|
} from "@/components/business";
|
||||||
import { Button } from "@/components/ui/button";
|
import { Button } from "@/components/ui/button";
|
||||||
import { Badge } from "@/components/ui/badge";
|
import { Badge } from "@/components/ui/badge";
|
||||||
|
|
@ -34,30 +33,6 @@ const STAGE_CONFIG = [
|
||||||
{ id: "monitoring", label: "监测优化" },
|
{ id: "monitoring", label: "监测优化" },
|
||||||
];
|
];
|
||||||
|
|
||||||
const MOCK_AGENTS = [
|
|
||||||
{
|
|
||||||
name: "内容生成Agent",
|
|
||||||
description: "自动化内容生产",
|
|
||||||
status: "busy" as const,
|
|
||||||
lastActiveAt: "2分钟前",
|
|
||||||
completedCount: 156,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "引用监测Agent",
|
|
||||||
description: "AI平台引用追踪",
|
|
||||||
status: "online" as const,
|
|
||||||
lastActiveAt: "刚刚",
|
|
||||||
completedCount: 3420,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "SEO诊断Agent",
|
|
||||||
description: "搜索引擎优化分析",
|
|
||||||
status: "offline" as const,
|
|
||||||
lastActiveAt: "3小时前",
|
|
||||||
completedCount: 89,
|
|
||||||
},
|
|
||||||
];
|
|
||||||
|
|
||||||
function buildStages(currentStage: GeoProject["current_stage"]) {
|
function buildStages(currentStage: GeoProject["current_stage"]) {
|
||||||
const currentIndex = STAGE_CONFIG.findIndex((s) => s.id === currentStage);
|
const currentIndex = STAGE_CONFIG.findIndex((s) => s.id === currentStage);
|
||||||
return STAGE_CONFIG.map((stage, idx) => {
|
return STAGE_CONFIG.map((stage, idx) => {
|
||||||
|
|
@ -331,17 +306,10 @@ export default function DashboardPage() {
|
||||||
查看全部
|
查看全部
|
||||||
</Link>
|
</Link>
|
||||||
</div>
|
</div>
|
||||||
<div className="space-y-3">
|
<div className="flex flex-col items-center justify-center py-8 text-center">
|
||||||
{MOCK_AGENTS.map((agent) => (
|
<Zap className="h-8 w-8 text-muted-foreground mb-3" />
|
||||||
<AgentStatusCard
|
<p className="text-sm font-medium text-muted-foreground">功能开发中</p>
|
||||||
key={agent.name}
|
<p className="text-xs text-muted-foreground mt-1">Agent状态监控即将上线</p>
|
||||||
name={agent.name}
|
|
||||||
description={agent.description}
|
|
||||||
status={agent.status}
|
|
||||||
lastActiveAt={agent.lastActiveAt}
|
|
||||||
completedCount={agent.completedCount}
|
|
||||||
/>
|
|
||||||
))}
|
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue