225 lines
7.7 KiB
Python
225 lines
7.7 KiB
Python
"""Agent集成测试"""
|
||
import pytest
|
||
import uuid
|
||
from datetime import datetime, timezone
|
||
|
||
from app.agent_framework.agents.citation_detector import CitationDetectorAgent
|
||
from app.agent_framework.agents.content_generator_agent import ContentGeneratorAgent
|
||
from app.agent_framework.protocol import (
|
||
TaskMessage,
|
||
TaskStatus,
|
||
)
|
||
|
||
|
||
class TestCitationDetectorAgent:
|
||
"""引用检测Agent测试"""
|
||
|
||
def test_agent_initialization(self):
|
||
"""测试Agent初始化"""
|
||
agent = CitationDetectorAgent()
|
||
|
||
assert agent.name == "citation_detector"
|
||
assert agent.agent_type.value == "citation_detector"
|
||
assert agent.version == "1.0.0"
|
||
|
||
def test_get_capabilities(self):
|
||
"""测试获取Agent能力"""
|
||
agent = CitationDetectorAgent()
|
||
capability = agent.get_capabilities()
|
||
|
||
assert capability.agent_name == "citation_detector"
|
||
assert "citation_detect" in capability.supported_tasks
|
||
assert "citation_detect_single" in capability.supported_tasks
|
||
assert capability.max_concurrency == 3
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_execute_with_invalid_task_type(self):
|
||
"""测试执行不支持的任务类型"""
|
||
agent = CitationDetectorAgent()
|
||
task = TaskMessage(
|
||
task_id=str(uuid.uuid4()),
|
||
agent_name="citation_detector",
|
||
task_type="invalid_task_type",
|
||
priority=5,
|
||
input_data={},
|
||
callback_url=None,
|
||
created_at=datetime.now(timezone.utc),
|
||
)
|
||
|
||
result = await agent.execute(task)
|
||
|
||
assert result.status == TaskStatus.FAILED
|
||
assert result.error_message is not None
|
||
assert "Unsupported task type" in result.error_message
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_execute_single_detect_missing_params(self):
|
||
"""测试单平台检测缺少必需参数"""
|
||
agent = CitationDetectorAgent()
|
||
task = TaskMessage(
|
||
task_id=str(uuid.uuid4()),
|
||
agent_name="citation_detector",
|
||
task_type="citation_detect_single",
|
||
priority=5,
|
||
input_data={}, # 缺少keyword, platform, target_brand
|
||
callback_url=None,
|
||
created_at=datetime.now(timezone.utc),
|
||
)
|
||
|
||
result = await agent.execute(task)
|
||
|
||
assert result.status == TaskStatus.FAILED
|
||
assert result.error_message is not None
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_execute_full_detect_missing_query_id(self):
|
||
"""测试完整检测缺少query_id"""
|
||
agent = CitationDetectorAgent()
|
||
task = TaskMessage(
|
||
task_id=str(uuid.uuid4()),
|
||
agent_name="citation_detector",
|
||
task_type="citation_detect",
|
||
priority=5,
|
||
input_data={}, # 缺少query_id
|
||
callback_url=None,
|
||
created_at=datetime.now(timezone.utc),
|
||
)
|
||
|
||
result = await agent.execute(task)
|
||
|
||
assert result.status == TaskStatus.FAILED
|
||
assert "query_id" in result.error_message or "must contain" in result.error_message
|
||
|
||
def test_compatibility_methods_exist(self):
|
||
"""测试向后兼容方法存在"""
|
||
agent = CitationDetectorAgent()
|
||
|
||
assert hasattr(agent, 'execute_query_compat')
|
||
assert hasattr(agent, 'execute_single_platform_compat')
|
||
assert callable(agent.execute_query_compat)
|
||
assert callable(agent.execute_single_platform_compat)
|
||
|
||
|
||
class TestContentGeneratorAgent:
|
||
"""内容生成Agent测试"""
|
||
|
||
def test_agent_initialization(self):
|
||
"""测试Agent初始化"""
|
||
agent = ContentGeneratorAgent()
|
||
|
||
assert agent.name == "content_generator"
|
||
assert agent.agent_type.value == "content_generator"
|
||
assert agent.version == "1.0.0"
|
||
|
||
def test_get_capabilities(self):
|
||
"""测试获取Agent能力"""
|
||
agent = ContentGeneratorAgent()
|
||
capability = agent.get_capabilities()
|
||
|
||
assert capability.agent_name == "content_generator"
|
||
assert "generate_topics" in capability.supported_tasks
|
||
assert "generate_article" in capability.supported_tasks
|
||
assert capability.max_concurrency == 2
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_execute_with_invalid_task_type(self):
|
||
"""测试执行不支持的任务类型"""
|
||
agent = ContentGeneratorAgent()
|
||
task = TaskMessage(
|
||
task_id=str(uuid.uuid4()),
|
||
agent_name="content_generator",
|
||
task_type="invalid_task_type",
|
||
priority=5,
|
||
input_data={},
|
||
callback_url=None,
|
||
created_at=datetime.now(timezone.utc),
|
||
)
|
||
|
||
result = await agent.execute(task)
|
||
|
||
assert result.status == TaskStatus.FAILED
|
||
assert "Unsupported task type" in result.error_message
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_generate_topics_missing_keyword(self):
|
||
"""测试生成选题缺少关键词"""
|
||
agent = ContentGeneratorAgent()
|
||
task = TaskMessage(
|
||
task_id=str(uuid.uuid4()),
|
||
agent_name="content_generator",
|
||
task_type="generate_topics",
|
||
priority=5,
|
||
input_data={}, # 缺少target_keyword
|
||
callback_url=None,
|
||
created_at=datetime.now(timezone.utc),
|
||
)
|
||
|
||
# 由于没有真实LLM调用和知识库,这个测试会调用LLM
|
||
# 我们主要验证方法能正常执行
|
||
result = await agent.execute(task)
|
||
|
||
# 结果可能是FAILED(因为缺少必要参数或LLM调用失败)
|
||
assert result is not None
|
||
assert result.status in [TaskStatus.COMPLETED, TaskStatus.FAILED]
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_generate_article_missing_keyword(self):
|
||
"""测试生成文章缺少关键词"""
|
||
agent = ContentGeneratorAgent()
|
||
task = TaskMessage(
|
||
task_id=str(uuid.uuid4()),
|
||
agent_name="content_generator",
|
||
task_type="generate_article",
|
||
priority=5,
|
||
input_data={}, # 缺少target_keyword
|
||
callback_url=None,
|
||
created_at=datetime.now(timezone.utc),
|
||
)
|
||
|
||
result = await agent.execute(task)
|
||
|
||
assert result is not None
|
||
# 缺少必要参数可能导致失败
|
||
assert result.status in [TaskStatus.COMPLETED, TaskStatus.FAILED]
|
||
|
||
def test_extract_json_method(self):
|
||
"""测试JSON提取方法(已提取到 app.utils.json_extractor)"""
|
||
from app.utils.json_extractor import extract_json
|
||
|
||
# 测试普通JSON
|
||
json_text = '{"title": "测试标题", "reason": "测试原因"}'
|
||
extracted = extract_json(json_text)
|
||
assert "title" in extracted
|
||
|
||
# 测试被markdown包裹的JSON
|
||
md_text = '```json\n{"title": "测试"}\n```'
|
||
extracted = extract_json(md_text)
|
||
assert "title" in extracted
|
||
|
||
|
||
class TestAgentProtocol:
|
||
"""Agent协议测试"""
|
||
|
||
def test_agent_type_enum_values(self):
|
||
"""测试AgentType枚举值"""
|
||
from app.agent_framework.protocol import AgentType
|
||
|
||
assert AgentType.CITATION_DETECTOR.value == "citation_detector"
|
||
assert AgentType.CONTENT_GENERATOR.value == "content_generator"
|
||
|
||
def test_task_status_enum_values(self):
|
||
"""测试TaskStatus枚举值"""
|
||
assert TaskStatus.PENDING.value == "pending"
|
||
assert TaskStatus.RUNNING.value == "running"
|
||
assert TaskStatus.COMPLETED.value == "completed"
|
||
assert TaskStatus.FAILED.value == "failed"
|
||
assert TaskStatus.CANCELLED.value == "cancelled"
|
||
|
||
def test_agent_status_enum_values(self):
|
||
"""测试AgentStatus枚举值"""
|
||
from app.agent_framework.protocol import AgentStatus
|
||
|
||
assert AgentStatus.ONLINE.value == "online"
|
||
assert AgentStatus.OFFLINE.value == "offline"
|
||
assert AgentStatus.BUSY.value == "busy"
|