460 lines
15 KiB
Python
460 lines
15 KiB
Python
"""内容API测试"""
|
|
import uuid
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
from httpx import AsyncClient, ASGITransport
|
|
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession, create_async_engine
|
|
from sqlalchemy.pool import StaticPool
|
|
|
|
from app.database import Base
|
|
from app.main import app
|
|
from app.models.user import User
|
|
from app.models.brand import Brand
|
|
from app.api.deps import get_current_user, get_db
|
|
from app.services.auth import hash_password, create_access_token
|
|
|
|
|
|
# ==================== Fixtures ====================
|
|
|
|
@pytest_asyncio.fixture
|
|
async def async_engine():
|
|
"""创建测试用SQLite异步引擎"""
|
|
engine = create_async_engine(
|
|
"sqlite+aiosqlite:///:memory:",
|
|
connect_args={"check_same_thread": False},
|
|
poolclass=StaticPool,
|
|
)
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
yield engine
|
|
await engine.dispose()
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def async_session(async_engine):
|
|
"""创建测试用异步数据库会话"""
|
|
async_session_maker = async_sessionmaker(
|
|
async_engine,
|
|
class_=AsyncSession,
|
|
expire_on_commit=False,
|
|
autoflush=False,
|
|
autocommit=False,
|
|
)
|
|
async with async_session_maker() as session:
|
|
yield session
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def test_user(async_session):
|
|
"""创建测试用户"""
|
|
user = User(
|
|
id=uuid.uuid4(),
|
|
email="test@example.com",
|
|
password_hash=hash_password("Test@123456"),
|
|
name="Test User",
|
|
plan="free",
|
|
max_queries=5,
|
|
is_active=True,
|
|
email_verified=True,
|
|
organization_id=uuid.uuid4(), # 需要organization_id用于内容API
|
|
)
|
|
async_session.add(user)
|
|
await async_session.commit()
|
|
await async_session.refresh(user)
|
|
return user
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def test_brand(async_session, test_user):
|
|
"""创建测试品牌"""
|
|
brand = Brand(
|
|
id=uuid.uuid4(),
|
|
user_id=test_user.id,
|
|
name="Test Brand",
|
|
aliases=["TestBrand", "TB"],
|
|
website="https://testbrand.com",
|
|
industry="technology",
|
|
platforms=["wenxin", "kimi"],
|
|
frequency="weekly",
|
|
status="active",
|
|
)
|
|
async_session.add(brand)
|
|
await async_session.commit()
|
|
await async_session.refresh(brand)
|
|
return brand
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def async_client(async_session, test_user):
|
|
"""创建异步HTTP客户端用于API测试"""
|
|
|
|
async def override_get_db():
|
|
yield async_session
|
|
|
|
async def override_get_current_user():
|
|
return test_user
|
|
|
|
app.dependency_overrides[get_db] = override_get_db
|
|
app.dependency_overrides[get_current_user] = override_get_current_user
|
|
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
yield client
|
|
|
|
app.dependency_overrides.clear()
|
|
|
|
|
|
@pytest.fixture
|
|
def auth_headers(test_user):
|
|
"""创建认证请求头"""
|
|
token = create_access_token(data={"sub": str(test_user.id)})
|
|
return {"Authorization": f"Bearer {token}"}
|
|
|
|
|
|
# ==================== 测试类 ====================
|
|
|
|
class TestContentAPI:
|
|
"""内容管理API测试"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_contents_empty(self, async_client):
|
|
"""测试获取空内容列表"""
|
|
response = await async_client.get("/api/v1/contents/")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert isinstance(data, list)
|
|
assert len(data) == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_content(self, async_client, test_user):
|
|
"""测试创建内容"""
|
|
content_data = {
|
|
"title": "测试文章标题",
|
|
"body": "这是测试文章的内容",
|
|
"content_type": "article",
|
|
"tags": ["测试", "文章"],
|
|
}
|
|
response = await async_client.post("/api/v1/contents/", json=content_data)
|
|
|
|
assert response.status_code == 201
|
|
data = response.json()
|
|
assert data["title"] == "测试文章标题"
|
|
assert data["body"] == "这是测试文章的内容"
|
|
assert data["content_type"] == "article"
|
|
assert data["status"] == "draft"
|
|
assert "id" in data
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_content_by_id(self, async_client, test_user):
|
|
"""测试通过ID获取内容"""
|
|
# 先创建内容
|
|
create_response = await async_client.post(
|
|
"/api/v1/contents/",
|
|
json={
|
|
"title": "测试内容",
|
|
"body": "测试内容正文",
|
|
"content_type": "article",
|
|
}
|
|
)
|
|
content_id = create_response.json()["id"]
|
|
|
|
# 获取内容
|
|
response = await async_client.get(f"/api/v1/contents/{content_id}")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["id"] == content_id
|
|
assert data["title"] == "测试内容"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_content_not_found(self, async_client):
|
|
"""测试获取不存在的内容"""
|
|
non_existent_id = uuid.uuid4()
|
|
response = await async_client.get(f"/api/v1/contents/{non_existent_id}")
|
|
|
|
assert response.status_code == 404
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_content(self, async_client, test_user):
|
|
"""测试更新内容"""
|
|
# 先创建内容
|
|
create_response = await async_client.post(
|
|
"/api/v1/contents/",
|
|
json={
|
|
"title": "原始标题",
|
|
"body": "原始内容",
|
|
"content_type": "article",
|
|
}
|
|
)
|
|
content_id = create_response.json()["id"]
|
|
|
|
# 更新内容
|
|
update_data = {
|
|
"title": "更新后的标题",
|
|
"body": "更新后的内容",
|
|
"status": "published",
|
|
}
|
|
response = await async_client.put(
|
|
f"/api/v1/contents/{content_id}", json=update_data
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["title"] == "更新后的标题"
|
|
assert data["body"] == "更新后的内容"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_content_partial(self, async_client, test_user):
|
|
"""测试部分更新内容"""
|
|
# 先创建内容
|
|
create_response = await async_client.post(
|
|
"/api/v1/contents/",
|
|
json={
|
|
"title": "原始标题",
|
|
"body": "原始内容",
|
|
}
|
|
)
|
|
content_id = create_response.json()["id"]
|
|
|
|
# 只更新标题
|
|
response = await async_client.put(
|
|
f"/api/v1/contents/{content_id}",
|
|
json={"title": "新标题"}
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["title"] == "新标题"
|
|
assert data["body"] == "原始内容" # 保持不变
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete_content(self, async_client, test_user):
|
|
"""测试删除内容"""
|
|
# 先创建内容
|
|
create_response = await async_client.post(
|
|
"/api/v1/contents/",
|
|
json={
|
|
"title": "待删除内容",
|
|
"body": "内容正文",
|
|
}
|
|
)
|
|
content_id = create_response.json()["id"]
|
|
|
|
# 删除内容
|
|
response = await async_client.delete(f"/api/v1/contents/{content_id}")
|
|
assert response.status_code == 204
|
|
|
|
# 验证已删除
|
|
get_response = await async_client.get(f"/api/v1/contents/{content_id}")
|
|
assert get_response.status_code == 404
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_publish_content(self, async_client, test_user):
|
|
"""测试发布内容"""
|
|
# 先创建内容
|
|
create_response = await async_client.post(
|
|
"/api/v1/contents/",
|
|
json={
|
|
"title": "待发布内容",
|
|
"body": "内容正文",
|
|
}
|
|
)
|
|
content_id = create_response.json()["id"]
|
|
|
|
# 发布内容
|
|
response = await async_client.post(f"/api/v1/contents/{content_id}/publish")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["status"] == "published"
|
|
assert data["published_at"] is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_contents_with_filter(self, async_client, test_user):
|
|
"""测试内容列表过滤"""
|
|
# 创建多个不同类型的内容
|
|
for i, content_type in enumerate(["article", "post", "article"]):
|
|
await async_client.post(
|
|
"/api/v1/contents/",
|
|
json={
|
|
"title": f"内容 {i}",
|
|
"body": "正文",
|
|
"content_type": content_type,
|
|
}
|
|
)
|
|
|
|
# 按类型过滤
|
|
response = await async_client.get("/api/v1/contents/?content_type=article")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert isinstance(data, list)
|
|
for item in data:
|
|
assert item["content_type"] == "article"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_contents_with_status_filter(self, async_client, test_user):
|
|
"""测试内容列表按状态过滤"""
|
|
# 创建一个草稿和一个已发布
|
|
draft_response = await async_client.post(
|
|
"/api/v1/contents/",
|
|
json={"title": "草稿", "body": "正文", "content_type": "article"}
|
|
)
|
|
|
|
published_response = await async_client.post(
|
|
"/api/v1/contents/",
|
|
json={"title": "已发布", "body": "正文", "content_type": "article"}
|
|
)
|
|
await async_client.post(f"/api/v1/contents/{published_response.json()['id']}/publish")
|
|
|
|
# 只获取草稿
|
|
response = await async_client.get("/api/v1/contents/?status=draft")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert isinstance(data, list)
|
|
for item in data:
|
|
assert item["status"] == "draft"
|
|
|
|
|
|
class TestContentGenerationAPI:
|
|
"""内容生成API测试"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_topics(self, async_client):
|
|
"""测试获取母题库列表"""
|
|
response = await async_client.get("/api/v1/content/topics")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert isinstance(data, list)
|
|
# 应该有多个母题
|
|
assert len(data) > 0
|
|
|
|
# 验证母题结构
|
|
for topic in data:
|
|
assert "id" in topic
|
|
assert "name" in topic
|
|
assert "description" in topic
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_topic_detail(self, async_client):
|
|
"""测试获取母题详情"""
|
|
# 获取第一个可用的topic id
|
|
topics_response = await async_client.get("/api/v1/content/topics")
|
|
topics = topics_response.json()
|
|
|
|
if len(topics) > 0:
|
|
topic_id = topics[0]["id"]
|
|
response = await async_client.get(f"/api/v1/content/topics/{topic_id}")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["id"] == topic_id
|
|
assert "prompt_template" in data
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_topic_not_found(self, async_client):
|
|
"""测试获取不存在的母题"""
|
|
response = await async_client.get("/api/v1/content/topics/nonexistent_topic")
|
|
|
|
assert response.status_code == 404
|
|
|
|
|
|
class TestBrandKnowledgeAPI:
|
|
"""品牌知识库API测试"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_brand_knowledge_empty(self, async_client):
|
|
"""测试获取空品牌知识库"""
|
|
response = await async_client.get("/api/v1/contents/knowledge/")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert isinstance(data, list)
|
|
assert len(data) == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_brand_knowledge(self, async_client, test_user):
|
|
"""测试创建品牌知识条目"""
|
|
knowledge_data = {
|
|
"category": "product",
|
|
"title": "产品介绍",
|
|
"body": "这是产品的详细介绍",
|
|
"source": "官网",
|
|
}
|
|
response = await async_client.post(
|
|
"/api/v1/contents/knowledge/", json=knowledge_data
|
|
)
|
|
|
|
assert response.status_code == 201
|
|
data = response.json()
|
|
assert data["title"] == "产品介绍"
|
|
assert data["category"] == "product"
|
|
assert data["body"] == "这是产品的详细介绍"
|
|
assert data["source"] == "官网"
|
|
assert "id" in data
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_brand_knowledge_with_category(self, async_client, test_user):
|
|
"""测试按分类获取品牌知识"""
|
|
# 创建不同分类的知识
|
|
categories = ["product", "technology", "product"]
|
|
for i, cat in enumerate(categories):
|
|
await async_client.post(
|
|
"/api/v1/contents/knowledge/",
|
|
json={"category": cat, "title": f"知识 {i}", "body": "正文"}
|
|
)
|
|
|
|
# 按product分类过滤
|
|
response = await async_client.get("/api/v1/contents/knowledge/?category=product")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert isinstance(data, list)
|
|
for item in data:
|
|
assert item["category"] == "product"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_brand_knowledge(self, async_client, test_user):
|
|
"""测试更新品牌知识"""
|
|
# 先创建知识
|
|
create_response = await async_client.post(
|
|
"/api/v1/contents/knowledge/",
|
|
json={"category": "test", "title": "原始标题", "body": "原始内容"}
|
|
)
|
|
knowledge_id = create_response.json()["id"]
|
|
|
|
# 更新
|
|
response = await async_client.put(
|
|
f"/api/v1/contents/knowledge/{knowledge_id}",
|
|
json={"title": "新标题", "body": "新内容"}
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["title"] == "新标题"
|
|
assert data["body"] == "新内容"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete_brand_knowledge(self, async_client, test_user):
|
|
"""测试删除品牌知识"""
|
|
# 先创建知识
|
|
create_response = await async_client.post(
|
|
"/api/v1/contents/knowledge/",
|
|
json={"category": "test", "title": "待删除", "body": "正文"}
|
|
)
|
|
knowledge_id = create_response.json()["id"]
|
|
|
|
# 删除
|
|
response = await async_client.delete(f"/api/v1/contents/knowledge/{knowledge_id}")
|
|
|
|
assert response.status_code == 204
|
|
|
|
# 验证删除 - 通过列出知识来确认
|
|
list_response = await async_client.get("/api/v1/contents/knowledge/")
|
|
knowledge_ids = [item["id"] for item in list_response.json()]
|
|
assert knowledge_id not in knowledge_ids
|