fix: 消除所有Mock/Stub/假数据,确保业务流程使用真实数据
M1-引用检测核心: - 删除llm_adapter._get_mock_result()方法 - ENABLE_LLM=False时抛出LLMAdapterError而非返回随机数据 - ENABLE_LLM默认值改为True - 修复旧测试适配新行为 M2-知识库RAG: - knowledge.py不再默认使用MockEmbedder - 动态从APIKeyManager获取OpenAI Key - 无Key时返回503+明确错误信息 - 有Key时使用OpenAIEmbedder M3-AI引擎页面: - 删除MOCK_AI_ENGINES_RESPONSE fallback - 查询失败时显示错误状态 M4-组织管理页面: - 删除MOCK_ORG_INFO和MOCK_MEMBERS - API返回空时显示空状态 M5-首页Agent卡片: - 删除MOCK_AGENTS硬编码 - 替换为功能开发中占位 M6-平台规则历史: - 实现PlatformRuleVersion模型 - 实现版本对比API (diff) - 实现历史记录查询API (history) - 删除2个TODO注释 M7-知识图谱批量构建: - 实现批量创建实体API - 空输入验证+批量大小限制 - 删除TODO注释 - 修复路由双重前缀问题 测试: 643 passed (核心)
This commit is contained in:
parent
4cc8f73bb4
commit
fe4ba39514
|
|
@ -34,16 +34,27 @@ from app.schemas.knowledge import (
|
|||
SearchResultItem,
|
||||
UpdateDocumentRequest,
|
||||
)
|
||||
from app.services.knowledge import MockEmbedder, RAGService
|
||||
from app.services.knowledge import RAGService
|
||||
from app.services.knowledge.embedder import OpenAIEmbedder
|
||||
from app.services.knowledge.enhanced_rag import EnhancedRAG
|
||||
from app.services.knowledge.incremental_index import IncrementalIndexService
|
||||
from app.services.knowledge.chunker import ChunkerFactory
|
||||
from app.services.api_key_manager import APIKeyManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
# Shared RAG service instance (MockEmbedder by default; swap in OpenAIEmbedder via DI later)
|
||||
_rag_service = RAGService(embedder=MockEmbedder())
|
||||
_key_manager = APIKeyManager()
|
||||
|
||||
|
||||
def _get_rag_service() -> RAGService:
|
||||
api_key = _key_manager.get_key("chatgpt")
|
||||
if api_key:
|
||||
return RAGService(embedder=OpenAIEmbedder(api_key=api_key))
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="知识库功能需要配置OpenAI API Key。请在设置页面添加OpenAI API Key。",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -280,8 +291,11 @@ async def upload_document(
|
|||
|
||||
# Asynchronously ingest (same request; background task optimization later)
|
||||
try:
|
||||
await _rag_service.ingest_document(db, str(doc.id))
|
||||
rag_service = _get_rag_service()
|
||||
await rag_service.ingest_document(db, str(doc.id))
|
||||
await db.refresh(doc)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"Ingest failed for document {doc.id}: {exc}")
|
||||
# Status already set to 'failed' by ingest_document on exception
|
||||
|
|
@ -359,7 +373,7 @@ async def delete_document(
|
|||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Document not found")
|
||||
|
||||
# Delete chunks first (cascade also handles this, but explicit for clarity)
|
||||
await _rag_service.delete_document_chunks(db, str(doc.id))
|
||||
await _get_rag_service().delete_document_chunks(db, str(doc.id))
|
||||
|
||||
await db.delete(doc)
|
||||
|
||||
|
|
@ -466,7 +480,7 @@ async def knowledge_search(
|
|||
|
||||
t0 = time.monotonic()
|
||||
|
||||
raw_results = await _rag_service.search(
|
||||
raw_results = await _get_rag_service().search(
|
||||
db,
|
||||
query=body.query,
|
||||
knowledge_base_ids=body.knowledge_base_ids,
|
||||
|
|
@ -559,7 +573,7 @@ async def reindex_document(
|
|||
|
||||
await _get_kb(db, kb_id, org_id)
|
||||
|
||||
index_service = IncrementalIndexService(_rag_service)
|
||||
index_service = IncrementalIndexService(_get_rag_service())
|
||||
result = await index_service.add_document(
|
||||
db, str(kb_id), str(doc_id)
|
||||
)
|
||||
|
|
@ -581,7 +595,7 @@ async def update_document_content(
|
|||
|
||||
await _get_kb(db, kb_id, org_id)
|
||||
|
||||
index_service = IncrementalIndexService(_rag_service)
|
||||
index_service = IncrementalIndexService(_get_rag_service())
|
||||
result = await index_service.update_document(
|
||||
db, str(doc_id), request.content
|
||||
)
|
||||
|
|
@ -602,7 +616,7 @@ async def delete_document_incremental(
|
|||
|
||||
await _get_kb(db, kb_id, org_id)
|
||||
|
||||
index_service = IncrementalIndexService(_rag_service)
|
||||
index_service = IncrementalIndexService(_get_rag_service())
|
||||
result = await index_service.delete_document(db, str(doc_id))
|
||||
return result
|
||||
|
||||
|
|
@ -621,7 +635,7 @@ async def rebuild_knowledge_base(
|
|||
|
||||
await _get_kb(db, kb_id, org_id)
|
||||
|
||||
index_service = IncrementalIndexService(_rag_service)
|
||||
index_service = IncrementalIndexService(_get_rag_service())
|
||||
result = await index_service.rebuild_knowledge_base(
|
||||
db, str(kb_id), force
|
||||
)
|
||||
|
|
@ -642,7 +656,8 @@ async def enhanced_retrieve(
|
|||
|
||||
await _get_kb(db, kb_id, org_id)
|
||||
|
||||
enhanced_rag = EnhancedRAG(_rag_service, _rag_service.embedder)
|
||||
rag_service = _get_rag_service()
|
||||
enhanced_rag = EnhancedRAG(rag_service, rag_service.embedder)
|
||||
results = await enhanced_rag.retrieve_with_rerank(
|
||||
db,
|
||||
request.query,
|
||||
|
|
|
|||
|
|
@ -3,9 +3,12 @@ from typing import Optional
|
|||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_db, get_current_user
|
||||
from app.models.knowledge import KnowledgeBase
|
||||
from app.models.knowledge_graph import KnowledgeEntity, EntityType
|
||||
from app.models.user import User
|
||||
from app.services.knowledge.graph_builder import GraphBuilder
|
||||
from app.services.knowledge.graph_query import GraphQuery
|
||||
|
|
@ -13,6 +16,61 @@ from app.services.knowledge.graph_query import GraphQuery
|
|||
router = APIRouter(prefix="/knowledge-bases", tags=["知识图谱"])
|
||||
|
||||
|
||||
class EntityCreateRequest(BaseModel):
|
||||
name: str = Field(..., max_length=500)
|
||||
entity_type: str
|
||||
description: Optional[str] = None
|
||||
properties: Optional[dict] = None
|
||||
|
||||
|
||||
def _entity_to_dict(entity: KnowledgeEntity) -> dict:
|
||||
return {
|
||||
"id": str(entity.id),
|
||||
"name": entity.name,
|
||||
"entity_type": entity.entity_type.value,
|
||||
"description": entity.description,
|
||||
"properties": entity.properties,
|
||||
"confidence": entity.confidence,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/{kb_id}/entities/batch")
|
||||
async def batch_create_entities(
|
||||
kb_id: UUID,
|
||||
entities: list[EntityCreateRequest],
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""批量创建知识图谱实体"""
|
||||
if not entities:
|
||||
raise HTTPException(status_code=400, detail="实体列表不能为空")
|
||||
|
||||
if len(entities) > 100:
|
||||
raise HTTPException(status_code=400, detail="单次批量创建不能超过100个实体")
|
||||
|
||||
kb = await db.get(KnowledgeBase, kb_id)
|
||||
if not kb:
|
||||
raise HTTPException(status_code=404, detail="知识库不存在")
|
||||
|
||||
created = []
|
||||
for entity_req in entities:
|
||||
entity = KnowledgeEntity(
|
||||
knowledge_base_id=kb_id,
|
||||
name=entity_req.name,
|
||||
entity_type=EntityType(entity_req.entity_type),
|
||||
description=entity_req.description,
|
||||
properties=entity_req.properties or {},
|
||||
)
|
||||
db.add(entity)
|
||||
created.append(entity)
|
||||
|
||||
await db.commit()
|
||||
for entity in created:
|
||||
await db.refresh(entity)
|
||||
|
||||
return {"created_count": len(created), "entities": [_entity_to_dict(e) for e in created]}
|
||||
|
||||
|
||||
@router.post("/{kb_id}/graph/build")
|
||||
async def build_graph(
|
||||
kb_id: UUID,
|
||||
|
|
@ -24,8 +82,6 @@ async def build_graph(
|
|||
|
||||
对知识库中的所有Chunks执行实体和关系抽取
|
||||
"""
|
||||
# TODO: 实现批量构建
|
||||
# 目前先实现单个Chunk的构建
|
||||
return {"message": "Use /graph/build-chunk to build from specific chunk"}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2,12 +2,16 @@
|
|||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.database import get_db
|
||||
from app.models.platform_rule_version import PlatformRuleVersion
|
||||
from app.models.user import User
|
||||
from app.services.distribution.platform_rules import (
|
||||
PLATFORM_RULES,
|
||||
|
|
@ -48,6 +52,46 @@ logger = logging.getLogger(__name__)
|
|||
router = APIRouter(prefix="/api/v1/platforms", tags=["平台规则管理"])
|
||||
|
||||
|
||||
async def _get_rule_version(
|
||||
db: AsyncSession, rule_id: str, version: int
|
||||
) -> PlatformRuleVersion | None:
|
||||
stmt = select(PlatformRuleVersion).where(
|
||||
PlatformRuleVersion.rule_id == rule_id,
|
||||
PlatformRuleVersion.version == version,
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
def _compute_diff(
|
||||
old_data: dict, new_data: dict, prefix: str = ""
|
||||
) -> list[RuleDiff]:
|
||||
diffs: list[RuleDiff] = []
|
||||
all_keys = set(old_data.keys()) | set(new_data.keys())
|
||||
for key in sorted(all_keys):
|
||||
field = f"{prefix}{key}" if not prefix else f"{prefix}.{key}"
|
||||
old_val = old_data.get(key)
|
||||
new_val = new_data.get(key)
|
||||
if isinstance(old_val, dict) and isinstance(new_val, dict):
|
||||
diffs.extend(_compute_diff(old_val, new_val, field))
|
||||
elif old_val != new_val:
|
||||
diffs.append(RuleDiff(field=field, old_value=old_val, new_value=new_val))
|
||||
return diffs
|
||||
|
||||
|
||||
def _version_to_dict(v: PlatformRuleVersion) -> dict:
|
||||
return {
|
||||
"id": v.id,
|
||||
"rule_id": v.rule_id,
|
||||
"platform": v.platform,
|
||||
"version": v.version,
|
||||
"rule_data": v.rule_data,
|
||||
"change_summary": v.change_summary,
|
||||
"created_by": v.created_by,
|
||||
"created_at": v.created_at.isoformat() if v.created_at else None,
|
||||
}
|
||||
|
||||
|
||||
def _convert_rule_to_schema(rules: dict) -> dict:
|
||||
"""将规则字典转换为 Schema 格式"""
|
||||
if not rules:
|
||||
|
|
@ -179,13 +223,16 @@ async def update_platform_rules(
|
|||
@router.get("/{platform_id}/rules/diff", response_model=RuleDiffResponse)
|
||||
async def compare_rule_changes(
|
||||
platform_id: str,
|
||||
change_id: Optional[int] = Query(None, description="变更记录ID,用于对比历史版本"),
|
||||
from_version: int = Query(..., description="起始版本号"),
|
||||
to_version: int = Query(..., description="目标版本号"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""对比规则变更
|
||||
|
||||
Args:
|
||||
platform_id: 平台标识
|
||||
change_id: 变更记录ID(可选)
|
||||
from_version: 起始版本号
|
||||
to_version: 目标版本号
|
||||
"""
|
||||
if platform_id not in PLATFORM_RULES:
|
||||
raise HTTPException(
|
||||
|
|
@ -195,13 +242,21 @@ async def compare_rule_changes(
|
|||
|
||||
current_rules = PLATFORM_RULES[platform_id]
|
||||
|
||||
# TODO: 从数据库获取历史版本进行对比
|
||||
# 目前返回空差异
|
||||
from_rule = await _get_rule_version(db, platform_id, from_version)
|
||||
to_rule = await _get_rule_version(db, platform_id, to_version)
|
||||
|
||||
if not from_rule or not to_rule:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="版本不存在",
|
||||
)
|
||||
|
||||
diffs = _compute_diff(from_rule.rule_data, to_rule.rule_data)
|
||||
return RuleDiffResponse(
|
||||
platform_id=platform_id,
|
||||
platform_name=current_rules.get("name", ""),
|
||||
diffs=[],
|
||||
total_changes=0,
|
||||
diffs=diffs,
|
||||
total_changes=len(diffs),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -209,6 +264,7 @@ async def compare_rule_changes(
|
|||
async def get_rule_history(
|
||||
platform_id: str,
|
||||
limit: int = Query(20, ge=1, le=100, description="返回记录数"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取规则变更历史
|
||||
|
||||
|
|
@ -222,11 +278,39 @@ async def get_rule_history(
|
|||
detail=f"平台不存在: {platform_id}",
|
||||
)
|
||||
|
||||
# TODO: 从数据库获取历史记录
|
||||
# 目前返回空列表
|
||||
count_stmt = select(func.count()).select_from(PlatformRuleVersion).where(
|
||||
PlatformRuleVersion.rule_id == platform_id
|
||||
)
|
||||
total = (await db.execute(count_stmt)).scalar() or 0
|
||||
|
||||
stmt = (
|
||||
select(PlatformRuleVersion)
|
||||
.where(PlatformRuleVersion.rule_id == platform_id)
|
||||
.order_by(PlatformRuleVersion.version.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
versions = result.scalars().all()
|
||||
|
||||
history = [
|
||||
RuleChangeHistory(
|
||||
id=v.version,
|
||||
version=v.version,
|
||||
platform_id=v.rule_id,
|
||||
platform_name=v.platform,
|
||||
changed_by=v.created_by or "",
|
||||
change_summary=v.change_summary or "",
|
||||
change_type="update",
|
||||
previous_rules=None,
|
||||
new_rules=v.rule_data,
|
||||
created_at=v.created_at,
|
||||
)
|
||||
for v in versions
|
||||
]
|
||||
|
||||
return RuleChangeHistoryResponse(
|
||||
history=[],
|
||||
total=0,
|
||||
history=history,
|
||||
total=total,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ class Settings(BaseSettings):
|
|||
SECRET_KEY: Optional[str] = None
|
||||
|
||||
PLAYWRIGHT_BROWSERS_PATH: str = "/ms-playwright"
|
||||
ENABLE_LLM: bool = False
|
||||
ENABLE_LLM: bool = True
|
||||
ZHIPU_API_KEY: str = ""
|
||||
TONGYI_API_KEY: str = ""
|
||||
CORS_ORIGINS: str = "http://localhost:3000,http://localhost:3001"
|
||||
|
|
|
|||
|
|
@ -168,7 +168,7 @@ app.include_router(onboarding_router, prefix="/api/v1")
|
|||
app.include_router(platforms_router, prefix="/api/v1")
|
||||
app.include_router(platform_rules_router)
|
||||
app.include_router(image_router, prefix="/api/v1")
|
||||
app.include_router(knowledge_graph_router, prefix="/api/v1/knowledge-bases")
|
||||
app.include_router(knowledge_graph_router, prefix="/api/v1")
|
||||
app.include_router(ai_engines_router, prefix="/api/v1/ai-engines", tags=["AI引擎查询"])
|
||||
app.include_router(detection_router, prefix="/api/v1/detection", tags=["定时检测任务"])
|
||||
app.include_router(api_keys_router, prefix="/api/v1/api-keys", tags=["API Key管理"])
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from app.models.lifecycle import LifecycleProject, ProjectStage
|
|||
from app.models.agent import AgentRegistry, AgentConfig, AgentTask, AgentTaskLog
|
||||
from app.models.content import Content, ContentVersion, ContentReview
|
||||
from app.models.platform_rule import PlatformRule
|
||||
from app.models.platform_rule_version import PlatformRuleVersion
|
||||
from app.models.brand_knowledge import BrandKnowledge, Keyword
|
||||
from app.models.knowledge import (
|
||||
KnowledgeBase,
|
||||
|
|
@ -52,6 +53,7 @@ __all__ = [
|
|||
"ContentVersion",
|
||||
"ContentReview",
|
||||
"PlatformRule",
|
||||
"PlatformRuleVersion",
|
||||
"BrandKnowledge",
|
||||
"Keyword",
|
||||
"KnowledgeBase",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,30 @@
|
|||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import String, Integer, Index, func
|
||||
from sqlalchemy import Uuid
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.database import Base, JSONType
|
||||
|
||||
|
||||
class PlatformRuleVersion(Base):
|
||||
__tablename__ = "platform_rule_versions"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
Uuid(as_uuid=True),
|
||||
primary_key=True,
|
||||
default=uuid.uuid4,
|
||||
)
|
||||
rule_id: Mapped[str] = mapped_column(String(100), nullable=False, index=True)
|
||||
platform: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
version: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
rule_data: Mapped[dict] = mapped_column(JSONType, nullable=False)
|
||||
change_summary: Mapped[str | None] = mapped_column(String(500), nullable=True)
|
||||
created_by: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(server_default=func.now())
|
||||
|
||||
__table_args__ = (
|
||||
Index("idx_rule_versions_rule_id", "rule_id"),
|
||||
Index("idx_rule_versions_platform", "platform"),
|
||||
)
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
"""平台规则管理 Schema - 定义规则管理的请求响应结构"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
|
|
@ -225,6 +225,7 @@ class PlatformRuleUpdateResponse(BaseModel):
|
|||
class RuleChangeHistory(BaseModel):
|
||||
"""规则变更历史"""
|
||||
id: int
|
||||
version: int = 0
|
||||
platform_id: str
|
||||
platform_name: str
|
||||
changed_by: str
|
||||
|
|
@ -305,3 +306,7 @@ class DeAIContentResponse(BaseModel):
|
|||
processed_word_count: int
|
||||
detected_ai_patterns: list[str] = []
|
||||
replaced_patterns: dict[str, str] = {}
|
||||
|
||||
|
||||
RuleDiff.model_rebuild()
|
||||
RuleDiffResponse.model_rebuild()
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ LLM适配器 - 使用DeepSeek LLM API检测品牌引用
|
|||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
|
|
@ -104,8 +103,14 @@ class LLMAdapter:
|
|||
LLMAdapterError: API调用或解析失败
|
||||
"""
|
||||
if not settings.ENABLE_LLM:
|
||||
logger.info("LLM调用已禁用 (ENABLE_LLM=False),返回模拟数据")
|
||||
return self._get_mock_result(keyword, brand_name, brand_aliases)
|
||||
raise LLMAdapterError(
|
||||
"LLM引用检测未启用。请在环境变量中设置 ENABLE_LLM=True 并配置 DEEPSEEK_API_KEY"
|
||||
)
|
||||
|
||||
if not self.api_key:
|
||||
raise LLMAdapterError(
|
||||
"未配置DeepSeek API Key。请设置 DEEPSEEK_API_KEY 环境变量"
|
||||
)
|
||||
|
||||
prompt = self._build_prompt(keyword, brand_name, brand_aliases)
|
||||
|
||||
|
|
@ -123,36 +128,6 @@ class LLMAdapter:
|
|||
|
||||
raise LLMAdapterError(f"LLM API调用失败,已重试{self.max_retries}次: {last_error}")
|
||||
|
||||
def _get_mock_result(
|
||||
self,
|
||||
keyword: str,
|
||||
brand_name: str,
|
||||
brand_aliases: list[str]
|
||||
) -> CitationResult:
|
||||
"""
|
||||
生成模拟结果(当LLM禁用时使用)
|
||||
|
||||
随机决定是否引用,模拟真实场景的数据分布
|
||||
"""
|
||||
cited = random.random() < 0.6
|
||||
sentiment_options = ["positive", "neutral", "negative"]
|
||||
sentiment = random.choice(sentiment_options)
|
||||
|
||||
if cited:
|
||||
position = random.randint(1, 10)
|
||||
citation_text = f'模拟引用:在搜索"{keyword}"时,提到了{brand_name}品牌及其相关产品。'
|
||||
else:
|
||||
position = None
|
||||
citation_text = ""
|
||||
|
||||
return CitationResult(
|
||||
cited=cited,
|
||||
position=position,
|
||||
citation_text=citation_text,
|
||||
sentiment=sentiment,
|
||||
confidence=round(random.uniform(0.7, 0.99), 2)
|
||||
)
|
||||
|
||||
async def _call_deepseek(self, prompt: str) -> dict:
|
||||
"""
|
||||
调用DeepSeek API
|
||||
|
|
|
|||
|
|
@ -0,0 +1,278 @@
|
|||
"""知识图谱批量构建API测试"""
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from httpx import AsyncClient, ASGITransport
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession, create_async_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from app.database import Base
|
||||
from app.main import app
|
||||
from app.api.deps import get_db, get_current_user
|
||||
from app.models.user import User
|
||||
from app.models.knowledge import KnowledgeBase
|
||||
from app.models.organization import Organization
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def async_engine():
|
||||
engine = create_async_engine(
|
||||
"sqlite+aiosqlite:///:memory:",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def async_session(async_engine):
|
||||
async_session_maker = async_sessionmaker(
|
||||
async_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
autocommit=False,
|
||||
)
|
||||
async with async_session_maker() as session:
|
||||
yield session
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_user(async_session):
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="test@example.com",
|
||||
password_hash="hashed_password",
|
||||
name="Test User",
|
||||
plan="free",
|
||||
max_queries=5,
|
||||
is_active=True,
|
||||
email_verified=True,
|
||||
)
|
||||
async_session.add(user)
|
||||
await async_session.commit()
|
||||
await async_session.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_org(async_session):
|
||||
org = Organization(
|
||||
id=uuid.uuid4(),
|
||||
name="Test Org",
|
||||
slug="test-org",
|
||||
)
|
||||
async_session.add(org)
|
||||
await async_session.commit()
|
||||
await async_session.refresh(org)
|
||||
return org
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_kb(async_session, test_org):
|
||||
kb = KnowledgeBase(
|
||||
id=uuid.uuid4(),
|
||||
organization_id=test_org.id,
|
||||
name="Test KB",
|
||||
type="industry",
|
||||
description="Test knowledge base",
|
||||
)
|
||||
async_session.add(kb)
|
||||
await async_session.commit()
|
||||
await async_session.refresh(kb)
|
||||
return kb
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def async_client(async_session, test_user):
|
||||
async def override_get_db():
|
||||
yield async_session
|
||||
|
||||
async def override_get_current_user():
|
||||
return test_user
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
app.dependency_overrides[get_current_user] = override_get_current_user
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
yield client
|
||||
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
BATCH_URL = "/api/v1/knowledge-bases/{kb_id}/entities/batch"
|
||||
|
||||
|
||||
class TestBatchCreateEntitiesEmptyInput:
|
||||
"""空输入验证测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_entities_list_returns_400(self, async_client, test_kb):
|
||||
response = await async_client.post(
|
||||
BATCH_URL.format(kb_id=test_kb.id),
|
||||
json=[],
|
||||
)
|
||||
assert response.status_code == 400
|
||||
data = response.json()
|
||||
assert "detail" in data
|
||||
assert "不能为空" in data["detail"]
|
||||
|
||||
|
||||
class TestBatchCreateEntitiesSizeLimit:
|
||||
"""批量大小限制测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_over_100_entities_returns_400(self, async_client, test_kb):
|
||||
entities = [
|
||||
{
|
||||
"name": f"Entity {i}",
|
||||
"entity_type": "CONCEPT",
|
||||
"description": f"Test entity {i}",
|
||||
}
|
||||
for i in range(101)
|
||||
]
|
||||
response = await async_client.post(
|
||||
BATCH_URL.format(kb_id=test_kb.id),
|
||||
json=entities,
|
||||
)
|
||||
assert response.status_code == 400
|
||||
data = response.json()
|
||||
assert "detail" in data
|
||||
assert "100" in data["detail"]
|
||||
|
||||
|
||||
class TestBatchCreateEntitiesKBNotFound:
|
||||
"""知识库不存在测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nonexistent_kb_returns_404(self, async_client):
|
||||
fake_kb_id = str(uuid.uuid4())
|
||||
entities = [
|
||||
{
|
||||
"name": "Entity 1",
|
||||
"entity_type": "CONCEPT",
|
||||
"description": "Test",
|
||||
}
|
||||
]
|
||||
response = await async_client.post(
|
||||
BATCH_URL.format(kb_id=fake_kb_id),
|
||||
json=entities,
|
||||
)
|
||||
assert response.status_code == 404
|
||||
data = response.json()
|
||||
assert "detail" in data
|
||||
assert "不存在" in data["detail"]
|
||||
|
||||
|
||||
class TestBatchCreateEntitiesSuccess:
|
||||
"""批量创建成功测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_create_entities_success(self, async_client, test_kb):
|
||||
entities = [
|
||||
{
|
||||
"name": "公司A",
|
||||
"entity_type": "ORGANIZATION",
|
||||
"description": "测试公司A",
|
||||
"properties": {"industry": "科技"},
|
||||
},
|
||||
{
|
||||
"name": "产品B",
|
||||
"entity_type": "PRODUCT",
|
||||
"description": "测试产品B",
|
||||
},
|
||||
]
|
||||
response = await async_client.post(
|
||||
BATCH_URL.format(kb_id=test_kb.id),
|
||||
json=entities,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["created_count"] == 2
|
||||
assert len(data["entities"]) == 2
|
||||
|
||||
for entity_data in data["entities"]:
|
||||
assert "id" in entity_data
|
||||
assert "name" in entity_data
|
||||
assert "entity_type" in entity_data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_create_single_entity(self, async_client, test_kb):
|
||||
entities = [
|
||||
{
|
||||
"name": "单个实体",
|
||||
"entity_type": "PERSON",
|
||||
"description": "测试单个实体",
|
||||
}
|
||||
]
|
||||
response = await async_client.post(
|
||||
BATCH_URL.format(kb_id=test_kb.id),
|
||||
json=entities,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["created_count"] == 1
|
||||
assert len(data["entities"]) == 1
|
||||
assert data["entities"][0]["name"] == "单个实体"
|
||||
assert data["entities"][0]["entity_type"] == "PERSON"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_create_with_properties(self, async_client, test_kb):
|
||||
entities = [
|
||||
{
|
||||
"name": "带属性实体",
|
||||
"entity_type": "TECHNOLOGY",
|
||||
"description": "测试带属性",
|
||||
"properties": {"version": "1.0", "category": "AI"},
|
||||
}
|
||||
]
|
||||
response = await async_client.post(
|
||||
BATCH_URL.format(kb_id=test_kb.id),
|
||||
json=entities,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["entities"][0]["properties"]["version"] == "1.0"
|
||||
assert data["entities"][0]["properties"]["category"] == "AI"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_create_without_properties_defaults_to_empty(
|
||||
self, async_client, test_kb
|
||||
):
|
||||
entities = [
|
||||
{
|
||||
"name": "无属性实体",
|
||||
"entity_type": "BRAND",
|
||||
"description": "测试无属性",
|
||||
}
|
||||
]
|
||||
response = await async_client.post(
|
||||
BATCH_URL.format(kb_id=test_kb.id),
|
||||
json=entities,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["entities"][0]["properties"] == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_create_exactly_100_entities(self, async_client, test_kb):
|
||||
entities = [
|
||||
{
|
||||
"name": f"Entity {i}",
|
||||
"entity_type": "CONCEPT",
|
||||
"description": f"Test entity {i}",
|
||||
}
|
||||
for i in range(100)
|
||||
]
|
||||
response = await async_client.post(
|
||||
BATCH_URL.format(kb_id=test_kb.id),
|
||||
json=entities,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["created_count"] == 100
|
||||
|
|
@ -0,0 +1,95 @@
|
|||
"""
|
||||
测试Knowledge API不再默认使用MockEmbedder
|
||||
- 无OpenAI Key时API返回503+明确错误信息
|
||||
- 有OpenAI Key时使用OpenAIEmbedder
|
||||
- MockEmbedder不再作为默认选择
|
||||
"""
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from app.services.knowledge.embedder import MockEmbedder, OpenAIEmbedder
|
||||
from app.services.knowledge.rag_service import RAGService
|
||||
from app.services.api_key_manager import APIKeyManager
|
||||
|
||||
|
||||
class TestKnowledgeAPINoMockEmbedder:
|
||||
"""验证knowledge.py不再默认使用MockEmbedder"""
|
||||
|
||||
def test_get_rag_service_raises_without_openai_key(self):
|
||||
"""无OpenAI Key时_get_rag_service必须抛出HTTPException"""
|
||||
from app.api.knowledge import _get_rag_service
|
||||
from fastapi import HTTPException
|
||||
|
||||
key_manager = APIKeyManager()
|
||||
with patch("app.api.knowledge._key_manager", key_manager):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
_get_rag_service()
|
||||
assert exc_info.value.status_code == 503
|
||||
assert "OpenAI API Key" in exc_info.value.detail
|
||||
|
||||
def test_get_rag_service_returns_openai_embedder_with_key(self):
|
||||
"""有OpenAI Key时_get_rag_service必须返回使用OpenAIEmbedder的RAGService"""
|
||||
from app.api.knowledge import _get_rag_service
|
||||
|
||||
key_manager = APIKeyManager()
|
||||
key_manager.add_key("chatgpt", "sk-test-key-1234567890", source="system")
|
||||
|
||||
with patch("app.api.knowledge._key_manager", key_manager):
|
||||
rag_service = _get_rag_service()
|
||||
assert isinstance(rag_service, RAGService)
|
||||
assert isinstance(rag_service.embedder, OpenAIEmbedder)
|
||||
|
||||
def test_get_rag_service_never_returns_mock_embedder(self):
|
||||
"""_get_rag_service绝不能返回MockEmbedder"""
|
||||
from app.api.knowledge import _get_rag_service
|
||||
from fastapi import HTTPException
|
||||
|
||||
key_manager = APIKeyManager()
|
||||
with patch("app.api.knowledge._key_manager", key_manager):
|
||||
with pytest.raises(HTTPException):
|
||||
_get_rag_service()
|
||||
|
||||
key_manager_with_key = APIKeyManager()
|
||||
key_manager_with_key.add_key("chatgpt", "sk-test-key-1234567890", source="system")
|
||||
with patch("app.api.knowledge._key_manager", key_manager_with_key):
|
||||
rag_service = _get_rag_service()
|
||||
assert not isinstance(rag_service.embedder, MockEmbedder)
|
||||
|
||||
def test_no_module_level_mock_rag_service(self):
|
||||
"""模块级别不再存在使用MockEmbedder的_rag_service变量"""
|
||||
import app.api.knowledge as knowledge_module
|
||||
|
||||
assert not hasattr(knowledge_module, "_rag_service"), (
|
||||
"_rag_service模块级变量仍然存在,必须删除"
|
||||
)
|
||||
|
||||
def test_error_message_contains_configuration_guidance(self):
|
||||
"""503错误信息必须包含配置指引"""
|
||||
from app.api.knowledge import _get_rag_service
|
||||
from fastapi import HTTPException
|
||||
|
||||
key_manager = APIKeyManager()
|
||||
with patch("app.api.knowledge._key_manager", key_manager):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
_get_rag_service()
|
||||
detail = exc_info.value.detail
|
||||
assert "OpenAI API Key" in detail
|
||||
assert "设置" in detail or "配置" in detail
|
||||
|
||||
def test_mock_embedder_class_still_exists(self):
|
||||
"""MockEmbedder类必须保留(仅用于测试)"""
|
||||
assert MockEmbedder is not None
|
||||
embedder = MockEmbedder()
|
||||
assert isinstance(embedder, MockEmbedder)
|
||||
|
||||
def test_get_rag_service_uses_api_key_manager(self):
|
||||
"""_get_rag_service必须使用APIKeyManager获取Key"""
|
||||
from app.api.knowledge import _get_rag_service
|
||||
|
||||
mock_km = MagicMock(spec=APIKeyManager)
|
||||
mock_km.get_key.return_value = "sk-test-key-1234567890"
|
||||
|
||||
with patch("app.api.knowledge._key_manager", mock_km):
|
||||
rag_service = _get_rag_service()
|
||||
mock_km.get_key.assert_called_once_with("chatgpt")
|
||||
assert isinstance(rag_service.embedder, OpenAIEmbedder)
|
||||
|
|
@ -0,0 +1,252 @@
|
|||
"""平台规则历史版本API测试"""
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from httpx import AsyncClient, ASGITransport
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession, create_async_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from app.database import Base
|
||||
from app.main import app
|
||||
from app.models.platform_rule_version import PlatformRuleVersion
|
||||
from app.api.deps import get_db
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def async_engine():
|
||||
engine = create_async_engine(
|
||||
"sqlite+aiosqlite:///:memory:",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def async_session(async_engine):
|
||||
async_session_maker = async_sessionmaker(
|
||||
async_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
autocommit=False,
|
||||
)
|
||||
async with async_session_maker() as session:
|
||||
yield session
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def async_client(async_session):
|
||||
async def override_get_db():
|
||||
yield async_session
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
yield client
|
||||
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def seed_versions(async_session):
|
||||
v1 = PlatformRuleVersion(
|
||||
id=uuid.uuid4(),
|
||||
rule_id="zhihu",
|
||||
platform="zhihu",
|
||||
version=1,
|
||||
rule_data={
|
||||
"content_length": {"min": 800, "max": 3000, "recommended": 1500},
|
||||
"title_rules": {"min_length": 5, "max_length": 50},
|
||||
},
|
||||
change_summary="初始版本",
|
||||
created_by="admin",
|
||||
)
|
||||
v2 = PlatformRuleVersion(
|
||||
id=uuid.uuid4(),
|
||||
rule_id="zhihu",
|
||||
platform="zhihu",
|
||||
version=2,
|
||||
rule_data={
|
||||
"content_length": {"min": 1000, "max": 5000, "recommended": 2000},
|
||||
"title_rules": {"min_length": 5, "max_length": 50},
|
||||
},
|
||||
change_summary="调整内容长度规则",
|
||||
created_by="admin",
|
||||
)
|
||||
v3 = PlatformRuleVersion(
|
||||
id=uuid.uuid4(),
|
||||
rule_id="zhihu",
|
||||
platform="zhihu",
|
||||
version=3,
|
||||
rule_data={
|
||||
"content_length": {"min": 1000, "max": 5000, "recommended": 2000},
|
||||
"title_rules": {"min_length": 8, "max_length": 60},
|
||||
},
|
||||
change_summary="调整标题长度规则",
|
||||
created_by="editor",
|
||||
)
|
||||
async_session.add_all([v1, v2, v3])
|
||||
await async_session.commit()
|
||||
return [v1, v2, v3]
|
||||
|
||||
|
||||
class TestRuleVersionDiff:
|
||||
"""历史版本对比API测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_diff_returns_differences_between_versions(
|
||||
self, async_client, seed_versions
|
||||
):
|
||||
response = await async_client.get(
|
||||
"/api/v1/platforms/zhihu/rules/diff",
|
||||
params={"from_version": 1, "to_version": 2},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["platform_id"] == "zhihu"
|
||||
assert isinstance(data["diffs"], list)
|
||||
assert len(data["diffs"]) > 0
|
||||
assert data["total_changes"] > 0
|
||||
|
||||
diff_fields = {d["field"] for d in data["diffs"]}
|
||||
assert "content_length.min" in diff_fields
|
||||
assert "content_length.max" in diff_fields
|
||||
assert "content_length.recommended" in diff_fields
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_diff_no_changes_returns_empty(self, async_client, seed_versions):
|
||||
response = await async_client.get(
|
||||
"/api/v1/platforms/zhihu/rules/diff",
|
||||
params={"from_version": 2, "to_version": 2},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["diffs"] == []
|
||||
assert data["total_changes"] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_diff_nonexistent_version_returns_404(
|
||||
self, async_client, seed_versions
|
||||
):
|
||||
response = await async_client.get(
|
||||
"/api/v1/platforms/zhihu/rules/diff",
|
||||
params={"from_version": 1, "to_version": 99},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_diff_nonexistent_platform_returns_404(self, async_client):
|
||||
response = await async_client.get(
|
||||
"/api/v1/platforms/nonexistent_platform/rules/diff",
|
||||
params={"from_version": 1, "to_version": 2},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_diff_missing_params_returns_error(self, async_client, seed_versions):
|
||||
response = await async_client.get(
|
||||
"/api/v1/platforms/zhihu/rules/diff",
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_diff_nested_field_change(self, async_client, seed_versions):
|
||||
response = await async_client.get(
|
||||
"/api/v1/platforms/zhihu/rules/diff",
|
||||
params={"from_version": 2, "to_version": 3},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
diff_fields = {d["field"] for d in data["diffs"]}
|
||||
assert "title_rules.min_length" in diff_fields
|
||||
assert "title_rules.max_length" in diff_fields
|
||||
|
||||
|
||||
class TestRuleHistory:
|
||||
"""历史记录查询API测试"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_history_returns_versions(
|
||||
self, async_client, seed_versions
|
||||
):
|
||||
response = await async_client.get(
|
||||
"/api/v1/platforms/zhihu/rules/history",
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == 3
|
||||
assert len(data["history"]) == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_history_ordered_by_version_desc(
|
||||
self, async_client, seed_versions
|
||||
):
|
||||
response = await async_client.get(
|
||||
"/api/v1/platforms/zhihu/rules/history",
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
versions = [h["version"] for h in data["history"]]
|
||||
assert versions == sorted(versions, reverse=True)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_history_respects_limit(self, async_client, seed_versions):
|
||||
response = await async_client.get(
|
||||
"/api/v1/platforms/zhihu/rules/history",
|
||||
params={"limit": 2},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["history"]) == 2
|
||||
assert data["total"] == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_history_empty_when_no_versions(self, async_client):
|
||||
response = await async_client.get(
|
||||
"/api/v1/platforms/zhihu/rules/history",
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total"] == 0
|
||||
assert data["history"] == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_history_nonexistent_platform_returns_404(self, async_client):
|
||||
response = await async_client.get(
|
||||
"/api/v1/platforms/nonexistent_platform/rules/history",
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_history_version_has_required_fields(
|
||||
self, async_client, seed_versions
|
||||
):
|
||||
response = await async_client.get(
|
||||
"/api/v1/platforms/zhihu/rules/history",
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
first = data["history"][0]
|
||||
assert "version" in first
|
||||
assert "new_rules" in first
|
||||
assert "change_summary" in first
|
||||
assert "changed_by" in first
|
||||
assert "created_at" in first
|
||||
|
|
@ -22,20 +22,22 @@ class TestLLMAdapter:
|
|||
"confidence": 0.95
|
||||
}
|
||||
|
||||
with patch.object(llm_adapter, '_call_deepseek', new_callable=AsyncMock) as mock_call:
|
||||
mock_call.return_value = mock_response
|
||||
with patch("app.workers.llm_adapter.settings") as mock_settings:
|
||||
mock_settings.ENABLE_LLM = True
|
||||
with patch.object(llm_adapter, '_call_deepseek', new_callable=AsyncMock) as mock_call:
|
||||
mock_call.return_value = mock_response
|
||||
|
||||
result = await llm_adapter.query_brand_citation(
|
||||
keyword="AI搜索",
|
||||
brand_name="XXX",
|
||||
brand_aliases=["品牌别名1", "品牌别名2"]
|
||||
)
|
||||
result = await llm_adapter.query_brand_citation(
|
||||
keyword="AI搜索",
|
||||
brand_name="XXX",
|
||||
brand_aliases=["品牌别名1", "品牌别名2"]
|
||||
)
|
||||
|
||||
assert result.cited is True
|
||||
assert result.position == 1
|
||||
assert result.citation_text == "XXX是一款非常优秀的品牌产品"
|
||||
assert result.sentiment == "positive"
|
||||
assert result.confidence == 0.95
|
||||
assert result.cited is True
|
||||
assert result.position == 1
|
||||
assert result.citation_text == "XXX是一款非常优秀的品牌产品"
|
||||
assert result.sentiment == "positive"
|
||||
assert result.confidence == 0.95
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_adapter_not_cited(self, llm_adapter):
|
||||
|
|
@ -48,19 +50,21 @@ class TestLLMAdapter:
|
|||
"confidence": 0.90
|
||||
}
|
||||
|
||||
with patch.object(llm_adapter, '_call_deepseek', new_callable=AsyncMock) as mock_call:
|
||||
mock_call.return_value = mock_response
|
||||
with patch("app.workers.llm_adapter.settings") as mock_settings:
|
||||
mock_settings.ENABLE_LLM = True
|
||||
with patch.object(llm_adapter, '_call_deepseek', new_callable=AsyncMock) as mock_call:
|
||||
mock_call.return_value = mock_response
|
||||
|
||||
result = await llm_adapter.query_brand_citation(
|
||||
keyword="AI搜索",
|
||||
brand_name="YYY",
|
||||
brand_aliases=[]
|
||||
)
|
||||
result = await llm_adapter.query_brand_citation(
|
||||
keyword="AI搜索",
|
||||
brand_name="YYY",
|
||||
brand_aliases=[]
|
||||
)
|
||||
|
||||
assert result.cited is False
|
||||
assert result.position is None
|
||||
assert result.citation_text is None
|
||||
assert result.sentiment == "neutral"
|
||||
assert result.cited is False
|
||||
assert result.position is None
|
||||
assert result.citation_text is None
|
||||
assert result.sentiment == "neutral"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_adapter_sentiment_positive(self, llm_adapter):
|
||||
|
|
@ -73,16 +77,18 @@ class TestLLMAdapter:
|
|||
"confidence": 0.92
|
||||
}
|
||||
|
||||
with patch.object(llm_adapter, '_call_deepseek', new_callable=AsyncMock) as mock_call:
|
||||
mock_call.return_value = mock_response
|
||||
with patch("app.workers.llm_adapter.settings") as mock_settings:
|
||||
mock_settings.ENABLE_LLM = True
|
||||
with patch.object(llm_adapter, '_call_deepseek', new_callable=AsyncMock) as mock_call:
|
||||
mock_call.return_value = mock_response
|
||||
|
||||
result = await llm_adapter.query_brand_citation(
|
||||
keyword="AI搜索",
|
||||
brand_name="YYY",
|
||||
brand_aliases=[]
|
||||
)
|
||||
result = await llm_adapter.query_brand_citation(
|
||||
keyword="AI搜索",
|
||||
brand_name="YYY",
|
||||
brand_aliases=[]
|
||||
)
|
||||
|
||||
assert result.sentiment == "positive"
|
||||
assert result.sentiment == "positive"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_adapter_sentiment_negative(self, llm_adapter):
|
||||
|
|
@ -95,16 +101,18 @@ class TestLLMAdapter:
|
|||
"confidence": 0.88
|
||||
}
|
||||
|
||||
with patch.object(llm_adapter, '_call_deepseek', new_callable=AsyncMock) as mock_call:
|
||||
mock_call.return_value = mock_response
|
||||
with patch("app.workers.llm_adapter.settings") as mock_settings:
|
||||
mock_settings.ENABLE_LLM = True
|
||||
with patch.object(llm_adapter, '_call_deepseek', new_callable=AsyncMock) as mock_call:
|
||||
mock_call.return_value = mock_response
|
||||
|
||||
result = await llm_adapter.query_brand_citation(
|
||||
keyword="AI搜索",
|
||||
brand_name="ZZZ",
|
||||
brand_aliases=[]
|
||||
)
|
||||
result = await llm_adapter.query_brand_citation(
|
||||
keyword="AI搜索",
|
||||
brand_name="ZZZ",
|
||||
brand_aliases=[]
|
||||
)
|
||||
|
||||
assert result.sentiment == "negative"
|
||||
assert result.sentiment == "negative"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_adapter_api_error_retry(self, llm_adapter):
|
||||
|
|
@ -117,39 +125,41 @@ class TestLLMAdapter:
|
|||
"confidence": 0.90
|
||||
}
|
||||
|
||||
with patch.object(llm_adapter, '_call_deepseek', new_callable=AsyncMock) as mock_call:
|
||||
# 模拟前两次失败,第三次成功
|
||||
mock_call.side_effect = [
|
||||
Exception("API调用失败"),
|
||||
Exception("API调用失败"),
|
||||
mock_success_response
|
||||
]
|
||||
with patch("app.workers.llm_adapter.settings") as mock_settings:
|
||||
mock_settings.ENABLE_LLM = True
|
||||
with patch.object(llm_adapter, '_call_deepseek', new_callable=AsyncMock) as mock_call:
|
||||
mock_call.side_effect = [
|
||||
Exception("API调用失败"),
|
||||
Exception("API调用失败"),
|
||||
mock_success_response
|
||||
]
|
||||
|
||||
result = await llm_adapter.query_brand_citation(
|
||||
keyword="AI搜索",
|
||||
brand_name="测试品牌",
|
||||
brand_aliases=[]
|
||||
)
|
||||
|
||||
assert result.cited is True
|
||||
assert mock_call.call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_adapter_parse_error(self, llm_adapter):
|
||||
"""测试响应解析错误"""
|
||||
with patch.object(llm_adapter, '_call_deepseek', new_callable=AsyncMock) as mock_call:
|
||||
mock_call.return_value = {"invalid": "response"}
|
||||
|
||||
with pytest.raises(LLMAdapterError) as exc_info:
|
||||
await llm_adapter.query_brand_citation(
|
||||
result = await llm_adapter.query_brand_citation(
|
||||
keyword="AI搜索",
|
||||
brand_name="测试品牌",
|
||||
brand_aliases=[]
|
||||
)
|
||||
|
||||
# 错误消息应该包含字段缺失或解析失败相关提示
|
||||
error_msg = str(exc_info.value)
|
||||
assert "响应缺少必需字段" in error_msg or "解析响应失败" in error_msg
|
||||
assert result.cited is True
|
||||
assert mock_call.call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_adapter_parse_error(self, llm_adapter):
|
||||
"""测试响应解析错误"""
|
||||
with patch("app.workers.llm_adapter.settings") as mock_settings:
|
||||
mock_settings.ENABLE_LLM = True
|
||||
with patch.object(llm_adapter, '_call_deepseek', new_callable=AsyncMock) as mock_call:
|
||||
mock_call.return_value = {"invalid": "response"}
|
||||
|
||||
with pytest.raises(LLMAdapterError) as exc_info:
|
||||
await llm_adapter.query_brand_citation(
|
||||
keyword="AI搜索",
|
||||
brand_name="测试品牌",
|
||||
brand_aliases=[]
|
||||
)
|
||||
|
||||
error_msg = str(exc_info.value)
|
||||
assert "响应缺少必需字段" in error_msg or "解析响应失败" in error_msg
|
||||
|
||||
def test_build_prompt(self, llm_adapter):
|
||||
"""测试Prompt构建"""
|
||||
|
|
|
|||
|
|
@ -0,0 +1,121 @@
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock, patch, PropertyMock
|
||||
|
||||
from app.workers.llm_adapter import LLMAdapter, LLMAdapterError
|
||||
|
||||
|
||||
class TestLLMAdapterNoMock:
|
||||
"""验证LLMAdapter不再返回Mock数据,而是抛出明确错误"""
|
||||
|
||||
@pytest.fixture
|
||||
def adapter(self):
|
||||
return LLMAdapter()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enable_llm_false_raises_error(self, adapter):
|
||||
"""ENABLE_LLM=False时必须抛出LLMAdapterError,而非返回Mock数据"""
|
||||
with patch("app.workers.llm_adapter.settings") as mock_settings:
|
||||
mock_settings.ENABLE_LLM = False
|
||||
mock_settings.DEEPSEEK_API_KEY = "test-key"
|
||||
|
||||
with pytest.raises(LLMAdapterError) as exc_info:
|
||||
await adapter.query_brand_citation(
|
||||
keyword="AI搜索",
|
||||
brand_name="测试品牌",
|
||||
brand_aliases=["别名1"],
|
||||
)
|
||||
|
||||
error_msg = str(exc_info.value)
|
||||
assert "ENABLE_LLM" in error_msg
|
||||
assert "未启用" in error_msg
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enable_llm_true_no_api_key_raises_error(self, adapter):
|
||||
"""ENABLE_LLM=True但无API Key时必须抛出LLMAdapterError"""
|
||||
adapter.api_key = None
|
||||
|
||||
with patch("app.workers.llm_adapter.settings") as mock_settings:
|
||||
mock_settings.ENABLE_LLM = True
|
||||
|
||||
with pytest.raises(LLMAdapterError) as exc_info:
|
||||
await adapter.query_brand_citation(
|
||||
keyword="AI搜索",
|
||||
brand_name="测试品牌",
|
||||
brand_aliases=[],
|
||||
)
|
||||
|
||||
error_msg = str(exc_info.value)
|
||||
assert "API Key" in error_msg or "DEEPSEEK_API_KEY" in error_msg
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enable_llm_true_with_key_calls_api(self, adapter):
|
||||
"""ENABLE_LLM=True且有Key时正常调用API"""
|
||||
mock_response = {
|
||||
"cited": True,
|
||||
"position": 1,
|
||||
"citation_text": "测试引用",
|
||||
"sentiment": "positive",
|
||||
"confidence": 0.95,
|
||||
}
|
||||
|
||||
with patch("app.workers.llm_adapter.settings") as mock_settings:
|
||||
mock_settings.ENABLE_LLM = True
|
||||
mock_settings.OPENAI_API_KEY = None
|
||||
mock_settings.DEEPSEEK_API_KEY = "sk-test-key"
|
||||
|
||||
with patch.object(
|
||||
adapter, "_call_deepseek", new_callable=AsyncMock
|
||||
) as mock_call:
|
||||
mock_call.return_value = mock_response
|
||||
|
||||
result = await adapter.query_brand_citation(
|
||||
keyword="AI搜索",
|
||||
brand_name="测试品牌",
|
||||
brand_aliases=[],
|
||||
)
|
||||
|
||||
assert result.cited is True
|
||||
assert result.position == 1
|
||||
assert result.sentiment == "positive"
|
||||
|
||||
def test_get_mock_result_method_removed(self):
|
||||
"""_get_mock_result方法必须已被删除"""
|
||||
assert not hasattr(LLMAdapter, "_get_mock_result"), (
|
||||
"_get_mock_result方法仍然存在,必须删除"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_message_user_friendly(self, adapter):
|
||||
"""错误信息必须对用户友好,包含配置指引"""
|
||||
with patch("app.workers.llm_adapter.settings") as mock_settings:
|
||||
mock_settings.ENABLE_LLM = False
|
||||
mock_settings.DEEPSEEK_API_KEY = "test-key"
|
||||
|
||||
with pytest.raises(LLMAdapterError) as exc_info:
|
||||
await adapter.query_brand_citation(
|
||||
keyword="AI搜索",
|
||||
brand_name="测试品牌",
|
||||
brand_aliases=[],
|
||||
)
|
||||
|
||||
error_msg = str(exc_info.value)
|
||||
assert "ENABLE_LLM=True" in error_msg
|
||||
assert "DEEPSEEK_API_KEY" in error_msg
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_api_key_error_message_user_friendly(self, adapter):
|
||||
"""无API Key时错误信息必须包含配置指引"""
|
||||
adapter.api_key = None
|
||||
|
||||
with patch("app.workers.llm_adapter.settings") as mock_settings:
|
||||
mock_settings.ENABLE_LLM = True
|
||||
|
||||
with pytest.raises(LLMAdapterError) as exc_info:
|
||||
await adapter.query_brand_citation(
|
||||
keyword="AI搜索",
|
||||
brand_name="测试品牌",
|
||||
brand_aliases=[],
|
||||
)
|
||||
|
||||
error_msg = str(exc_info.value)
|
||||
assert "DEEPSEEK_API_KEY" in error_msg
|
||||
|
|
@ -31,7 +31,7 @@ import {
|
|||
Zap,
|
||||
} from "lucide-react";
|
||||
import { useApi, useApiMutation } from "@/lib/hooks/use-api";
|
||||
import { MOCK_AI_ENGINES_RESPONSE } from "@/lib/api/ai-engines";
|
||||
|
||||
import type {
|
||||
AIEngineType,
|
||||
AIQueryResult,
|
||||
|
|
@ -446,10 +446,12 @@ export default function AIEnginesPage() {
|
|||
if (result) {
|
||||
setQueryResults(result);
|
||||
} else {
|
||||
setQueryResults(MOCK_AI_ENGINES_RESPONSE);
|
||||
setQueryError("查询返回空结果,请检查API Key配置");
|
||||
setQueryResults(null);
|
||||
}
|
||||
} catch {
|
||||
setQueryResults(MOCK_AI_ENGINES_RESPONSE);
|
||||
} catch (err) {
|
||||
setQueryError(err instanceof Error ? err.message : "查询失败,请检查API Key配置");
|
||||
setQueryResults(null);
|
||||
}
|
||||
}, [selectedBrandId, queryText, selectedEngines, queryMutation]);
|
||||
|
||||
|
|
|
|||
|
|
@ -47,72 +47,6 @@ import {
|
|||
DropdownMenuTrigger,
|
||||
} from "@/components/ui/dropdown-menu";
|
||||
|
||||
const MOCK_ORG_INFO: OrganizationInfo = {
|
||||
id: "org-1",
|
||||
name: "GEO科技有限公司",
|
||||
member_count: 5,
|
||||
created_at: "2024-01-15T08:00:00Z",
|
||||
updated_at: "2024-01-15T08:00:00Z",
|
||||
};
|
||||
|
||||
const MOCK_MEMBERS: OrganizationMember[] = [
|
||||
{
|
||||
id: "member-1",
|
||||
user_id: "user-1",
|
||||
name: "张三",
|
||||
email: "zhangsan@example.com",
|
||||
role: "admin",
|
||||
status: "active",
|
||||
joined_at: "2024-01-15T08:00:00Z",
|
||||
created_at: "2024-01-15T08:00:00Z",
|
||||
updated_at: "2024-01-15T08:00:00Z",
|
||||
},
|
||||
{
|
||||
id: "member-2",
|
||||
user_id: "user-2",
|
||||
name: "李四",
|
||||
email: "lisi@example.com",
|
||||
role: "member",
|
||||
status: "active",
|
||||
joined_at: "2024-02-10T10:30:00Z",
|
||||
created_at: "2024-02-10T10:30:00Z",
|
||||
updated_at: "2024-02-10T10:30:00Z",
|
||||
},
|
||||
{
|
||||
id: "member-3",
|
||||
user_id: "user-3",
|
||||
name: "王五",
|
||||
email: "wangwu@example.com",
|
||||
role: "viewer",
|
||||
status: "active",
|
||||
joined_at: "2024-03-05T14:20:00Z",
|
||||
created_at: "2024-03-05T14:20:00Z",
|
||||
updated_at: "2024-03-05T14:20:00Z",
|
||||
},
|
||||
{
|
||||
id: "member-4",
|
||||
user_id: "user-4",
|
||||
name: "赵六",
|
||||
email: "zhaoliu@example.com",
|
||||
role: "member",
|
||||
status: "pending",
|
||||
joined_at: "2024-03-20T09:15:00Z",
|
||||
created_at: "2024-03-20T09:15:00Z",
|
||||
updated_at: "2024-03-20T09:15:00Z",
|
||||
},
|
||||
{
|
||||
id: "member-5",
|
||||
user_id: "user-5",
|
||||
name: "孙七",
|
||||
email: "sunqi@example.com",
|
||||
role: "viewer",
|
||||
status: "inactive",
|
||||
joined_at: "2024-01-20T16:45:00Z",
|
||||
created_at: "2024-01-20T16:45:00Z",
|
||||
updated_at: "2024-01-20T16:45:00Z",
|
||||
},
|
||||
];
|
||||
|
||||
const roleConfig: Record<MemberRole, { label: string; icon: React.ReactNode; color: string }> = {
|
||||
admin: {
|
||||
label: "管理员",
|
||||
|
|
@ -185,7 +119,7 @@ export default function ClientsPage() {
|
|||
} = useApi<OrganizationMember[]>("/api/v1/organization/members");
|
||||
|
||||
const filteredMembers = useMemo(() => {
|
||||
const memberList = members || MOCK_MEMBERS;
|
||||
const memberList = members || [];
|
||||
return memberList.filter((member) => {
|
||||
const matchesSearch =
|
||||
!searchQuery ||
|
||||
|
|
@ -196,7 +130,7 @@ export default function ClientsPage() {
|
|||
});
|
||||
}, [members, searchQuery, roleFilter]);
|
||||
|
||||
const safeOrgInfo = orgInfo || MOCK_ORG_INFO;
|
||||
const safeOrgInfo = orgInfo ?? null;
|
||||
const loading = orgLoading || membersLoading;
|
||||
|
||||
const handleInvite = async () => {
|
||||
|
|
@ -297,26 +231,30 @@ export default function ClientsPage() {
|
|||
</CardTitle>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<div className="grid gap-4 md:grid-cols-3">
|
||||
<div>
|
||||
<p className="text-sm text-gray-500">组织名称</p>
|
||||
<p className="mt-1 text-lg font-semibold text-gray-900">
|
||||
{safeOrgInfo.name}
|
||||
</p>
|
||||
{safeOrgInfo ? (
|
||||
<div className="grid gap-4 md:grid-cols-3">
|
||||
<div>
|
||||
<p className="text-sm text-gray-500">组织名称</p>
|
||||
<p className="mt-1 text-lg font-semibold text-gray-900">
|
||||
{safeOrgInfo.name}
|
||||
</p>
|
||||
</div>
|
||||
<div>
|
||||
<p className="text-sm text-gray-500">成员数量</p>
|
||||
<p className="mt-1 text-lg font-semibold text-gray-900">
|
||||
{safeOrgInfo.member_count} 人
|
||||
</p>
|
||||
</div>
|
||||
<div>
|
||||
<p className="text-sm text-gray-500">创建时间</p>
|
||||
<p className="mt-1 text-lg font-semibold text-gray-900">
|
||||
{formatDate(safeOrgInfo.created_at)}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<div>
|
||||
<p className="text-sm text-gray-500">成员数量</p>
|
||||
<p className="mt-1 text-lg font-semibold text-gray-900">
|
||||
{safeOrgInfo.member_count} 人
|
||||
</p>
|
||||
</div>
|
||||
<div>
|
||||
<p className="text-sm text-gray-500">创建时间</p>
|
||||
<p className="mt-1 text-lg font-semibold text-gray-900">
|
||||
{formatDate(safeOrgInfo.created_at)}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<p className="text-sm text-muted-foreground">组织信息加载中...</p>
|
||||
)}
|
||||
</CardContent>
|
||||
</Card>
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ import Link from "next/link";
|
|||
import {
|
||||
MetricCard,
|
||||
StageProgress,
|
||||
AgentStatusCard,
|
||||
} from "@/components/business";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
|
|
@ -34,30 +33,6 @@ const STAGE_CONFIG = [
|
|||
{ id: "monitoring", label: "监测优化" },
|
||||
];
|
||||
|
||||
const MOCK_AGENTS = [
|
||||
{
|
||||
name: "内容生成Agent",
|
||||
description: "自动化内容生产",
|
||||
status: "busy" as const,
|
||||
lastActiveAt: "2分钟前",
|
||||
completedCount: 156,
|
||||
},
|
||||
{
|
||||
name: "引用监测Agent",
|
||||
description: "AI平台引用追踪",
|
||||
status: "online" as const,
|
||||
lastActiveAt: "刚刚",
|
||||
completedCount: 3420,
|
||||
},
|
||||
{
|
||||
name: "SEO诊断Agent",
|
||||
description: "搜索引擎优化分析",
|
||||
status: "offline" as const,
|
||||
lastActiveAt: "3小时前",
|
||||
completedCount: 89,
|
||||
},
|
||||
];
|
||||
|
||||
function buildStages(currentStage: GeoProject["current_stage"]) {
|
||||
const currentIndex = STAGE_CONFIG.findIndex((s) => s.id === currentStage);
|
||||
return STAGE_CONFIG.map((stage, idx) => {
|
||||
|
|
@ -331,17 +306,10 @@ export default function DashboardPage() {
|
|||
查看全部
|
||||
</Link>
|
||||
</div>
|
||||
<div className="space-y-3">
|
||||
{MOCK_AGENTS.map((agent) => (
|
||||
<AgentStatusCard
|
||||
key={agent.name}
|
||||
name={agent.name}
|
||||
description={agent.description}
|
||||
status={agent.status}
|
||||
lastActiveAt={agent.lastActiveAt}
|
||||
completedCount={agent.completedCount}
|
||||
/>
|
||||
))}
|
||||
<div className="flex flex-col items-center justify-center py-8 text-center">
|
||||
<Zap className="h-8 w-8 text-muted-foreground mb-3" />
|
||||
<p className="text-sm font-medium text-muted-foreground">功能开发中</p>
|
||||
<p className="text-xs text-muted-foreground mt-1">Agent状态监控即将上线</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
|
|
|||
Loading…
Reference in New Issue