geo/backend/tests/test_api/test_content_api.py

461 lines
15 KiB
Python

"""内容API测试"""
import uuid
import pytest
import pytest_asyncio
from httpx import AsyncClient, ASGITransport
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession, create_async_engine
from sqlalchemy.pool import StaticPool
from app.database import Base
from app.main import app
from app.models.user import User
from app.models.brand import Brand
from app.api.deps import get_current_user, get_db
from app.services.auth import hash_password, create_access_token
from tests.fixtures.auth import _to_uuid
# ==================== Fixtures ====================
@pytest_asyncio.fixture
async def async_engine():
"""创建测试用SQLite异步引擎"""
engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest_asyncio.fixture
async def async_session(async_engine):
"""创建测试用异步数据库会话"""
async_session_maker = async_sessionmaker(
async_engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
autocommit=False,
)
async with async_session_maker() as session:
yield session
@pytest_asyncio.fixture
async def test_user(async_session):
"""创建测试用户"""
user = User(
id=str(uuid.uuid4()),
email="test@example.com",
password=hash_password("Test@123456"),
firstName="Test User",
plan="free",
max_queries=5,
isActive=True,
emailVerified=True,
organization_id=uuid.uuid4(), # 需要organization_id用于内容API
)
async_session.add(user)
await async_session.commit()
await async_session.refresh(user)
return user
@pytest_asyncio.fixture
async def test_brand(async_session, test_user):
"""创建测试品牌"""
brand = Brand(
id=uuid.uuid4(),
user_id=_to_uuid(test_user.id),
name="Test Brand",
aliases=["TestBrand", "TB"],
website="https://testbrand.com",
industry="technology",
platforms=["wenxin", "kimi"],
frequency="weekly",
status="active",
)
async_session.add(brand)
await async_session.commit()
await async_session.refresh(brand)
return brand
@pytest_asyncio.fixture
async def async_client(async_session, test_user):
"""创建异步HTTP客户端用于API测试"""
async def override_get_db():
yield async_session
async def override_get_current_user():
return test_user
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_get_current_user
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client
app.dependency_overrides.clear()
@pytest.fixture
def auth_headers(test_user):
"""创建认证请求头"""
token = create_access_token(data={"sub": str(test_user.id)})
return {"Authorization": f"Bearer {token}"}
# ==================== 测试类 ====================
class TestContentAPI:
"""内容管理API测试"""
@pytest.mark.asyncio
async def test_list_contents_empty(self, async_client):
"""测试获取空内容列表"""
response = await async_client.get("/api/v1/contents/")
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
assert len(data) == 0
@pytest.mark.asyncio
async def test_create_content(self, async_client, test_user):
"""测试创建内容"""
content_data = {
"title": "测试文章标题",
"body": "这是测试文章的内容",
"content_type": "article",
"tags": ["测试", "文章"],
}
response = await async_client.post("/api/v1/contents/", json=content_data)
assert response.status_code == 201
data = response.json()
assert data["title"] == "测试文章标题"
assert data["body"] == "这是测试文章的内容"
assert data["content_type"] == "article"
assert data["status"] == "draft"
assert "id" in data
@pytest.mark.asyncio
async def test_get_content_by_id(self, async_client, test_user):
"""测试通过ID获取内容"""
# 先创建内容
create_response = await async_client.post(
"/api/v1/contents/",
json={
"title": "测试内容",
"body": "测试内容正文",
"content_type": "article",
}
)
content_id = create_response.json()["id"]
# 获取内容
response = await async_client.get(f"/api/v1/contents/{content_id}")
assert response.status_code == 200
data = response.json()
assert data["id"] == content_id
assert data["title"] == "测试内容"
@pytest.mark.asyncio
async def test_get_content_not_found(self, async_client):
"""测试获取不存在的内容"""
non_existent_id = uuid.uuid4()
response = await async_client.get(f"/api/v1/contents/{non_existent_id}")
assert response.status_code == 404
@pytest.mark.asyncio
async def test_update_content(self, async_client, test_user):
"""测试更新内容"""
# 先创建内容
create_response = await async_client.post(
"/api/v1/contents/",
json={
"title": "原始标题",
"body": "原始内容",
"content_type": "article",
}
)
content_id = create_response.json()["id"]
# 更新内容
update_data = {
"title": "更新后的标题",
"body": "更新后的内容",
"status": "published",
}
response = await async_client.put(
f"/api/v1/contents/{content_id}", json=update_data
)
assert response.status_code == 200
data = response.json()
assert data["title"] == "更新后的标题"
assert data["body"] == "更新后的内容"
@pytest.mark.asyncio
async def test_update_content_partial(self, async_client, test_user):
"""测试部分更新内容"""
# 先创建内容
create_response = await async_client.post(
"/api/v1/contents/",
json={
"title": "原始标题",
"body": "原始内容",
}
)
content_id = create_response.json()["id"]
# 只更新标题
response = await async_client.put(
f"/api/v1/contents/{content_id}",
json={"title": "新标题"}
)
assert response.status_code == 200
data = response.json()
assert data["title"] == "新标题"
assert data["body"] == "原始内容" # 保持不变
@pytest.mark.asyncio
async def test_delete_content(self, async_client, test_user):
"""测试删除内容"""
# 先创建内容
create_response = await async_client.post(
"/api/v1/contents/",
json={
"title": "待删除内容",
"body": "内容正文",
}
)
content_id = create_response.json()["id"]
# 删除内容
response = await async_client.delete(f"/api/v1/contents/{content_id}")
assert response.status_code == 204
# 验证已删除
get_response = await async_client.get(f"/api/v1/contents/{content_id}")
assert get_response.status_code == 404
@pytest.mark.asyncio
async def test_publish_content(self, async_client, test_user):
"""测试发布内容"""
# 先创建内容
create_response = await async_client.post(
"/api/v1/contents/",
json={
"title": "待发布内容",
"body": "内容正文",
}
)
content_id = create_response.json()["id"]
# 发布内容
response = await async_client.post(f"/api/v1/contents/{content_id}/publish")
assert response.status_code == 200
data = response.json()
assert data["status"] == "published"
assert data["published_at"] is not None
@pytest.mark.asyncio
async def test_list_contents_with_filter(self, async_client, test_user):
"""测试内容列表过滤"""
# 创建多个不同类型的内容
for i, content_type in enumerate(["article", "post", "article"]):
await async_client.post(
"/api/v1/contents/",
json={
"title": f"内容 {i}",
"body": "正文",
"content_type": content_type,
}
)
# 按类型过滤
response = await async_client.get("/api/v1/contents/?content_type=article")
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
for item in data:
assert item["content_type"] == "article"
@pytest.mark.asyncio
async def test_list_contents_with_status_filter(self, async_client, test_user):
"""测试内容列表按状态过滤"""
# 创建一个草稿和一个已发布
draft_response = await async_client.post(
"/api/v1/contents/",
json={"title": "草稿", "body": "正文", "content_type": "article"}
)
published_response = await async_client.post(
"/api/v1/contents/",
json={"title": "已发布", "body": "正文", "content_type": "article"}
)
await async_client.post(f"/api/v1/contents/{published_response.json()['id']}/publish")
# 只获取草稿
response = await async_client.get("/api/v1/contents/?status=draft")
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
for item in data:
assert item["status"] == "draft"
class TestContentGenerationAPI:
"""内容生成API测试"""
@pytest.mark.asyncio
async def test_list_topics(self, async_client):
"""测试获取母题库列表"""
response = await async_client.get("/api/v1/content/topics")
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
# 应该有多个母题
assert len(data) > 0
# 验证母题结构
for topic in data:
assert "id" in topic
assert "name" in topic
assert "description" in topic
@pytest.mark.asyncio
async def test_get_topic_detail(self, async_client):
"""测试获取母题详情"""
# 获取第一个可用的topic id
topics_response = await async_client.get("/api/v1/content/topics")
topics = topics_response.json()
if len(topics) > 0:
topic_id = topics[0]["id"]
response = await async_client.get(f"/api/v1/content/topics/{topic_id}")
assert response.status_code == 200
data = response.json()
assert data["id"] == topic_id
assert "prompt_template" in data
@pytest.mark.asyncio
async def test_get_topic_not_found(self, async_client):
"""测试获取不存在的母题"""
response = await async_client.get("/api/v1/content/topics/nonexistent_topic")
assert response.status_code == 404
class TestBrandKnowledgeAPI:
"""品牌知识库API测试"""
@pytest.mark.asyncio
async def test_list_brand_knowledge_empty(self, async_client):
"""测试获取空品牌知识库"""
response = await async_client.get("/api/v1/contents/knowledge/")
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
assert len(data) == 0
@pytest.mark.asyncio
async def test_create_brand_knowledge(self, async_client, test_user):
"""测试创建品牌知识条目"""
knowledge_data = {
"category": "product",
"title": "产品介绍",
"body": "这是产品的详细介绍",
"source": "官网",
}
response = await async_client.post(
"/api/v1/contents/knowledge/", json=knowledge_data
)
assert response.status_code == 201
data = response.json()
assert data["title"] == "产品介绍"
assert data["category"] == "product"
assert data["body"] == "这是产品的详细介绍"
assert data["source"] == "官网"
assert "id" in data
@pytest.mark.asyncio
async def test_list_brand_knowledge_with_category(self, async_client, test_user):
"""测试按分类获取品牌知识"""
# 创建不同分类的知识
categories = ["product", "technology", "product"]
for i, cat in enumerate(categories):
await async_client.post(
"/api/v1/contents/knowledge/",
json={"category": cat, "title": f"知识 {i}", "body": "正文"}
)
# 按product分类过滤
response = await async_client.get("/api/v1/contents/knowledge/?category=product")
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
for item in data:
assert item["category"] == "product"
@pytest.mark.asyncio
async def test_update_brand_knowledge(self, async_client, test_user):
"""测试更新品牌知识"""
# 先创建知识
create_response = await async_client.post(
"/api/v1/contents/knowledge/",
json={"category": "test", "title": "原始标题", "body": "原始内容"}
)
knowledge_id = create_response.json()["id"]
# 更新
response = await async_client.put(
f"/api/v1/contents/knowledge/{knowledge_id}",
json={"title": "新标题", "body": "新内容"}
)
assert response.status_code == 200
data = response.json()
assert data["title"] == "新标题"
assert data["body"] == "新内容"
@pytest.mark.asyncio
async def test_delete_brand_knowledge(self, async_client, test_user):
"""测试删除品牌知识"""
# 先创建知识
create_response = await async_client.post(
"/api/v1/contents/knowledge/",
json={"category": "test", "title": "待删除", "body": "正文"}
)
knowledge_id = create_response.json()["id"]
# 删除
response = await async_client.delete(f"/api/v1/contents/knowledge/{knowledge_id}")
assert response.status_code == 204
# 验证删除 - 通过列出知识来确认
list_response = await async_client.get("/api/v1/contents/knowledge/")
knowledge_ids = [item["id"] for item in list_response.json()]
assert knowledge_id not in knowledge_ids