geo/backend/tests/test_agent_framework/test_agents_integration.py

225 lines
7.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Agent集成测试"""
import pytest
import uuid
from datetime import datetime, timezone
from app.agent_framework.agents.citation_detector import CitationDetectorAgent
from app.agent_framework.agents.content_generator_agent import ContentGeneratorAgent
from app.agent_framework.protocol import (
TaskMessage,
TaskStatus,
)
class TestCitationDetectorAgent:
"""引用检测Agent测试"""
def test_agent_initialization(self):
"""测试Agent初始化"""
agent = CitationDetectorAgent()
assert agent.name == "citation_detector"
assert agent.agent_type.value == "citation_detector"
assert agent.version == "1.0.0"
def test_get_capabilities(self):
"""测试获取Agent能力"""
agent = CitationDetectorAgent()
capability = agent.get_capabilities()
assert capability.agent_name == "citation_detector"
assert "citation_detect" in capability.supported_tasks
assert "citation_detect_single" in capability.supported_tasks
assert capability.max_concurrency == 3
@pytest.mark.asyncio
async def test_execute_with_invalid_task_type(self):
"""测试执行不支持的任务类型"""
agent = CitationDetectorAgent()
task = TaskMessage(
task_id=str(uuid.uuid4()),
agent_name="citation_detector",
task_type="invalid_task_type",
priority=5,
input_data={},
callback_url=None,
created_at=datetime.now(timezone.utc),
)
result = await agent.execute(task)
assert result.status == TaskStatus.FAILED
assert result.error_message is not None
assert "Unsupported task type" in result.error_message
@pytest.mark.asyncio
async def test_execute_single_detect_missing_params(self):
"""测试单平台检测缺少必需参数"""
agent = CitationDetectorAgent()
task = TaskMessage(
task_id=str(uuid.uuid4()),
agent_name="citation_detector",
task_type="citation_detect_single",
priority=5,
input_data={}, # 缺少keyword, platform, target_brand
callback_url=None,
created_at=datetime.now(timezone.utc),
)
result = await agent.execute(task)
assert result.status == TaskStatus.FAILED
assert result.error_message is not None
@pytest.mark.asyncio
async def test_execute_full_detect_missing_query_id(self):
"""测试完整检测缺少query_id"""
agent = CitationDetectorAgent()
task = TaskMessage(
task_id=str(uuid.uuid4()),
agent_name="citation_detector",
task_type="citation_detect",
priority=5,
input_data={}, # 缺少query_id
callback_url=None,
created_at=datetime.now(timezone.utc),
)
result = await agent.execute(task)
assert result.status == TaskStatus.FAILED
assert "query_id" in result.error_message or "must contain" in result.error_message
def test_compatibility_methods_exist(self):
"""测试向后兼容方法存在"""
agent = CitationDetectorAgent()
assert hasattr(agent, 'execute_query_compat')
assert hasattr(agent, 'execute_single_platform_compat')
assert callable(agent.execute_query_compat)
assert callable(agent.execute_single_platform_compat)
class TestContentGeneratorAgent:
"""内容生成Agent测试"""
def test_agent_initialization(self):
"""测试Agent初始化"""
agent = ContentGeneratorAgent()
assert agent.name == "content_generator"
assert agent.agent_type.value == "content_generator"
assert agent.version == "1.0.0"
def test_get_capabilities(self):
"""测试获取Agent能力"""
agent = ContentGeneratorAgent()
capability = agent.get_capabilities()
assert capability.agent_name == "content_generator"
assert "generate_topics" in capability.supported_tasks
assert "generate_article" in capability.supported_tasks
assert capability.max_concurrency == 2
@pytest.mark.asyncio
async def test_execute_with_invalid_task_type(self):
"""测试执行不支持的任务类型"""
agent = ContentGeneratorAgent()
task = TaskMessage(
task_id=str(uuid.uuid4()),
agent_name="content_generator",
task_type="invalid_task_type",
priority=5,
input_data={},
callback_url=None,
created_at=datetime.now(timezone.utc),
)
result = await agent.execute(task)
assert result.status == TaskStatus.FAILED
assert "Unsupported task type" in result.error_message
@pytest.mark.asyncio
async def test_generate_topics_missing_keyword(self):
"""测试生成选题缺少关键词"""
agent = ContentGeneratorAgent()
task = TaskMessage(
task_id=str(uuid.uuid4()),
agent_name="content_generator",
task_type="generate_topics",
priority=5,
input_data={}, # 缺少target_keyword
callback_url=None,
created_at=datetime.now(timezone.utc),
)
# 由于没有真实LLM调用和知识库这个测试会调用LLM
# 我们主要验证方法能正常执行
result = await agent.execute(task)
# 结果可能是FAILED因为缺少必要参数或LLM调用失败
assert result is not None
assert result.status in [TaskStatus.COMPLETED, TaskStatus.FAILED]
@pytest.mark.asyncio
async def test_generate_article_missing_keyword(self):
"""测试生成文章缺少关键词"""
agent = ContentGeneratorAgent()
task = TaskMessage(
task_id=str(uuid.uuid4()),
agent_name="content_generator",
task_type="generate_article",
priority=5,
input_data={}, # 缺少target_keyword
callback_url=None,
created_at=datetime.now(timezone.utc),
)
result = await agent.execute(task)
assert result is not None
# 缺少必要参数可能导致失败
assert result.status in [TaskStatus.COMPLETED, TaskStatus.FAILED]
def test_extract_json_method(self):
"""测试JSON提取方法已提取到 app.utils.json_extractor"""
from app.utils.json_extractor import extract_json
# 测试普通JSON
json_text = '{"title": "测试标题", "reason": "测试原因"}'
extracted = extract_json(json_text)
assert "title" in extracted
# 测试被markdown包裹的JSON
md_text = '```json\n{"title": "测试"}\n```'
extracted = extract_json(md_text)
assert "title" in extracted
class TestAgentProtocol:
"""Agent协议测试"""
def test_agent_type_enum_values(self):
"""测试AgentType枚举值"""
from app.agent_framework.protocol import AgentType
assert AgentType.CITATION_DETECTOR.value == "citation_detector"
assert AgentType.CONTENT_GENERATOR.value == "content_generator"
def test_task_status_enum_values(self):
"""测试TaskStatus枚举值"""
assert TaskStatus.PENDING.value == "pending"
assert TaskStatus.RUNNING.value == "running"
assert TaskStatus.COMPLETED.value == "completed"
assert TaskStatus.FAILED.value == "failed"
assert TaskStatus.CANCELLED.value == "cancelled"
def test_agent_status_enum_values(self):
"""测试AgentStatus枚举值"""
from app.agent_framework.protocol import AgentStatus
assert AgentStatus.ONLINE.value == "online"
assert AgentStatus.OFFLINE.value == "offline"
assert AgentStatus.BUSY.value == "busy"