"""Tests for ContentGenerationService - extracted from api/content.py TDD RED phase: tests for the service that extracts the 3-stage content generation flow (generate -> de-AI -> GEO optimize) out of the API handler. """ import uuid from unittest.mock import AsyncMock, MagicMock, patch import pytest from app.services.llm.base import LLMError, LLMResponse def _make_llm_response(content: str) -> LLMResponse: """Helper to create LLMResponse objects for mocking.""" return LLMResponse(content=content, model="test-model") class TestContentGenerationService: """ContentGenerationService unit tests.""" @pytest.mark.asyncio async def test_generate_content_basic_three_stages(self): """Test basic 3-stage content generation (generate -> de-AI -> GEO optimize).""" from app.services.content.content_generation_service import ( ContentGenerationService, ) service = ContentGenerationService() mock_provider = AsyncMock() mock_provider.chat.side_effect = [ _make_llm_response("Raw generated content"), _make_llm_response("De-AIed content"), _make_llm_response("GEO optimized content"), ] with patch.object(service, "_get_provider", return_value=mock_provider): result = await service.generate_content( keyword="AI搜索", brand_name="Brand X", platform="wenxin", content_style="专业严谨", word_count=2000, run_deai=True, run_geo=True, ) assert result is not None assert result["content"] == "De-AIed content" assert result["optimized_content"] == "GEO optimized content" assert result["pipeline_stages"] is not None assert len(result["pipeline_stages"]) == 3 assert result["pipeline_stages"][0]["stage"] == "content_generation" assert result["pipeline_stages"][1]["stage"] == "deai" assert result["pipeline_stages"][2]["stage"] == "geo_optimization" @pytest.mark.asyncio async def test_generate_content_skip_deai(self): """Test generation with run_deai=False skips the de-AI stage.""" from app.services.content.content_generation_service import ( ContentGenerationService, ) service = ContentGenerationService() mock_provider = AsyncMock() mock_provider.chat.side_effect = [ _make_llm_response("Raw generated content"), _make_llm_response("GEO optimized content"), ] with patch.object(service, "_get_provider", return_value=mock_provider): result = await service.generate_content( keyword="test", brand_name="Brand X", platform="wenxin", run_deai=False, run_geo=True, ) assert result["content"] == "Raw generated content" assert result["optimized_content"] == "GEO optimized content" assert len(result["pipeline_stages"]) == 2 stage_names = [s["stage"] for s in result["pipeline_stages"]] assert "deai" not in stage_names @pytest.mark.asyncio async def test_generate_content_skip_geo(self): """Test generation with run_geo=False skips the GEO optimization stage.""" from app.services.content.content_generation_service import ( ContentGenerationService, ) service = ContentGenerationService() mock_provider = AsyncMock() mock_provider.chat.side_effect = [ _make_llm_response("Raw generated content"), _make_llm_response("De-AIed content"), ] with patch.object(service, "_get_provider", return_value=mock_provider): result = await service.generate_content( keyword="test", brand_name="Brand X", platform="wenxin", run_deai=True, run_geo=False, ) assert result["content"] == "De-AIed content" assert result["optimized_content"] == "De-AIed content" assert len(result["pipeline_stages"]) == 2 stage_names = [s["stage"] for s in result["pipeline_stages"]] assert "geo_optimization" not in stage_names @pytest.mark.asyncio async def test_generate_content_skip_both_stages(self): """Test generation with both run_deai=False and run_geo=False.""" from app.services.content.content_generation_service import ( ContentGenerationService, ) service = ContentGenerationService() mock_provider = AsyncMock() mock_provider.chat.side_effect = [ _make_llm_response("Raw generated content"), ] with patch.object(service, "_get_provider", return_value=mock_provider): result = await service.generate_content( keyword="test", brand_name="Brand X", platform="wenxin", run_deai=False, run_geo=False, ) assert result["content"] == "Raw generated content" assert result["optimized_content"] == "Raw generated content" assert len(result["pipeline_stages"]) == 1 @pytest.mark.asyncio async def test_generate_content_with_knowledge_context(self): """Test that knowledge_context is passed through to the generation prompt.""" from app.services.content.content_generation_service import ( ContentGenerationService, ) service = ContentGenerationService() mock_provider = AsyncMock() mock_provider.chat.side_effect = [ _make_llm_response("Raw content"), _make_llm_response("De-AIed"), _make_llm_response("Optimized"), ] with patch.object(service, "_get_provider", return_value=mock_provider): result = await service.generate_content( keyword="test keyword", brand_name="Brand X", platform="wenxin", knowledge_context="Some knowledge base content", ) # The first call to provider.chat should include the knowledge context first_call_args = mock_provider.chat.call_args_list[0] messages = first_call_args[0][0] # Verify knowledge context appears in the rendered messages all_content = " ".join(str(m) for m in messages) assert "Some knowledge base content" in all_content @pytest.mark.asyncio async def test_generate_content_saves_to_database(self): """Test that generated content is saved to database when db and user_id are provided.""" from app.services.content.content_generation_service import ( ContentGenerationService, ) service = ContentGenerationService() mock_provider = AsyncMock() mock_provider.chat.side_effect = [ _make_llm_response("Raw content"), _make_llm_response("De-AIed"), _make_llm_response("Optimized"), ] mock_db = AsyncMock() with patch.object(service, "_get_provider", return_value=mock_provider): result = await service.generate_content( keyword="test", brand_name="Brand X", platform="wenxin", db=mock_db, user_id="test-user-123", org_id=str(uuid.uuid4()), ) # Verify db.add was called (Content + ContentVersion) assert mock_db.add.call_count == 2 mock_db.commit.assert_called_once() @pytest.mark.asyncio async def test_generate_content_no_db_no_save(self): """Test that when db is not provided, no database operations occur.""" from app.services.content.content_generation_service import ( ContentGenerationService, ) service = ContentGenerationService() mock_provider = AsyncMock() mock_provider.chat.side_effect = [ _make_llm_response("Raw"), _make_llm_response("De-AIed"), _make_llm_response("Optimized"), ] with patch.object(service, "_get_provider", return_value=mock_provider): result = await service.generate_content( keyword="test", brand_name="Brand X", platform="wenxin", ) # No content_id when db is not provided assert result.get("content_id") is None @pytest.mark.asyncio async def test_generate_content_llm_error_propagates(self): """Test that LLMError from the provider propagates correctly.""" from app.services.content.content_generation_service import ( ContentGenerationService, ) service = ContentGenerationService() mock_provider = AsyncMock() mock_provider.chat.side_effect = LLMError( "API rate limit", provider="openai" ) with patch.object(service, "_get_provider", return_value=mock_provider): with pytest.raises(LLMError, match="API rate limit"): await service.generate_content( keyword="test", brand_name="Brand X", platform="wenxin", ) @pytest.mark.asyncio async def test_get_knowledge_context_with_ids(self): """Test _get_knowledge_context retrieves context when knowledge_base_ids are provided.""" from app.services.content.content_generation_service import ( ContentGenerationService, ) service = ContentGenerationService() mock_db = AsyncMock() mock_rag = MagicMock() mock_rag.search = AsyncMock(return_value=[ {"content": "Knowledge chunk 1", "document_title": "Doc A"}, {"content": "Knowledge chunk 2", "document_title": "Doc B"}, ]) with patch( "app.services.knowledge.rag_service.RAGService", return_value=mock_rag, ): context = await service._get_knowledge_context( db=mock_db, brand_name="Brand X", knowledge_base_ids=["kb-1"], target_keyword="test keyword", ) assert "Knowledge chunk 1" in context assert "Knowledge chunk 2" in context assert "Doc A" in context @pytest.mark.asyncio async def test_get_knowledge_context_empty_ids(self): """Test _get_knowledge_context returns empty string when no IDs provided.""" from app.services.content.content_generation_service import ( ContentGenerationService, ) service = ContentGenerationService() mock_db = AsyncMock() context = await service._get_knowledge_context( db=mock_db, brand_name="Brand X", knowledge_base_ids=[], target_keyword="test", ) assert context == "" @pytest.mark.asyncio async def test_get_knowledge_context_rag_failure_returns_empty(self): """Test _get_knowledge_context returns empty string when RAG search fails.""" from app.services.content.content_generation_service import ( ContentGenerationService, ) service = ContentGenerationService() mock_db = AsyncMock() mock_rag = MagicMock() mock_rag.search = AsyncMock(side_effect=Exception("RAG service down")) with patch( "app.services.knowledge.rag_service.RAGService", return_value=mock_rag, ): context = await service._get_knowledge_context( db=mock_db, brand_name="Brand X", knowledge_base_ids=["kb-1"], target_keyword="test", ) assert context == "" @pytest.mark.asyncio async def test_generate_content_passes_correct_prompt_variables(self): """Test that the service passes correct variables to each prompt template stage.""" from app.services.content.content_generation_service import ( ContentGenerationService, ) service = ContentGenerationService() mock_provider = AsyncMock() mock_provider.chat.side_effect = [ _make_llm_response("Raw content"), _make_llm_response("De-AIed content"), _make_llm_response("Optimized content"), ] with patch.object(service, "_get_provider", return_value=mock_provider): with patch( "app.services.content.content_generation_service.CONTENT_GENERATOR_TEMPLATE" ) as mock_gen_template, patch( "app.services.content.content_generation_service.DEAI_TEMPLATE" ) as mock_deai_template, patch( "app.services.content.content_generation_service.GEO_OPTIMIZER_TEMPLATE" ) as mock_geo_template: mock_gen_template.render.return_value = [ {"role": "user", "content": "gen prompt"} ] mock_deai_template.render.return_value = [ {"role": "user", "content": "deai prompt"} ] mock_geo_template.render.return_value = [ {"role": "user", "content": "geo prompt"} ] await service.generate_content( keyword="AI搜索", brand_name="Brand X", platform="微信公众号", content_style="专业严谨", word_count=3000, knowledge_context="Some context", ) # Verify CONTENT_GENERATOR_TEMPLATE.render called with correct variables gen_call_kwargs = mock_gen_template.render.call_args[0][0] assert gen_call_kwargs["topic_title"] == "AI搜索" assert gen_call_kwargs["target_keyword"] == "AI搜索" assert gen_call_kwargs["target_platform"] == "微信公众号" assert gen_call_kwargs["content_style"] == "专业严谨" assert gen_call_kwargs["word_count"] == "3000" assert gen_call_kwargs["brand_name"] == "Brand X" assert gen_call_kwargs["knowledge_context"] == "Some context" # Verify DEAI_TEMPLATE.render called with the generated content deai_call_kwargs = mock_deai_template.render.call_args[0][0] assert deai_call_kwargs["original_content"] == "Raw content" # Verify GEO_OPTIMIZER_TEMPLATE.render called with de-AIed content geo_call_kwargs = mock_geo_template.render.call_args[0][0] assert geo_call_kwargs["original_content"] == "De-AIed content" assert geo_call_kwargs["target_keywords"] == "AI搜索" assert geo_call_kwargs["target_platform"] == "微信公众号" @pytest.mark.asyncio async def test_generate_content_default_parameters(self): """Test that default parameter values are applied correctly.""" from app.services.content.content_generation_service import ( ContentGenerationService, ) service = ContentGenerationService() mock_provider = AsyncMock() mock_provider.chat.side_effect = [ _make_llm_response("Raw"), _make_llm_response("De-AIed"), _make_llm_response("Optimized"), ] with patch.object(service, "_get_provider", return_value=mock_provider): result = await service.generate_content( keyword="test", brand_name="Brand X", ) # Defaults: platform="通用", content_style="专业严谨", word_count=2000 # run_deai=True, run_geo=True assert result is not None assert len(result["pipeline_stages"]) == 3