geo/tests/test_content_agents.py

358 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Agent执行逻辑单元测试 - ContentGeneratorAgent / DeAIAgent / GEOOptimizerAgent
测试策略:
- 使用 FakeLLMProvider mock LLM 调用,避免真实网络请求
- patch BaseAgent.report_progress 避免 Redis / 数据库依赖
- patch RAGService / AsyncSessionLocal 避免真实数据库访问
"""
import json
import uuid
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.agent_framework.agents.content_generator_agent import ContentGeneratorAgent
from app.agent_framework.agents.deai_agent import DeAIAgent
from app.agent_framework.agents.geo_optimizer_agent import GEOOptimizerAgent
from app.agent_framework.protocol import TaskMessage
from app.services.llm import LLMProvider, LLMResponse, LLMError
# ---------------------------------------------------------------------------
# FakeLLMProvider - 测试用假LLM
# ---------------------------------------------------------------------------
class FakeLLMProvider(LLMProvider):
"""测试用假LLM返回预设响应"""
def __init__(self, response_content: str = "fake response"):
self._response = response_content
@property
def provider_name(self) -> str:
return "fake"
@property
def model_name(self) -> str:
return "fake-model"
@property
def max_context_length(self) -> int:
return 4096
async def chat(self, messages, **kwargs) -> LLMResponse:
return LLMResponse(
content=self._response,
model="fake-model",
usage={"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
)
async def chat_stream(self, messages, **kwargs):
for word in self._response.split():
yield word + " "
# ---------------------------------------------------------------------------
# 测试辅助函数
# ---------------------------------------------------------------------------
def _make_task(task_type: str, input_data: dict) -> TaskMessage:
return TaskMessage(
task_id=str(uuid.uuid4()),
agent_name="test_agent",
task_type=task_type,
priority=1,
input_data=input_data,
callback_url=None,
created_at=datetime.now(timezone.utc),
timeout_seconds=300,
)
# ---------------------------------------------------------------------------
# ContentGeneratorAgent 测试
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_generate_topics_returns_parsed_json():
"""FakeLLM返回JSON数组验证topics字段正确解析"""
agent = ContentGeneratorAgent()
fake_llm = FakeLLMProvider(
response_content='[{"title": "AI营销趋势", "reason": "热门话题"}]'
)
with patch.object(agent, "report_progress", new_callable=AsyncMock):
with patch(
"app.agent_framework.agents.content_generator_agent.LLMFactory.get_default",
return_value=fake_llm,
):
task = _make_task("generate_topics", {"target_keyword": "AI营销"})
result = await agent.execute(task)
assert result.status == "completed"
assert "topics" in result.output_data
topics = result.output_data["topics"]
assert isinstance(topics, list)
assert topics[0]["title"] == "AI营销趋势"
assert topics[0]["reason"] == "热门话题"
@pytest.mark.asyncio
async def test_generate_article_success():
"""验证返回content字段"""
agent = ContentGeneratorAgent()
fake_llm = FakeLLMProvider(response_content="这是一篇测试文章")
with patch.object(agent, "report_progress", new_callable=AsyncMock):
with patch(
"app.agent_framework.agents.content_generator_agent.LLMFactory.get_default",
return_value=fake_llm,
):
task = _make_task("generate_article", {"target_keyword": "AI营销"})
result = await agent.execute(task)
assert result.status == "completed"
assert result.output_data["content"] == "这是一篇测试文章"
assert result.output_data["word_count"] == len("这是一篇测试文章")
assert "usage" in result.output_data
@pytest.mark.asyncio
async def test_generate_with_rag_context():
"""Mock RAGService验证知识上下文被注入"""
agent = ContentGeneratorAgent()
fake_llm = FakeLLMProvider(
response_content='[{"title": "RAG测试选题", "reason": "测试"}]'
)
# Mock AsyncSessionLocal 上下文管理器
mock_session = AsyncMock()
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
mock_session.__aexit__ = AsyncMock(return_value=False)
mock_local = MagicMock()
mock_local.return_value.__aenter__ = AsyncMock(return_value=mock_session)
mock_local.return_value.__aexit__ = AsyncMock(return_value=False)
with patch.object(agent, "report_progress", new_callable=AsyncMock):
with patch("app.database.AsyncSessionLocal", mock_local):
with patch(
"app.services.knowledge.rag_service.RAGService"
) as MockRAG:
mock_rag = MockRAG.return_value
mock_rag.search = AsyncMock(
return_value=[
{"document_title": "知识库文档", "content": "相关知识内容"}
]
)
with patch(
"app.agent_framework.agents.content_generator_agent.LLMFactory.get_default",
return_value=fake_llm,
):
task = _make_task(
"generate_topics",
{"target_keyword": "AI营销", "knowledge_base_ids": ["kb-1"]},
)
result = await agent.execute(task)
assert result.status == "completed"
mock_rag.search.assert_awaited_once()
# 验证 search 调用参数
call_kwargs = mock_rag.search.call_args.kwargs
assert call_kwargs["query"] == "AI营销"
assert call_kwargs["knowledge_base_ids"] == ["kb-1"]
@pytest.mark.asyncio
async def test_llm_error_returns_failed():
"""Mock LLM抛出LLMError验证返回failed状态"""
agent = ContentGeneratorAgent()
class ErrorLLM(FakeLLMProvider):
async def chat(self, messages, **kwargs) -> LLMResponse:
raise LLMError("API错误", provider="fake", status_code=500)
error_llm = ErrorLLM()
with patch.object(agent, "report_progress", new_callable=AsyncMock):
with patch(
"app.agent_framework.agents.content_generator_agent.LLMFactory.get_default",
return_value=error_llm,
):
task = _make_task("generate_topics", {"target_keyword": "AI营销"})
result = await agent.execute(task)
assert result.status == "failed"
assert "LLM调用失败" in result.error_message
@pytest.mark.asyncio
async def test_extract_json_from_code_block():
"""测试```json```包裹的JSON提取"""
agent = ContentGeneratorAgent()
text = '```json\n[{"title": "测试"}]\n```'
result = agent._extract_json(text)
assert result == '[{"title": "测试"}]'
# ---------------------------------------------------------------------------
# DeAIAgent 测试
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_deai_success():
"""正常处理返回success"""
agent = DeAIAgent()
fake_llm = FakeLLMProvider(response_content="去AI化后的内容")
with patch.object(agent, "report_progress", new_callable=AsyncMock):
with patch(
"app.agent_framework.agents.deai_agent.LLMFactory.get_default",
return_value=fake_llm,
):
task = _make_task("deai_process", {"content": "原始的AI生成内容"})
result = await agent.execute(task)
assert result.status == "completed"
assert result.output_data["content"] == "去AI化后的内容"
assert result.output_data["original_word_count"] == len("原始的AI生成内容")
assert result.output_data["processed_word_count"] == len("去AI化后的内容")
@pytest.mark.asyncio
async def test_deai_empty_content_fails():
"""空content返回failed"""
agent = DeAIAgent()
fake_llm = FakeLLMProvider(response_content="something")
with patch.object(agent, "report_progress", new_callable=AsyncMock):
with patch(
"app.agent_framework.agents.deai_agent.LLMFactory.get_default",
return_value=fake_llm,
):
task = _make_task("deai_process", {"content": ""})
result = await agent.execute(task)
assert result.status == "failed"
# ValueError 会被外层 except 捕获error_message 包含原始异常信息
assert "content" in result.error_message.lower() or "input_data" in result.error_message.lower()
@pytest.mark.asyncio
async def test_deai_temperature_is_high():
"""验证调用LLM时temperature=0.9"""
agent = DeAIAgent()
mock_provider = AsyncMock()
mock_provider.chat = AsyncMock(
return_value=LLMResponse(
content="processed",
model="fake",
usage={"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
)
)
with patch.object(agent, "report_progress", new_callable=AsyncMock):
with patch(
"app.agent_framework.agents.deai_agent.LLMFactory.get_default",
return_value=mock_provider,
):
task = _make_task("deai_process", {"content": "some content"})
await agent.execute(task)
mock_provider.chat.assert_awaited_once()
_, kwargs = mock_provider.chat.call_args
assert kwargs.get("temperature") == 0.9
# ---------------------------------------------------------------------------
# GEOOptimizerAgent 测试
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_geo_optimize_json_response():
"""FakeLLM返回标准JSON验证解析"""
agent = GEOOptimizerAgent()
fake_llm = FakeLLMProvider(
response_content=json.dumps(
{
"optimized_content": "优化后的文章",
"seo_score": 85,
"changes": ["优化了标题"],
}
)
)
with patch.object(agent, "report_progress", new_callable=AsyncMock):
with patch(
"app.agent_framework.agents.geo_optimizer_agent.LLMFactory.get_default",
return_value=fake_llm,
):
task = _make_task(
"geo_optimize",
{"content": "原始文章", "target_keywords": ["SEO"]},
)
result = await agent.execute(task)
assert result.status == "completed"
assert result.output_data["optimized_content"] == "优化后的文章"
assert result.output_data["seo_score"] == 85
assert result.output_data["changes"] == ["优化了标题"]
assert "usage" in result.output_data
@pytest.mark.asyncio
async def test_geo_optimize_fallback():
"""FakeLLM返回纯文本验证降级处理"""
agent = GEOOptimizerAgent()
fake_llm = FakeLLMProvider(response_content="这不是JSON格式")
with patch.object(agent, "report_progress", new_callable=AsyncMock):
with patch(
"app.agent_framework.agents.geo_optimizer_agent.LLMFactory.get_default",
return_value=fake_llm,
):
task = _make_task(
"geo_optimize",
{"content": "原始文章", "target_keywords": ["SEO"]},
)
result = await agent.execute(task)
assert result.status == "completed"
assert result.output_data["optimized_content"] == "这不是JSON格式"
assert result.output_data["seo_score"] is None
assert result.output_data["changes"] == ["LLM输出非标准格式已返回原始优化结果"]
@pytest.mark.asyncio
async def test_geo_optimize_keywords_in_prompt():
"""验证关键词出现在渲染后的prompt variables中"""
agent = GEOOptimizerAgent()
with patch.object(agent, "report_progress", new_callable=AsyncMock):
with patch(
"app.agent_framework.agents.geo_optimizer_agent.GEO_OPTIMIZER_TEMPLATE.render"
) as mock_render:
mock_render.return_value = [
{"role": "system", "content": "test prompt"}
]
with patch(
"app.agent_framework.agents.geo_optimizer_agent.LLMFactory.get_default"
) as mock_factory:
mock_provider = AsyncMock()
mock_provider.chat = AsyncMock(
return_value=LLMResponse(
content=json.dumps({"optimized_content": "test"}),
model="fake",
usage={},
)
)
mock_factory.return_value = mock_provider
task = _make_task(
"geo_optimize",
{"content": "原始文章", "target_keywords": ["SEO", "GEO优化"]},
)
await agent.execute(task)
mock_render.assert_called_once()
variables = mock_render.call_args[0][0]
assert "SEO" in variables["target_keywords"]
assert "GEO优化" in variables["target_keywords"]