feat(react): add loop detection to prevent repeated identical tool calls
U1: Sliding window hash detection in ReAct loop. When the same tool is called with identical arguments >= threshold times (default 2), injects a correction message first, then raises LoopDetectedError if the LLM doesn't change strategy. Covers both _execute_loop and execute_stream.
This commit is contained in:
parent
a312e584ae
commit
018b342d96
|
|
@ -60,6 +60,22 @@ class TaskCancelledError(AgentFrameworkError):
|
||||||
super().__init__(f"Task {task_id} was cancelled")
|
super().__init__(f"Task {task_id} was cancelled")
|
||||||
|
|
||||||
|
|
||||||
|
class LoopDetectedError(AgentFrameworkError):
|
||||||
|
"""ReAct 循环检测异常 — LLM 重复调用相同工具+参数且纠正后未改变策略。
|
||||||
|
|
||||||
|
ponytail: 滑动窗口 hash 检测,窗口大小和阈值可配置。
|
||||||
|
升级路径:可引入语义相似度检测(embedding 距离)替代精确 hash。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, tool_name: str, repetitions: int):
|
||||||
|
self.tool_name = tool_name
|
||||||
|
self.repetitions = repetitions
|
||||||
|
super().__init__(
|
||||||
|
f"Loop detected: tool '{tool_name}' called {repetitions} times "
|
||||||
|
f"with identical arguments after correction"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class NoAvailableAgentError(AgentFrameworkError):
|
class NoAvailableAgentError(AgentFrameworkError):
|
||||||
def __init__(self, task_type: str):
|
def __init__(self, task_type: str):
|
||||||
self.task_type = task_type
|
self.task_type = task_type
|
||||||
|
|
|
||||||
|
|
@ -9,11 +9,12 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
|
from collections import Counter, deque
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from agentkit.core.exceptions import TaskCancelledError, TaskTimeoutError
|
from agentkit.core.exceptions import LoopDetectedError, TaskCancelledError, TaskTimeoutError
|
||||||
from agentkit.core.protocol import CancellationToken
|
from agentkit.core.protocol import CancellationToken
|
||||||
from agentkit.llm.gateway import LLMGateway
|
from agentkit.llm.gateway import LLMGateway
|
||||||
from agentkit.llm.protocol import LLMResponse
|
from agentkit.llm.protocol import LLMResponse
|
||||||
|
|
@ -187,6 +188,12 @@ class ReActEngine:
|
||||||
|
|
||||||
self._compressor = ContextCompressor(llm_gateway=llm_gateway, keep_recent=10)
|
self._compressor = ContextCompressor(llm_gateway=llm_gateway, keep_recent=10)
|
||||||
|
|
||||||
|
# Loop detection: sliding window of (tool_name, args_hash) to catch
|
||||||
|
# repeated identical tool calls. ponytail: hash-based, not semantic.
|
||||||
|
self._loop_window: deque[str] = deque(maxlen=5)
|
||||||
|
self._loop_threshold: int = 2
|
||||||
|
self._loop_corrected: bool = False
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
"""Reset internal state for reuse across conversations.
|
"""Reset internal state for reuse across conversations.
|
||||||
|
|
||||||
|
|
@ -196,7 +203,31 @@ class ReActEngine:
|
||||||
# ReActEngine is stateless between calls — conversation history,
|
# ReActEngine is stateless between calls — conversation history,
|
||||||
# step counts, and trajectory are local to each execute call.
|
# step counts, and trajectory are local to each execute call.
|
||||||
# This method exists for API clarity and future stateful extensions.
|
# This method exists for API clarity and future stateful extensions.
|
||||||
pass
|
self._loop_window.clear()
|
||||||
|
self._loop_corrected = False
|
||||||
|
|
||||||
|
def _check_tool_loop(self, tool_calls: list[Any]) -> str | None:
|
||||||
|
"""检测重复工具调用模式。
|
||||||
|
|
||||||
|
将当前步的工具调用 hash 加入滑动窗口,若同一 hash 在窗口内出现
|
||||||
|
>= threshold 次,返回对应的 tool_name;否则返回 None。
|
||||||
|
|
||||||
|
ponytail: 精确 hash 匹配,不做语义相似度。
|
||||||
|
"""
|
||||||
|
for tc in tool_calls:
|
||||||
|
args_str = json.dumps(tc.arguments, sort_keys=True, default=str)
|
||||||
|
h = hash(f"{tc.name}:{args_str}")
|
||||||
|
self._loop_window.append(str(h))
|
||||||
|
|
||||||
|
counts = Counter(self._loop_window)
|
||||||
|
for h, count in counts.items():
|
||||||
|
if count >= self._loop_threshold:
|
||||||
|
# Find the tool name for this hash
|
||||||
|
for tc in tool_calls:
|
||||||
|
args_str = json.dumps(tc.arguments, sort_keys=True, default=str)
|
||||||
|
if str(hash(f"{tc.name}:{args_str}")) == h:
|
||||||
|
return tc.name
|
||||||
|
return None
|
||||||
|
|
||||||
async def execute(
|
async def execute(
|
||||||
self,
|
self,
|
||||||
|
|
@ -407,6 +438,33 @@ class ReActEngine:
|
||||||
|
|
||||||
# 检查是否有 Function Calling 的 tool_calls
|
# 检查是否有 Function Calling 的 tool_calls
|
||||||
if response.has_tool_calls:
|
if response.has_tool_calls:
|
||||||
|
# 循环检测:检查是否重复调用相同工具+参数
|
||||||
|
looped_tool = self._check_tool_loop(response.tool_calls)
|
||||||
|
if looped_tool is not None:
|
||||||
|
if not self._loop_corrected:
|
||||||
|
# 第一次检测:注入纠正消息,给 LLM 改变策略的机会
|
||||||
|
logger.warning(
|
||||||
|
f"Loop detected: tool '{looped_tool}' repeated, "
|
||||||
|
f"injecting correction at step {step}"
|
||||||
|
)
|
||||||
|
correction_msg = {
|
||||||
|
"role": "user",
|
||||||
|
"content": (
|
||||||
|
f"You are repeatedly calling tool '{looped_tool}' "
|
||||||
|
f"with the same arguments. This indicates a loop. "
|
||||||
|
f"Please change your strategy or provide a final answer."
|
||||||
|
),
|
||||||
|
}
|
||||||
|
conversation.append(correction_msg)
|
||||||
|
self._loop_corrected = True
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
# 第二次检测:纠正后仍未改变,强制中断
|
||||||
|
raise LoopDetectedError(
|
||||||
|
tool_name=looped_tool,
|
||||||
|
repetitions=self._loop_threshold + 1,
|
||||||
|
)
|
||||||
|
|
||||||
# 记录 LLM 调用步骤
|
# 记录 LLM 调用步骤
|
||||||
if trace_recorder is not None:
|
if trace_recorder is not None:
|
||||||
trace_recorder.record_step(
|
trace_recorder.record_step(
|
||||||
|
|
@ -1014,14 +1072,41 @@ class ReActEngine:
|
||||||
total_tokens += step_tokens
|
total_tokens += step_tokens
|
||||||
|
|
||||||
if response.has_tool_calls:
|
if response.has_tool_calls:
|
||||||
|
# 循环检测:检查是否重复调用相同工具+参数
|
||||||
|
looped_tool = self._check_tool_loop(response.tool_calls)
|
||||||
|
if looped_tool is not None:
|
||||||
|
if not self._loop_corrected:
|
||||||
|
logger.warning(
|
||||||
|
f"Loop detected (stream): tool '{looped_tool}' repeated, "
|
||||||
|
f"injecting correction at step {step}"
|
||||||
|
)
|
||||||
|
correction_msg = {
|
||||||
|
"role": "user",
|
||||||
|
"content": (
|
||||||
|
f"You are repeatedly calling tool '{looped_tool}' "
|
||||||
|
f"with the same arguments. This indicates a loop. "
|
||||||
|
f"Please change your strategy or provide a final answer."
|
||||||
|
),
|
||||||
|
}
|
||||||
|
conversation.append(correction_msg)
|
||||||
|
self._loop_corrected = True
|
||||||
|
yield ReActEvent(
|
||||||
|
event_type="step",
|
||||||
|
step=step,
|
||||||
|
data={
|
||||||
|
"message": f"Loop detected: tool '{looped_tool}' repeated. Correction injected.",
|
||||||
|
"loop_detected": True,
|
||||||
|
"tool_name": looped_tool,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
raise LoopDetectedError(
|
||||||
|
tool_name=looped_tool,
|
||||||
|
repetitions=self._loop_threshold + 1,
|
||||||
|
)
|
||||||
|
|
||||||
# 记录 LLM 调用步骤
|
# 记录 LLM 调用步骤
|
||||||
if trace_recorder is not None:
|
|
||||||
trace_recorder.record_step(
|
|
||||||
step=step,
|
|
||||||
action="llm_call",
|
|
||||||
duration_ms=llm_duration_ms,
|
|
||||||
tokens_used=step_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Record assistant message
|
# Record assistant message
|
||||||
assistant_msg: dict[str, Any] = {
|
assistant_msg: dict[str, Any] = {
|
||||||
|
|
|
||||||
|
|
@ -165,12 +165,15 @@ class TestReActMaxSteps:
|
||||||
|
|
||||||
tool = FakeTool(name="search", result={"results": ["data"]})
|
tool = FakeTool(name="search", result={"results": ["data"]})
|
||||||
|
|
||||||
# LLM 一直返回 tool_calls,不会给出 final answer
|
# LLM 一直返回 tool_calls(参数递增以避免循环检测),不会给出 final answer
|
||||||
always_tool_response = make_response(
|
responses = [
|
||||||
content="Thinking...",
|
make_response(
|
||||||
tool_calls=[ToolCall(id="tc_loop", name="search", arguments={"query": "more"})],
|
content="Thinking...",
|
||||||
)
|
tool_calls=[ToolCall(id=f"tc_{i}", name="search", arguments={"query": f"attempt_{i}"})],
|
||||||
gateway = make_mock_gateway([always_tool_response] * 20)
|
)
|
||||||
|
for i in range(20)
|
||||||
|
]
|
||||||
|
gateway = make_mock_gateway(responses)
|
||||||
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
|
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
|
||||||
|
|
||||||
result = await engine.execute(
|
result = await engine.execute(
|
||||||
|
|
@ -653,3 +656,123 @@ class TestReActCancellation:
|
||||||
)
|
)
|
||||||
assert result.output == "Answer"
|
assert result.output == "Answer"
|
||||||
assert result.status == "success"
|
assert result.status == "success"
|
||||||
|
|
||||||
|
|
||||||
|
class TestLoopDetection:
|
||||||
|
"""循环检测:ReAct 循环内滑动窗口 hash 检测重复工具调用"""
|
||||||
|
|
||||||
|
async def test_normal_different_tools_no_detection(self):
|
||||||
|
"""不同工具调用不触发检测"""
|
||||||
|
from agentkit.core.react import ReActEngine
|
||||||
|
|
||||||
|
tool1 = FakeTool(name="search", result={"results": ["a"]})
|
||||||
|
tool2 = FakeTool(name="calculator", result={"value": 42})
|
||||||
|
gateway = make_mock_gateway([
|
||||||
|
make_response(
|
||||||
|
tool_calls=[ToolCall(id="tc_1", name="search", arguments={"q": "test"})],
|
||||||
|
),
|
||||||
|
make_response(
|
||||||
|
tool_calls=[ToolCall(id="tc_2", name="calculator", arguments={"expr": "6*7"})],
|
||||||
|
),
|
||||||
|
make_response(content="Done"),
|
||||||
|
])
|
||||||
|
engine = ReActEngine(llm_gateway=gateway)
|
||||||
|
|
||||||
|
result = await engine.execute(
|
||||||
|
messages=[{"role": "user", "content": "Search and calculate"}],
|
||||||
|
tools=[tool1, tool2],
|
||||||
|
)
|
||||||
|
assert result.status == "success"
|
||||||
|
assert result.total_steps == 3
|
||||||
|
|
||||||
|
async def test_same_tool_different_args_no_detection(self):
|
||||||
|
"""相同工具不同参数不触发检测"""
|
||||||
|
from agentkit.core.react import ReActEngine
|
||||||
|
|
||||||
|
tool = FakeTool(name="search", result={"results": []})
|
||||||
|
gateway = make_mock_gateway([
|
||||||
|
make_response(
|
||||||
|
tool_calls=[ToolCall(id="tc_1", name="search", arguments={"q": "hello"})],
|
||||||
|
),
|
||||||
|
make_response(
|
||||||
|
tool_calls=[ToolCall(id="tc_2", name="search", arguments={"q": "world"})],
|
||||||
|
),
|
||||||
|
make_response(content="Done"),
|
||||||
|
])
|
||||||
|
engine = ReActEngine(llm_gateway=gateway)
|
||||||
|
|
||||||
|
result = await engine.execute(
|
||||||
|
messages=[{"role": "user", "content": "Search twice"}],
|
||||||
|
tools=[tool],
|
||||||
|
)
|
||||||
|
assert result.status == "success"
|
||||||
|
assert result.total_steps == 3
|
||||||
|
|
||||||
|
async def test_loop_detected_injects_correction_then_raises(self):
|
||||||
|
"""连续重复调用相同工具+参数:第一次注入纠正,第二次抛 LoopDetectedError"""
|
||||||
|
from agentkit.core.react import ReActEngine
|
||||||
|
from agentkit.core.exceptions import LoopDetectedError
|
||||||
|
|
||||||
|
tool = FakeTool(name="search", result={"results": []})
|
||||||
|
# Step 1: tool call (executed, window=[hash])
|
||||||
|
# Step 2: same tool call (detected, correction injected, continue)
|
||||||
|
# Step 3: same tool call again (detected, already corrected → raise)
|
||||||
|
gateway = make_mock_gateway([
|
||||||
|
make_response(
|
||||||
|
tool_calls=[ToolCall(id="tc_1", name="search", arguments={"q": "test"})],
|
||||||
|
),
|
||||||
|
make_response(
|
||||||
|
tool_calls=[ToolCall(id="tc_2", name="search", arguments={"q": "test"})],
|
||||||
|
),
|
||||||
|
make_response(
|
||||||
|
tool_calls=[ToolCall(id="tc_3", name="search", arguments={"q": "test"})],
|
||||||
|
),
|
||||||
|
])
|
||||||
|
engine = ReActEngine(llm_gateway=gateway, max_steps=10)
|
||||||
|
|
||||||
|
with pytest.raises(LoopDetectedError) as exc_info:
|
||||||
|
await engine.execute(
|
||||||
|
messages=[{"role": "user", "content": "Search"}],
|
||||||
|
tools=[tool],
|
||||||
|
)
|
||||||
|
assert "search" in str(exc_info.value)
|
||||||
|
|
||||||
|
async def test_loop_correction_allows_recovery(self):
|
||||||
|
"""循环检测注入纠正后,LLM 改变策略则正常完成"""
|
||||||
|
from agentkit.core.react import ReActEngine
|
||||||
|
|
||||||
|
tool = FakeTool(name="search", result={"results": []})
|
||||||
|
# Step 1: tool call (executed)
|
||||||
|
# Step 2: same tool call (detected, correction injected)
|
||||||
|
# Step 3: LLM changes strategy → final answer
|
||||||
|
gateway = make_mock_gateway([
|
||||||
|
make_response(
|
||||||
|
tool_calls=[ToolCall(id="tc_1", name="search", arguments={"q": "test"})],
|
||||||
|
),
|
||||||
|
make_response(
|
||||||
|
tool_calls=[ToolCall(id="tc_2", name="search", arguments={"q": "test"})],
|
||||||
|
),
|
||||||
|
make_response(content="I found the answer after changing strategy"),
|
||||||
|
])
|
||||||
|
engine = ReActEngine(llm_gateway=gateway, max_steps=10)
|
||||||
|
|
||||||
|
result = await engine.execute(
|
||||||
|
messages=[{"role": "user", "content": "Search"}],
|
||||||
|
tools=[tool],
|
||||||
|
)
|
||||||
|
assert result.status == "success"
|
||||||
|
assert "changing strategy" in result.output
|
||||||
|
|
||||||
|
async def test_reset_clears_loop_state(self):
|
||||||
|
"""reset() 清除循环检测状态"""
|
||||||
|
from agentkit.core.react import ReActEngine
|
||||||
|
|
||||||
|
gateway = make_mock_gateway([make_response(content="Done")])
|
||||||
|
engine = ReActEngine(llm_gateway=gateway)
|
||||||
|
engine._loop_window.append("some_hash")
|
||||||
|
engine._loop_corrected = True
|
||||||
|
|
||||||
|
engine.reset()
|
||||||
|
|
||||||
|
assert len(engine._loop_window) == 0
|
||||||
|
assert engine._loop_corrected is False
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue