geo/backend/tests/test_services/test_content_generation.py

421 lines
16 KiB
Python

"""Tests for ContentGenerationService - extracted from api/content.py
TDD RED phase: tests for the service that extracts the 3-stage
content generation flow (generate -> de-AI -> GEO optimize)
out of the API handler.
"""
import uuid
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.services.llm.base import LLMError, LLMResponse
def _make_llm_response(content: str) -> LLMResponse:
"""Helper to create LLMResponse objects for mocking."""
return LLMResponse(content=content, model="test-model")
class TestContentGenerationService:
"""ContentGenerationService unit tests."""
@pytest.mark.asyncio
async def test_generate_content_basic_three_stages(self):
"""Test basic 3-stage content generation (generate -> de-AI -> GEO optimize)."""
from app.services.content.content_generation_service import (
ContentGenerationService,
)
service = ContentGenerationService()
mock_provider = AsyncMock()
mock_provider.chat.side_effect = [
_make_llm_response("Raw generated content"),
_make_llm_response("De-AIed content"),
_make_llm_response("GEO optimized content"),
]
with patch.object(service, "_get_provider", return_value=mock_provider):
result = await service.generate_content(
keyword="AI搜索",
brand_name="Brand X",
platform="wenxin",
content_style="专业严谨",
word_count=2000,
run_deai=True,
run_geo=True,
)
assert result is not None
assert result["content"] == "De-AIed content"
assert result["optimized_content"] == "GEO optimized content"
assert result["pipeline_stages"] is not None
assert len(result["pipeline_stages"]) == 3
assert result["pipeline_stages"][0]["stage"] == "content_generation"
assert result["pipeline_stages"][1]["stage"] == "deai"
assert result["pipeline_stages"][2]["stage"] == "geo_optimization"
@pytest.mark.asyncio
async def test_generate_content_skip_deai(self):
"""Test generation with run_deai=False skips the de-AI stage."""
from app.services.content.content_generation_service import (
ContentGenerationService,
)
service = ContentGenerationService()
mock_provider = AsyncMock()
mock_provider.chat.side_effect = [
_make_llm_response("Raw generated content"),
_make_llm_response("GEO optimized content"),
]
with patch.object(service, "_get_provider", return_value=mock_provider):
result = await service.generate_content(
keyword="test",
brand_name="Brand X",
platform="wenxin",
run_deai=False,
run_geo=True,
)
assert result["content"] == "Raw generated content"
assert result["optimized_content"] == "GEO optimized content"
assert len(result["pipeline_stages"]) == 2
stage_names = [s["stage"] for s in result["pipeline_stages"]]
assert "deai" not in stage_names
@pytest.mark.asyncio
async def test_generate_content_skip_geo(self):
"""Test generation with run_geo=False skips the GEO optimization stage."""
from app.services.content.content_generation_service import (
ContentGenerationService,
)
service = ContentGenerationService()
mock_provider = AsyncMock()
mock_provider.chat.side_effect = [
_make_llm_response("Raw generated content"),
_make_llm_response("De-AIed content"),
]
with patch.object(service, "_get_provider", return_value=mock_provider):
result = await service.generate_content(
keyword="test",
brand_name="Brand X",
platform="wenxin",
run_deai=True,
run_geo=False,
)
assert result["content"] == "De-AIed content"
assert result["optimized_content"] == "De-AIed content"
assert len(result["pipeline_stages"]) == 2
stage_names = [s["stage"] for s in result["pipeline_stages"]]
assert "geo_optimization" not in stage_names
@pytest.mark.asyncio
async def test_generate_content_skip_both_stages(self):
"""Test generation with both run_deai=False and run_geo=False."""
from app.services.content.content_generation_service import (
ContentGenerationService,
)
service = ContentGenerationService()
mock_provider = AsyncMock()
mock_provider.chat.side_effect = [
_make_llm_response("Raw generated content"),
]
with patch.object(service, "_get_provider", return_value=mock_provider):
result = await service.generate_content(
keyword="test",
brand_name="Brand X",
platform="wenxin",
run_deai=False,
run_geo=False,
)
assert result["content"] == "Raw generated content"
assert result["optimized_content"] == "Raw generated content"
assert len(result["pipeline_stages"]) == 1
@pytest.mark.asyncio
async def test_generate_content_with_knowledge_context(self):
"""Test that knowledge_context is passed through to the generation prompt."""
from app.services.content.content_generation_service import (
ContentGenerationService,
)
service = ContentGenerationService()
mock_provider = AsyncMock()
mock_provider.chat.side_effect = [
_make_llm_response("Raw content"),
_make_llm_response("De-AIed"),
_make_llm_response("Optimized"),
]
with patch.object(service, "_get_provider", return_value=mock_provider):
result = await service.generate_content(
keyword="test keyword",
brand_name="Brand X",
platform="wenxin",
knowledge_context="Some knowledge base content",
)
# The first call to provider.chat should include the knowledge context
first_call_args = mock_provider.chat.call_args_list[0]
messages = first_call_args[0][0]
# Verify knowledge context appears in the rendered messages
all_content = " ".join(str(m) for m in messages)
assert "Some knowledge base content" in all_content
@pytest.mark.asyncio
async def test_generate_content_saves_to_database(self):
"""Test that generated content is saved to database when db and user_id are provided."""
from app.services.content.content_generation_service import (
ContentGenerationService,
)
service = ContentGenerationService()
mock_provider = AsyncMock()
mock_provider.chat.side_effect = [
_make_llm_response("Raw content"),
_make_llm_response("De-AIed"),
_make_llm_response("Optimized"),
]
mock_db = AsyncMock()
with patch.object(service, "_get_provider", return_value=mock_provider):
result = await service.generate_content(
keyword="test",
brand_name="Brand X",
platform="wenxin",
db=mock_db,
user_id="test-user-123",
org_id=str(uuid.uuid4()),
)
# Verify db.add was called (Content + ContentVersion)
assert mock_db.add.call_count == 2
mock_db.commit.assert_called_once()
@pytest.mark.asyncio
async def test_generate_content_no_db_no_save(self):
"""Test that when db is not provided, no database operations occur."""
from app.services.content.content_generation_service import (
ContentGenerationService,
)
service = ContentGenerationService()
mock_provider = AsyncMock()
mock_provider.chat.side_effect = [
_make_llm_response("Raw"),
_make_llm_response("De-AIed"),
_make_llm_response("Optimized"),
]
with patch.object(service, "_get_provider", return_value=mock_provider):
result = await service.generate_content(
keyword="test",
brand_name="Brand X",
platform="wenxin",
)
# No content_id when db is not provided
assert result.get("content_id") is None
@pytest.mark.asyncio
async def test_generate_content_llm_error_propagates(self):
"""Test that LLMError from the provider propagates correctly."""
from app.services.content.content_generation_service import (
ContentGenerationService,
)
service = ContentGenerationService()
mock_provider = AsyncMock()
mock_provider.chat.side_effect = LLMError(
"API rate limit", provider="openai"
)
with patch.object(service, "_get_provider", return_value=mock_provider):
with pytest.raises(LLMError, match="API rate limit"):
await service.generate_content(
keyword="test",
brand_name="Brand X",
platform="wenxin",
)
@pytest.mark.asyncio
async def test_get_knowledge_context_with_ids(self):
"""Test _get_knowledge_context retrieves context when knowledge_base_ids are provided."""
from app.services.content.content_generation_service import (
ContentGenerationService,
)
service = ContentGenerationService()
mock_db = AsyncMock()
mock_rag = MagicMock()
mock_rag.search = AsyncMock(return_value=[
{"content": "Knowledge chunk 1", "document_title": "Doc A"},
{"content": "Knowledge chunk 2", "document_title": "Doc B"},
])
with patch(
"app.services.knowledge.rag_service.RAGService",
return_value=mock_rag,
):
context = await service._get_knowledge_context(
db=mock_db,
brand_name="Brand X",
knowledge_base_ids=["kb-1"],
target_keyword="test keyword",
)
assert "Knowledge chunk 1" in context
assert "Knowledge chunk 2" in context
assert "Doc A" in context
@pytest.mark.asyncio
async def test_get_knowledge_context_empty_ids(self):
"""Test _get_knowledge_context returns empty string when no IDs provided."""
from app.services.content.content_generation_service import (
ContentGenerationService,
)
service = ContentGenerationService()
mock_db = AsyncMock()
context = await service._get_knowledge_context(
db=mock_db,
brand_name="Brand X",
knowledge_base_ids=[],
target_keyword="test",
)
assert context == ""
@pytest.mark.asyncio
async def test_get_knowledge_context_rag_failure_returns_empty(self):
"""Test _get_knowledge_context returns empty string when RAG search fails."""
from app.services.content.content_generation_service import (
ContentGenerationService,
)
service = ContentGenerationService()
mock_db = AsyncMock()
mock_rag = MagicMock()
mock_rag.search = AsyncMock(side_effect=Exception("RAG service down"))
with patch(
"app.services.knowledge.rag_service.RAGService",
return_value=mock_rag,
):
context = await service._get_knowledge_context(
db=mock_db,
brand_name="Brand X",
knowledge_base_ids=["kb-1"],
target_keyword="test",
)
assert context == ""
@pytest.mark.asyncio
async def test_generate_content_passes_correct_prompt_variables(self):
"""Test that the service passes correct variables to each prompt template stage."""
from app.services.content.content_generation_service import (
ContentGenerationService,
)
service = ContentGenerationService()
mock_provider = AsyncMock()
mock_provider.chat.side_effect = [
_make_llm_response("Raw content"),
_make_llm_response("De-AIed content"),
_make_llm_response("Optimized content"),
]
with patch.object(service, "_get_provider", return_value=mock_provider):
with patch(
"app.services.content.content_generation_service.CONTENT_GENERATOR_TEMPLATE"
) as mock_gen_template, patch(
"app.services.content.content_generation_service.DEAI_TEMPLATE"
) as mock_deai_template, patch(
"app.services.content.content_generation_service.GEO_OPTIMIZER_TEMPLATE"
) as mock_geo_template:
mock_gen_template.render.return_value = [
{"role": "user", "content": "gen prompt"}
]
mock_deai_template.render.return_value = [
{"role": "user", "content": "deai prompt"}
]
mock_geo_template.render.return_value = [
{"role": "user", "content": "geo prompt"}
]
await service.generate_content(
keyword="AI搜索",
brand_name="Brand X",
platform="微信公众号",
content_style="专业严谨",
word_count=3000,
knowledge_context="Some context",
)
# Verify CONTENT_GENERATOR_TEMPLATE.render called with correct variables
gen_call_kwargs = mock_gen_template.render.call_args[0][0]
assert gen_call_kwargs["topic_title"] == "AI搜索"
assert gen_call_kwargs["target_keyword"] == "AI搜索"
assert gen_call_kwargs["target_platform"] == "微信公众号"
assert gen_call_kwargs["content_style"] == "专业严谨"
assert gen_call_kwargs["word_count"] == "3000"
assert gen_call_kwargs["brand_name"] == "Brand X"
assert gen_call_kwargs["knowledge_context"] == "Some context"
# Verify DEAI_TEMPLATE.render called with the generated content
deai_call_kwargs = mock_deai_template.render.call_args[0][0]
assert deai_call_kwargs["original_content"] == "Raw content"
# Verify GEO_OPTIMIZER_TEMPLATE.render called with de-AIed content
geo_call_kwargs = mock_geo_template.render.call_args[0][0]
assert geo_call_kwargs["original_content"] == "De-AIed content"
assert geo_call_kwargs["target_keywords"] == "AI搜索"
assert geo_call_kwargs["target_platform"] == "微信公众号"
@pytest.mark.asyncio
async def test_generate_content_default_parameters(self):
"""Test that default parameter values are applied correctly."""
from app.services.content.content_generation_service import (
ContentGenerationService,
)
service = ContentGenerationService()
mock_provider = AsyncMock()
mock_provider.chat.side_effect = [
_make_llm_response("Raw"),
_make_llm_response("De-AIed"),
_make_llm_response("Optimized"),
]
with patch.object(service, "_get_provider", return_value=mock_provider):
result = await service.generate_content(
keyword="test",
brand_name="Brand X",
)
# Defaults: platform="通用", content_style="专业严谨", word_count=2000
# run_deai=True, run_geo=True
assert result is not None
assert len(result["pipeline_stages"]) == 3