feat(react): add loop detection to prevent repeated identical tool calls

U1: Sliding window hash detection in ReAct loop. When the same tool is
called with identical arguments >= threshold times (default 2), injects
a correction message first, then raises LoopDetectedError if the LLM
doesn't change strategy. Covers both _execute_loop and execute_stream.
This commit is contained in:
chiguyong 2026-06-24 20:12:35 +08:00
parent a312e584ae
commit 018b342d96
3 changed files with 239 additions and 15 deletions

View File

@ -60,6 +60,22 @@ class TaskCancelledError(AgentFrameworkError):
super().__init__(f"Task {task_id} was cancelled") super().__init__(f"Task {task_id} was cancelled")
class LoopDetectedError(AgentFrameworkError):
"""ReAct 循环检测异常 — LLM 重复调用相同工具+参数且纠正后未改变策略。
ponytail: 滑动窗口 hash 检测窗口大小和阈值可配置
升级路径可引入语义相似度检测embedding 距离替代精确 hash
"""
def __init__(self, tool_name: str, repetitions: int):
self.tool_name = tool_name
self.repetitions = repetitions
super().__init__(
f"Loop detected: tool '{tool_name}' called {repetitions} times "
f"with identical arguments after correction"
)
class NoAvailableAgentError(AgentFrameworkError): class NoAvailableAgentError(AgentFrameworkError):
def __init__(self, task_type: str): def __init__(self, task_type: str):
self.task_type = task_type self.task_type = task_type

View File

@ -9,11 +9,12 @@ import json
import logging import logging
import re import re
import time import time
from collections import Counter, deque
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from agentkit.core.exceptions import TaskCancelledError, TaskTimeoutError from agentkit.core.exceptions import LoopDetectedError, TaskCancelledError, TaskTimeoutError
from agentkit.core.protocol import CancellationToken from agentkit.core.protocol import CancellationToken
from agentkit.llm.gateway import LLMGateway from agentkit.llm.gateway import LLMGateway
from agentkit.llm.protocol import LLMResponse from agentkit.llm.protocol import LLMResponse
@ -187,6 +188,12 @@ class ReActEngine:
self._compressor = ContextCompressor(llm_gateway=llm_gateway, keep_recent=10) self._compressor = ContextCompressor(llm_gateway=llm_gateway, keep_recent=10)
# Loop detection: sliding window of (tool_name, args_hash) to catch
# repeated identical tool calls. ponytail: hash-based, not semantic.
self._loop_window: deque[str] = deque(maxlen=5)
self._loop_threshold: int = 2
self._loop_corrected: bool = False
def reset(self) -> None: def reset(self) -> None:
"""Reset internal state for reuse across conversations. """Reset internal state for reuse across conversations.
@ -196,7 +203,31 @@ class ReActEngine:
# ReActEngine is stateless between calls — conversation history, # ReActEngine is stateless between calls — conversation history,
# step counts, and trajectory are local to each execute call. # step counts, and trajectory are local to each execute call.
# This method exists for API clarity and future stateful extensions. # This method exists for API clarity and future stateful extensions.
pass self._loop_window.clear()
self._loop_corrected = False
def _check_tool_loop(self, tool_calls: list[Any]) -> str | None:
"""检测重复工具调用模式。
将当前步的工具调用 hash 加入滑动窗口若同一 hash 在窗口内出现
>= threshold 返回对应的 tool_name否则返回 None
ponytail: 精确 hash 匹配不做语义相似度
"""
for tc in tool_calls:
args_str = json.dumps(tc.arguments, sort_keys=True, default=str)
h = hash(f"{tc.name}:{args_str}")
self._loop_window.append(str(h))
counts = Counter(self._loop_window)
for h, count in counts.items():
if count >= self._loop_threshold:
# Find the tool name for this hash
for tc in tool_calls:
args_str = json.dumps(tc.arguments, sort_keys=True, default=str)
if str(hash(f"{tc.name}:{args_str}")) == h:
return tc.name
return None
async def execute( async def execute(
self, self,
@ -407,6 +438,33 @@ class ReActEngine:
# 检查是否有 Function Calling 的 tool_calls # 检查是否有 Function Calling 的 tool_calls
if response.has_tool_calls: if response.has_tool_calls:
# 循环检测:检查是否重复调用相同工具+参数
looped_tool = self._check_tool_loop(response.tool_calls)
if looped_tool is not None:
if not self._loop_corrected:
# 第一次检测:注入纠正消息,给 LLM 改变策略的机会
logger.warning(
f"Loop detected: tool '{looped_tool}' repeated, "
f"injecting correction at step {step}"
)
correction_msg = {
"role": "user",
"content": (
f"You are repeatedly calling tool '{looped_tool}' "
f"with the same arguments. This indicates a loop. "
f"Please change your strategy or provide a final answer."
),
}
conversation.append(correction_msg)
self._loop_corrected = True
continue
else:
# 第二次检测:纠正后仍未改变,强制中断
raise LoopDetectedError(
tool_name=looped_tool,
repetitions=self._loop_threshold + 1,
)
# 记录 LLM 调用步骤 # 记录 LLM 调用步骤
if trace_recorder is not None: if trace_recorder is not None:
trace_recorder.record_step( trace_recorder.record_step(
@ -1014,14 +1072,41 @@ class ReActEngine:
total_tokens += step_tokens total_tokens += step_tokens
if response.has_tool_calls: if response.has_tool_calls:
# 循环检测:检查是否重复调用相同工具+参数
looped_tool = self._check_tool_loop(response.tool_calls)
if looped_tool is not None:
if not self._loop_corrected:
logger.warning(
f"Loop detected (stream): tool '{looped_tool}' repeated, "
f"injecting correction at step {step}"
)
correction_msg = {
"role": "user",
"content": (
f"You are repeatedly calling tool '{looped_tool}' "
f"with the same arguments. This indicates a loop. "
f"Please change your strategy or provide a final answer."
),
}
conversation.append(correction_msg)
self._loop_corrected = True
yield ReActEvent(
event_type="step",
step=step,
data={
"message": f"Loop detected: tool '{looped_tool}' repeated. Correction injected.",
"loop_detected": True,
"tool_name": looped_tool,
},
)
continue
else:
raise LoopDetectedError(
tool_name=looped_tool,
repetitions=self._loop_threshold + 1,
)
# 记录 LLM 调用步骤 # 记录 LLM 调用步骤
if trace_recorder is not None:
trace_recorder.record_step(
step=step,
action="llm_call",
duration_ms=llm_duration_ms,
tokens_used=step_tokens,
)
# Record assistant message # Record assistant message
assistant_msg: dict[str, Any] = { assistant_msg: dict[str, Any] = {

View File

@ -165,12 +165,15 @@ class TestReActMaxSteps:
tool = FakeTool(name="search", result={"results": ["data"]}) tool = FakeTool(name="search", result={"results": ["data"]})
# LLM 一直返回 tool_calls不会给出 final answer # LLM 一直返回 tool_calls参数递增以避免循环检测不会给出 final answer
always_tool_response = make_response( responses = [
content="Thinking...", make_response(
tool_calls=[ToolCall(id="tc_loop", name="search", arguments={"query": "more"})], content="Thinking...",
) tool_calls=[ToolCall(id=f"tc_{i}", name="search", arguments={"query": f"attempt_{i}"})],
gateway = make_mock_gateway([always_tool_response] * 20) )
for i in range(20)
]
gateway = make_mock_gateway(responses)
engine = ReActEngine(llm_gateway=gateway, max_steps=3) engine = ReActEngine(llm_gateway=gateway, max_steps=3)
result = await engine.execute( result = await engine.execute(
@ -653,3 +656,123 @@ class TestReActCancellation:
) )
assert result.output == "Answer" assert result.output == "Answer"
assert result.status == "success" assert result.status == "success"
class TestLoopDetection:
"""循环检测ReAct 循环内滑动窗口 hash 检测重复工具调用"""
async def test_normal_different_tools_no_detection(self):
"""不同工具调用不触发检测"""
from agentkit.core.react import ReActEngine
tool1 = FakeTool(name="search", result={"results": ["a"]})
tool2 = FakeTool(name="calculator", result={"value": 42})
gateway = make_mock_gateway([
make_response(
tool_calls=[ToolCall(id="tc_1", name="search", arguments={"q": "test"})],
),
make_response(
tool_calls=[ToolCall(id="tc_2", name="calculator", arguments={"expr": "6*7"})],
),
make_response(content="Done"),
])
engine = ReActEngine(llm_gateway=gateway)
result = await engine.execute(
messages=[{"role": "user", "content": "Search and calculate"}],
tools=[tool1, tool2],
)
assert result.status == "success"
assert result.total_steps == 3
async def test_same_tool_different_args_no_detection(self):
"""相同工具不同参数不触发检测"""
from agentkit.core.react import ReActEngine
tool = FakeTool(name="search", result={"results": []})
gateway = make_mock_gateway([
make_response(
tool_calls=[ToolCall(id="tc_1", name="search", arguments={"q": "hello"})],
),
make_response(
tool_calls=[ToolCall(id="tc_2", name="search", arguments={"q": "world"})],
),
make_response(content="Done"),
])
engine = ReActEngine(llm_gateway=gateway)
result = await engine.execute(
messages=[{"role": "user", "content": "Search twice"}],
tools=[tool],
)
assert result.status == "success"
assert result.total_steps == 3
async def test_loop_detected_injects_correction_then_raises(self):
"""连续重复调用相同工具+参数:第一次注入纠正,第二次抛 LoopDetectedError"""
from agentkit.core.react import ReActEngine
from agentkit.core.exceptions import LoopDetectedError
tool = FakeTool(name="search", result={"results": []})
# Step 1: tool call (executed, window=[hash])
# Step 2: same tool call (detected, correction injected, continue)
# Step 3: same tool call again (detected, already corrected → raise)
gateway = make_mock_gateway([
make_response(
tool_calls=[ToolCall(id="tc_1", name="search", arguments={"q": "test"})],
),
make_response(
tool_calls=[ToolCall(id="tc_2", name="search", arguments={"q": "test"})],
),
make_response(
tool_calls=[ToolCall(id="tc_3", name="search", arguments={"q": "test"})],
),
])
engine = ReActEngine(llm_gateway=gateway, max_steps=10)
with pytest.raises(LoopDetectedError) as exc_info:
await engine.execute(
messages=[{"role": "user", "content": "Search"}],
tools=[tool],
)
assert "search" in str(exc_info.value)
async def test_loop_correction_allows_recovery(self):
"""循环检测注入纠正后LLM 改变策略则正常完成"""
from agentkit.core.react import ReActEngine
tool = FakeTool(name="search", result={"results": []})
# Step 1: tool call (executed)
# Step 2: same tool call (detected, correction injected)
# Step 3: LLM changes strategy → final answer
gateway = make_mock_gateway([
make_response(
tool_calls=[ToolCall(id="tc_1", name="search", arguments={"q": "test"})],
),
make_response(
tool_calls=[ToolCall(id="tc_2", name="search", arguments={"q": "test"})],
),
make_response(content="I found the answer after changing strategy"),
])
engine = ReActEngine(llm_gateway=gateway, max_steps=10)
result = await engine.execute(
messages=[{"role": "user", "content": "Search"}],
tools=[tool],
)
assert result.status == "success"
assert "changing strategy" in result.output
async def test_reset_clears_loop_state(self):
"""reset() 清除循环检测状态"""
from agentkit.core.react import ReActEngine
gateway = make_mock_gateway([make_response(content="Done")])
engine = ReActEngine(llm_gateway=gateway)
engine._loop_window.append("some_hash")
engine._loop_corrected = True
engine.reset()
assert len(engine._loop_window) == 0
assert engine._loop_corrected is False