319 lines
11 KiB
Python
319 lines
11 KiB
Python
"""G7/U2 — Emergency layer rule template + TaskResult extension.
|
||
|
||
Verifies:
|
||
- EmergencyRules.classify maps each exception type to correct error_code
|
||
- TaskCancelledError raises ValueError (caller must propagate as-is)
|
||
- EmergencyError.to_dict produces all 5 fields
|
||
- EmergencyError.to_error_message formats suggestions as "建议:1) ... 2) ..."
|
||
- Config overrides apply (suggestions, retryable, message)
|
||
- TaskResult.error_struct field: default None preserves byte-for-byte
|
||
to_dict() output (backward compat)
|
||
- TaskResult round-trip serialization includes error_struct when set
|
||
"""
|
||
|
||
from datetime import datetime, timezone
|
||
|
||
import pytest
|
||
|
||
from agentkit.core.exceptions import (
|
||
LLMProviderError,
|
||
LoopDetectedError,
|
||
TaskCancelledError,
|
||
TaskTimeoutError,
|
||
)
|
||
from agentkit.core.fallback import (
|
||
EMPTY_LLM_RESPONSE,
|
||
MAX_STEPS_REACHED,
|
||
SHELL_NO_OUTPUT,
|
||
EmergencyError,
|
||
EmergencyRules,
|
||
)
|
||
from agentkit.core.protocol import TaskResult
|
||
|
||
|
||
# ── Constants unchanged (contract preservation) ──────
|
||
|
||
|
||
class TestExistingConstantsUnchanged:
|
||
"""Existing 3 constants preserved byte-for-byte."""
|
||
|
||
def test_empty_llm_response_unchanged(self):
|
||
assert "模型未返回有效内容" in EMPTY_LLM_RESPONSE
|
||
assert "建议" in EMPTY_LLM_RESPONSE
|
||
|
||
def test_max_steps_reached_unchanged(self):
|
||
assert "已达到最大推理步数" in MAX_STEPS_REACHED
|
||
|
||
def test_shell_no_output_unchanged(self):
|
||
assert SHELL_NO_OUTPUT == "[命令执行成功,无输出内容]"
|
||
|
||
|
||
# ── EmergencyRules.classify ──────────────────────────
|
||
|
||
|
||
class TestEmergencyRulesClassify:
|
||
"""classify() maps exception types to EmergencyError."""
|
||
|
||
def test_timeout(self):
|
||
exc = TaskTimeoutError(task_id="t1", timeout_seconds=30)
|
||
err = EmergencyRules.classify(exc)
|
||
assert err.error_code == "timeout"
|
||
assert err.retryable is True
|
||
assert "稍后重试" in err.suggestions
|
||
assert "简化任务范围" in err.suggestions
|
||
assert err.original_error == str(exc)
|
||
|
||
def test_loop_detected(self):
|
||
exc = LoopDetectedError(tool_name="shell", repetitions=3)
|
||
err = EmergencyRules.classify(exc)
|
||
assert err.error_code == "loop_detected"
|
||
assert err.retryable is True
|
||
assert "拆分任务" in err.suggestions
|
||
assert "检查工具参数" in err.suggestions
|
||
|
||
def test_llm_provider_error(self):
|
||
exc = LLMProviderError("openai", "rate limited")
|
||
err = EmergencyRules.classify(exc)
|
||
assert err.error_code == "llm_failure"
|
||
assert err.retryable is True
|
||
assert "稍后重试" in err.suggestions
|
||
assert "切换模型" in err.suggestions
|
||
|
||
def test_llm_error_subclass_also_classified(self):
|
||
"""LLMProviderError is a subclass of LLMError; ensure isinstance check works."""
|
||
from agentkit.core.exceptions import LLMError
|
||
|
||
class CustomLLMError(LLMError):
|
||
pass
|
||
|
||
err = EmergencyRules.classify(CustomLLMError("custom"))
|
||
# CustomLLMError is NOT a LLMProviderError, falls through to generic
|
||
assert err.error_code == "internal_error"
|
||
|
||
def test_generic_exception_internal_error(self):
|
||
err = EmergencyRules.classify(Exception("unknown boom"))
|
||
assert err.error_code == "internal_error"
|
||
assert err.retryable is False
|
||
assert "联系管理员" in err.suggestions
|
||
assert err.original_error == "unknown boom"
|
||
|
||
def test_task_cancelled_raises(self):
|
||
"""TaskCancelledError must propagate; classify() raises ValueError."""
|
||
exc = TaskCancelledError(task_id="t1")
|
||
with pytest.raises(ValueError, match="TaskCancelledError"):
|
||
EmergencyRules.classify(exc)
|
||
|
||
def test_subclass_of_timeout_classified(self):
|
||
"""Subclasses of TaskTimeoutError are classified as timeout."""
|
||
|
||
class CustomTimeout(TaskTimeoutError):
|
||
def __init__(self):
|
||
super().__init__(task_id="custom", timeout_seconds=10)
|
||
|
||
err = EmergencyRules.classify(CustomTimeout())
|
||
assert err.error_code == "timeout"
|
||
|
||
|
||
# ── EmergencyError serialization ─────────────────────
|
||
|
||
|
||
class TestEmergencyErrorSerialization:
|
||
"""to_dict / to_error_message on EmergencyError."""
|
||
|
||
def test_to_dict_produces_all_five_fields(self):
|
||
err = EmergencyError(
|
||
error_code="timeout",
|
||
message="任务执行超时。",
|
||
suggestions=["稍后重试", "简化任务范围"],
|
||
retryable=True,
|
||
original_error="Task t1 timed out after 30s",
|
||
)
|
||
d = err.to_dict()
|
||
assert set(d.keys()) == {
|
||
"error_code",
|
||
"message",
|
||
"suggestions",
|
||
"retryable",
|
||
"original_error",
|
||
}
|
||
assert d["error_code"] == "timeout"
|
||
assert d["message"] == "任务执行超时。"
|
||
assert d["suggestions"] == ["稍后重试", "简化任务范围"]
|
||
assert d["retryable"] is True
|
||
assert d["original_error"] == "Task t1 timed out after 30s"
|
||
|
||
def test_to_dict_suggestions_list_is_copy(self):
|
||
"""to_dict returns a fresh list, not the internal reference."""
|
||
suggestions = ["a", "b"]
|
||
err = EmergencyError(
|
||
error_code="x",
|
||
message="m",
|
||
suggestions=suggestions,
|
||
retryable=False,
|
||
original_error="e",
|
||
)
|
||
d = err.to_dict()
|
||
assert d["suggestions"] is not suggestions
|
||
d["suggestions"].append("c")
|
||
assert err.suggestions == ["a", "b"]
|
||
|
||
def test_to_error_message_with_suggestions(self):
|
||
err = EmergencyError(
|
||
error_code="timeout",
|
||
message="任务执行超时。",
|
||
suggestions=["稍后重试", "简化任务范围"],
|
||
retryable=True,
|
||
original_error="err",
|
||
)
|
||
msg = err.to_error_message()
|
||
assert msg.startswith("任务执行超时。建议:")
|
||
assert "1) 稍后重试" in msg
|
||
assert "2) 简化任务范围" in msg
|
||
# Format mirrors EMPTY_LLM_RESPONSE style
|
||
assert msg.endswith("。")
|
||
|
||
def test_to_error_message_no_suggestions(self):
|
||
err = EmergencyError(
|
||
error_code="x",
|
||
message="just a message",
|
||
suggestions=[],
|
||
retryable=False,
|
||
original_error="e",
|
||
)
|
||
assert err.to_error_message() == "just a message"
|
||
|
||
def test_to_error_message_single_suggestion(self):
|
||
err = EmergencyError(
|
||
error_code="x",
|
||
message="msg",
|
||
suggestions=["only one"],
|
||
retryable=False,
|
||
original_error="e",
|
||
)
|
||
msg = err.to_error_message()
|
||
assert msg == "msg建议:1) only one。"
|
||
|
||
|
||
# ── Config override ──────────────────────────────────
|
||
|
||
|
||
class TestConfigOverride:
|
||
"""classify() applies per-rule config overrides."""
|
||
|
||
def test_override_suggestions(self):
|
||
exc = TaskTimeoutError(task_id="t", timeout_seconds=1)
|
||
cfg = {"timeout": {"suggestions": ["自定义建议 A", "自定义建议 B"]}}
|
||
err = EmergencyRules.classify(exc, config=cfg)
|
||
assert err.suggestions == ["自定义建议 A", "自定义建议 B"]
|
||
assert err.error_code == "timeout"
|
||
|
||
def test_override_retryable(self):
|
||
exc = LLMProviderError("openai", "boom")
|
||
cfg = {"llm_failure": {"retryable": False}}
|
||
err = EmergencyRules.classify(exc, config=cfg)
|
||
assert err.retryable is False
|
||
|
||
def test_override_message(self):
|
||
exc = LoopDetectedError(tool_name="x", repetitions=2)
|
||
cfg = {"loop_detected": {"message": "循环啦!"}}
|
||
err = EmergencyRules.classify(exc, config=cfg)
|
||
assert err.message == "循环啦!"
|
||
|
||
def test_override_internal_error_rule(self):
|
||
cfg = {"internal_error": {"suggestions": ["联系客服"]}}
|
||
err = EmergencyRules.classify(Exception("boom"), config=cfg)
|
||
assert err.error_code == "internal_error"
|
||
assert err.suggestions == ["联系客服"]
|
||
|
||
def test_config_none_uses_defaults(self):
|
||
err = EmergencyRules.classify(TaskTimeoutError(task_id="t", timeout_seconds=1))
|
||
assert err.error_code == "timeout"
|
||
assert err.retryable is True
|
||
|
||
def test_config_empty_dict_uses_defaults(self):
|
||
err = EmergencyRules.classify(
|
||
TaskTimeoutError(task_id="t", timeout_seconds=1), config={}
|
||
)
|
||
assert err.error_code == "timeout"
|
||
assert err.retryable is True
|
||
|
||
|
||
# ── TaskResult.error_struct extension ────────────────
|
||
|
||
|
||
def _make_task_result(
|
||
error_struct: dict | None = None, error_message: str | None = None
|
||
) -> TaskResult:
|
||
now = datetime.now(timezone.utc)
|
||
return TaskResult(
|
||
task_id="t1",
|
||
agent_name="a1",
|
||
status="completed",
|
||
output_data={"k": "v"},
|
||
error_message=error_message,
|
||
started_at=now,
|
||
completed_at=now,
|
||
metrics={"m": 1},
|
||
error_struct=error_struct,
|
||
)
|
||
|
||
|
||
class TestTaskResultErrorStruct:
|
||
"""TaskResult.error_struct field — backward-compatible extension."""
|
||
|
||
def test_default_error_struct_is_none(self):
|
||
tr = _make_task_result()
|
||
assert tr.error_struct is None
|
||
|
||
def test_to_dict_without_error_struct_preserves_existing_shape(self):
|
||
"""error_struct=None → to_dict() output has NO error_struct key (byte-for-byte)."""
|
||
tr = _make_task_result()
|
||
d = tr.to_dict()
|
||
assert "error_struct" not in d
|
||
# Existing keys unchanged
|
||
assert set(d.keys()) == {
|
||
"task_id",
|
||
"agent_name",
|
||
"status",
|
||
"output_data",
|
||
"error_message",
|
||
"started_at",
|
||
"completed_at",
|
||
"metrics",
|
||
}
|
||
|
||
def test_to_dict_with_error_struct_includes_key(self):
|
||
struct = {
|
||
"error_code": "timeout",
|
||
"message": "超时",
|
||
"suggestions": ["重试"],
|
||
"retryable": True,
|
||
"original_error": "boom",
|
||
}
|
||
tr = _make_task_result(error_struct=struct, error_message="超时建议:1) 重试。")
|
||
d = tr.to_dict()
|
||
assert d["error_struct"] == struct
|
||
assert d["error_message"] == "超时建议:1) 重试。"
|
||
|
||
def test_from_dict_round_trip_with_error_struct(self):
|
||
struct = {"error_code": "loop_detected", "message": "m", "suggestions": [], "retryable": True, "original_error": "e"}
|
||
tr = _make_task_result(error_struct=struct)
|
||
d = tr.to_dict()
|
||
restored = TaskResult.from_dict(d)
|
||
assert restored.error_struct == struct
|
||
|
||
def test_from_dict_without_error_struct_defaults_none(self):
|
||
tr = _make_task_result()
|
||
d = tr.to_dict()
|
||
# Simulate legacy data without error_struct key
|
||
restored = TaskResult.from_dict(d)
|
||
assert restored.error_struct is None
|
||
|
||
def test_error_message_and_error_struct_coexist(self):
|
||
"""Both fields can be set simultaneously (parallel contract per KTD2)."""
|
||
struct = {"error_code": "timeout", "message": "超时", "suggestions": ["重试"], "retryable": True, "original_error": "err"}
|
||
tr = _make_task_result(error_struct=struct, error_message="超时建议:1) 重试。")
|
||
d = tr.to_dict()
|
||
assert d["error_message"] == "超时建议:1) 重试。"
|
||
assert d["error_struct"] == struct
|