296 lines
10 KiB
Python
296 lines
10 KiB
Python
"""Tests for LLMReflector - LLM 驱动的执行反思器"""
|
||
|
||
import json
|
||
from datetime import datetime, timezone
|
||
from unittest.mock import AsyncMock, MagicMock
|
||
|
||
import pytest
|
||
|
||
from agentkit.core.protocol import TaskMessage, TaskResult, TaskStatus
|
||
from agentkit.core.trace import ExecutionTrace, TraceStep
|
||
from agentkit.evolution.llm_reflector import LLMReflector
|
||
from agentkit.evolution.reflector import Reflection, Reflector, RuleBasedReflector
|
||
from agentkit.evolution.lifecycle import EvolutionMixin
|
||
from agentkit.skills.base import EvolutionConfig
|
||
|
||
|
||
# ── 辅助函数 ──────────────────────────────────────────────────
|
||
|
||
|
||
def _make_task() -> TaskMessage:
|
||
return TaskMessage(
|
||
task_id="test-001",
|
||
agent_name="test_agent",
|
||
task_type="echo",
|
||
priority=0,
|
||
input_data={"query": "hello"},
|
||
callback_url=None,
|
||
created_at=datetime.now(timezone.utc),
|
||
)
|
||
|
||
|
||
def _make_result(status: str = TaskStatus.COMPLETED) -> TaskResult:
|
||
return TaskResult(
|
||
task_id="test-001",
|
||
agent_name="test_agent",
|
||
status=status,
|
||
output_data={"key": "value"},
|
||
error_message=None,
|
||
started_at=datetime.now(timezone.utc),
|
||
completed_at=datetime.now(timezone.utc),
|
||
metrics={"elapsed_seconds": 5.0},
|
||
)
|
||
|
||
|
||
def _make_trace() -> ExecutionTrace:
|
||
return ExecutionTrace(
|
||
task_id="test-001",
|
||
agent_name="test_agent",
|
||
steps=[
|
||
TraceStep(step=1, action="llm_call", tokens_used=100),
|
||
TraceStep(
|
||
step=2,
|
||
action="tool_call",
|
||
tool_name="search",
|
||
duration_ms=200,
|
||
tokens_used=50,
|
||
),
|
||
TraceStep(step=3, action="final_answer", tokens_used=80),
|
||
],
|
||
total_duration_ms=500,
|
||
total_tokens=230,
|
||
outcome="success",
|
||
)
|
||
|
||
|
||
def _make_mock_gateway(response_content: str) -> MagicMock:
|
||
"""创建返回指定内容的 mock LLMGateway"""
|
||
gateway = MagicMock()
|
||
mock_response = MagicMock()
|
||
mock_response.content = response_content
|
||
gateway.chat = AsyncMock(return_value=mock_response)
|
||
return gateway
|
||
|
||
|
||
# ── LLMReflector 基础功能 ──────────────────────────────────────
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_llm_reflector_parses_json_in_code_block():
|
||
"""LLMReflector 从代码块中的 JSON 生成 Reflection"""
|
||
json_data = {
|
||
"outcome": "success",
|
||
"quality_score": 0.85,
|
||
"patterns": ["fast_execution"],
|
||
"insights": ["Task completed efficiently"],
|
||
"suggestions": ["Consider caching results"],
|
||
}
|
||
response = f"```json\n{json.dumps(json_data)}\n```"
|
||
gateway = _make_mock_gateway(response)
|
||
reflector = LLMReflector(llm_gateway=gateway, model="test-model")
|
||
|
||
task = _make_task()
|
||
result = _make_result()
|
||
reflection = await reflector.reflect(task, result)
|
||
|
||
assert isinstance(reflection, Reflection)
|
||
assert reflection.outcome == "success"
|
||
assert reflection.quality_score == 0.85
|
||
assert reflection.patterns == ["fast_execution"]
|
||
assert reflection.insights == ["Task completed efficiently"]
|
||
assert reflection.suggestions == ["Consider caching results"]
|
||
assert reflection.task_id == "test-001"
|
||
assert reflection.agent_name == "test_agent"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_llm_reflector_parses_raw_json():
|
||
"""LLMReflector 从原始 JSON 响应生成 Reflection"""
|
||
json_data = {
|
||
"outcome": "failure",
|
||
"quality_score": 0.2,
|
||
"patterns": ["slow_execution", "error_type:TimeoutError"],
|
||
"insights": ["Timeout occurred"],
|
||
"suggestions": ["Increase timeout"],
|
||
}
|
||
gateway = _make_mock_gateway(json.dumps(json_data))
|
||
reflector = LLMReflector(llm_gateway=gateway, model="test-model")
|
||
|
||
task = _make_task()
|
||
result = _make_result(status=TaskStatus.FAILED)
|
||
reflection = await reflector.reflect(task, result)
|
||
|
||
assert reflection.outcome == "failure"
|
||
assert reflection.quality_score == 0.2
|
||
assert "slow_execution" in reflection.patterns
|
||
assert "Increase timeout" in reflection.suggestions
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_llm_reflector_handles_unparseable_response():
|
||
"""LLMReflector 处理无法解析的 LLM 响应(降级反思)"""
|
||
gateway = _make_mock_gateway("This is not JSON at all, just plain text.")
|
||
reflector = LLMReflector(llm_gateway=gateway, model="test-model")
|
||
|
||
task = _make_task()
|
||
result = _make_result()
|
||
reflection = await reflector.reflect(task, result)
|
||
|
||
assert isinstance(reflection, Reflection)
|
||
assert reflection.outcome == "partial"
|
||
assert reflection.quality_score == 0.5
|
||
assert "LLM response could not be parsed as structured reflection" in reflection.insights
|
||
assert "Review LLM output format" in reflection.suggestions
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_llm_reflector_handles_llm_call_failure():
|
||
"""LLMReflector 处理 LLM 调用失败(返回失败反思)"""
|
||
gateway = MagicMock()
|
||
gateway.chat = AsyncMock(side_effect=Exception("LLM service unavailable"))
|
||
reflector = LLMReflector(llm_gateway=gateway, model="test-model")
|
||
|
||
task = _make_task()
|
||
result = _make_result()
|
||
reflection = await reflector.reflect(task, result)
|
||
|
||
assert isinstance(reflection, Reflection)
|
||
assert reflection.outcome == "failure"
|
||
assert reflection.quality_score == 0.0
|
||
assert any("LLM reflection failed" in i for i in reflection.insights)
|
||
assert "Consider using rule-based reflector as fallback" in reflection.suggestions
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_llm_reflector_uses_execution_trace():
|
||
"""LLMReflector 使用 ExecutionTrace 信息"""
|
||
gateway = _make_mock_gateway('{"outcome": "success", "quality_score": 0.9}')
|
||
reflector = LLMReflector(llm_gateway=gateway, model="test-model")
|
||
|
||
task = _make_task()
|
||
result = _make_result()
|
||
trace = _make_trace()
|
||
reflection = await reflector.reflect(task, result, trace=trace)
|
||
|
||
# 验证 LLM 被调用,且 prompt 中包含 trace 信息
|
||
call_args = gateway.chat.call_args
|
||
prompt = call_args.kwargs["messages"][0]["content"]
|
||
assert "Total Steps: 3" in prompt
|
||
assert "Total Duration: 500ms" in prompt
|
||
assert "Total Tokens: 230" in prompt
|
||
assert "Tool: search" in prompt
|
||
assert reflection.outcome == "success"
|
||
|
||
|
||
# ── Auto 模式 ──────────────────────────────────────────────────
|
||
|
||
|
||
def test_auto_mode_with_llm_available():
|
||
"""Auto 模式:LLM 可用时使用 LLMReflector"""
|
||
gateway = MagicMock()
|
||
mixin = EvolutionMixin(reflector_type="auto", llm_gateway=gateway)
|
||
assert isinstance(mixin._reflector, LLMReflector)
|
||
|
||
|
||
def test_auto_mode_without_llm_falls_back():
|
||
"""Auto 模式:LLM 不可用时降级到 RuleBasedReflector"""
|
||
mixin = EvolutionMixin(reflector_type="auto", llm_gateway=None)
|
||
assert isinstance(mixin._reflector, RuleBasedReflector)
|
||
|
||
|
||
def test_rule_mode_always_uses_rule_based():
|
||
"""Rule 模式:始终使用 RuleBasedReflector"""
|
||
gateway = MagicMock()
|
||
mixin = EvolutionMixin(reflector_type="rule", llm_gateway=gateway)
|
||
assert isinstance(mixin._reflector, RuleBasedReflector)
|
||
|
||
|
||
def test_llm_mode_without_gateway_falls_back():
|
||
"""LLM 模式:无 gateway 时降级到 RuleBasedReflector"""
|
||
mixin = EvolutionMixin(reflector_type="llm", llm_gateway=None)
|
||
assert isinstance(mixin._reflector, RuleBasedReflector)
|
||
|
||
|
||
def test_llm_mode_with_gateway():
|
||
"""LLM 模式:有 gateway 时使用 LLMReflector"""
|
||
gateway = MagicMock()
|
||
mixin = EvolutionMixin(reflector_type="llm", llm_gateway=gateway)
|
||
assert isinstance(mixin._reflector, LLMReflector)
|
||
|
||
|
||
def test_explicit_reflector_overrides_type():
|
||
"""显式传入 reflector 时覆盖 reflector_type"""
|
||
gateway = MagicMock()
|
||
rule_reflector = RuleBasedReflector()
|
||
mixin = EvolutionMixin(
|
||
reflector=rule_reflector,
|
||
reflector_type="llm",
|
||
llm_gateway=gateway,
|
||
)
|
||
assert mixin._reflector is rule_reflector
|
||
|
||
|
||
def test_auxiliary_model_passed_to_llm_reflector():
|
||
"""auxiliary_model 正确传递给 LLMReflector"""
|
||
gateway = MagicMock()
|
||
mixin = EvolutionMixin(
|
||
reflector_type="llm",
|
||
llm_gateway=gateway,
|
||
auxiliary_model="gpt-4o-mini",
|
||
)
|
||
assert isinstance(mixin._reflector, LLMReflector)
|
||
assert mixin._reflector._model == "gpt-4o-mini"
|
||
|
||
|
||
def test_no_reflector_type_defaults_to_none():
|
||
"""不指定 reflector_type 时,reflector 为 None(向后兼容)"""
|
||
mixin = EvolutionMixin()
|
||
assert mixin._reflector is None
|
||
|
||
|
||
# ── EvolutionConfig 新字段 ──────────────────────────────────────
|
||
|
||
|
||
def test_evolution_config_default_values():
|
||
"""EvolutionConfig 默认值"""
|
||
config = EvolutionConfig()
|
||
assert config.reflector_type == "auto"
|
||
assert config.auxiliary_model is None
|
||
|
||
|
||
def test_evolution_config_custom_values():
|
||
"""EvolutionConfig 自定义值"""
|
||
config = EvolutionConfig(
|
||
enabled=True,
|
||
reflector_type="llm",
|
||
auxiliary_model="gpt-4o-mini",
|
||
)
|
||
assert config.reflector_type == "llm"
|
||
assert config.auxiliary_model == "gpt-4o-mini"
|
||
|
||
|
||
# ── 向后兼容 ──────────────────────────────────────────────────
|
||
|
||
|
||
def test_reflector_alias_still_works():
|
||
"""Reflector 别名仍然可用"""
|
||
assert Reflector is RuleBasedReflector
|
||
reflector = Reflector()
|
||
assert isinstance(reflector, RuleBasedReflector)
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_reflector_alias_produces_same_reflection():
|
||
"""Reflector 别名产生与 RuleBasedReflector 相同的结果"""
|
||
task = _make_task()
|
||
result = _make_result()
|
||
|
||
r1 = Reflector()
|
||
r2 = RuleBasedReflector()
|
||
|
||
reflection1 = await r1.reflect(task, result)
|
||
reflection2 = await r2.reflect(task, result)
|
||
|
||
assert reflection1.outcome == reflection2.outcome
|
||
assert reflection1.quality_score == reflection2.quality_score
|