309 lines
9.8 KiB
Python
309 lines
9.8 KiB
Python
import uuid
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
from httpx import AsyncClient, ASGITransport
|
|
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession, create_async_engine
|
|
from sqlalchemy.pool import StaticPool
|
|
|
|
from app.database import Base
|
|
from app.main import app
|
|
from app.models.brand import Brand
|
|
from app.models.content import Content
|
|
from app.models.diagnosis_record import DiagnosisRecord
|
|
from app.models.user import User
|
|
from app.api.deps import get_current_user, get_db
|
|
from app.services.auth import hash_password
|
|
|
|
|
|
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
|
|
if isinstance(value, uuid.UUID):
|
|
return value
|
|
return uuid.UUID(str(value))
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def async_engine():
|
|
engine = create_async_engine(
|
|
"sqlite+aiosqlite:///:memory:",
|
|
connect_args={"check_same_thread": False},
|
|
poolclass=StaticPool,
|
|
)
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
yield engine
|
|
await engine.dispose()
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def async_session(async_engine):
|
|
async_session_maker = async_sessionmaker(
|
|
async_engine,
|
|
class_=AsyncSession,
|
|
expire_on_commit=False,
|
|
autoflush=False,
|
|
autocommit=False,
|
|
)
|
|
async with async_session_maker() as session:
|
|
yield session
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def test_user(async_session):
|
|
user_id = str(uuid.uuid4())
|
|
user = User(
|
|
id=user_id,
|
|
email="test_dist@example.com",
|
|
password=hash_password("Test@123456"),
|
|
firstName="Test",
|
|
lastName="User",
|
|
isActive=True,
|
|
emailVerified=True,
|
|
organization_id=uuid.uuid4(),
|
|
)
|
|
async_session.add(user)
|
|
await async_session.commit()
|
|
await async_session.refresh(user)
|
|
return user
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def test_brand(async_session, test_user):
|
|
brand = Brand(
|
|
id=uuid.uuid4(),
|
|
user_id=_to_uuid(test_user.id),
|
|
name="Test Brand",
|
|
aliases=["TestBrand"],
|
|
website="https://testbrand.com",
|
|
industry="technology",
|
|
platforms=["wenxin", "kimi"],
|
|
frequency="weekly",
|
|
status="active",
|
|
)
|
|
async_session.add(brand)
|
|
await async_session.commit()
|
|
await async_session.refresh(brand)
|
|
return brand
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def test_content(async_session, test_user):
|
|
content = Content(
|
|
id=uuid.uuid4(),
|
|
organization_id=test_user.organization_id,
|
|
title="测试文章",
|
|
content_type="article",
|
|
body="这是一篇测试文章的内容,用于发布测试。",
|
|
status="draft",
|
|
target_platforms=["zhihu", "toutiao"],
|
|
keywords=["测试"],
|
|
created_by=test_user.id,
|
|
current_version=1,
|
|
)
|
|
async_session.add(content)
|
|
await async_session.commit()
|
|
await async_session.refresh(content)
|
|
return content
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def test_diagnosis(async_session, test_brand):
|
|
diagnosis = DiagnosisRecord(
|
|
id=uuid.uuid4(),
|
|
brand_id=test_brand.id,
|
|
user_id=_to_uuid(test_brand.user_id),
|
|
diagnosis_type="geo",
|
|
status="completed",
|
|
overall_score=55.0,
|
|
result_json={
|
|
"dimensions": {
|
|
"visibility": {"score": 40, "details": "low"},
|
|
"authority": {"score": 70, "details": "ok"},
|
|
"relevance": {"score": 50, "details": "low"},
|
|
}
|
|
},
|
|
)
|
|
async_session.add(diagnosis)
|
|
await async_session.commit()
|
|
await async_session.refresh(diagnosis)
|
|
return diagnosis
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def async_client(async_session, test_user):
|
|
async def override_get_db():
|
|
yield async_session
|
|
|
|
async def override_get_current_user():
|
|
return test_user
|
|
|
|
app.dependency_overrides[get_db] = override_get_db
|
|
app.dependency_overrides[get_current_user] = override_get_current_user
|
|
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
yield client
|
|
|
|
app.dependency_overrides.clear()
|
|
|
|
|
|
class TestGEOContentGeneration:
|
|
@pytest.mark.asyncio
|
|
async def test_generate_geo_returns_201(self, async_client, test_brand, test_diagnosis):
|
|
with patch(
|
|
"app.services.content.content_generation_service.ContentGenerationService.generate_content",
|
|
new_callable=AsyncMock,
|
|
) as mock_gen:
|
|
mock_gen.return_value = {
|
|
"content": "生成的内容",
|
|
"optimized_content": "优化后的内容",
|
|
"seo_score": 85,
|
|
"content_id": str(uuid.uuid4()),
|
|
"pipeline_stages": [{"stage": "content_generation", "status": "success"}],
|
|
}
|
|
|
|
response = await async_client.post(
|
|
"/api/v1/content/generate-geo",
|
|
json={
|
|
"brand_id": str(test_brand.id),
|
|
"target_keywords": ["AI优化", "品牌曝光"],
|
|
"platform": "通用",
|
|
"content_style": "专业严谨",
|
|
"word_count": 2000,
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 201
|
|
data = response.json()
|
|
assert data["content_id"] is not None
|
|
assert data["content"] == "生成的内容"
|
|
assert data["optimized_content"] == "优化后的内容"
|
|
assert data["seo_score"] == 85
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_geo_with_invalid_brand_returns_404(self, async_client):
|
|
fake_brand_id = str(uuid.uuid4())
|
|
response = await async_client.post(
|
|
"/api/v1/content/generate-geo",
|
|
json={
|
|
"brand_id": fake_brand_id,
|
|
"target_keywords": ["测试"],
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 404
|
|
|
|
|
|
class TestPublishAPI:
|
|
@pytest.mark.asyncio
|
|
async def test_publish_to_mock_platforms(self, async_client, test_content):
|
|
response = await async_client.post(
|
|
"/api/v1/distribution/publish",
|
|
json={
|
|
"content_id": str(test_content.id),
|
|
"platforms": ["zhihu", "toutiao"],
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "results" in data
|
|
assert len(data["results"]) == 2
|
|
for r in data["results"]:
|
|
assert r["success"] is True
|
|
assert r["article_id"] is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_publish_with_invalid_content_id_returns_404(self, async_client):
|
|
fake_content_id = str(uuid.uuid4())
|
|
response = await async_client.post(
|
|
"/api/v1/distribution/publish",
|
|
json={
|
|
"content_id": fake_content_id,
|
|
"platforms": ["zhihu"],
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 404
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_publish_status(self, async_client, test_content):
|
|
await async_client.post(
|
|
"/api/v1/distribution/publish",
|
|
json={
|
|
"content_id": str(test_content.id),
|
|
"platforms": ["zhihu"],
|
|
},
|
|
)
|
|
|
|
response = await async_client.get(
|
|
f"/api/v1/distribution/publish/{test_content.id}/status",
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "platforms" in data
|
|
assert len(data["platforms"]) >= 1
|
|
assert data["platforms"][0]["platform"] == "zhihu"
|
|
|
|
|
|
class TestPublishers:
|
|
@pytest.mark.asyncio
|
|
async def test_zhihu_publisher_mock_mode(self):
|
|
from app.services.distribution.publishers.zhihu_publisher import ZhihuPublisher
|
|
|
|
pub = ZhihuPublisher()
|
|
assert pub.is_configured() is False
|
|
|
|
result = await pub.publish(title="测试标题", content="测试内容")
|
|
assert result.success is True
|
|
assert result.platform == "zhihu"
|
|
assert result.article_id is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_wechat_publisher_returns_formatted_content_in_mock_mode(self):
|
|
from app.services.distribution.publishers.wechat_publisher import WeChatPublisher
|
|
|
|
pub = WeChatPublisher()
|
|
assert pub.is_configured() is False
|
|
|
|
result = await pub.publish(title="微信测试", content="## 标题\n微信内容")
|
|
assert result.success is True
|
|
assert result.platform == "wechat"
|
|
assert "formatted_content" in result.raw_response
|
|
assert "instructions" in result.raw_response
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_publisher_factory_returns_mock_by_default(self):
|
|
from app.services.distribution.publishers import get_publisher
|
|
from app.services.distribution.publishers.mock_publisher import MockPublisher
|
|
|
|
pub = get_publisher("zhihu")
|
|
assert isinstance(pub, MockPublisher)
|
|
assert pub.platform == "zhihu"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_publisher_factory_returns_mock_for_unknown_platform(self):
|
|
from app.services.distribution.publishers import get_publisher
|
|
from app.services.distribution.publishers.mock_publisher import MockPublisher
|
|
|
|
pub = get_publisher("unknown_platform")
|
|
assert isinstance(pub, MockPublisher)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_mock_publisher_verify_credentials(self):
|
|
from app.services.distribution.publishers.mock_publisher import MockPublisher
|
|
|
|
pub = MockPublisher(platform="test")
|
|
assert await pub.verify_credentials() is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_mock_publisher_get_article_status(self):
|
|
from app.services.distribution.publishers.mock_publisher import MockPublisher
|
|
|
|
pub = MockPublisher(platform="test")
|
|
status = await pub.get_article_status("article_123")
|
|
assert status["status"] == "published"
|
|
assert status["mock"] is True
|