fischer-agentkit/tests/unit/test_trace_recorder.py

483 lines
16 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""TraceRecorder 单元测试"""
import time
from unittest.mock import AsyncMock, MagicMock
import pytest
from agentkit.core.trace import ExecutionTrace, TraceRecorder, TraceStep
from agentkit.llm.gateway import LLMGateway
from agentkit.llm.protocol import LLMResponse, TokenUsage, ToolCall
from agentkit.tools.base import Tool
# ── Test Helpers ──────────────────────────────────────────
class FakeTool(Tool):
"""用于测试的 Fake Tool"""
def __init__(
self,
name: str = "fake_tool",
description: str = "A fake tool for testing",
result: dict | None = None,
):
super().__init__(name=name, description=description)
self._result = result or {"status": "ok"}
async def execute(self, **kwargs) -> dict:
return self._result
def make_mock_gateway(responses: list[LLMResponse]) -> LLMGateway:
"""创建一个 mock LLMGateway按顺序返回给定响应"""
gateway = MagicMock(spec=LLMGateway)
gateway.chat = AsyncMock(side_effect=responses)
return gateway
def make_response(
content: str = "",
tool_calls: list[ToolCall] | None = None,
prompt_tokens: int = 10,
completion_tokens: int = 20,
) -> LLMResponse:
"""快速构造 LLMResponse"""
return LLMResponse(
content=content,
model="test-model",
usage=TokenUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
),
tool_calls=tool_calls or [],
)
# ── TraceStep Tests ──────────────────────────────────────
class TestTraceStep:
"""TraceStep 数据类测试"""
def test_to_dict_with_all_fields(self):
step = TraceStep(
step=1,
action="tool_call",
tool_name="search",
input_data={"query": "test"},
output_data={"results": ["found"]},
duration_ms=100,
tokens_used=50,
error=None,
)
d = step.to_dict()
assert d["step"] == 1
assert d["action"] == "tool_call"
assert d["tool_name"] == "search"
assert d["input_data"] == {"query": "test"}
assert d["output_data"] == {"results": ["found"]}
assert d["duration_ms"] == 100
assert d["tokens_used"] == 50
assert "error" not in d
def test_to_dict_omits_none_fields(self):
step = TraceStep(step=1, action="llm_call", duration_ms=50, tokens_used=30)
d = step.to_dict()
assert "tool_name" not in d
assert "input_data" not in d
assert "output_data" not in d
assert "error" not in d
def test_to_dict_includes_error_when_present(self):
step = TraceStep(step=1, action="tool_call", error="Tool not found")
d = step.to_dict()
assert d["error"] == "Tool not found"
# ── ExecutionTrace Tests ─────────────────────────────────
class TestExecutionTrace:
"""ExecutionTrace 数据类测试"""
def test_to_dict(self):
trace = ExecutionTrace(
task_id="t1",
agent_name="agent1",
skill_name="search_skill",
steps=[
TraceStep(step=1, action="llm_call", duration_ms=50, tokens_used=30),
TraceStep(step=1, action="tool_call", tool_name="search", duration_ms=100, tokens_used=0),
],
total_duration_ms=150,
total_tokens=30,
outcome="success",
quality_score=0.9,
)
d = trace.to_dict()
assert d["task_id"] == "t1"
assert d["agent_name"] == "agent1"
assert d["skill_name"] == "search_skill"
assert len(d["steps"]) == 2
assert d["total_duration_ms"] == 150
assert d["total_tokens"] == 30
assert d["outcome"] == "success"
assert d["quality_score"] == 0.9
# ── TraceRecorder Happy Path Tests ───────────────────────
class TestTraceRecorderHappyPath:
"""TraceRecorder 正常流程测试"""
def test_start_record_end_returns_trace(self):
recorder = TraceRecorder()
recorder.start_trace(task_id="t1", agent_name="agent1")
recorder.record_step(
step=1,
action="llm_call",
duration_ms=50,
tokens_used=30,
)
recorder.record_step(
step=1,
action="tool_call",
tool_name="search",
input_data={"query": "test"},
output_data={"results": ["found"]},
duration_ms=100,
)
trace = recorder.end_trace(outcome="success", quality_score=0.9)
assert isinstance(trace, ExecutionTrace)
assert trace.task_id == "t1"
assert trace.agent_name == "agent1"
assert trace.outcome == "success"
assert trace.quality_score == 0.9
assert len(trace.steps) == 2
assert trace.steps[0].action == "llm_call"
assert trace.steps[1].action == "tool_call"
assert trace.steps[1].tool_name == "search"
def test_multiple_steps_recorded_in_order(self):
recorder = TraceRecorder()
recorder.start_trace(task_id="t2", agent_name="agent2")
recorder.record_step(step=1, action="llm_call", tokens_used=100)
recorder.record_step(step=1, action="tool_call", tool_name="calc", tokens_used=0)
recorder.record_step(step=2, action="llm_call", tokens_used=80)
recorder.record_step(step=2, action="final_answer", tokens_used=0)
trace = recorder.end_trace()
assert len(trace.steps) == 4
assert trace.steps[0].action == "llm_call"
assert trace.steps[1].action == "tool_call"
assert trace.steps[2].action == "llm_call"
assert trace.steps[3].action == "final_answer"
assert trace.total_tokens == 180 # 100 + 0 + 80 + 0
def test_total_duration_calculated(self):
recorder = TraceRecorder()
recorder.start_trace(task_id="t3", agent_name="agent3")
recorder.record_step(step=1, action="llm_call", duration_ms=50)
recorder.record_step(step=1, action="tool_call", duration_ms=100)
trace = recorder.end_trace()
# total_duration_ms 应该基于实际经过的时间(>=0
assert trace.total_duration_ms >= 0
def test_constructor_with_params_auto_starts(self):
recorder = TraceRecorder(task_id="t4", agent_name="agent4", skill_name="skill1")
recorder.record_step(step=1, action="llm_call", duration_ms=10)
trace = recorder.end_trace()
assert trace.task_id == "t4"
assert trace.agent_name == "agent4"
assert trace.skill_name == "skill1"
assert len(trace.steps) == 1
def test_start_trace_generates_uuid_when_no_task_id(self):
recorder = TraceRecorder()
recorder.start_trace(agent_name="agent5")
trace = recorder.end_trace()
assert trace.task_id # 应该有值UUID
assert len(trace.task_id) > 0
# ── TraceRecorder Edge Case Tests ────────────────────────
class TestTraceRecorderEdgeCases:
"""TraceRecorder 边界情况测试"""
def test_end_trace_without_start_returns_default(self):
recorder = TraceRecorder()
trace = recorder.end_trace(outcome="failure")
assert isinstance(trace, ExecutionTrace)
assert trace.task_id == "unknown"
assert trace.agent_name == ""
assert trace.outcome == "failure"
assert len(trace.steps) == 0
def test_get_trace_returns_trace_after_start(self):
recorder = TraceRecorder()
recorder.start_trace(task_id="t1", agent_name="a1")
trace = recorder.get_trace()
assert trace is not None
assert trace.task_id == "t1"
def test_get_trace_returns_none_before_start(self):
recorder = TraceRecorder()
trace = recorder.get_trace()
assert trace is None
def test_record_step_without_start_does_nothing(self):
recorder = TraceRecorder()
# 不应抛异常
recorder.record_step(step=1, action="llm_call")
trace = recorder.end_trace()
assert len(trace.steps) == 0
def test_elapsed_ms_without_timer_returns_zero(self):
recorder = TraceRecorder()
assert recorder.elapsed_ms() == 0
def test_start_step_timer_and_elapsed_ms(self):
recorder = TraceRecorder()
recorder.start_step_timer()
time.sleep(0.01) # 10ms
elapsed = recorder.elapsed_ms()
assert elapsed >= 8 # 至少 8ms考虑精度
# ── Integration: TraceRecorder with ReActEngine ──────────
class TestTraceRecorderWithReActEngine:
"""TraceRecorder 与 ReActEngine 集成测试"""
async def test_single_step_with_recorder(self):
from agentkit.core.react import ReActEngine
gateway = make_mock_gateway([
make_response(content="The answer is 42"),
])
engine = ReActEngine(llm_gateway=gateway)
recorder = TraceRecorder()
result = await engine.execute(
messages=[{"role": "user", "content": "What is the answer?"}],
trace_recorder=recorder,
)
trace = recorder.get_trace()
assert trace is not None
assert trace.outcome == "success"
assert len(trace.steps) == 1
assert trace.steps[0].action == "final_answer"
assert trace.steps[0].tokens_used > 0
async def test_two_step_with_recorder(self):
from agentkit.core.react import ReActEngine
tool = FakeTool(name="calculator", result={"value": 42})
gateway = make_mock_gateway([
make_response(
content="",
tool_calls=[ToolCall(id="tc_1", name="calculator", arguments={"expr": "6*7"})],
),
make_response(content="The result is 42"),
])
engine = ReActEngine(llm_gateway=gateway)
recorder = TraceRecorder()
result = await engine.execute(
messages=[{"role": "user", "content": "Calculate 6*7"}],
tools=[tool],
trace_recorder=recorder,
)
trace = recorder.get_trace()
assert trace is not None
assert trace.outcome == "success"
# 应记录: llm_call(步骤1) + tool_call(步骤1) + final_answer(步骤2)
# 注意: final_answer 分支中 LLM 调用和最终答案合并为一个 trace step
assert len(trace.steps) == 3
assert trace.steps[0].action == "llm_call"
assert trace.steps[1].action == "tool_call"
assert trace.steps[1].tool_name == "calculator"
assert trace.steps[1].input_data == {"expr": "6*7"}
assert trace.steps[1].output_data == {"value": 42}
assert trace.steps[2].action == "final_answer"
async def test_max_steps_outcome_is_partial(self):
from agentkit.core.react import ReActEngine
tool = FakeTool(name="search", result={"results": ["data"]})
always_tool_response = make_response(
content="Thinking...",
tool_calls=[ToolCall(id="tc_loop", name="search", arguments={"query": "more"})],
)
gateway = make_mock_gateway([always_tool_response] * 20)
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
recorder = TraceRecorder()
result = await engine.execute(
messages=[{"role": "user", "content": "Keep searching"}],
tools=[tool],
trace_recorder=recorder,
)
trace = recorder.get_trace()
assert trace is not None
assert trace.outcome == "partial"
async def test_without_recorder_backward_compatible(self):
"""不传 trace_recorder 时行为不变"""
from agentkit.core.react import ReActEngine
gateway = make_mock_gateway([
make_response(content="Direct answer"),
])
engine = ReActEngine(llm_gateway=gateway)
result = await engine.execute(
messages=[{"role": "user", "content": "Hello"}],
)
assert result.output == "Direct answer"
assert result.total_steps == 1
async def test_tool_error_recorded_in_trace(self):
from agentkit.core.react import ReActEngine
gateway = make_mock_gateway([
make_response(
content="",
tool_calls=[ToolCall(id="tc_1", name="nonexistent_tool", arguments={})],
),
make_response(content="Tool not found, here is my answer"),
])
engine = ReActEngine(llm_gateway=gateway)
recorder = TraceRecorder()
result = await engine.execute(
messages=[{"role": "user", "content": "Use unknown tool"}],
tools=[],
trace_recorder=recorder,
)
trace = recorder.get_trace()
assert trace is not None
# 找到 tool_call 步骤
tool_steps = [s for s in trace.steps if s.action == "tool_call"]
assert len(tool_steps) == 1
assert tool_steps[0].error is not None
async def test_trace_total_tokens(self):
from agentkit.core.react import ReActEngine
tool = FakeTool(name="search", result={"results": ["data"]})
gateway = make_mock_gateway([
make_response(
content="",
tool_calls=[ToolCall(id="tc_1", name="search", arguments={"q": "test"})],
prompt_tokens=100,
completion_tokens=50,
),
make_response(
content="Final answer",
prompt_tokens=200,
completion_tokens=30,
),
])
engine = ReActEngine(llm_gateway=gateway)
recorder = TraceRecorder()
result = await engine.execute(
messages=[{"role": "user", "content": "Search"}],
tools=[tool],
trace_recorder=recorder,
)
trace = recorder.get_trace()
assert trace is not None
assert trace.total_tokens == 380 # 150 + 230
async def test_agent_name_and_skill_name_in_trace(self):
from agentkit.core.react import ReActEngine
gateway = make_mock_gateway([
make_response(content="Done"),
])
engine = ReActEngine(llm_gateway=gateway)
recorder = TraceRecorder()
result = await engine.execute(
messages=[{"role": "user", "content": "Hello"}],
agent_name="test_agent",
task_type="search_task",
trace_recorder=recorder,
)
trace = recorder.get_trace()
assert trace.agent_name == "test_agent"
assert trace.skill_name == "search_task"
# ── Integration: TraceRecorder with execute_stream ───────
class TestTraceRecorderWithExecuteStream:
"""TraceRecorder 与 execute_stream 集成测试"""
async def test_stream_with_recorder(self):
from agentkit.core.react import ReActEngine, ReActEvent
tool = FakeTool(name="search", result={"results": ["data"]})
gateway = make_mock_gateway([
make_response(
content="",
tool_calls=[ToolCall(id="tc_1", name="search", arguments={"q": "test"})],
),
make_response(content="Final answer"),
])
engine = ReActEngine(llm_gateway=gateway)
recorder = TraceRecorder()
events = []
async for event in engine.execute_stream(
messages=[{"role": "user", "content": "Search"}],
tools=[tool],
trace_recorder=recorder,
):
events.append(event)
trace = recorder.get_trace()
assert trace is not None
assert trace.outcome == "success"
# llm_call(步骤1) + tool_call(步骤1) + final_answer(步骤2)
assert len(trace.steps) == 3
async def test_stream_without_recorder_backward_compatible(self):
from agentkit.core.react import ReActEngine
gateway = make_mock_gateway([
make_response(content="Direct answer"),
])
engine = ReActEngine(llm_gateway=gateway)
events = []
async for event in engine.execute_stream(
messages=[{"role": "user", "content": "Hello"}],
):
events.append(event)
assert any(e.event_type == "final_answer" for e in events)