fischer-agentkit/tests/unit/test_base_agent_v2.py

374 lines
13 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""U6 测试: BaseAgent v2 集成 — LLM Gateway + Skill + Quality Gate + ReAct"""
import json
from datetime import datetime, timezone
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import pytest
from agentkit.core.base import BaseAgent
from agentkit.core.protocol import (
AgentCapability,
AgentStatus,
TaskMessage,
TaskResult,
TaskStatus,
)
from agentkit.llm.gateway import LLMGateway
from agentkit.llm.protocol import LLMResponse, TokenUsage, ToolCall
from agentkit.quality.gate import QualityGate, QualityResult, QualityCheck
from agentkit.quality.output import OutputStandardizer, StandardOutput
from agentkit.skills.base import Skill, SkillConfig, QualityGateConfig, IntentConfig
# ── Helpers ──────────────────────────────────────────────
def _make_task(task_type: str = "echo", input_data: dict | None = None) -> TaskMessage:
return TaskMessage(
task_id="test-001",
agent_name="test_agent",
task_type=task_type,
priority=0,
input_data=input_data or {},
callback_url=None,
created_at=datetime.now(timezone.utc),
)
def _make_skill_config(
name: str = "test_skill",
execution_mode: str = "react",
quality_gate: dict | None = None,
prompt: dict | None = None,
) -> SkillConfig:
return SkillConfig(
name=name,
agent_type="test",
task_mode="llm_generate",
prompt=prompt or {"identity": "Test skill", "instructions": "Do test things"},
execution_mode=execution_mode,
quality_gate=quality_gate,
)
class SimpleV2Agent(BaseAgent):
"""测试用 v2 Agent"""
def __init__(self):
super().__init__(name="v2_agent", agent_type="test", version="2.0.0")
self.last_task = None
self.last_feedback = None
async def handle_task(self, task: TaskMessage) -> dict:
self.last_task = task
return {"result": "ok", "task_type": task.task_type}
async def handle_task_with_feedback(self, task: TaskMessage, feedback: str) -> dict:
self.last_feedback = feedback
return {"result": "retry_ok", "feedback": feedback}
def get_capabilities(self) -> AgentCapability:
return AgentCapability(
agent_name=self.name,
agent_type=self.agent_type,
version=self.version,
supported_tasks=["echo"],
max_concurrency=1,
description="V2 test agent",
)
# ── BaseAgent v2 属性测试 ────────────────────────────────
class TestBaseAgentV2Properties:
"""测试 BaseAgent 新增的 v2 属性"""
def test_llm_gateway_property_default_none(self):
agent = SimpleV2Agent()
assert agent.llm_gateway is None
def test_llm_gateway_setter(self):
agent = SimpleV2Agent()
gateway = LLMGateway()
agent.llm_gateway = gateway
assert agent.llm_gateway is gateway
def test_skill_property_default_none(self):
agent = SimpleV2Agent()
assert agent.skill is None
def test_skill_setter(self):
agent = SimpleV2Agent()
skill_config = _make_skill_config()
skill = Skill(config=skill_config)
agent.skill = skill
assert agent.skill is skill
assert agent.skill.name == "test_skill"
def test_quality_gate_property_default(self):
agent = SimpleV2Agent()
qg = agent.quality_gate
assert qg is not None
assert isinstance(qg, QualityGate)
# ── Quality Gate 集成测试 ────────────────────────────────
class TestQualityGateIntegration:
"""测试 execute() 中的 Quality Gate 集成"""
@pytest.mark.asyncio
async def test_quality_passes_no_retry(self):
"""Quality Gate 通过时不重试"""
agent = SimpleV2Agent()
skill_config = _make_skill_config(
quality_gate={"required_fields": ["result"], "max_retries": 2}
)
skill = Skill(config=skill_config)
agent.skill = skill
task = _make_task()
result = await agent.execute(task)
assert result.status == TaskStatus.COMPLETED
assert result.output_data == {"result": "ok", "task_type": "echo"}
# handle_task 只被调用一次(没有重试)
assert agent.last_feedback is None
@pytest.mark.asyncio
async def test_quality_fails_triggers_retry(self):
"""Quality Gate 失败时触发重试"""
agent = SimpleV2Agent()
skill_config = _make_skill_config(
quality_gate={"required_fields": ["missing_field"], "max_retries": 2}
)
skill = Skill(config=skill_config)
agent.skill = skill
task = _make_task()
result = await agent.execute(task)
# 即使质量检查失败execute 仍返回结果(重试后仍可能失败)
assert result.status == TaskStatus.COMPLETED
# handle_task_with_feedback 应该被调用了
assert agent.last_feedback is not None
@pytest.mark.asyncio
async def test_quality_retry_stops_on_pass(self):
"""Quality Gate 重试后通过则停止"""
class RetryAgent(BaseAgent):
def __init__(self):
super().__init__(name="retry_agent", agent_type="test", version="1.0.0")
self.call_count = 0
async def handle_task(self, task: TaskMessage) -> dict:
self.call_count += 1
if self.call_count == 1:
return {"content": "short"} # 第一次:字数不够
return {"content": "this is a longer response that meets the minimum word count requirement"}
async def handle_task_with_feedback(self, task: TaskMessage, feedback: str) -> dict:
self.call_count += 1
return {"content": "this is a longer response that meets the minimum word count requirement"}
def get_capabilities(self) -> AgentCapability:
return AgentCapability(
agent_name=self.name,
agent_type=self.agent_type,
version=self.version,
supported_tasks=["test"],
max_concurrency=1,
description="Retry test agent",
)
agent = RetryAgent()
skill_config = _make_skill_config(
quality_gate={"min_word_count": 5, "max_retries": 3}
)
skill = Skill(config=skill_config)
agent.skill = skill
task = _make_task()
result = await agent.execute(task)
assert result.status == TaskStatus.COMPLETED
# 应该调用了 handle_task 1次 + handle_task_with_feedback 1次 = 2次
assert agent.call_count == 2
@pytest.mark.asyncio
async def test_quality_no_retry_when_max_retries_zero(self):
"""max_retries=0 时不重试"""
agent = SimpleV2Agent()
skill_config = _make_skill_config(
quality_gate={"required_fields": ["missing_field"], "max_retries": 0}
)
skill = Skill(config=skill_config)
agent.skill = skill
task = _make_task()
result = await agent.execute(task)
assert result.status == TaskStatus.COMPLETED
assert agent.last_feedback is None # 没有重试
@pytest.mark.asyncio
async def test_no_quality_check_without_skill(self):
"""没有 Skill 时不执行 Quality Gate"""
agent = SimpleV2Agent()
# 不设置 skill
task = _make_task()
result = await agent.execute(task)
assert result.status == TaskStatus.COMPLETED
assert result.output_data == {"result": "ok", "task_type": "echo"}
# ── handle_task_with_feedback 测试 ───────────────────────
class TestHandleTaskWithFeedback:
"""测试 handle_task_with_feedback 默认行为"""
@pytest.mark.asyncio
async def test_default_handle_task_with_feedback(self):
"""默认 handle_task_with_feedback 回退到 handle_task"""
class DefaultFeedbackAgent(BaseAgent):
def __init__(self):
super().__init__(name="fb_agent", agent_type="test", version="1.0.0")
async def handle_task(self, task: TaskMessage) -> dict:
return {"result": "default"}
def get_capabilities(self) -> AgentCapability:
return AgentCapability(
agent_name=self.name,
agent_type=self.agent_type,
version=self.version,
supported_tasks=["test"],
max_concurrency=1,
description="Feedback test agent",
)
agent = DefaultFeedbackAgent()
task = _make_task()
result = await agent.handle_task_with_feedback(task, "quality feedback")
assert result == {"result": "default"}
# ── _build_quality_feedback 测试 ─────────────────────────
class TestBuildQualityFeedback:
"""测试质量反馈构建"""
@pytest.mark.asyncio
async def test_build_quality_feedback(self):
"""_build_quality_feedback 正确构建反馈字符串"""
agent = SimpleV2Agent()
quality_result = QualityResult(
passed=False,
checks=[
QualityCheck(name="required_field:title", passed=False, message="Field 'title' is missing"),
QualityCheck(name="min_word_count", passed=False, message="Word count 2 < minimum 10"),
],
can_retry=True,
)
feedback = agent._build_quality_feedback(quality_result)
assert "title" in feedback
assert "minimum 10" in feedback
assert "Quality check failed" in feedback
# ── Backward Compatibility 测试 ──────────────────────────
class TestBackwardCompatibility:
"""测试向后兼容性"""
@pytest.mark.asyncio
async def test_execute_without_v2_features(self):
"""不使用 v2 功能时execute 行为与 v1 一致"""
agent = SimpleV2Agent()
task = _make_task("echo", {"msg": "hello"})
result = await agent.execute(task)
assert result.status == TaskStatus.COMPLETED
assert result.output_data == {"result": "ok", "task_type": "echo"}
assert result.error_message is None
assert result.metrics["task_type"] == "echo"
@pytest.mark.asyncio
async def test_execute_failure_still_works(self):
"""v1 的失败路径仍然正常"""
class FailAgent(BaseAgent):
def __init__(self):
super().__init__(name="fail_agent", agent_type="test", version="1.0.0")
async def handle_task(self, task: TaskMessage) -> dict:
raise ValueError("intentional failure")
def get_capabilities(self) -> AgentCapability:
return AgentCapability(
agent_name=self.name,
agent_type=self.agent_type,
version=self.version,
supported_tasks=["test"],
max_concurrency=1,
description="Fail test agent",
)
agent = FailAgent()
task = _make_task()
result = await agent.execute(task)
assert result.status == TaskStatus.FAILED
assert result.error_message == "intentional failure"
@pytest.mark.asyncio
async def test_lifecycle_hooks_still_work(self):
"""v1 的生命周期钩子仍然正常"""
class HookAgent(BaseAgent):
def __init__(self):
super().__init__(name="hook_agent", agent_type="test", version="1.0.0")
self.started = False
self.completed = False
self.failed = False
async def handle_task(self, task: TaskMessage) -> dict:
return {"ok": True}
async def on_task_start(self, task):
self.started = True
async def on_task_complete(self, task, output):
self.completed = True
async def on_task_failed(self, task, error):
self.failed = True
def get_capabilities(self) -> AgentCapability:
return AgentCapability(
agent_name=self.name,
agent_type=self.agent_type,
version=self.version,
supported_tasks=["test"],
max_concurrency=1,
description="Hook test agent",
)
agent = HookAgent()
task = _make_task()
await agent.execute(task)
assert agent.started is True
assert agent.completed is True
assert agent.failed is False