fix: 消除所有Mock/Stub/假数据,确保业务流程使用真实数据

M1-引用检测核心:
- 删除llm_adapter._get_mock_result()方法
- ENABLE_LLM=False时抛出LLMAdapterError而非返回随机数据
- ENABLE_LLM默认值改为True
- 修复旧测试适配新行为

M2-知识库RAG:
- knowledge.py不再默认使用MockEmbedder
- 动态从APIKeyManager获取OpenAI Key
- 无Key时返回503+明确错误信息
- 有Key时使用OpenAIEmbedder

M3-AI引擎页面:
- 删除MOCK_AI_ENGINES_RESPONSE fallback
- 查询失败时显示错误状态

M4-组织管理页面:
- 删除MOCK_ORG_INFO和MOCK_MEMBERS
- API返回空时显示空状态

M5-首页Agent卡片:
- 删除MOCK_AGENTS硬编码
- 替换为功能开发中占位

M6-平台规则历史:
- 实现PlatformRuleVersion模型
- 实现版本对比API (diff)
- 实现历史记录查询API (history)
- 删除2个TODO注释

M7-知识图谱批量构建:
- 实现批量创建实体API
- 空输入验证+批量大小限制
- 删除TODO注释
- 修复路由双重前缀问题

测试: 643 passed (核心)
This commit is contained in:
chiguyong 2026-05-25 21:51:48 +08:00
parent 4cc8f73bb4
commit fe4ba39514
17 changed files with 1084 additions and 253 deletions

View File

@ -34,16 +34,27 @@ from app.schemas.knowledge import (
SearchResultItem,
UpdateDocumentRequest,
)
from app.services.knowledge import MockEmbedder, RAGService
from app.services.knowledge import RAGService
from app.services.knowledge.embedder import OpenAIEmbedder
from app.services.knowledge.enhanced_rag import EnhancedRAG
from app.services.knowledge.incremental_index import IncrementalIndexService
from app.services.knowledge.chunker import ChunkerFactory
from app.services.api_key_manager import APIKeyManager
logger = logging.getLogger(__name__)
router = APIRouter()
# Shared RAG service instance (MockEmbedder by default; swap in OpenAIEmbedder via DI later)
_rag_service = RAGService(embedder=MockEmbedder())
_key_manager = APIKeyManager()
def _get_rag_service() -> RAGService:
api_key = _key_manager.get_key("chatgpt")
if api_key:
return RAGService(embedder=OpenAIEmbedder(api_key=api_key))
raise HTTPException(
status_code=503,
detail="知识库功能需要配置OpenAI API Key。请在设置页面添加OpenAI API Key。",
)
# ---------------------------------------------------------------------------
@ -280,8 +291,11 @@ async def upload_document(
# Asynchronously ingest (same request; background task optimization later)
try:
await _rag_service.ingest_document(db, str(doc.id))
rag_service = _get_rag_service()
await rag_service.ingest_document(db, str(doc.id))
await db.refresh(doc)
except HTTPException:
raise
except Exception as exc:
logger.error(f"Ingest failed for document {doc.id}: {exc}")
# Status already set to 'failed' by ingest_document on exception
@ -359,7 +373,7 @@ async def delete_document(
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Document not found")
# Delete chunks first (cascade also handles this, but explicit for clarity)
await _rag_service.delete_document_chunks(db, str(doc.id))
await _get_rag_service().delete_document_chunks(db, str(doc.id))
await db.delete(doc)
@ -466,7 +480,7 @@ async def knowledge_search(
t0 = time.monotonic()
raw_results = await _rag_service.search(
raw_results = await _get_rag_service().search(
db,
query=body.query,
knowledge_base_ids=body.knowledge_base_ids,
@ -559,7 +573,7 @@ async def reindex_document(
await _get_kb(db, kb_id, org_id)
index_service = IncrementalIndexService(_rag_service)
index_service = IncrementalIndexService(_get_rag_service())
result = await index_service.add_document(
db, str(kb_id), str(doc_id)
)
@ -581,7 +595,7 @@ async def update_document_content(
await _get_kb(db, kb_id, org_id)
index_service = IncrementalIndexService(_rag_service)
index_service = IncrementalIndexService(_get_rag_service())
result = await index_service.update_document(
db, str(doc_id), request.content
)
@ -602,7 +616,7 @@ async def delete_document_incremental(
await _get_kb(db, kb_id, org_id)
index_service = IncrementalIndexService(_rag_service)
index_service = IncrementalIndexService(_get_rag_service())
result = await index_service.delete_document(db, str(doc_id))
return result
@ -621,7 +635,7 @@ async def rebuild_knowledge_base(
await _get_kb(db, kb_id, org_id)
index_service = IncrementalIndexService(_rag_service)
index_service = IncrementalIndexService(_get_rag_service())
result = await index_service.rebuild_knowledge_base(
db, str(kb_id), force
)
@ -642,7 +656,8 @@ async def enhanced_retrieve(
await _get_kb(db, kb_id, org_id)
enhanced_rag = EnhancedRAG(_rag_service, _rag_service.embedder)
rag_service = _get_rag_service()
enhanced_rag = EnhancedRAG(rag_service, rag_service.embedder)
results = await enhanced_rag.retrieve_with_rerank(
db,
request.query,

View File

@ -3,9 +3,12 @@ from typing import Optional
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_db, get_current_user
from app.models.knowledge import KnowledgeBase
from app.models.knowledge_graph import KnowledgeEntity, EntityType
from app.models.user import User
from app.services.knowledge.graph_builder import GraphBuilder
from app.services.knowledge.graph_query import GraphQuery
@ -13,6 +16,61 @@ from app.services.knowledge.graph_query import GraphQuery
router = APIRouter(prefix="/knowledge-bases", tags=["知识图谱"])
class EntityCreateRequest(BaseModel):
name: str = Field(..., max_length=500)
entity_type: str
description: Optional[str] = None
properties: Optional[dict] = None
def _entity_to_dict(entity: KnowledgeEntity) -> dict:
return {
"id": str(entity.id),
"name": entity.name,
"entity_type": entity.entity_type.value,
"description": entity.description,
"properties": entity.properties,
"confidence": entity.confidence,
}
@router.post("/{kb_id}/entities/batch")
async def batch_create_entities(
kb_id: UUID,
entities: list[EntityCreateRequest],
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""批量创建知识图谱实体"""
if not entities:
raise HTTPException(status_code=400, detail="实体列表不能为空")
if len(entities) > 100:
raise HTTPException(status_code=400, detail="单次批量创建不能超过100个实体")
kb = await db.get(KnowledgeBase, kb_id)
if not kb:
raise HTTPException(status_code=404, detail="知识库不存在")
created = []
for entity_req in entities:
entity = KnowledgeEntity(
knowledge_base_id=kb_id,
name=entity_req.name,
entity_type=EntityType(entity_req.entity_type),
description=entity_req.description,
properties=entity_req.properties or {},
)
db.add(entity)
created.append(entity)
await db.commit()
for entity in created:
await db.refresh(entity)
return {"created_count": len(created), "entities": [_entity_to_dict(e) for e in created]}
@router.post("/{kb_id}/graph/build")
async def build_graph(
kb_id: UUID,
@ -24,8 +82,6 @@ async def build_graph(
对知识库中的所有Chunks执行实体和关系抽取
"""
# TODO: 实现批量构建
# 目前先实现单个Chunk的构建
return {"message": "Use /graph/build-chunk to build from specific chunk"}

View File

@ -2,12 +2,16 @@
import logging
from datetime import datetime
from typing import Optional
from typing import Any, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, status
from pydantic import ValidationError
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.database import get_db
from app.models.platform_rule_version import PlatformRuleVersion
from app.models.user import User
from app.services.distribution.platform_rules import (
PLATFORM_RULES,
@ -48,6 +52,46 @@ logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/platforms", tags=["平台规则管理"])
async def _get_rule_version(
db: AsyncSession, rule_id: str, version: int
) -> PlatformRuleVersion | None:
stmt = select(PlatformRuleVersion).where(
PlatformRuleVersion.rule_id == rule_id,
PlatformRuleVersion.version == version,
)
result = await db.execute(stmt)
return result.scalar_one_or_none()
def _compute_diff(
old_data: dict, new_data: dict, prefix: str = ""
) -> list[RuleDiff]:
diffs: list[RuleDiff] = []
all_keys = set(old_data.keys()) | set(new_data.keys())
for key in sorted(all_keys):
field = f"{prefix}{key}" if not prefix else f"{prefix}.{key}"
old_val = old_data.get(key)
new_val = new_data.get(key)
if isinstance(old_val, dict) and isinstance(new_val, dict):
diffs.extend(_compute_diff(old_val, new_val, field))
elif old_val != new_val:
diffs.append(RuleDiff(field=field, old_value=old_val, new_value=new_val))
return diffs
def _version_to_dict(v: PlatformRuleVersion) -> dict:
return {
"id": v.id,
"rule_id": v.rule_id,
"platform": v.platform,
"version": v.version,
"rule_data": v.rule_data,
"change_summary": v.change_summary,
"created_by": v.created_by,
"created_at": v.created_at.isoformat() if v.created_at else None,
}
def _convert_rule_to_schema(rules: dict) -> dict:
"""将规则字典转换为 Schema 格式"""
if not rules:
@ -179,13 +223,16 @@ async def update_platform_rules(
@router.get("/{platform_id}/rules/diff", response_model=RuleDiffResponse)
async def compare_rule_changes(
platform_id: str,
change_id: Optional[int] = Query(None, description="变更记录ID用于对比历史版本"),
from_version: int = Query(..., description="起始版本号"),
to_version: int = Query(..., description="目标版本号"),
db: AsyncSession = Depends(get_db),
):
"""对比规则变更
Args:
platform_id: 平台标识
change_id: 变更记录ID可选
from_version: 起始版本号
to_version: 目标版本号
"""
if platform_id not in PLATFORM_RULES:
raise HTTPException(
@ -195,13 +242,21 @@ async def compare_rule_changes(
current_rules = PLATFORM_RULES[platform_id]
# TODO: 从数据库获取历史版本进行对比
# 目前返回空差异
from_rule = await _get_rule_version(db, platform_id, from_version)
to_rule = await _get_rule_version(db, platform_id, to_version)
if not from_rule or not to_rule:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="版本不存在",
)
diffs = _compute_diff(from_rule.rule_data, to_rule.rule_data)
return RuleDiffResponse(
platform_id=platform_id,
platform_name=current_rules.get("name", ""),
diffs=[],
total_changes=0,
diffs=diffs,
total_changes=len(diffs),
)
@ -209,6 +264,7 @@ async def compare_rule_changes(
async def get_rule_history(
platform_id: str,
limit: int = Query(20, ge=1, le=100, description="返回记录数"),
db: AsyncSession = Depends(get_db),
):
"""获取规则变更历史
@ -222,11 +278,39 @@ async def get_rule_history(
detail=f"平台不存在: {platform_id}",
)
# TODO: 从数据库获取历史记录
# 目前返回空列表
count_stmt = select(func.count()).select_from(PlatformRuleVersion).where(
PlatformRuleVersion.rule_id == platform_id
)
total = (await db.execute(count_stmt)).scalar() or 0
stmt = (
select(PlatformRuleVersion)
.where(PlatformRuleVersion.rule_id == platform_id)
.order_by(PlatformRuleVersion.version.desc())
.limit(limit)
)
result = await db.execute(stmt)
versions = result.scalars().all()
history = [
RuleChangeHistory(
id=v.version,
version=v.version,
platform_id=v.rule_id,
platform_name=v.platform,
changed_by=v.created_by or "",
change_summary=v.change_summary or "",
change_type="update",
previous_rules=None,
new_rules=v.rule_data,
created_at=v.created_at,
)
for v in versions
]
return RuleChangeHistoryResponse(
history=[],
total=0,
history=history,
total=total,
)

View File

@ -23,7 +23,7 @@ class Settings(BaseSettings):
SECRET_KEY: Optional[str] = None
PLAYWRIGHT_BROWSERS_PATH: str = "/ms-playwright"
ENABLE_LLM: bool = False
ENABLE_LLM: bool = True
ZHIPU_API_KEY: str = ""
TONGYI_API_KEY: str = ""
CORS_ORIGINS: str = "http://localhost:3000,http://localhost:3001"

View File

@ -168,7 +168,7 @@ app.include_router(onboarding_router, prefix="/api/v1")
app.include_router(platforms_router, prefix="/api/v1")
app.include_router(platform_rules_router)
app.include_router(image_router, prefix="/api/v1")
app.include_router(knowledge_graph_router, prefix="/api/v1/knowledge-bases")
app.include_router(knowledge_graph_router, prefix="/api/v1")
app.include_router(ai_engines_router, prefix="/api/v1/ai-engines", tags=["AI引擎查询"])
app.include_router(detection_router, prefix="/api/v1/detection", tags=["定时检测任务"])
app.include_router(api_keys_router, prefix="/api/v1/api-keys", tags=["API Key管理"])

View File

@ -9,6 +9,7 @@ from app.models.lifecycle import LifecycleProject, ProjectStage
from app.models.agent import AgentRegistry, AgentConfig, AgentTask, AgentTaskLog
from app.models.content import Content, ContentVersion, ContentReview
from app.models.platform_rule import PlatformRule
from app.models.platform_rule_version import PlatformRuleVersion
from app.models.brand_knowledge import BrandKnowledge, Keyword
from app.models.knowledge import (
KnowledgeBase,
@ -52,6 +53,7 @@ __all__ = [
"ContentVersion",
"ContentReview",
"PlatformRule",
"PlatformRuleVersion",
"BrandKnowledge",
"Keyword",
"KnowledgeBase",

View File

@ -0,0 +1,30 @@
import uuid
from datetime import datetime
from sqlalchemy import String, Integer, Index, func
from sqlalchemy import Uuid
from sqlalchemy.orm import Mapped, mapped_column
from app.database import Base, JSONType
class PlatformRuleVersion(Base):
__tablename__ = "platform_rule_versions"
id: Mapped[uuid.UUID] = mapped_column(
Uuid(as_uuid=True),
primary_key=True,
default=uuid.uuid4,
)
rule_id: Mapped[str] = mapped_column(String(100), nullable=False, index=True)
platform: Mapped[str] = mapped_column(String(50), nullable=False)
version: Mapped[int] = mapped_column(Integer, nullable=False)
rule_data: Mapped[dict] = mapped_column(JSONType, nullable=False)
change_summary: Mapped[str | None] = mapped_column(String(500), nullable=True)
created_by: Mapped[str | None] = mapped_column(String(100), nullable=True)
created_at: Mapped[datetime] = mapped_column(server_default=func.now())
__table_args__ = (
Index("idx_rule_versions_rule_id", "rule_id"),
Index("idx_rule_versions_platform", "platform"),
)

View File

@ -1,7 +1,7 @@
"""平台规则管理 Schema - 定义规则管理的请求响应结构"""
from datetime import datetime
from typing import Optional
from typing import Any, Optional
from pydantic import BaseModel, Field
@ -225,6 +225,7 @@ class PlatformRuleUpdateResponse(BaseModel):
class RuleChangeHistory(BaseModel):
"""规则变更历史"""
id: int
version: int = 0
platform_id: str
platform_name: str
changed_by: str
@ -305,3 +306,7 @@ class DeAIContentResponse(BaseModel):
processed_word_count: int
detected_ai_patterns: list[str] = []
replaced_patterns: dict[str, str] = {}
RuleDiff.model_rebuild()
RuleDiffResponse.model_rebuild()

View File

@ -4,7 +4,6 @@ LLM适配器 - 使用DeepSeek LLM API检测品牌引用
import asyncio
import json
import logging
import random
import re
from typing import Optional
@ -104,8 +103,14 @@ class LLMAdapter:
LLMAdapterError: API调用或解析失败
"""
if not settings.ENABLE_LLM:
logger.info("LLM调用已禁用 (ENABLE_LLM=False),返回模拟数据")
return self._get_mock_result(keyword, brand_name, brand_aliases)
raise LLMAdapterError(
"LLM引用检测未启用。请在环境变量中设置 ENABLE_LLM=True 并配置 DEEPSEEK_API_KEY"
)
if not self.api_key:
raise LLMAdapterError(
"未配置DeepSeek API Key。请设置 DEEPSEEK_API_KEY 环境变量"
)
prompt = self._build_prompt(keyword, brand_name, brand_aliases)
@ -123,36 +128,6 @@ class LLMAdapter:
raise LLMAdapterError(f"LLM API调用失败已重试{self.max_retries}次: {last_error}")
def _get_mock_result(
self,
keyword: str,
brand_name: str,
brand_aliases: list[str]
) -> CitationResult:
"""
生成模拟结果当LLM禁用时使用
随机决定是否引用模拟真实场景的数据分布
"""
cited = random.random() < 0.6
sentiment_options = ["positive", "neutral", "negative"]
sentiment = random.choice(sentiment_options)
if cited:
position = random.randint(1, 10)
citation_text = f'模拟引用:在搜索"{keyword}"时,提到了{brand_name}品牌及其相关产品。'
else:
position = None
citation_text = ""
return CitationResult(
cited=cited,
position=position,
citation_text=citation_text,
sentiment=sentiment,
confidence=round(random.uniform(0.7, 0.99), 2)
)
async def _call_deepseek(self, prompt: str) -> dict:
"""
调用DeepSeek API

View File

@ -0,0 +1,278 @@
"""知识图谱批量构建API测试"""
import uuid
import pytest
import pytest_asyncio
from httpx import AsyncClient, ASGITransport
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession, create_async_engine
from sqlalchemy.pool import StaticPool
from app.database import Base
from app.main import app
from app.api.deps import get_db, get_current_user
from app.models.user import User
from app.models.knowledge import KnowledgeBase
from app.models.organization import Organization
@pytest_asyncio.fixture
async def async_engine():
engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest_asyncio.fixture
async def async_session(async_engine):
async_session_maker = async_sessionmaker(
async_engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
autocommit=False,
)
async with async_session_maker() as session:
yield session
@pytest_asyncio.fixture
async def test_user(async_session):
user = User(
id=uuid.uuid4(),
email="test@example.com",
password_hash="hashed_password",
name="Test User",
plan="free",
max_queries=5,
is_active=True,
email_verified=True,
)
async_session.add(user)
await async_session.commit()
await async_session.refresh(user)
return user
@pytest_asyncio.fixture
async def test_org(async_session):
org = Organization(
id=uuid.uuid4(),
name="Test Org",
slug="test-org",
)
async_session.add(org)
await async_session.commit()
await async_session.refresh(org)
return org
@pytest_asyncio.fixture
async def test_kb(async_session, test_org):
kb = KnowledgeBase(
id=uuid.uuid4(),
organization_id=test_org.id,
name="Test KB",
type="industry",
description="Test knowledge base",
)
async_session.add(kb)
await async_session.commit()
await async_session.refresh(kb)
return kb
@pytest_asyncio.fixture
async def async_client(async_session, test_user):
async def override_get_db():
yield async_session
async def override_get_current_user():
return test_user
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_get_current_user
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client
app.dependency_overrides.clear()
BATCH_URL = "/api/v1/knowledge-bases/{kb_id}/entities/batch"
class TestBatchCreateEntitiesEmptyInput:
"""空输入验证测试"""
@pytest.mark.asyncio
async def test_empty_entities_list_returns_400(self, async_client, test_kb):
response = await async_client.post(
BATCH_URL.format(kb_id=test_kb.id),
json=[],
)
assert response.status_code == 400
data = response.json()
assert "detail" in data
assert "不能为空" in data["detail"]
class TestBatchCreateEntitiesSizeLimit:
"""批量大小限制测试"""
@pytest.mark.asyncio
async def test_over_100_entities_returns_400(self, async_client, test_kb):
entities = [
{
"name": f"Entity {i}",
"entity_type": "CONCEPT",
"description": f"Test entity {i}",
}
for i in range(101)
]
response = await async_client.post(
BATCH_URL.format(kb_id=test_kb.id),
json=entities,
)
assert response.status_code == 400
data = response.json()
assert "detail" in data
assert "100" in data["detail"]
class TestBatchCreateEntitiesKBNotFound:
"""知识库不存在测试"""
@pytest.mark.asyncio
async def test_nonexistent_kb_returns_404(self, async_client):
fake_kb_id = str(uuid.uuid4())
entities = [
{
"name": "Entity 1",
"entity_type": "CONCEPT",
"description": "Test",
}
]
response = await async_client.post(
BATCH_URL.format(kb_id=fake_kb_id),
json=entities,
)
assert response.status_code == 404
data = response.json()
assert "detail" in data
assert "不存在" in data["detail"]
class TestBatchCreateEntitiesSuccess:
"""批量创建成功测试"""
@pytest.mark.asyncio
async def test_batch_create_entities_success(self, async_client, test_kb):
entities = [
{
"name": "公司A",
"entity_type": "ORGANIZATION",
"description": "测试公司A",
"properties": {"industry": "科技"},
},
{
"name": "产品B",
"entity_type": "PRODUCT",
"description": "测试产品B",
},
]
response = await async_client.post(
BATCH_URL.format(kb_id=test_kb.id),
json=entities,
)
assert response.status_code == 200
data = response.json()
assert data["created_count"] == 2
assert len(data["entities"]) == 2
for entity_data in data["entities"]:
assert "id" in entity_data
assert "name" in entity_data
assert "entity_type" in entity_data
@pytest.mark.asyncio
async def test_batch_create_single_entity(self, async_client, test_kb):
entities = [
{
"name": "单个实体",
"entity_type": "PERSON",
"description": "测试单个实体",
}
]
response = await async_client.post(
BATCH_URL.format(kb_id=test_kb.id),
json=entities,
)
assert response.status_code == 200
data = response.json()
assert data["created_count"] == 1
assert len(data["entities"]) == 1
assert data["entities"][0]["name"] == "单个实体"
assert data["entities"][0]["entity_type"] == "PERSON"
@pytest.mark.asyncio
async def test_batch_create_with_properties(self, async_client, test_kb):
entities = [
{
"name": "带属性实体",
"entity_type": "TECHNOLOGY",
"description": "测试带属性",
"properties": {"version": "1.0", "category": "AI"},
}
]
response = await async_client.post(
BATCH_URL.format(kb_id=test_kb.id),
json=entities,
)
assert response.status_code == 200
data = response.json()
assert data["entities"][0]["properties"]["version"] == "1.0"
assert data["entities"][0]["properties"]["category"] == "AI"
@pytest.mark.asyncio
async def test_batch_create_without_properties_defaults_to_empty(
self, async_client, test_kb
):
entities = [
{
"name": "无属性实体",
"entity_type": "BRAND",
"description": "测试无属性",
}
]
response = await async_client.post(
BATCH_URL.format(kb_id=test_kb.id),
json=entities,
)
assert response.status_code == 200
data = response.json()
assert data["entities"][0]["properties"] == {}
@pytest.mark.asyncio
async def test_batch_create_exactly_100_entities(self, async_client, test_kb):
entities = [
{
"name": f"Entity {i}",
"entity_type": "CONCEPT",
"description": f"Test entity {i}",
}
for i in range(100)
]
response = await async_client.post(
BATCH_URL.format(kb_id=test_kb.id),
json=entities,
)
assert response.status_code == 200
data = response.json()
assert data["created_count"] == 100

View File

@ -0,0 +1,95 @@
"""
测试Knowledge API不再默认使用MockEmbedder
- 无OpenAI Key时API返回503+明确错误信息
- 有OpenAI Key时使用OpenAIEmbedder
- MockEmbedder不再作为默认选择
"""
import pytest
from unittest.mock import patch, MagicMock
from app.services.knowledge.embedder import MockEmbedder, OpenAIEmbedder
from app.services.knowledge.rag_service import RAGService
from app.services.api_key_manager import APIKeyManager
class TestKnowledgeAPINoMockEmbedder:
"""验证knowledge.py不再默认使用MockEmbedder"""
def test_get_rag_service_raises_without_openai_key(self):
"""无OpenAI Key时_get_rag_service必须抛出HTTPException"""
from app.api.knowledge import _get_rag_service
from fastapi import HTTPException
key_manager = APIKeyManager()
with patch("app.api.knowledge._key_manager", key_manager):
with pytest.raises(HTTPException) as exc_info:
_get_rag_service()
assert exc_info.value.status_code == 503
assert "OpenAI API Key" in exc_info.value.detail
def test_get_rag_service_returns_openai_embedder_with_key(self):
"""有OpenAI Key时_get_rag_service必须返回使用OpenAIEmbedder的RAGService"""
from app.api.knowledge import _get_rag_service
key_manager = APIKeyManager()
key_manager.add_key("chatgpt", "sk-test-key-1234567890", source="system")
with patch("app.api.knowledge._key_manager", key_manager):
rag_service = _get_rag_service()
assert isinstance(rag_service, RAGService)
assert isinstance(rag_service.embedder, OpenAIEmbedder)
def test_get_rag_service_never_returns_mock_embedder(self):
"""_get_rag_service绝不能返回MockEmbedder"""
from app.api.knowledge import _get_rag_service
from fastapi import HTTPException
key_manager = APIKeyManager()
with patch("app.api.knowledge._key_manager", key_manager):
with pytest.raises(HTTPException):
_get_rag_service()
key_manager_with_key = APIKeyManager()
key_manager_with_key.add_key("chatgpt", "sk-test-key-1234567890", source="system")
with patch("app.api.knowledge._key_manager", key_manager_with_key):
rag_service = _get_rag_service()
assert not isinstance(rag_service.embedder, MockEmbedder)
def test_no_module_level_mock_rag_service(self):
"""模块级别不再存在使用MockEmbedder的_rag_service变量"""
import app.api.knowledge as knowledge_module
assert not hasattr(knowledge_module, "_rag_service"), (
"_rag_service模块级变量仍然存在必须删除"
)
def test_error_message_contains_configuration_guidance(self):
"""503错误信息必须包含配置指引"""
from app.api.knowledge import _get_rag_service
from fastapi import HTTPException
key_manager = APIKeyManager()
with patch("app.api.knowledge._key_manager", key_manager):
with pytest.raises(HTTPException) as exc_info:
_get_rag_service()
detail = exc_info.value.detail
assert "OpenAI API Key" in detail
assert "设置" in detail or "配置" in detail
def test_mock_embedder_class_still_exists(self):
"""MockEmbedder类必须保留仅用于测试"""
assert MockEmbedder is not None
embedder = MockEmbedder()
assert isinstance(embedder, MockEmbedder)
def test_get_rag_service_uses_api_key_manager(self):
"""_get_rag_service必须使用APIKeyManager获取Key"""
from app.api.knowledge import _get_rag_service
mock_km = MagicMock(spec=APIKeyManager)
mock_km.get_key.return_value = "sk-test-key-1234567890"
with patch("app.api.knowledge._key_manager", mock_km):
rag_service = _get_rag_service()
mock_km.get_key.assert_called_once_with("chatgpt")
assert isinstance(rag_service.embedder, OpenAIEmbedder)

View File

@ -0,0 +1,252 @@
"""平台规则历史版本API测试"""
import uuid
import pytest
import pytest_asyncio
from httpx import AsyncClient, ASGITransport
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession, create_async_engine
from sqlalchemy.pool import StaticPool
from app.database import Base
from app.main import app
from app.models.platform_rule_version import PlatformRuleVersion
from app.api.deps import get_db
@pytest_asyncio.fixture
async def async_engine():
engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest_asyncio.fixture
async def async_session(async_engine):
async_session_maker = async_sessionmaker(
async_engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
autocommit=False,
)
async with async_session_maker() as session:
yield session
@pytest_asyncio.fixture
async def async_client(async_session):
async def override_get_db():
yield async_session
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client
app.dependency_overrides.clear()
@pytest_asyncio.fixture
async def seed_versions(async_session):
v1 = PlatformRuleVersion(
id=uuid.uuid4(),
rule_id="zhihu",
platform="zhihu",
version=1,
rule_data={
"content_length": {"min": 800, "max": 3000, "recommended": 1500},
"title_rules": {"min_length": 5, "max_length": 50},
},
change_summary="初始版本",
created_by="admin",
)
v2 = PlatformRuleVersion(
id=uuid.uuid4(),
rule_id="zhihu",
platform="zhihu",
version=2,
rule_data={
"content_length": {"min": 1000, "max": 5000, "recommended": 2000},
"title_rules": {"min_length": 5, "max_length": 50},
},
change_summary="调整内容长度规则",
created_by="admin",
)
v3 = PlatformRuleVersion(
id=uuid.uuid4(),
rule_id="zhihu",
platform="zhihu",
version=3,
rule_data={
"content_length": {"min": 1000, "max": 5000, "recommended": 2000},
"title_rules": {"min_length": 8, "max_length": 60},
},
change_summary="调整标题长度规则",
created_by="editor",
)
async_session.add_all([v1, v2, v3])
await async_session.commit()
return [v1, v2, v3]
class TestRuleVersionDiff:
"""历史版本对比API测试"""
@pytest.mark.asyncio
async def test_diff_returns_differences_between_versions(
self, async_client, seed_versions
):
response = await async_client.get(
"/api/v1/platforms/zhihu/rules/diff",
params={"from_version": 1, "to_version": 2},
)
assert response.status_code == 200
data = response.json()
assert data["platform_id"] == "zhihu"
assert isinstance(data["diffs"], list)
assert len(data["diffs"]) > 0
assert data["total_changes"] > 0
diff_fields = {d["field"] for d in data["diffs"]}
assert "content_length.min" in diff_fields
assert "content_length.max" in diff_fields
assert "content_length.recommended" in diff_fields
@pytest.mark.asyncio
async def test_diff_no_changes_returns_empty(self, async_client, seed_versions):
response = await async_client.get(
"/api/v1/platforms/zhihu/rules/diff",
params={"from_version": 2, "to_version": 2},
)
assert response.status_code == 200
data = response.json()
assert data["diffs"] == []
assert data["total_changes"] == 0
@pytest.mark.asyncio
async def test_diff_nonexistent_version_returns_404(
self, async_client, seed_versions
):
response = await async_client.get(
"/api/v1/platforms/zhihu/rules/diff",
params={"from_version": 1, "to_version": 99},
)
assert response.status_code == 404
@pytest.mark.asyncio
async def test_diff_nonexistent_platform_returns_404(self, async_client):
response = await async_client.get(
"/api/v1/platforms/nonexistent_platform/rules/diff",
params={"from_version": 1, "to_version": 2},
)
assert response.status_code == 404
@pytest.mark.asyncio
async def test_diff_missing_params_returns_error(self, async_client, seed_versions):
response = await async_client.get(
"/api/v1/platforms/zhihu/rules/diff",
)
assert response.status_code == 422
@pytest.mark.asyncio
async def test_diff_nested_field_change(self, async_client, seed_versions):
response = await async_client.get(
"/api/v1/platforms/zhihu/rules/diff",
params={"from_version": 2, "to_version": 3},
)
assert response.status_code == 200
data = response.json()
diff_fields = {d["field"] for d in data["diffs"]}
assert "title_rules.min_length" in diff_fields
assert "title_rules.max_length" in diff_fields
class TestRuleHistory:
"""历史记录查询API测试"""
@pytest.mark.asyncio
async def test_history_returns_versions(
self, async_client, seed_versions
):
response = await async_client.get(
"/api/v1/platforms/zhihu/rules/history",
)
assert response.status_code == 200
data = response.json()
assert data["total"] == 3
assert len(data["history"]) == 3
@pytest.mark.asyncio
async def test_history_ordered_by_version_desc(
self, async_client, seed_versions
):
response = await async_client.get(
"/api/v1/platforms/zhihu/rules/history",
)
assert response.status_code == 200
data = response.json()
versions = [h["version"] for h in data["history"]]
assert versions == sorted(versions, reverse=True)
@pytest.mark.asyncio
async def test_history_respects_limit(self, async_client, seed_versions):
response = await async_client.get(
"/api/v1/platforms/zhihu/rules/history",
params={"limit": 2},
)
assert response.status_code == 200
data = response.json()
assert len(data["history"]) == 2
assert data["total"] == 3
@pytest.mark.asyncio
async def test_history_empty_when_no_versions(self, async_client):
response = await async_client.get(
"/api/v1/platforms/zhihu/rules/history",
)
assert response.status_code == 200
data = response.json()
assert data["total"] == 0
assert data["history"] == []
@pytest.mark.asyncio
async def test_history_nonexistent_platform_returns_404(self, async_client):
response = await async_client.get(
"/api/v1/platforms/nonexistent_platform/rules/history",
)
assert response.status_code == 404
@pytest.mark.asyncio
async def test_history_version_has_required_fields(
self, async_client, seed_versions
):
response = await async_client.get(
"/api/v1/platforms/zhihu/rules/history",
)
assert response.status_code == 200
data = response.json()
first = data["history"][0]
assert "version" in first
assert "new_rules" in first
assert "change_summary" in first
assert "changed_by" in first
assert "created_at" in first

View File

@ -22,6 +22,8 @@ class TestLLMAdapter:
"confidence": 0.95
}
with patch("app.workers.llm_adapter.settings") as mock_settings:
mock_settings.ENABLE_LLM = True
with patch.object(llm_adapter, '_call_deepseek', new_callable=AsyncMock) as mock_call:
mock_call.return_value = mock_response
@ -48,6 +50,8 @@ class TestLLMAdapter:
"confidence": 0.90
}
with patch("app.workers.llm_adapter.settings") as mock_settings:
mock_settings.ENABLE_LLM = True
with patch.object(llm_adapter, '_call_deepseek', new_callable=AsyncMock) as mock_call:
mock_call.return_value = mock_response
@ -73,6 +77,8 @@ class TestLLMAdapter:
"confidence": 0.92
}
with patch("app.workers.llm_adapter.settings") as mock_settings:
mock_settings.ENABLE_LLM = True
with patch.object(llm_adapter, '_call_deepseek', new_callable=AsyncMock) as mock_call:
mock_call.return_value = mock_response
@ -95,6 +101,8 @@ class TestLLMAdapter:
"confidence": 0.88
}
with patch("app.workers.llm_adapter.settings") as mock_settings:
mock_settings.ENABLE_LLM = True
with patch.object(llm_adapter, '_call_deepseek', new_callable=AsyncMock) as mock_call:
mock_call.return_value = mock_response
@ -117,8 +125,9 @@ class TestLLMAdapter:
"confidence": 0.90
}
with patch("app.workers.llm_adapter.settings") as mock_settings:
mock_settings.ENABLE_LLM = True
with patch.object(llm_adapter, '_call_deepseek', new_callable=AsyncMock) as mock_call:
# 模拟前两次失败,第三次成功
mock_call.side_effect = [
Exception("API调用失败"),
Exception("API调用失败"),
@ -137,6 +146,8 @@ class TestLLMAdapter:
@pytest.mark.asyncio
async def test_llm_adapter_parse_error(self, llm_adapter):
"""测试响应解析错误"""
with patch("app.workers.llm_adapter.settings") as mock_settings:
mock_settings.ENABLE_LLM = True
with patch.object(llm_adapter, '_call_deepseek', new_callable=AsyncMock) as mock_call:
mock_call.return_value = {"invalid": "response"}
@ -147,7 +158,6 @@ class TestLLMAdapter:
brand_aliases=[]
)
# 错误消息应该包含字段缺失或解析失败相关提示
error_msg = str(exc_info.value)
assert "响应缺少必需字段" in error_msg or "解析响应失败" in error_msg

View File

@ -0,0 +1,121 @@
import pytest
from unittest.mock import AsyncMock, patch, PropertyMock
from app.workers.llm_adapter import LLMAdapter, LLMAdapterError
class TestLLMAdapterNoMock:
"""验证LLMAdapter不再返回Mock数据而是抛出明确错误"""
@pytest.fixture
def adapter(self):
return LLMAdapter()
@pytest.mark.asyncio
async def test_enable_llm_false_raises_error(self, adapter):
"""ENABLE_LLM=False时必须抛出LLMAdapterError而非返回Mock数据"""
with patch("app.workers.llm_adapter.settings") as mock_settings:
mock_settings.ENABLE_LLM = False
mock_settings.DEEPSEEK_API_KEY = "test-key"
with pytest.raises(LLMAdapterError) as exc_info:
await adapter.query_brand_citation(
keyword="AI搜索",
brand_name="测试品牌",
brand_aliases=["别名1"],
)
error_msg = str(exc_info.value)
assert "ENABLE_LLM" in error_msg
assert "未启用" in error_msg
@pytest.mark.asyncio
async def test_enable_llm_true_no_api_key_raises_error(self, adapter):
"""ENABLE_LLM=True但无API Key时必须抛出LLMAdapterError"""
adapter.api_key = None
with patch("app.workers.llm_adapter.settings") as mock_settings:
mock_settings.ENABLE_LLM = True
with pytest.raises(LLMAdapterError) as exc_info:
await adapter.query_brand_citation(
keyword="AI搜索",
brand_name="测试品牌",
brand_aliases=[],
)
error_msg = str(exc_info.value)
assert "API Key" in error_msg or "DEEPSEEK_API_KEY" in error_msg
@pytest.mark.asyncio
async def test_enable_llm_true_with_key_calls_api(self, adapter):
"""ENABLE_LLM=True且有Key时正常调用API"""
mock_response = {
"cited": True,
"position": 1,
"citation_text": "测试引用",
"sentiment": "positive",
"confidence": 0.95,
}
with patch("app.workers.llm_adapter.settings") as mock_settings:
mock_settings.ENABLE_LLM = True
mock_settings.OPENAI_API_KEY = None
mock_settings.DEEPSEEK_API_KEY = "sk-test-key"
with patch.object(
adapter, "_call_deepseek", new_callable=AsyncMock
) as mock_call:
mock_call.return_value = mock_response
result = await adapter.query_brand_citation(
keyword="AI搜索",
brand_name="测试品牌",
brand_aliases=[],
)
assert result.cited is True
assert result.position == 1
assert result.sentiment == "positive"
def test_get_mock_result_method_removed(self):
"""_get_mock_result方法必须已被删除"""
assert not hasattr(LLMAdapter, "_get_mock_result"), (
"_get_mock_result方法仍然存在必须删除"
)
@pytest.mark.asyncio
async def test_error_message_user_friendly(self, adapter):
"""错误信息必须对用户友好,包含配置指引"""
with patch("app.workers.llm_adapter.settings") as mock_settings:
mock_settings.ENABLE_LLM = False
mock_settings.DEEPSEEK_API_KEY = "test-key"
with pytest.raises(LLMAdapterError) as exc_info:
await adapter.query_brand_citation(
keyword="AI搜索",
brand_name="测试品牌",
brand_aliases=[],
)
error_msg = str(exc_info.value)
assert "ENABLE_LLM=True" in error_msg
assert "DEEPSEEK_API_KEY" in error_msg
@pytest.mark.asyncio
async def test_no_api_key_error_message_user_friendly(self, adapter):
"""无API Key时错误信息必须包含配置指引"""
adapter.api_key = None
with patch("app.workers.llm_adapter.settings") as mock_settings:
mock_settings.ENABLE_LLM = True
with pytest.raises(LLMAdapterError) as exc_info:
await adapter.query_brand_citation(
keyword="AI搜索",
brand_name="测试品牌",
brand_aliases=[],
)
error_msg = str(exc_info.value)
assert "DEEPSEEK_API_KEY" in error_msg

View File

@ -31,7 +31,7 @@ import {
Zap,
} from "lucide-react";
import { useApi, useApiMutation } from "@/lib/hooks/use-api";
import { MOCK_AI_ENGINES_RESPONSE } from "@/lib/api/ai-engines";
import type {
AIEngineType,
AIQueryResult,
@ -446,10 +446,12 @@ export default function AIEnginesPage() {
if (result) {
setQueryResults(result);
} else {
setQueryResults(MOCK_AI_ENGINES_RESPONSE);
setQueryError("查询返回空结果请检查API Key配置");
setQueryResults(null);
}
} catch {
setQueryResults(MOCK_AI_ENGINES_RESPONSE);
} catch (err) {
setQueryError(err instanceof Error ? err.message : "查询失败请检查API Key配置");
setQueryResults(null);
}
}, [selectedBrandId, queryText, selectedEngines, queryMutation]);

View File

@ -47,72 +47,6 @@ import {
DropdownMenuTrigger,
} from "@/components/ui/dropdown-menu";
const MOCK_ORG_INFO: OrganizationInfo = {
id: "org-1",
name: "GEO科技有限公司",
member_count: 5,
created_at: "2024-01-15T08:00:00Z",
updated_at: "2024-01-15T08:00:00Z",
};
const MOCK_MEMBERS: OrganizationMember[] = [
{
id: "member-1",
user_id: "user-1",
name: "张三",
email: "zhangsan@example.com",
role: "admin",
status: "active",
joined_at: "2024-01-15T08:00:00Z",
created_at: "2024-01-15T08:00:00Z",
updated_at: "2024-01-15T08:00:00Z",
},
{
id: "member-2",
user_id: "user-2",
name: "李四",
email: "lisi@example.com",
role: "member",
status: "active",
joined_at: "2024-02-10T10:30:00Z",
created_at: "2024-02-10T10:30:00Z",
updated_at: "2024-02-10T10:30:00Z",
},
{
id: "member-3",
user_id: "user-3",
name: "王五",
email: "wangwu@example.com",
role: "viewer",
status: "active",
joined_at: "2024-03-05T14:20:00Z",
created_at: "2024-03-05T14:20:00Z",
updated_at: "2024-03-05T14:20:00Z",
},
{
id: "member-4",
user_id: "user-4",
name: "赵六",
email: "zhaoliu@example.com",
role: "member",
status: "pending",
joined_at: "2024-03-20T09:15:00Z",
created_at: "2024-03-20T09:15:00Z",
updated_at: "2024-03-20T09:15:00Z",
},
{
id: "member-5",
user_id: "user-5",
name: "孙七",
email: "sunqi@example.com",
role: "viewer",
status: "inactive",
joined_at: "2024-01-20T16:45:00Z",
created_at: "2024-01-20T16:45:00Z",
updated_at: "2024-01-20T16:45:00Z",
},
];
const roleConfig: Record<MemberRole, { label: string; icon: React.ReactNode; color: string }> = {
admin: {
label: "管理员",
@ -185,7 +119,7 @@ export default function ClientsPage() {
} = useApi<OrganizationMember[]>("/api/v1/organization/members");
const filteredMembers = useMemo(() => {
const memberList = members || MOCK_MEMBERS;
const memberList = members || [];
return memberList.filter((member) => {
const matchesSearch =
!searchQuery ||
@ -196,7 +130,7 @@ export default function ClientsPage() {
});
}, [members, searchQuery, roleFilter]);
const safeOrgInfo = orgInfo || MOCK_ORG_INFO;
const safeOrgInfo = orgInfo ?? null;
const loading = orgLoading || membersLoading;
const handleInvite = async () => {
@ -297,6 +231,7 @@ export default function ClientsPage() {
</CardTitle>
</CardHeader>
<CardContent>
{safeOrgInfo ? (
<div className="grid gap-4 md:grid-cols-3">
<div>
<p className="text-sm text-gray-500"></p>
@ -317,6 +252,9 @@ export default function ClientsPage() {
</p>
</div>
</div>
) : (
<p className="text-sm text-muted-foreground">...</p>
)}
</CardContent>
</Card>

View File

@ -5,7 +5,6 @@ import Link from "next/link";
import {
MetricCard,
StageProgress,
AgentStatusCard,
} from "@/components/business";
import { Button } from "@/components/ui/button";
import { Badge } from "@/components/ui/badge";
@ -34,30 +33,6 @@ const STAGE_CONFIG = [
{ id: "monitoring", label: "监测优化" },
];
const MOCK_AGENTS = [
{
name: "内容生成Agent",
description: "自动化内容生产",
status: "busy" as const,
lastActiveAt: "2分钟前",
completedCount: 156,
},
{
name: "引用监测Agent",
description: "AI平台引用追踪",
status: "online" as const,
lastActiveAt: "刚刚",
completedCount: 3420,
},
{
name: "SEO诊断Agent",
description: "搜索引擎优化分析",
status: "offline" as const,
lastActiveAt: "3小时前",
completedCount: 89,
},
];
function buildStages(currentStage: GeoProject["current_stage"]) {
const currentIndex = STAGE_CONFIG.findIndex((s) => s.id === currentStage);
return STAGE_CONFIG.map((stage, idx) => {
@ -331,17 +306,10 @@ export default function DashboardPage() {
</Link>
</div>
<div className="space-y-3">
{MOCK_AGENTS.map((agent) => (
<AgentStatusCard
key={agent.name}
name={agent.name}
description={agent.description}
status={agent.status}
lastActiveAt={agent.lastActiveAt}
completedCount={agent.completedCount}
/>
))}
<div className="flex flex-col items-center justify-center py-8 text-center">
<Zap className="h-8 w-8 text-muted-foreground mb-3" />
<p className="text-sm font-medium text-muted-foreground"></p>
<p className="text-xs text-muted-foreground mt-1">Agent状态监控即将上线</p>
</div>
</div>
</div>