483 lines
16 KiB
Python
483 lines
16 KiB
Python
"""TraceRecorder 单元测试"""
|
||
|
||
import time
|
||
from unittest.mock import AsyncMock, MagicMock
|
||
|
||
import pytest
|
||
|
||
from agentkit.core.trace import ExecutionTrace, TraceRecorder, TraceStep
|
||
from agentkit.llm.gateway import LLMGateway
|
||
from agentkit.llm.protocol import LLMResponse, TokenUsage, ToolCall
|
||
from agentkit.tools.base import Tool
|
||
|
||
|
||
# ── Test Helpers ──────────────────────────────────────────
|
||
|
||
|
||
class FakeTool(Tool):
|
||
"""用于测试的 Fake Tool"""
|
||
|
||
def __init__(
|
||
self,
|
||
name: str = "fake_tool",
|
||
description: str = "A fake tool for testing",
|
||
result: dict | None = None,
|
||
):
|
||
super().__init__(name=name, description=description)
|
||
self._result = result or {"status": "ok"}
|
||
|
||
async def execute(self, **kwargs) -> dict:
|
||
return self._result
|
||
|
||
|
||
def make_mock_gateway(responses: list[LLMResponse]) -> LLMGateway:
|
||
"""创建一个 mock LLMGateway,按顺序返回给定响应"""
|
||
gateway = MagicMock(spec=LLMGateway)
|
||
gateway.chat = AsyncMock(side_effect=responses)
|
||
return gateway
|
||
|
||
|
||
def make_response(
|
||
content: str = "",
|
||
tool_calls: list[ToolCall] | None = None,
|
||
prompt_tokens: int = 10,
|
||
completion_tokens: int = 20,
|
||
) -> LLMResponse:
|
||
"""快速构造 LLMResponse"""
|
||
return LLMResponse(
|
||
content=content,
|
||
model="test-model",
|
||
usage=TokenUsage(
|
||
prompt_tokens=prompt_tokens,
|
||
completion_tokens=completion_tokens,
|
||
),
|
||
tool_calls=tool_calls or [],
|
||
)
|
||
|
||
|
||
# ── TraceStep Tests ──────────────────────────────────────
|
||
|
||
|
||
class TestTraceStep:
|
||
"""TraceStep 数据类测试"""
|
||
|
||
def test_to_dict_with_all_fields(self):
|
||
step = TraceStep(
|
||
step=1,
|
||
action="tool_call",
|
||
tool_name="search",
|
||
input_data={"query": "test"},
|
||
output_data={"results": ["found"]},
|
||
duration_ms=100,
|
||
tokens_used=50,
|
||
error=None,
|
||
)
|
||
d = step.to_dict()
|
||
assert d["step"] == 1
|
||
assert d["action"] == "tool_call"
|
||
assert d["tool_name"] == "search"
|
||
assert d["input_data"] == {"query": "test"}
|
||
assert d["output_data"] == {"results": ["found"]}
|
||
assert d["duration_ms"] == 100
|
||
assert d["tokens_used"] == 50
|
||
assert "error" not in d
|
||
|
||
def test_to_dict_omits_none_fields(self):
|
||
step = TraceStep(step=1, action="llm_call", duration_ms=50, tokens_used=30)
|
||
d = step.to_dict()
|
||
assert "tool_name" not in d
|
||
assert "input_data" not in d
|
||
assert "output_data" not in d
|
||
assert "error" not in d
|
||
|
||
def test_to_dict_includes_error_when_present(self):
|
||
step = TraceStep(step=1, action="tool_call", error="Tool not found")
|
||
d = step.to_dict()
|
||
assert d["error"] == "Tool not found"
|
||
|
||
|
||
# ── ExecutionTrace Tests ─────────────────────────────────
|
||
|
||
|
||
class TestExecutionTrace:
|
||
"""ExecutionTrace 数据类测试"""
|
||
|
||
def test_to_dict(self):
|
||
trace = ExecutionTrace(
|
||
task_id="t1",
|
||
agent_name="agent1",
|
||
skill_name="search_skill",
|
||
steps=[
|
||
TraceStep(step=1, action="llm_call", duration_ms=50, tokens_used=30),
|
||
TraceStep(step=1, action="tool_call", tool_name="search", duration_ms=100, tokens_used=0),
|
||
],
|
||
total_duration_ms=150,
|
||
total_tokens=30,
|
||
outcome="success",
|
||
quality_score=0.9,
|
||
)
|
||
d = trace.to_dict()
|
||
assert d["task_id"] == "t1"
|
||
assert d["agent_name"] == "agent1"
|
||
assert d["skill_name"] == "search_skill"
|
||
assert len(d["steps"]) == 2
|
||
assert d["total_duration_ms"] == 150
|
||
assert d["total_tokens"] == 30
|
||
assert d["outcome"] == "success"
|
||
assert d["quality_score"] == 0.9
|
||
|
||
|
||
# ── TraceRecorder Happy Path Tests ───────────────────────
|
||
|
||
|
||
class TestTraceRecorderHappyPath:
|
||
"""TraceRecorder 正常流程测试"""
|
||
|
||
def test_start_record_end_returns_trace(self):
|
||
recorder = TraceRecorder()
|
||
recorder.start_trace(task_id="t1", agent_name="agent1")
|
||
recorder.record_step(
|
||
step=1,
|
||
action="llm_call",
|
||
duration_ms=50,
|
||
tokens_used=30,
|
||
)
|
||
recorder.record_step(
|
||
step=1,
|
||
action="tool_call",
|
||
tool_name="search",
|
||
input_data={"query": "test"},
|
||
output_data={"results": ["found"]},
|
||
duration_ms=100,
|
||
)
|
||
trace = recorder.end_trace(outcome="success", quality_score=0.9)
|
||
|
||
assert isinstance(trace, ExecutionTrace)
|
||
assert trace.task_id == "t1"
|
||
assert trace.agent_name == "agent1"
|
||
assert trace.outcome == "success"
|
||
assert trace.quality_score == 0.9
|
||
assert len(trace.steps) == 2
|
||
assert trace.steps[0].action == "llm_call"
|
||
assert trace.steps[1].action == "tool_call"
|
||
assert trace.steps[1].tool_name == "search"
|
||
|
||
def test_multiple_steps_recorded_in_order(self):
|
||
recorder = TraceRecorder()
|
||
recorder.start_trace(task_id="t2", agent_name="agent2")
|
||
recorder.record_step(step=1, action="llm_call", tokens_used=100)
|
||
recorder.record_step(step=1, action="tool_call", tool_name="calc", tokens_used=0)
|
||
recorder.record_step(step=2, action="llm_call", tokens_used=80)
|
||
recorder.record_step(step=2, action="final_answer", tokens_used=0)
|
||
trace = recorder.end_trace()
|
||
|
||
assert len(trace.steps) == 4
|
||
assert trace.steps[0].action == "llm_call"
|
||
assert trace.steps[1].action == "tool_call"
|
||
assert trace.steps[2].action == "llm_call"
|
||
assert trace.steps[3].action == "final_answer"
|
||
assert trace.total_tokens == 180 # 100 + 0 + 80 + 0
|
||
|
||
def test_total_duration_calculated(self):
|
||
recorder = TraceRecorder()
|
||
recorder.start_trace(task_id="t3", agent_name="agent3")
|
||
recorder.record_step(step=1, action="llm_call", duration_ms=50)
|
||
recorder.record_step(step=1, action="tool_call", duration_ms=100)
|
||
trace = recorder.end_trace()
|
||
|
||
# total_duration_ms 应该基于实际经过的时间(>=0)
|
||
assert trace.total_duration_ms >= 0
|
||
|
||
def test_constructor_with_params_auto_starts(self):
|
||
recorder = TraceRecorder(task_id="t4", agent_name="agent4", skill_name="skill1")
|
||
recorder.record_step(step=1, action="llm_call", duration_ms=10)
|
||
trace = recorder.end_trace()
|
||
|
||
assert trace.task_id == "t4"
|
||
assert trace.agent_name == "agent4"
|
||
assert trace.skill_name == "skill1"
|
||
assert len(trace.steps) == 1
|
||
|
||
def test_start_trace_generates_uuid_when_no_task_id(self):
|
||
recorder = TraceRecorder()
|
||
recorder.start_trace(agent_name="agent5")
|
||
trace = recorder.end_trace()
|
||
|
||
assert trace.task_id # 应该有值(UUID)
|
||
assert len(trace.task_id) > 0
|
||
|
||
|
||
# ── TraceRecorder Edge Case Tests ────────────────────────
|
||
|
||
|
||
class TestTraceRecorderEdgeCases:
|
||
"""TraceRecorder 边界情况测试"""
|
||
|
||
def test_end_trace_without_start_returns_default(self):
|
||
recorder = TraceRecorder()
|
||
trace = recorder.end_trace(outcome="failure")
|
||
|
||
assert isinstance(trace, ExecutionTrace)
|
||
assert trace.task_id == "unknown"
|
||
assert trace.agent_name == ""
|
||
assert trace.outcome == "failure"
|
||
assert len(trace.steps) == 0
|
||
|
||
def test_get_trace_returns_trace_after_start(self):
|
||
recorder = TraceRecorder()
|
||
recorder.start_trace(task_id="t1", agent_name="a1")
|
||
trace = recorder.get_trace()
|
||
|
||
assert trace is not None
|
||
assert trace.task_id == "t1"
|
||
|
||
def test_get_trace_returns_none_before_start(self):
|
||
recorder = TraceRecorder()
|
||
trace = recorder.get_trace()
|
||
|
||
assert trace is None
|
||
|
||
def test_record_step_without_start_does_nothing(self):
|
||
recorder = TraceRecorder()
|
||
# 不应抛异常
|
||
recorder.record_step(step=1, action="llm_call")
|
||
trace = recorder.end_trace()
|
||
assert len(trace.steps) == 0
|
||
|
||
def test_elapsed_ms_without_timer_returns_zero(self):
|
||
recorder = TraceRecorder()
|
||
assert recorder.elapsed_ms() == 0
|
||
|
||
def test_start_step_timer_and_elapsed_ms(self):
|
||
recorder = TraceRecorder()
|
||
recorder.start_step_timer()
|
||
time.sleep(0.01) # 10ms
|
||
elapsed = recorder.elapsed_ms()
|
||
assert elapsed >= 8 # 至少 8ms(考虑精度)
|
||
|
||
|
||
# ── Integration: TraceRecorder with ReActEngine ──────────
|
||
|
||
|
||
class TestTraceRecorderWithReActEngine:
|
||
"""TraceRecorder 与 ReActEngine 集成测试"""
|
||
|
||
async def test_single_step_with_recorder(self):
|
||
from agentkit.core.react import ReActEngine
|
||
|
||
gateway = make_mock_gateway([
|
||
make_response(content="The answer is 42"),
|
||
])
|
||
engine = ReActEngine(llm_gateway=gateway)
|
||
recorder = TraceRecorder()
|
||
|
||
result = await engine.execute(
|
||
messages=[{"role": "user", "content": "What is the answer?"}],
|
||
trace_recorder=recorder,
|
||
)
|
||
|
||
trace = recorder.get_trace()
|
||
assert trace is not None
|
||
assert trace.outcome == "success"
|
||
assert len(trace.steps) == 1
|
||
assert trace.steps[0].action == "final_answer"
|
||
assert trace.steps[0].tokens_used > 0
|
||
|
||
async def test_two_step_with_recorder(self):
|
||
from agentkit.core.react import ReActEngine
|
||
|
||
tool = FakeTool(name="calculator", result={"value": 42})
|
||
gateway = make_mock_gateway([
|
||
make_response(
|
||
content="",
|
||
tool_calls=[ToolCall(id="tc_1", name="calculator", arguments={"expr": "6*7"})],
|
||
),
|
||
make_response(content="The result is 42"),
|
||
])
|
||
engine = ReActEngine(llm_gateway=gateway)
|
||
recorder = TraceRecorder()
|
||
|
||
result = await engine.execute(
|
||
messages=[{"role": "user", "content": "Calculate 6*7"}],
|
||
tools=[tool],
|
||
trace_recorder=recorder,
|
||
)
|
||
|
||
trace = recorder.get_trace()
|
||
assert trace is not None
|
||
assert trace.outcome == "success"
|
||
# 应记录: llm_call(步骤1) + tool_call(步骤1) + final_answer(步骤2)
|
||
# 注意: final_answer 分支中 LLM 调用和最终答案合并为一个 trace step
|
||
assert len(trace.steps) == 3
|
||
assert trace.steps[0].action == "llm_call"
|
||
assert trace.steps[1].action == "tool_call"
|
||
assert trace.steps[1].tool_name == "calculator"
|
||
assert trace.steps[1].input_data == {"expr": "6*7"}
|
||
assert trace.steps[1].output_data == {"value": 42}
|
||
assert trace.steps[2].action == "final_answer"
|
||
|
||
async def test_max_steps_outcome_is_partial(self):
|
||
from agentkit.core.react import ReActEngine
|
||
|
||
tool = FakeTool(name="search", result={"results": ["data"]})
|
||
always_tool_response = make_response(
|
||
content="Thinking...",
|
||
tool_calls=[ToolCall(id="tc_loop", name="search", arguments={"query": "more"})],
|
||
)
|
||
gateway = make_mock_gateway([always_tool_response] * 20)
|
||
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
|
||
recorder = TraceRecorder()
|
||
|
||
result = await engine.execute(
|
||
messages=[{"role": "user", "content": "Keep searching"}],
|
||
tools=[tool],
|
||
trace_recorder=recorder,
|
||
)
|
||
|
||
trace = recorder.get_trace()
|
||
assert trace is not None
|
||
assert trace.outcome == "partial"
|
||
|
||
async def test_without_recorder_backward_compatible(self):
|
||
"""不传 trace_recorder 时行为不变"""
|
||
from agentkit.core.react import ReActEngine
|
||
|
||
gateway = make_mock_gateway([
|
||
make_response(content="Direct answer"),
|
||
])
|
||
engine = ReActEngine(llm_gateway=gateway)
|
||
|
||
result = await engine.execute(
|
||
messages=[{"role": "user", "content": "Hello"}],
|
||
)
|
||
|
||
assert result.output == "Direct answer"
|
||
assert result.total_steps == 1
|
||
|
||
async def test_tool_error_recorded_in_trace(self):
|
||
from agentkit.core.react import ReActEngine
|
||
|
||
gateway = make_mock_gateway([
|
||
make_response(
|
||
content="",
|
||
tool_calls=[ToolCall(id="tc_1", name="nonexistent_tool", arguments={})],
|
||
),
|
||
make_response(content="Tool not found, here is my answer"),
|
||
])
|
||
engine = ReActEngine(llm_gateway=gateway)
|
||
recorder = TraceRecorder()
|
||
|
||
result = await engine.execute(
|
||
messages=[{"role": "user", "content": "Use unknown tool"}],
|
||
tools=[],
|
||
trace_recorder=recorder,
|
||
)
|
||
|
||
trace = recorder.get_trace()
|
||
assert trace is not None
|
||
# 找到 tool_call 步骤
|
||
tool_steps = [s for s in trace.steps if s.action == "tool_call"]
|
||
assert len(tool_steps) == 1
|
||
assert tool_steps[0].error is not None
|
||
|
||
async def test_trace_total_tokens(self):
|
||
from agentkit.core.react import ReActEngine
|
||
|
||
tool = FakeTool(name="search", result={"results": ["data"]})
|
||
gateway = make_mock_gateway([
|
||
make_response(
|
||
content="",
|
||
tool_calls=[ToolCall(id="tc_1", name="search", arguments={"q": "test"})],
|
||
prompt_tokens=100,
|
||
completion_tokens=50,
|
||
),
|
||
make_response(
|
||
content="Final answer",
|
||
prompt_tokens=200,
|
||
completion_tokens=30,
|
||
),
|
||
])
|
||
engine = ReActEngine(llm_gateway=gateway)
|
||
recorder = TraceRecorder()
|
||
|
||
result = await engine.execute(
|
||
messages=[{"role": "user", "content": "Search"}],
|
||
tools=[tool],
|
||
trace_recorder=recorder,
|
||
)
|
||
|
||
trace = recorder.get_trace()
|
||
assert trace is not None
|
||
assert trace.total_tokens == 380 # 150 + 230
|
||
|
||
async def test_agent_name_and_skill_name_in_trace(self):
|
||
from agentkit.core.react import ReActEngine
|
||
|
||
gateway = make_mock_gateway([
|
||
make_response(content="Done"),
|
||
])
|
||
engine = ReActEngine(llm_gateway=gateway)
|
||
recorder = TraceRecorder()
|
||
|
||
result = await engine.execute(
|
||
messages=[{"role": "user", "content": "Hello"}],
|
||
agent_name="test_agent",
|
||
task_type="search_task",
|
||
trace_recorder=recorder,
|
||
)
|
||
|
||
trace = recorder.get_trace()
|
||
assert trace.agent_name == "test_agent"
|
||
assert trace.skill_name == "search_task"
|
||
|
||
|
||
# ── Integration: TraceRecorder with execute_stream ───────
|
||
|
||
|
||
class TestTraceRecorderWithExecuteStream:
|
||
"""TraceRecorder 与 execute_stream 集成测试"""
|
||
|
||
async def test_stream_with_recorder(self):
|
||
from agentkit.core.react import ReActEngine, ReActEvent
|
||
|
||
tool = FakeTool(name="search", result={"results": ["data"]})
|
||
gateway = make_mock_gateway([
|
||
make_response(
|
||
content="",
|
||
tool_calls=[ToolCall(id="tc_1", name="search", arguments={"q": "test"})],
|
||
),
|
||
make_response(content="Final answer"),
|
||
])
|
||
engine = ReActEngine(llm_gateway=gateway)
|
||
recorder = TraceRecorder()
|
||
|
||
events = []
|
||
async for event in engine.execute_stream(
|
||
messages=[{"role": "user", "content": "Search"}],
|
||
tools=[tool],
|
||
trace_recorder=recorder,
|
||
):
|
||
events.append(event)
|
||
|
||
trace = recorder.get_trace()
|
||
assert trace is not None
|
||
assert trace.outcome == "success"
|
||
# llm_call(步骤1) + tool_call(步骤1) + final_answer(步骤2)
|
||
assert len(trace.steps) == 3
|
||
|
||
async def test_stream_without_recorder_backward_compatible(self):
|
||
from agentkit.core.react import ReActEngine
|
||
|
||
gateway = make_mock_gateway([
|
||
make_response(content="Direct answer"),
|
||
])
|
||
engine = ReActEngine(llm_gateway=gateway)
|
||
|
||
events = []
|
||
async for event in engine.execute_stream(
|
||
messages=[{"role": "user", "content": "Hello"}],
|
||
):
|
||
events.append(event)
|
||
|
||
assert any(e.event_type == "final_answer" for e in events)
|