421 lines
16 KiB
Python
421 lines
16 KiB
Python
"""Tests for ContentGenerationService - extracted from api/content.py
|
|
|
|
TDD RED phase: tests for the service that extracts the 3-stage
|
|
content generation flow (generate -> de-AI -> GEO optimize)
|
|
out of the API handler.
|
|
"""
|
|
import uuid
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from app.services.llm.base import LLMError, LLMResponse
|
|
|
|
|
|
def _make_llm_response(content: str) -> LLMResponse:
|
|
"""Helper to create LLMResponse objects for mocking."""
|
|
return LLMResponse(content=content, model="test-model")
|
|
|
|
|
|
class TestContentGenerationService:
|
|
"""ContentGenerationService unit tests."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_content_basic_three_stages(self):
|
|
"""Test basic 3-stage content generation (generate -> de-AI -> GEO optimize)."""
|
|
from app.services.content.content_generation_service import (
|
|
ContentGenerationService,
|
|
)
|
|
|
|
service = ContentGenerationService()
|
|
|
|
mock_provider = AsyncMock()
|
|
mock_provider.chat.side_effect = [
|
|
_make_llm_response("Raw generated content"),
|
|
_make_llm_response("De-AIed content"),
|
|
_make_llm_response("GEO optimized content"),
|
|
]
|
|
|
|
with patch.object(service, "_get_provider", return_value=mock_provider):
|
|
result = await service.generate_content(
|
|
keyword="AI搜索",
|
|
brand_name="Brand X",
|
|
platform="wenxin",
|
|
content_style="专业严谨",
|
|
word_count=2000,
|
|
run_deai=True,
|
|
run_geo=True,
|
|
)
|
|
|
|
assert result is not None
|
|
assert result["content"] == "De-AIed content"
|
|
assert result["optimized_content"] == "GEO optimized content"
|
|
assert result["pipeline_stages"] is not None
|
|
assert len(result["pipeline_stages"]) == 3
|
|
assert result["pipeline_stages"][0]["stage"] == "content_generation"
|
|
assert result["pipeline_stages"][1]["stage"] == "deai"
|
|
assert result["pipeline_stages"][2]["stage"] == "geo_optimization"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_content_skip_deai(self):
|
|
"""Test generation with run_deai=False skips the de-AI stage."""
|
|
from app.services.content.content_generation_service import (
|
|
ContentGenerationService,
|
|
)
|
|
|
|
service = ContentGenerationService()
|
|
|
|
mock_provider = AsyncMock()
|
|
mock_provider.chat.side_effect = [
|
|
_make_llm_response("Raw generated content"),
|
|
_make_llm_response("GEO optimized content"),
|
|
]
|
|
|
|
with patch.object(service, "_get_provider", return_value=mock_provider):
|
|
result = await service.generate_content(
|
|
keyword="test",
|
|
brand_name="Brand X",
|
|
platform="wenxin",
|
|
run_deai=False,
|
|
run_geo=True,
|
|
)
|
|
|
|
assert result["content"] == "Raw generated content"
|
|
assert result["optimized_content"] == "GEO optimized content"
|
|
assert len(result["pipeline_stages"]) == 2
|
|
stage_names = [s["stage"] for s in result["pipeline_stages"]]
|
|
assert "deai" not in stage_names
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_content_skip_geo(self):
|
|
"""Test generation with run_geo=False skips the GEO optimization stage."""
|
|
from app.services.content.content_generation_service import (
|
|
ContentGenerationService,
|
|
)
|
|
|
|
service = ContentGenerationService()
|
|
|
|
mock_provider = AsyncMock()
|
|
mock_provider.chat.side_effect = [
|
|
_make_llm_response("Raw generated content"),
|
|
_make_llm_response("De-AIed content"),
|
|
]
|
|
|
|
with patch.object(service, "_get_provider", return_value=mock_provider):
|
|
result = await service.generate_content(
|
|
keyword="test",
|
|
brand_name="Brand X",
|
|
platform="wenxin",
|
|
run_deai=True,
|
|
run_geo=False,
|
|
)
|
|
|
|
assert result["content"] == "De-AIed content"
|
|
assert result["optimized_content"] == "De-AIed content"
|
|
assert len(result["pipeline_stages"]) == 2
|
|
stage_names = [s["stage"] for s in result["pipeline_stages"]]
|
|
assert "geo_optimization" not in stage_names
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_content_skip_both_stages(self):
|
|
"""Test generation with both run_deai=False and run_geo=False."""
|
|
from app.services.content.content_generation_service import (
|
|
ContentGenerationService,
|
|
)
|
|
|
|
service = ContentGenerationService()
|
|
|
|
mock_provider = AsyncMock()
|
|
mock_provider.chat.side_effect = [
|
|
_make_llm_response("Raw generated content"),
|
|
]
|
|
|
|
with patch.object(service, "_get_provider", return_value=mock_provider):
|
|
result = await service.generate_content(
|
|
keyword="test",
|
|
brand_name="Brand X",
|
|
platform="wenxin",
|
|
run_deai=False,
|
|
run_geo=False,
|
|
)
|
|
|
|
assert result["content"] == "Raw generated content"
|
|
assert result["optimized_content"] == "Raw generated content"
|
|
assert len(result["pipeline_stages"]) == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_content_with_knowledge_context(self):
|
|
"""Test that knowledge_context is passed through to the generation prompt."""
|
|
from app.services.content.content_generation_service import (
|
|
ContentGenerationService,
|
|
)
|
|
|
|
service = ContentGenerationService()
|
|
|
|
mock_provider = AsyncMock()
|
|
mock_provider.chat.side_effect = [
|
|
_make_llm_response("Raw content"),
|
|
_make_llm_response("De-AIed"),
|
|
_make_llm_response("Optimized"),
|
|
]
|
|
|
|
with patch.object(service, "_get_provider", return_value=mock_provider):
|
|
result = await service.generate_content(
|
|
keyword="test keyword",
|
|
brand_name="Brand X",
|
|
platform="wenxin",
|
|
knowledge_context="Some knowledge base content",
|
|
)
|
|
|
|
# The first call to provider.chat should include the knowledge context
|
|
first_call_args = mock_provider.chat.call_args_list[0]
|
|
messages = first_call_args[0][0]
|
|
# Verify knowledge context appears in the rendered messages
|
|
all_content = " ".join(str(m) for m in messages)
|
|
assert "Some knowledge base content" in all_content
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_content_saves_to_database(self):
|
|
"""Test that generated content is saved to database when db and user_id are provided."""
|
|
from app.services.content.content_generation_service import (
|
|
ContentGenerationService,
|
|
)
|
|
|
|
service = ContentGenerationService()
|
|
|
|
mock_provider = AsyncMock()
|
|
mock_provider.chat.side_effect = [
|
|
_make_llm_response("Raw content"),
|
|
_make_llm_response("De-AIed"),
|
|
_make_llm_response("Optimized"),
|
|
]
|
|
mock_db = AsyncMock()
|
|
|
|
with patch.object(service, "_get_provider", return_value=mock_provider):
|
|
result = await service.generate_content(
|
|
keyword="test",
|
|
brand_name="Brand X",
|
|
platform="wenxin",
|
|
db=mock_db,
|
|
user_id="test-user-123",
|
|
org_id=str(uuid.uuid4()),
|
|
)
|
|
|
|
# Verify db.add was called (Content + ContentVersion)
|
|
assert mock_db.add.call_count == 2
|
|
mock_db.commit.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_content_no_db_no_save(self):
|
|
"""Test that when db is not provided, no database operations occur."""
|
|
from app.services.content.content_generation_service import (
|
|
ContentGenerationService,
|
|
)
|
|
|
|
service = ContentGenerationService()
|
|
|
|
mock_provider = AsyncMock()
|
|
mock_provider.chat.side_effect = [
|
|
_make_llm_response("Raw"),
|
|
_make_llm_response("De-AIed"),
|
|
_make_llm_response("Optimized"),
|
|
]
|
|
|
|
with patch.object(service, "_get_provider", return_value=mock_provider):
|
|
result = await service.generate_content(
|
|
keyword="test",
|
|
brand_name="Brand X",
|
|
platform="wenxin",
|
|
)
|
|
|
|
# No content_id when db is not provided
|
|
assert result.get("content_id") is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_content_llm_error_propagates(self):
|
|
"""Test that LLMError from the provider propagates correctly."""
|
|
from app.services.content.content_generation_service import (
|
|
ContentGenerationService,
|
|
)
|
|
|
|
service = ContentGenerationService()
|
|
|
|
mock_provider = AsyncMock()
|
|
mock_provider.chat.side_effect = LLMError(
|
|
"API rate limit", provider="openai"
|
|
)
|
|
|
|
with patch.object(service, "_get_provider", return_value=mock_provider):
|
|
with pytest.raises(LLMError, match="API rate limit"):
|
|
await service.generate_content(
|
|
keyword="test",
|
|
brand_name="Brand X",
|
|
platform="wenxin",
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_knowledge_context_with_ids(self):
|
|
"""Test _get_knowledge_context retrieves context when knowledge_base_ids are provided."""
|
|
from app.services.content.content_generation_service import (
|
|
ContentGenerationService,
|
|
)
|
|
|
|
service = ContentGenerationService()
|
|
mock_db = AsyncMock()
|
|
|
|
mock_rag = MagicMock()
|
|
mock_rag.search = AsyncMock(return_value=[
|
|
{"content": "Knowledge chunk 1", "document_title": "Doc A"},
|
|
{"content": "Knowledge chunk 2", "document_title": "Doc B"},
|
|
])
|
|
|
|
with patch(
|
|
"app.services.knowledge.rag_service.RAGService",
|
|
return_value=mock_rag,
|
|
):
|
|
context = await service._get_knowledge_context(
|
|
db=mock_db,
|
|
brand_name="Brand X",
|
|
knowledge_base_ids=["kb-1"],
|
|
target_keyword="test keyword",
|
|
)
|
|
|
|
assert "Knowledge chunk 1" in context
|
|
assert "Knowledge chunk 2" in context
|
|
assert "Doc A" in context
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_knowledge_context_empty_ids(self):
|
|
"""Test _get_knowledge_context returns empty string when no IDs provided."""
|
|
from app.services.content.content_generation_service import (
|
|
ContentGenerationService,
|
|
)
|
|
|
|
service = ContentGenerationService()
|
|
mock_db = AsyncMock()
|
|
|
|
context = await service._get_knowledge_context(
|
|
db=mock_db,
|
|
brand_name="Brand X",
|
|
knowledge_base_ids=[],
|
|
target_keyword="test",
|
|
)
|
|
|
|
assert context == ""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_knowledge_context_rag_failure_returns_empty(self):
|
|
"""Test _get_knowledge_context returns empty string when RAG search fails."""
|
|
from app.services.content.content_generation_service import (
|
|
ContentGenerationService,
|
|
)
|
|
|
|
service = ContentGenerationService()
|
|
mock_db = AsyncMock()
|
|
|
|
mock_rag = MagicMock()
|
|
mock_rag.search = AsyncMock(side_effect=Exception("RAG service down"))
|
|
|
|
with patch(
|
|
"app.services.knowledge.rag_service.RAGService",
|
|
return_value=mock_rag,
|
|
):
|
|
context = await service._get_knowledge_context(
|
|
db=mock_db,
|
|
brand_name="Brand X",
|
|
knowledge_base_ids=["kb-1"],
|
|
target_keyword="test",
|
|
)
|
|
|
|
assert context == ""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_content_passes_correct_prompt_variables(self):
|
|
"""Test that the service passes correct variables to each prompt template stage."""
|
|
from app.services.content.content_generation_service import (
|
|
ContentGenerationService,
|
|
)
|
|
|
|
service = ContentGenerationService()
|
|
|
|
mock_provider = AsyncMock()
|
|
mock_provider.chat.side_effect = [
|
|
_make_llm_response("Raw content"),
|
|
_make_llm_response("De-AIed content"),
|
|
_make_llm_response("Optimized content"),
|
|
]
|
|
|
|
with patch.object(service, "_get_provider", return_value=mock_provider):
|
|
with patch(
|
|
"app.services.content.content_generation_service.CONTENT_GENERATOR_TEMPLATE"
|
|
) as mock_gen_template, patch(
|
|
"app.services.content.content_generation_service.DEAI_TEMPLATE"
|
|
) as mock_deai_template, patch(
|
|
"app.services.content.content_generation_service.GEO_OPTIMIZER_TEMPLATE"
|
|
) as mock_geo_template:
|
|
mock_gen_template.render.return_value = [
|
|
{"role": "user", "content": "gen prompt"}
|
|
]
|
|
mock_deai_template.render.return_value = [
|
|
{"role": "user", "content": "deai prompt"}
|
|
]
|
|
mock_geo_template.render.return_value = [
|
|
{"role": "user", "content": "geo prompt"}
|
|
]
|
|
|
|
await service.generate_content(
|
|
keyword="AI搜索",
|
|
brand_name="Brand X",
|
|
platform="微信公众号",
|
|
content_style="专业严谨",
|
|
word_count=3000,
|
|
knowledge_context="Some context",
|
|
)
|
|
|
|
# Verify CONTENT_GENERATOR_TEMPLATE.render called with correct variables
|
|
gen_call_kwargs = mock_gen_template.render.call_args[0][0]
|
|
assert gen_call_kwargs["topic_title"] == "AI搜索"
|
|
assert gen_call_kwargs["target_keyword"] == "AI搜索"
|
|
assert gen_call_kwargs["target_platform"] == "微信公众号"
|
|
assert gen_call_kwargs["content_style"] == "专业严谨"
|
|
assert gen_call_kwargs["word_count"] == "3000"
|
|
assert gen_call_kwargs["brand_name"] == "Brand X"
|
|
assert gen_call_kwargs["knowledge_context"] == "Some context"
|
|
|
|
# Verify DEAI_TEMPLATE.render called with the generated content
|
|
deai_call_kwargs = mock_deai_template.render.call_args[0][0]
|
|
assert deai_call_kwargs["original_content"] == "Raw content"
|
|
|
|
# Verify GEO_OPTIMIZER_TEMPLATE.render called with de-AIed content
|
|
geo_call_kwargs = mock_geo_template.render.call_args[0][0]
|
|
assert geo_call_kwargs["original_content"] == "De-AIed content"
|
|
assert geo_call_kwargs["target_keywords"] == "AI搜索"
|
|
assert geo_call_kwargs["target_platform"] == "微信公众号"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_content_default_parameters(self):
|
|
"""Test that default parameter values are applied correctly."""
|
|
from app.services.content.content_generation_service import (
|
|
ContentGenerationService,
|
|
)
|
|
|
|
service = ContentGenerationService()
|
|
|
|
mock_provider = AsyncMock()
|
|
mock_provider.chat.side_effect = [
|
|
_make_llm_response("Raw"),
|
|
_make_llm_response("De-AIed"),
|
|
_make_llm_response("Optimized"),
|
|
]
|
|
|
|
with patch.object(service, "_get_provider", return_value=mock_provider):
|
|
result = await service.generate_content(
|
|
keyword="test",
|
|
brand_name="Brand X",
|
|
)
|
|
|
|
# Defaults: platform="通用", content_style="专业严谨", word_count=2000
|
|
# run_deai=True, run_geo=True
|
|
assert result is not None
|
|
assert len(result["pipeline_stages"]) == 3
|