374 lines
13 KiB
Python
374 lines
13 KiB
Python
"""U6 测试: BaseAgent v2 集成 — LLM Gateway + Skill + Quality Gate + ReAct"""
|
||
|
||
import json
|
||
from datetime import datetime, timezone
|
||
from typing import Any
|
||
from unittest.mock import AsyncMock, MagicMock
|
||
|
||
import pytest
|
||
|
||
from agentkit.core.base import BaseAgent
|
||
from agentkit.core.protocol import (
|
||
AgentCapability,
|
||
AgentStatus,
|
||
TaskMessage,
|
||
TaskResult,
|
||
TaskStatus,
|
||
)
|
||
from agentkit.llm.gateway import LLMGateway
|
||
from agentkit.llm.protocol import LLMResponse, TokenUsage, ToolCall
|
||
from agentkit.quality.gate import QualityGate, QualityResult, QualityCheck
|
||
from agentkit.quality.output import OutputStandardizer, StandardOutput
|
||
from agentkit.skills.base import Skill, SkillConfig, QualityGateConfig, IntentConfig
|
||
|
||
|
||
# ── Helpers ──────────────────────────────────────────────
|
||
|
||
|
||
def _make_task(task_type: str = "echo", input_data: dict | None = None) -> TaskMessage:
|
||
return TaskMessage(
|
||
task_id="test-001",
|
||
agent_name="test_agent",
|
||
task_type=task_type,
|
||
priority=0,
|
||
input_data=input_data or {},
|
||
callback_url=None,
|
||
created_at=datetime.now(timezone.utc),
|
||
)
|
||
|
||
|
||
def _make_skill_config(
|
||
name: str = "test_skill",
|
||
execution_mode: str = "react",
|
||
quality_gate: dict | None = None,
|
||
prompt: dict | None = None,
|
||
) -> SkillConfig:
|
||
return SkillConfig(
|
||
name=name,
|
||
agent_type="test",
|
||
task_mode="llm_generate",
|
||
prompt=prompt or {"identity": "Test skill", "instructions": "Do test things"},
|
||
execution_mode=execution_mode,
|
||
quality_gate=quality_gate,
|
||
)
|
||
|
||
|
||
class SimpleV2Agent(BaseAgent):
|
||
"""测试用 v2 Agent"""
|
||
|
||
def __init__(self):
|
||
super().__init__(name="v2_agent", agent_type="test", version="2.0.0")
|
||
self.last_task = None
|
||
self.last_feedback = None
|
||
|
||
async def handle_task(self, task: TaskMessage) -> dict:
|
||
self.last_task = task
|
||
return {"result": "ok", "task_type": task.task_type}
|
||
|
||
async def handle_task_with_feedback(self, task: TaskMessage, feedback: str) -> dict:
|
||
self.last_feedback = feedback
|
||
return {"result": "retry_ok", "feedback": feedback}
|
||
|
||
def get_capabilities(self) -> AgentCapability:
|
||
return AgentCapability(
|
||
agent_name=self.name,
|
||
agent_type=self.agent_type,
|
||
version=self.version,
|
||
supported_tasks=["echo"],
|
||
max_concurrency=1,
|
||
description="V2 test agent",
|
||
)
|
||
|
||
|
||
# ── BaseAgent v2 属性测试 ────────────────────────────────
|
||
|
||
|
||
class TestBaseAgentV2Properties:
|
||
"""测试 BaseAgent 新增的 v2 属性"""
|
||
|
||
def test_llm_gateway_property_default_none(self):
|
||
agent = SimpleV2Agent()
|
||
assert agent.llm_gateway is None
|
||
|
||
def test_llm_gateway_setter(self):
|
||
agent = SimpleV2Agent()
|
||
gateway = LLMGateway()
|
||
agent.llm_gateway = gateway
|
||
assert agent.llm_gateway is gateway
|
||
|
||
def test_skill_property_default_none(self):
|
||
agent = SimpleV2Agent()
|
||
assert agent.skill is None
|
||
|
||
def test_skill_setter(self):
|
||
agent = SimpleV2Agent()
|
||
skill_config = _make_skill_config()
|
||
skill = Skill(config=skill_config)
|
||
agent.skill = skill
|
||
assert agent.skill is skill
|
||
assert agent.skill.name == "test_skill"
|
||
|
||
def test_quality_gate_property_default(self):
|
||
agent = SimpleV2Agent()
|
||
qg = agent.quality_gate
|
||
assert qg is not None
|
||
assert isinstance(qg, QualityGate)
|
||
|
||
|
||
# ── Quality Gate 集成测试 ────────────────────────────────
|
||
|
||
|
||
class TestQualityGateIntegration:
|
||
"""测试 execute() 中的 Quality Gate 集成"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_quality_passes_no_retry(self):
|
||
"""Quality Gate 通过时不重试"""
|
||
agent = SimpleV2Agent()
|
||
skill_config = _make_skill_config(
|
||
quality_gate={"required_fields": ["result"], "max_retries": 2}
|
||
)
|
||
skill = Skill(config=skill_config)
|
||
agent.skill = skill
|
||
|
||
task = _make_task()
|
||
result = await agent.execute(task)
|
||
|
||
assert result.status == TaskStatus.COMPLETED
|
||
assert result.output_data == {"result": "ok", "task_type": "echo"}
|
||
# handle_task 只被调用一次(没有重试)
|
||
assert agent.last_feedback is None
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_quality_fails_triggers_retry(self):
|
||
"""Quality Gate 失败时触发重试"""
|
||
agent = SimpleV2Agent()
|
||
skill_config = _make_skill_config(
|
||
quality_gate={"required_fields": ["missing_field"], "max_retries": 2}
|
||
)
|
||
skill = Skill(config=skill_config)
|
||
agent.skill = skill
|
||
|
||
task = _make_task()
|
||
result = await agent.execute(task)
|
||
|
||
# 即使质量检查失败,execute 仍返回结果(重试后仍可能失败)
|
||
assert result.status == TaskStatus.COMPLETED
|
||
# handle_task_with_feedback 应该被调用了
|
||
assert agent.last_feedback is not None
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_quality_retry_stops_on_pass(self):
|
||
"""Quality Gate 重试后通过则停止"""
|
||
|
||
class RetryAgent(BaseAgent):
|
||
def __init__(self):
|
||
super().__init__(name="retry_agent", agent_type="test", version="1.0.0")
|
||
self.call_count = 0
|
||
|
||
async def handle_task(self, task: TaskMessage) -> dict:
|
||
self.call_count += 1
|
||
if self.call_count == 1:
|
||
return {"content": "short"} # 第一次:字数不够
|
||
return {"content": "this is a longer response that meets the minimum word count requirement"}
|
||
|
||
async def handle_task_with_feedback(self, task: TaskMessage, feedback: str) -> dict:
|
||
self.call_count += 1
|
||
return {"content": "this is a longer response that meets the minimum word count requirement"}
|
||
|
||
def get_capabilities(self) -> AgentCapability:
|
||
return AgentCapability(
|
||
agent_name=self.name,
|
||
agent_type=self.agent_type,
|
||
version=self.version,
|
||
supported_tasks=["test"],
|
||
max_concurrency=1,
|
||
description="Retry test agent",
|
||
)
|
||
|
||
agent = RetryAgent()
|
||
skill_config = _make_skill_config(
|
||
quality_gate={"min_word_count": 5, "max_retries": 3}
|
||
)
|
||
skill = Skill(config=skill_config)
|
||
agent.skill = skill
|
||
|
||
task = _make_task()
|
||
result = await agent.execute(task)
|
||
|
||
assert result.status == TaskStatus.COMPLETED
|
||
# 应该调用了 handle_task 1次 + handle_task_with_feedback 1次 = 2次
|
||
assert agent.call_count == 2
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_quality_no_retry_when_max_retries_zero(self):
|
||
"""max_retries=0 时不重试"""
|
||
agent = SimpleV2Agent()
|
||
skill_config = _make_skill_config(
|
||
quality_gate={"required_fields": ["missing_field"], "max_retries": 0}
|
||
)
|
||
skill = Skill(config=skill_config)
|
||
agent.skill = skill
|
||
|
||
task = _make_task()
|
||
result = await agent.execute(task)
|
||
|
||
assert result.status == TaskStatus.COMPLETED
|
||
assert agent.last_feedback is None # 没有重试
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_no_quality_check_without_skill(self):
|
||
"""没有 Skill 时不执行 Quality Gate"""
|
||
agent = SimpleV2Agent()
|
||
# 不设置 skill
|
||
task = _make_task()
|
||
result = await agent.execute(task)
|
||
|
||
assert result.status == TaskStatus.COMPLETED
|
||
assert result.output_data == {"result": "ok", "task_type": "echo"}
|
||
|
||
|
||
# ── handle_task_with_feedback 测试 ───────────────────────
|
||
|
||
|
||
class TestHandleTaskWithFeedback:
|
||
"""测试 handle_task_with_feedback 默认行为"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_default_handle_task_with_feedback(self):
|
||
"""默认 handle_task_with_feedback 回退到 handle_task"""
|
||
|
||
class DefaultFeedbackAgent(BaseAgent):
|
||
def __init__(self):
|
||
super().__init__(name="fb_agent", agent_type="test", version="1.0.0")
|
||
|
||
async def handle_task(self, task: TaskMessage) -> dict:
|
||
return {"result": "default"}
|
||
|
||
def get_capabilities(self) -> AgentCapability:
|
||
return AgentCapability(
|
||
agent_name=self.name,
|
||
agent_type=self.agent_type,
|
||
version=self.version,
|
||
supported_tasks=["test"],
|
||
max_concurrency=1,
|
||
description="Feedback test agent",
|
||
)
|
||
|
||
agent = DefaultFeedbackAgent()
|
||
task = _make_task()
|
||
result = await agent.handle_task_with_feedback(task, "quality feedback")
|
||
assert result == {"result": "default"}
|
||
|
||
|
||
# ── _build_quality_feedback 测试 ─────────────────────────
|
||
|
||
|
||
class TestBuildQualityFeedback:
|
||
"""测试质量反馈构建"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_build_quality_feedback(self):
|
||
"""_build_quality_feedback 正确构建反馈字符串"""
|
||
agent = SimpleV2Agent()
|
||
quality_result = QualityResult(
|
||
passed=False,
|
||
checks=[
|
||
QualityCheck(name="required_field:title", passed=False, message="Field 'title' is missing"),
|
||
QualityCheck(name="min_word_count", passed=False, message="Word count 2 < minimum 10"),
|
||
],
|
||
can_retry=True,
|
||
)
|
||
feedback = agent._build_quality_feedback(quality_result)
|
||
assert "title" in feedback
|
||
assert "minimum 10" in feedback
|
||
assert "Quality check failed" in feedback
|
||
|
||
|
||
# ── Backward Compatibility 测试 ──────────────────────────
|
||
|
||
|
||
class TestBackwardCompatibility:
|
||
"""测试向后兼容性"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_execute_without_v2_features(self):
|
||
"""不使用 v2 功能时,execute 行为与 v1 一致"""
|
||
agent = SimpleV2Agent()
|
||
task = _make_task("echo", {"msg": "hello"})
|
||
result = await agent.execute(task)
|
||
|
||
assert result.status == TaskStatus.COMPLETED
|
||
assert result.output_data == {"result": "ok", "task_type": "echo"}
|
||
assert result.error_message is None
|
||
assert result.metrics["task_type"] == "echo"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_execute_failure_still_works(self):
|
||
"""v1 的失败路径仍然正常"""
|
||
|
||
class FailAgent(BaseAgent):
|
||
def __init__(self):
|
||
super().__init__(name="fail_agent", agent_type="test", version="1.0.0")
|
||
|
||
async def handle_task(self, task: TaskMessage) -> dict:
|
||
raise ValueError("intentional failure")
|
||
|
||
def get_capabilities(self) -> AgentCapability:
|
||
return AgentCapability(
|
||
agent_name=self.name,
|
||
agent_type=self.agent_type,
|
||
version=self.version,
|
||
supported_tasks=["test"],
|
||
max_concurrency=1,
|
||
description="Fail test agent",
|
||
)
|
||
|
||
agent = FailAgent()
|
||
task = _make_task()
|
||
result = await agent.execute(task)
|
||
|
||
assert result.status == TaskStatus.FAILED
|
||
assert result.error_message == "intentional failure"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_lifecycle_hooks_still_work(self):
|
||
"""v1 的生命周期钩子仍然正常"""
|
||
|
||
class HookAgent(BaseAgent):
|
||
def __init__(self):
|
||
super().__init__(name="hook_agent", agent_type="test", version="1.0.0")
|
||
self.started = False
|
||
self.completed = False
|
||
self.failed = False
|
||
|
||
async def handle_task(self, task: TaskMessage) -> dict:
|
||
return {"ok": True}
|
||
|
||
async def on_task_start(self, task):
|
||
self.started = True
|
||
|
||
async def on_task_complete(self, task, output):
|
||
self.completed = True
|
||
|
||
async def on_task_failed(self, task, error):
|
||
self.failed = True
|
||
|
||
def get_capabilities(self) -> AgentCapability:
|
||
return AgentCapability(
|
||
agent_name=self.name,
|
||
agent_type=self.agent_type,
|
||
version=self.version,
|
||
supported_tasks=["test"],
|
||
max_concurrency=1,
|
||
description="Hook test agent",
|
||
)
|
||
|
||
agent = HookAgent()
|
||
task = _make_task()
|
||
await agent.execute(task)
|
||
|
||
assert agent.started is True
|
||
assert agent.completed is True
|
||
assert agent.failed is False
|