358 lines
13 KiB
Python
358 lines
13 KiB
Python
"""Agent执行逻辑单元测试 - ContentGeneratorAgent / DeAIAgent / GEOOptimizerAgent
|
||
|
||
测试策略:
|
||
- 使用 FakeLLMProvider mock LLM 调用,避免真实网络请求
|
||
- patch BaseAgent.report_progress 避免 Redis / 数据库依赖
|
||
- patch RAGService / AsyncSessionLocal 避免真实数据库访问
|
||
"""
|
||
|
||
import json
|
||
import uuid
|
||
from datetime import datetime, timezone
|
||
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
||
import pytest
|
||
|
||
from app.agent_framework.agents.content_generator_agent import ContentGeneratorAgent
|
||
from app.agent_framework.agents.deai_agent import DeAIAgent
|
||
from app.agent_framework.agents.geo_optimizer_agent import GEOOptimizerAgent
|
||
from app.agent_framework.protocol import TaskMessage
|
||
from app.services.llm import LLMProvider, LLMResponse, LLMError
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# FakeLLMProvider - 测试用假LLM
|
||
# ---------------------------------------------------------------------------
|
||
class FakeLLMProvider(LLMProvider):
|
||
"""测试用假LLM,返回预设响应"""
|
||
|
||
def __init__(self, response_content: str = "fake response"):
|
||
self._response = response_content
|
||
|
||
@property
|
||
def provider_name(self) -> str:
|
||
return "fake"
|
||
|
||
@property
|
||
def model_name(self) -> str:
|
||
return "fake-model"
|
||
|
||
@property
|
||
def max_context_length(self) -> int:
|
||
return 4096
|
||
|
||
async def chat(self, messages, **kwargs) -> LLMResponse:
|
||
return LLMResponse(
|
||
content=self._response,
|
||
model="fake-model",
|
||
usage={"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
|
||
)
|
||
|
||
async def chat_stream(self, messages, **kwargs):
|
||
for word in self._response.split():
|
||
yield word + " "
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 测试辅助函数
|
||
# ---------------------------------------------------------------------------
|
||
def _make_task(task_type: str, input_data: dict) -> TaskMessage:
|
||
return TaskMessage(
|
||
task_id=str(uuid.uuid4()),
|
||
agent_name="test_agent",
|
||
task_type=task_type,
|
||
priority=1,
|
||
input_data=input_data,
|
||
callback_url=None,
|
||
created_at=datetime.now(timezone.utc),
|
||
timeout_seconds=300,
|
||
)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# ContentGeneratorAgent 测试
|
||
# ---------------------------------------------------------------------------
|
||
@pytest.mark.asyncio
|
||
async def test_generate_topics_returns_parsed_json():
|
||
"""FakeLLM返回JSON数组,验证topics字段正确解析"""
|
||
agent = ContentGeneratorAgent()
|
||
fake_llm = FakeLLMProvider(
|
||
response_content='[{"title": "AI营销趋势", "reason": "热门话题"}]'
|
||
)
|
||
|
||
with patch.object(agent, "report_progress", new_callable=AsyncMock):
|
||
with patch(
|
||
"app.agent_framework.agents.content_generator_agent.LLMFactory.get_default",
|
||
return_value=fake_llm,
|
||
):
|
||
task = _make_task("generate_topics", {"target_keyword": "AI营销"})
|
||
result = await agent.execute(task)
|
||
|
||
assert result.status == "completed"
|
||
assert "topics" in result.output_data
|
||
topics = result.output_data["topics"]
|
||
assert isinstance(topics, list)
|
||
assert topics[0]["title"] == "AI营销趋势"
|
||
assert topics[0]["reason"] == "热门话题"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_generate_article_success():
|
||
"""验证返回content字段"""
|
||
agent = ContentGeneratorAgent()
|
||
fake_llm = FakeLLMProvider(response_content="这是一篇测试文章")
|
||
|
||
with patch.object(agent, "report_progress", new_callable=AsyncMock):
|
||
with patch(
|
||
"app.agent_framework.agents.content_generator_agent.LLMFactory.get_default",
|
||
return_value=fake_llm,
|
||
):
|
||
task = _make_task("generate_article", {"target_keyword": "AI营销"})
|
||
result = await agent.execute(task)
|
||
|
||
assert result.status == "completed"
|
||
assert result.output_data["content"] == "这是一篇测试文章"
|
||
assert result.output_data["word_count"] == len("这是一篇测试文章")
|
||
assert "usage" in result.output_data
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_generate_with_rag_context():
|
||
"""Mock RAGService,验证知识上下文被注入"""
|
||
agent = ContentGeneratorAgent()
|
||
fake_llm = FakeLLMProvider(
|
||
response_content='[{"title": "RAG测试选题", "reason": "测试"}]'
|
||
)
|
||
|
||
# Mock AsyncSessionLocal 上下文管理器
|
||
mock_session = AsyncMock()
|
||
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||
mock_session.__aexit__ = AsyncMock(return_value=False)
|
||
|
||
mock_local = MagicMock()
|
||
mock_local.return_value.__aenter__ = AsyncMock(return_value=mock_session)
|
||
mock_local.return_value.__aexit__ = AsyncMock(return_value=False)
|
||
|
||
with patch.object(agent, "report_progress", new_callable=AsyncMock):
|
||
with patch("app.database.AsyncSessionLocal", mock_local):
|
||
with patch(
|
||
"app.services.knowledge.rag_service.RAGService"
|
||
) as MockRAG:
|
||
mock_rag = MockRAG.return_value
|
||
mock_rag.search = AsyncMock(
|
||
return_value=[
|
||
{"document_title": "知识库文档", "content": "相关知识内容"}
|
||
]
|
||
)
|
||
with patch(
|
||
"app.agent_framework.agents.content_generator_agent.LLMFactory.get_default",
|
||
return_value=fake_llm,
|
||
):
|
||
task = _make_task(
|
||
"generate_topics",
|
||
{"target_keyword": "AI营销", "knowledge_base_ids": ["kb-1"]},
|
||
)
|
||
result = await agent.execute(task)
|
||
|
||
assert result.status == "completed"
|
||
mock_rag.search.assert_awaited_once()
|
||
# 验证 search 调用参数
|
||
call_kwargs = mock_rag.search.call_args.kwargs
|
||
assert call_kwargs["query"] == "AI营销"
|
||
assert call_kwargs["knowledge_base_ids"] == ["kb-1"]
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_llm_error_returns_failed():
|
||
"""Mock LLM抛出LLMError,验证返回failed状态"""
|
||
agent = ContentGeneratorAgent()
|
||
|
||
class ErrorLLM(FakeLLMProvider):
|
||
async def chat(self, messages, **kwargs) -> LLMResponse:
|
||
raise LLMError("API错误", provider="fake", status_code=500)
|
||
|
||
error_llm = ErrorLLM()
|
||
|
||
with patch.object(agent, "report_progress", new_callable=AsyncMock):
|
||
with patch(
|
||
"app.agent_framework.agents.content_generator_agent.LLMFactory.get_default",
|
||
return_value=error_llm,
|
||
):
|
||
task = _make_task("generate_topics", {"target_keyword": "AI营销"})
|
||
result = await agent.execute(task)
|
||
|
||
assert result.status == "failed"
|
||
assert "LLM调用失败" in result.error_message
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_extract_json_from_code_block():
|
||
"""测试```json```包裹的JSON提取"""
|
||
agent = ContentGeneratorAgent()
|
||
text = '```json\n[{"title": "测试"}]\n```'
|
||
result = agent._extract_json(text)
|
||
assert result == '[{"title": "测试"}]'
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# DeAIAgent 测试
|
||
# ---------------------------------------------------------------------------
|
||
@pytest.mark.asyncio
|
||
async def test_deai_success():
|
||
"""正常处理返回success"""
|
||
agent = DeAIAgent()
|
||
fake_llm = FakeLLMProvider(response_content="去AI化后的内容")
|
||
|
||
with patch.object(agent, "report_progress", new_callable=AsyncMock):
|
||
with patch(
|
||
"app.agent_framework.agents.deai_agent.LLMFactory.get_default",
|
||
return_value=fake_llm,
|
||
):
|
||
task = _make_task("deai_process", {"content": "原始的AI生成内容"})
|
||
result = await agent.execute(task)
|
||
|
||
assert result.status == "completed"
|
||
assert result.output_data["content"] == "去AI化后的内容"
|
||
assert result.output_data["original_word_count"] == len("原始的AI生成内容")
|
||
assert result.output_data["processed_word_count"] == len("去AI化后的内容")
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_deai_empty_content_fails():
|
||
"""空content返回failed"""
|
||
agent = DeAIAgent()
|
||
fake_llm = FakeLLMProvider(response_content="something")
|
||
|
||
with patch.object(agent, "report_progress", new_callable=AsyncMock):
|
||
with patch(
|
||
"app.agent_framework.agents.deai_agent.LLMFactory.get_default",
|
||
return_value=fake_llm,
|
||
):
|
||
task = _make_task("deai_process", {"content": ""})
|
||
result = await agent.execute(task)
|
||
|
||
assert result.status == "failed"
|
||
# ValueError 会被外层 except 捕获,error_message 包含原始异常信息
|
||
assert "content" in result.error_message.lower() or "input_data" in result.error_message.lower()
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_deai_temperature_is_high():
|
||
"""验证调用LLM时temperature=0.9"""
|
||
agent = DeAIAgent()
|
||
mock_provider = AsyncMock()
|
||
mock_provider.chat = AsyncMock(
|
||
return_value=LLMResponse(
|
||
content="processed",
|
||
model="fake",
|
||
usage={"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
|
||
)
|
||
)
|
||
|
||
with patch.object(agent, "report_progress", new_callable=AsyncMock):
|
||
with patch(
|
||
"app.agent_framework.agents.deai_agent.LLMFactory.get_default",
|
||
return_value=mock_provider,
|
||
):
|
||
task = _make_task("deai_process", {"content": "some content"})
|
||
await agent.execute(task)
|
||
|
||
mock_provider.chat.assert_awaited_once()
|
||
_, kwargs = mock_provider.chat.call_args
|
||
assert kwargs.get("temperature") == 0.9
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# GEOOptimizerAgent 测试
|
||
# ---------------------------------------------------------------------------
|
||
@pytest.mark.asyncio
|
||
async def test_geo_optimize_json_response():
|
||
"""FakeLLM返回标准JSON,验证解析"""
|
||
agent = GEOOptimizerAgent()
|
||
fake_llm = FakeLLMProvider(
|
||
response_content=json.dumps(
|
||
{
|
||
"optimized_content": "优化后的文章",
|
||
"seo_score": 85,
|
||
"changes": ["优化了标题"],
|
||
}
|
||
)
|
||
)
|
||
|
||
with patch.object(agent, "report_progress", new_callable=AsyncMock):
|
||
with patch(
|
||
"app.agent_framework.agents.geo_optimizer_agent.LLMFactory.get_default",
|
||
return_value=fake_llm,
|
||
):
|
||
task = _make_task(
|
||
"geo_optimize",
|
||
{"content": "原始文章", "target_keywords": ["SEO"]},
|
||
)
|
||
result = await agent.execute(task)
|
||
|
||
assert result.status == "completed"
|
||
assert result.output_data["optimized_content"] == "优化后的文章"
|
||
assert result.output_data["seo_score"] == 85
|
||
assert result.output_data["changes"] == ["优化了标题"]
|
||
assert "usage" in result.output_data
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_geo_optimize_fallback():
|
||
"""FakeLLM返回纯文本,验证降级处理"""
|
||
agent = GEOOptimizerAgent()
|
||
fake_llm = FakeLLMProvider(response_content="这不是JSON格式")
|
||
|
||
with patch.object(agent, "report_progress", new_callable=AsyncMock):
|
||
with patch(
|
||
"app.agent_framework.agents.geo_optimizer_agent.LLMFactory.get_default",
|
||
return_value=fake_llm,
|
||
):
|
||
task = _make_task(
|
||
"geo_optimize",
|
||
{"content": "原始文章", "target_keywords": ["SEO"]},
|
||
)
|
||
result = await agent.execute(task)
|
||
|
||
assert result.status == "completed"
|
||
assert result.output_data["optimized_content"] == "这不是JSON格式"
|
||
assert result.output_data["seo_score"] is None
|
||
assert result.output_data["changes"] == ["LLM输出非标准格式,已返回原始优化结果"]
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_geo_optimize_keywords_in_prompt():
|
||
"""验证关键词出现在渲染后的prompt variables中"""
|
||
agent = GEOOptimizerAgent()
|
||
|
||
with patch.object(agent, "report_progress", new_callable=AsyncMock):
|
||
with patch(
|
||
"app.agent_framework.agents.geo_optimizer_agent.GEO_OPTIMIZER_TEMPLATE.render"
|
||
) as mock_render:
|
||
mock_render.return_value = [
|
||
{"role": "system", "content": "test prompt"}
|
||
]
|
||
with patch(
|
||
"app.agent_framework.agents.geo_optimizer_agent.LLMFactory.get_default"
|
||
) as mock_factory:
|
||
mock_provider = AsyncMock()
|
||
mock_provider.chat = AsyncMock(
|
||
return_value=LLMResponse(
|
||
content=json.dumps({"optimized_content": "test"}),
|
||
model="fake",
|
||
usage={},
|
||
)
|
||
)
|
||
mock_factory.return_value = mock_provider
|
||
|
||
task = _make_task(
|
||
"geo_optimize",
|
||
{"content": "原始文章", "target_keywords": ["SEO", "GEO优化"]},
|
||
)
|
||
await agent.execute(task)
|
||
|
||
mock_render.assert_called_once()
|
||
variables = mock_render.call_args[0][0]
|
||
assert "SEO" in variables["target_keywords"]
|
||
assert "GEO优化" in variables["target_keywords"]
|