feat(react): add loop detection to prevent repeated identical tool calls
U1: Sliding window hash detection in ReAct loop. When the same tool is called with identical arguments >= threshold times (default 2), injects a correction message first, then raises LoopDetectedError if the LLM doesn't change strategy. Covers both _execute_loop and execute_stream.
This commit is contained in:
parent
a312e584ae
commit
018b342d96
|
|
@ -60,6 +60,22 @@ class TaskCancelledError(AgentFrameworkError):
|
|||
super().__init__(f"Task {task_id} was cancelled")
|
||||
|
||||
|
||||
class LoopDetectedError(AgentFrameworkError):
|
||||
"""ReAct 循环检测异常 — LLM 重复调用相同工具+参数且纠正后未改变策略。
|
||||
|
||||
ponytail: 滑动窗口 hash 检测,窗口大小和阈值可配置。
|
||||
升级路径:可引入语义相似度检测(embedding 距离)替代精确 hash。
|
||||
"""
|
||||
|
||||
def __init__(self, tool_name: str, repetitions: int):
|
||||
self.tool_name = tool_name
|
||||
self.repetitions = repetitions
|
||||
super().__init__(
|
||||
f"Loop detected: tool '{tool_name}' called {repetitions} times "
|
||||
f"with identical arguments after correction"
|
||||
)
|
||||
|
||||
|
||||
class NoAvailableAgentError(AgentFrameworkError):
|
||||
def __init__(self, task_type: str):
|
||||
self.task_type = task_type
|
||||
|
|
|
|||
|
|
@ -9,11 +9,12 @@ import json
|
|||
import logging
|
||||
import re
|
||||
import time
|
||||
from collections import Counter, deque
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from agentkit.core.exceptions import TaskCancelledError, TaskTimeoutError
|
||||
from agentkit.core.exceptions import LoopDetectedError, TaskCancelledError, TaskTimeoutError
|
||||
from agentkit.core.protocol import CancellationToken
|
||||
from agentkit.llm.gateway import LLMGateway
|
||||
from agentkit.llm.protocol import LLMResponse
|
||||
|
|
@ -187,6 +188,12 @@ class ReActEngine:
|
|||
|
||||
self._compressor = ContextCompressor(llm_gateway=llm_gateway, keep_recent=10)
|
||||
|
||||
# Loop detection: sliding window of (tool_name, args_hash) to catch
|
||||
# repeated identical tool calls. ponytail: hash-based, not semantic.
|
||||
self._loop_window: deque[str] = deque(maxlen=5)
|
||||
self._loop_threshold: int = 2
|
||||
self._loop_corrected: bool = False
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset internal state for reuse across conversations.
|
||||
|
||||
|
|
@ -196,7 +203,31 @@ class ReActEngine:
|
|||
# ReActEngine is stateless between calls — conversation history,
|
||||
# step counts, and trajectory are local to each execute call.
|
||||
# This method exists for API clarity and future stateful extensions.
|
||||
pass
|
||||
self._loop_window.clear()
|
||||
self._loop_corrected = False
|
||||
|
||||
def _check_tool_loop(self, tool_calls: list[Any]) -> str | None:
|
||||
"""检测重复工具调用模式。
|
||||
|
||||
将当前步的工具调用 hash 加入滑动窗口,若同一 hash 在窗口内出现
|
||||
>= threshold 次,返回对应的 tool_name;否则返回 None。
|
||||
|
||||
ponytail: 精确 hash 匹配,不做语义相似度。
|
||||
"""
|
||||
for tc in tool_calls:
|
||||
args_str = json.dumps(tc.arguments, sort_keys=True, default=str)
|
||||
h = hash(f"{tc.name}:{args_str}")
|
||||
self._loop_window.append(str(h))
|
||||
|
||||
counts = Counter(self._loop_window)
|
||||
for h, count in counts.items():
|
||||
if count >= self._loop_threshold:
|
||||
# Find the tool name for this hash
|
||||
for tc in tool_calls:
|
||||
args_str = json.dumps(tc.arguments, sort_keys=True, default=str)
|
||||
if str(hash(f"{tc.name}:{args_str}")) == h:
|
||||
return tc.name
|
||||
return None
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
|
|
@ -407,6 +438,33 @@ class ReActEngine:
|
|||
|
||||
# 检查是否有 Function Calling 的 tool_calls
|
||||
if response.has_tool_calls:
|
||||
# 循环检测:检查是否重复调用相同工具+参数
|
||||
looped_tool = self._check_tool_loop(response.tool_calls)
|
||||
if looped_tool is not None:
|
||||
if not self._loop_corrected:
|
||||
# 第一次检测:注入纠正消息,给 LLM 改变策略的机会
|
||||
logger.warning(
|
||||
f"Loop detected: tool '{looped_tool}' repeated, "
|
||||
f"injecting correction at step {step}"
|
||||
)
|
||||
correction_msg = {
|
||||
"role": "user",
|
||||
"content": (
|
||||
f"You are repeatedly calling tool '{looped_tool}' "
|
||||
f"with the same arguments. This indicates a loop. "
|
||||
f"Please change your strategy or provide a final answer."
|
||||
),
|
||||
}
|
||||
conversation.append(correction_msg)
|
||||
self._loop_corrected = True
|
||||
continue
|
||||
else:
|
||||
# 第二次检测:纠正后仍未改变,强制中断
|
||||
raise LoopDetectedError(
|
||||
tool_name=looped_tool,
|
||||
repetitions=self._loop_threshold + 1,
|
||||
)
|
||||
|
||||
# 记录 LLM 调用步骤
|
||||
if trace_recorder is not None:
|
||||
trace_recorder.record_step(
|
||||
|
|
@ -1014,14 +1072,41 @@ class ReActEngine:
|
|||
total_tokens += step_tokens
|
||||
|
||||
if response.has_tool_calls:
|
||||
# 循环检测:检查是否重复调用相同工具+参数
|
||||
looped_tool = self._check_tool_loop(response.tool_calls)
|
||||
if looped_tool is not None:
|
||||
if not self._loop_corrected:
|
||||
logger.warning(
|
||||
f"Loop detected (stream): tool '{looped_tool}' repeated, "
|
||||
f"injecting correction at step {step}"
|
||||
)
|
||||
correction_msg = {
|
||||
"role": "user",
|
||||
"content": (
|
||||
f"You are repeatedly calling tool '{looped_tool}' "
|
||||
f"with the same arguments. This indicates a loop. "
|
||||
f"Please change your strategy or provide a final answer."
|
||||
),
|
||||
}
|
||||
conversation.append(correction_msg)
|
||||
self._loop_corrected = True
|
||||
yield ReActEvent(
|
||||
event_type="step",
|
||||
step=step,
|
||||
data={
|
||||
"message": f"Loop detected: tool '{looped_tool}' repeated. Correction injected.",
|
||||
"loop_detected": True,
|
||||
"tool_name": looped_tool,
|
||||
},
|
||||
)
|
||||
continue
|
||||
else:
|
||||
raise LoopDetectedError(
|
||||
tool_name=looped_tool,
|
||||
repetitions=self._loop_threshold + 1,
|
||||
)
|
||||
|
||||
# 记录 LLM 调用步骤
|
||||
if trace_recorder is not None:
|
||||
trace_recorder.record_step(
|
||||
step=step,
|
||||
action="llm_call",
|
||||
duration_ms=llm_duration_ms,
|
||||
tokens_used=step_tokens,
|
||||
)
|
||||
|
||||
# Record assistant message
|
||||
assistant_msg: dict[str, Any] = {
|
||||
|
|
|
|||
|
|
@ -165,12 +165,15 @@ class TestReActMaxSteps:
|
|||
|
||||
tool = FakeTool(name="search", result={"results": ["data"]})
|
||||
|
||||
# LLM 一直返回 tool_calls,不会给出 final answer
|
||||
always_tool_response = make_response(
|
||||
content="Thinking...",
|
||||
tool_calls=[ToolCall(id="tc_loop", name="search", arguments={"query": "more"})],
|
||||
)
|
||||
gateway = make_mock_gateway([always_tool_response] * 20)
|
||||
# LLM 一直返回 tool_calls(参数递增以避免循环检测),不会给出 final answer
|
||||
responses = [
|
||||
make_response(
|
||||
content="Thinking...",
|
||||
tool_calls=[ToolCall(id=f"tc_{i}", name="search", arguments={"query": f"attempt_{i}"})],
|
||||
)
|
||||
for i in range(20)
|
||||
]
|
||||
gateway = make_mock_gateway(responses)
|
||||
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
|
||||
|
||||
result = await engine.execute(
|
||||
|
|
@ -653,3 +656,123 @@ class TestReActCancellation:
|
|||
)
|
||||
assert result.output == "Answer"
|
||||
assert result.status == "success"
|
||||
|
||||
|
||||
class TestLoopDetection:
|
||||
"""循环检测:ReAct 循环内滑动窗口 hash 检测重复工具调用"""
|
||||
|
||||
async def test_normal_different_tools_no_detection(self):
|
||||
"""不同工具调用不触发检测"""
|
||||
from agentkit.core.react import ReActEngine
|
||||
|
||||
tool1 = FakeTool(name="search", result={"results": ["a"]})
|
||||
tool2 = FakeTool(name="calculator", result={"value": 42})
|
||||
gateway = make_mock_gateway([
|
||||
make_response(
|
||||
tool_calls=[ToolCall(id="tc_1", name="search", arguments={"q": "test"})],
|
||||
),
|
||||
make_response(
|
||||
tool_calls=[ToolCall(id="tc_2", name="calculator", arguments={"expr": "6*7"})],
|
||||
),
|
||||
make_response(content="Done"),
|
||||
])
|
||||
engine = ReActEngine(llm_gateway=gateway)
|
||||
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "Search and calculate"}],
|
||||
tools=[tool1, tool2],
|
||||
)
|
||||
assert result.status == "success"
|
||||
assert result.total_steps == 3
|
||||
|
||||
async def test_same_tool_different_args_no_detection(self):
|
||||
"""相同工具不同参数不触发检测"""
|
||||
from agentkit.core.react import ReActEngine
|
||||
|
||||
tool = FakeTool(name="search", result={"results": []})
|
||||
gateway = make_mock_gateway([
|
||||
make_response(
|
||||
tool_calls=[ToolCall(id="tc_1", name="search", arguments={"q": "hello"})],
|
||||
),
|
||||
make_response(
|
||||
tool_calls=[ToolCall(id="tc_2", name="search", arguments={"q": "world"})],
|
||||
),
|
||||
make_response(content="Done"),
|
||||
])
|
||||
engine = ReActEngine(llm_gateway=gateway)
|
||||
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "Search twice"}],
|
||||
tools=[tool],
|
||||
)
|
||||
assert result.status == "success"
|
||||
assert result.total_steps == 3
|
||||
|
||||
async def test_loop_detected_injects_correction_then_raises(self):
|
||||
"""连续重复调用相同工具+参数:第一次注入纠正,第二次抛 LoopDetectedError"""
|
||||
from agentkit.core.react import ReActEngine
|
||||
from agentkit.core.exceptions import LoopDetectedError
|
||||
|
||||
tool = FakeTool(name="search", result={"results": []})
|
||||
# Step 1: tool call (executed, window=[hash])
|
||||
# Step 2: same tool call (detected, correction injected, continue)
|
||||
# Step 3: same tool call again (detected, already corrected → raise)
|
||||
gateway = make_mock_gateway([
|
||||
make_response(
|
||||
tool_calls=[ToolCall(id="tc_1", name="search", arguments={"q": "test"})],
|
||||
),
|
||||
make_response(
|
||||
tool_calls=[ToolCall(id="tc_2", name="search", arguments={"q": "test"})],
|
||||
),
|
||||
make_response(
|
||||
tool_calls=[ToolCall(id="tc_3", name="search", arguments={"q": "test"})],
|
||||
),
|
||||
])
|
||||
engine = ReActEngine(llm_gateway=gateway, max_steps=10)
|
||||
|
||||
with pytest.raises(LoopDetectedError) as exc_info:
|
||||
await engine.execute(
|
||||
messages=[{"role": "user", "content": "Search"}],
|
||||
tools=[tool],
|
||||
)
|
||||
assert "search" in str(exc_info.value)
|
||||
|
||||
async def test_loop_correction_allows_recovery(self):
|
||||
"""循环检测注入纠正后,LLM 改变策略则正常完成"""
|
||||
from agentkit.core.react import ReActEngine
|
||||
|
||||
tool = FakeTool(name="search", result={"results": []})
|
||||
# Step 1: tool call (executed)
|
||||
# Step 2: same tool call (detected, correction injected)
|
||||
# Step 3: LLM changes strategy → final answer
|
||||
gateway = make_mock_gateway([
|
||||
make_response(
|
||||
tool_calls=[ToolCall(id="tc_1", name="search", arguments={"q": "test"})],
|
||||
),
|
||||
make_response(
|
||||
tool_calls=[ToolCall(id="tc_2", name="search", arguments={"q": "test"})],
|
||||
),
|
||||
make_response(content="I found the answer after changing strategy"),
|
||||
])
|
||||
engine = ReActEngine(llm_gateway=gateway, max_steps=10)
|
||||
|
||||
result = await engine.execute(
|
||||
messages=[{"role": "user", "content": "Search"}],
|
||||
tools=[tool],
|
||||
)
|
||||
assert result.status == "success"
|
||||
assert "changing strategy" in result.output
|
||||
|
||||
async def test_reset_clears_loop_state(self):
|
||||
"""reset() 清除循环检测状态"""
|
||||
from agentkit.core.react import ReActEngine
|
||||
|
||||
gateway = make_mock_gateway([make_response(content="Done")])
|
||||
engine = ReActEngine(llm_gateway=gateway)
|
||||
engine._loop_window.append("some_hash")
|
||||
engine._loop_corrected = True
|
||||
|
||||
engine.reset()
|
||||
|
||||
assert len(engine._loop_window) == 0
|
||||
assert engine._loop_corrected is False
|
||||
|
|
|
|||
Loading…
Reference in New Issue