277 lines
10 KiB
Python
277 lines
10 KiB
Python
"""Unit tests for verification defaults (U3, R2/R3) + sandbox integration.
|
|
|
|
Covers:
|
|
- default_policy(workspace_root) — coding-task detection sets verification_commands
|
|
- PhasePolicy.verification_commands field — default None, to_dict() round-trip
|
|
- PlanExecEngine — verification_enabled defaults True (R2), thread-through
|
|
- TeamOrchestrator — verification_enabled defaults True (R2)
|
|
- ReActEngine — verification_commands inherited from phase_policy; default
|
|
verification_enabled stays False (RV2 — DIRECT_CHAT/REACT do not verify)
|
|
- ReActEngine._execute_tool — sandbox blocks network during VERIFICATION,
|
|
no block in other phases or when sandbox is None
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from pathlib import Path
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
from agentkit.core.phase import PhasePolicy, PhaseState, WILDCARD, default_policy
|
|
from agentkit.core.plan_exec_engine import PlanExecEngine, ReActStepExecutor
|
|
from agentkit.core.react import ReActEngine
|
|
from agentkit.core.sandbox import WorkspaceSandbox
|
|
from agentkit.tools.base import Tool
|
|
|
|
|
|
# ── helpers ───────────────────────────────────────────────────────────
|
|
|
|
|
|
def make_mock_gateway() -> MagicMock:
|
|
"""A minimal mock LLMGateway for ReActEngine construction."""
|
|
from agentkit.llm.gateway import LLMGateway
|
|
|
|
gateway = MagicMock(spec=LLMGateway)
|
|
gateway.chat = AsyncMock(return_value=MagicMock())
|
|
return gateway
|
|
|
|
|
|
class _NetworkTool(Tool):
|
|
"""A test tool that attempts a socket connect — used to verify the sandbox
|
|
network block is active during VERIFICATION.
|
|
|
|
Catches ``OSError`` (e.g. ``ConnectionRefusedError``) so that when the
|
|
sandbox is NOT active, the tool returns a normal result dict. When the
|
|
sandbox IS active, ``SandboxNetworkBlockedError`` (a ``RuntimeError``,
|
|
not an ``OSError``) propagates past this catch to ``_execute_tool``'s
|
|
dedicated handler.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__(
|
|
name="net_tool",
|
|
description="test tool that connects a socket",
|
|
input_schema={"type": "object", "properties": {}, "additionalProperties": False},
|
|
)
|
|
|
|
async def execute(self, **kwargs) -> dict[str, object]:
|
|
import socket
|
|
|
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
try:
|
|
sock.connect(("127.0.0.1", 1))
|
|
except OSError as e:
|
|
# Normal connection refusal (no listener) — proves the sandbox
|
|
# did NOT intercept the connect.
|
|
return {"ok": False, "error": type(e).__name__}
|
|
finally:
|
|
sock.close()
|
|
return {"ok": True}
|
|
|
|
|
|
# ── default_policy + PhasePolicy.verification_commands ────────────────
|
|
|
|
|
|
def test_default_policy_no_workspace_has_none_commands() -> None:
|
|
policy = default_policy()
|
|
assert policy.verification_commands is None
|
|
|
|
|
|
def test_default_policy_coding_workspace_forces_pytest_ruff(tmp_path: Path) -> None:
|
|
(tmp_path / "pyproject.toml").write_text("[project]\nname='x'\n")
|
|
policy = default_policy(workspace_root=tmp_path)
|
|
assert policy.verification_commands == ["pytest -x -q", "ruff check src/"]
|
|
|
|
|
|
def test_default_policy_non_coding_workspace_has_none_commands(tmp_path: Path) -> None:
|
|
(tmp_path / "README.md").write_text("# docs only")
|
|
policy = default_policy(workspace_root=tmp_path)
|
|
assert policy.verification_commands is None
|
|
|
|
|
|
def test_default_policy_empty_workspace_has_none_commands(tmp_path: Path) -> None:
|
|
policy = default_policy(workspace_root=tmp_path)
|
|
assert policy.verification_commands is None
|
|
|
|
|
|
def test_phase_policy_verification_commands_defaults_none() -> None:
|
|
policy = PhasePolicy(
|
|
whitelist={PhaseState.PLANNING: frozenset({WILDCARD})},
|
|
)
|
|
assert policy.verification_commands is None
|
|
|
|
|
|
def test_phase_policy_to_dict_includes_verification_commands() -> None:
|
|
policy = PhasePolicy(
|
|
whitelist={PhaseState.PLANNING: frozenset({WILDCARD})},
|
|
verification_commands=["pytest -x -q"],
|
|
)
|
|
d = policy.to_dict()
|
|
assert d["verification_commands"] == ["pytest -x -q"]
|
|
|
|
|
|
# ── PlanExecEngine defaults (R2) ──────────────────────────────────────
|
|
|
|
|
|
def test_plan_exec_engine_verification_enabled_defaults_true() -> None:
|
|
engine = PlanExecEngine(llm_gateway=None)
|
|
assert engine._verification_enabled is True
|
|
|
|
|
|
def test_plan_exec_engine_verification_enabled_can_be_disabled() -> None:
|
|
engine = PlanExecEngine(llm_gateway=None, verification_enabled=False)
|
|
assert engine._verification_enabled is False
|
|
|
|
|
|
def test_plan_exec_engine_verification_commands_threaded() -> None:
|
|
cmds = ["pytest -x -q", "ruff check src/"]
|
|
engine = PlanExecEngine(llm_gateway=None, verification_commands=cmds)
|
|
assert engine._verification_commands == cmds
|
|
|
|
|
|
def test_react_step_executor_threads_verification_params() -> None:
|
|
executor = ReActStepExecutor(
|
|
verification_enabled=True,
|
|
verification_commands=["pytest"],
|
|
)
|
|
assert executor._verification_enabled is True
|
|
assert executor._verification_commands == ["pytest"]
|
|
|
|
|
|
# ── TeamOrchestrator defaults (R2) ────────────────────────────────────
|
|
|
|
|
|
def test_team_orchestrator_verification_enabled_defaults_true() -> None:
|
|
from agentkit.experts.orchestrator import TeamOrchestrator
|
|
from agentkit.experts.team import ExpertTeam
|
|
|
|
team = MagicMock(spec=ExpertTeam)
|
|
orch = TeamOrchestrator(team=team)
|
|
assert orch._verification_enabled is True
|
|
|
|
|
|
def test_team_orchestrator_verification_can_be_disabled() -> None:
|
|
from agentkit.experts.orchestrator import TeamOrchestrator
|
|
from agentkit.experts.team import ExpertTeam
|
|
|
|
team = MagicMock(spec=ExpertTeam)
|
|
orch = TeamOrchestrator(team=team, verification_enabled=False)
|
|
assert orch._verification_enabled is False
|
|
|
|
|
|
def test_team_orchestrator_detects_commands_from_workspace(tmp_path: Path) -> None:
|
|
from agentkit.experts.orchestrator import TeamOrchestrator
|
|
from agentkit.experts.team import ExpertTeam
|
|
|
|
(tmp_path / "pyproject.toml").write_text("[project]\nname='x'\n")
|
|
team = MagicMock(spec=ExpertTeam)
|
|
orch = TeamOrchestrator(team=team, workspace_root=str(tmp_path))
|
|
assert orch._verification_commands == ["pytest -x -q", "ruff check src/"]
|
|
|
|
|
|
# ── ReActEngine: verification_commands inheritance + default (RV2) ────
|
|
|
|
|
|
def test_react_engine_default_verification_enabled_stays_false() -> None:
|
|
"""RV2: DIRECT_CHAT/REACT do not verify by default."""
|
|
engine = ReActEngine(llm_gateway=make_mock_gateway())
|
|
assert engine._verification_enabled is False
|
|
|
|
|
|
def test_react_engine_inherits_verification_commands_from_phase_policy() -> None:
|
|
policy = PhasePolicy(
|
|
whitelist={PhaseState.PLANNING: frozenset({WILDCARD})},
|
|
verification_commands=["pytest -x -q", "ruff check src/"],
|
|
)
|
|
engine = ReActEngine(
|
|
llm_gateway=make_mock_gateway(),
|
|
phase_policy=policy,
|
|
)
|
|
assert engine._verification_commands == ["pytest -x -q", "ruff check src/"]
|
|
|
|
|
|
def test_react_engine_explicit_commands_override_phase_policy() -> None:
|
|
policy = PhasePolicy(
|
|
whitelist={PhaseState.PLANNING: frozenset({WILDCARD})},
|
|
verification_commands=["pytest -x -q", "ruff check src/"],
|
|
)
|
|
engine = ReActEngine(
|
|
llm_gateway=make_mock_gateway(),
|
|
phase_policy=policy,
|
|
verification_commands=["echo custom"],
|
|
)
|
|
assert engine._verification_commands == ["echo custom"]
|
|
|
|
|
|
def test_react_engine_no_policy_no_commands() -> None:
|
|
engine = ReActEngine(llm_gateway=make_mock_gateway())
|
|
assert engine._verification_commands is None
|
|
|
|
|
|
# ── ReActEngine._execute_tool sandbox integration (RV3) ───────────────
|
|
|
|
|
|
async def test_execute_tool_blocks_network_in_verification_phase() -> None:
|
|
"""Sandbox blocks a tool's network call during VERIFICATION phase and
|
|
returns a structured error instead of raising."""
|
|
policy = PhasePolicy(
|
|
whitelist={
|
|
PhaseState.VERIFICATION: frozenset({"net_tool"}),
|
|
PhaseState.PLANNING: frozenset({WILDCARD}),
|
|
},
|
|
start_phase=PhaseState.VERIFICATION,
|
|
)
|
|
sandbox = WorkspaceSandbox(workspace_root=Path("/tmp"))
|
|
engine = ReActEngine(
|
|
llm_gateway=make_mock_gateway(),
|
|
phase_policy=policy,
|
|
sandbox=sandbox,
|
|
)
|
|
tool = _NetworkTool()
|
|
result = await engine._execute_tool("net_tool", {}, [tool])
|
|
assert result["error_code"] == "sandbox_network_blocked"
|
|
assert result["current_phase"] == "verification"
|
|
assert result["tool"] == "net_tool"
|
|
|
|
|
|
async def test_execute_tool_no_block_outside_verification() -> None:
|
|
"""Sandbox does not block tool calls in non-VERIFICATION phases."""
|
|
policy = PhasePolicy(
|
|
whitelist={
|
|
PhaseState.PLANNING: frozenset({"net_tool"}),
|
|
PhaseState.VERIFICATION: frozenset({WILDCARD}),
|
|
},
|
|
start_phase=PhaseState.PLANNING,
|
|
)
|
|
sandbox = WorkspaceSandbox(workspace_root=Path("/tmp"))
|
|
engine = ReActEngine(
|
|
llm_gateway=make_mock_gateway(),
|
|
phase_policy=policy,
|
|
sandbox=sandbox,
|
|
)
|
|
tool = _NetworkTool()
|
|
# In PLANNING phase, the tool should attempt the connect and fail with
|
|
# a connection error (not sandbox block). The connect to port 1 on
|
|
# localhost will fail with ECONNREFUSED — we just assert it's NOT the
|
|
# sandbox error code.
|
|
result = await engine._execute_tool("net_tool", {}, [tool])
|
|
assert result.get("error_code") != "sandbox_network_blocked"
|
|
|
|
|
|
async def test_execute_tool_no_sandbox_no_block() -> None:
|
|
"""No sandbox configured → no network blocking even in VERIFICATION."""
|
|
policy = PhasePolicy(
|
|
whitelist={
|
|
PhaseState.VERIFICATION: frozenset({"net_tool"}),
|
|
PhaseState.PLANNING: frozenset({WILDCARD}),
|
|
},
|
|
start_phase=PhaseState.VERIFICATION,
|
|
)
|
|
engine = ReActEngine(
|
|
llm_gateway=make_mock_gateway(),
|
|
phase_policy=policy,
|
|
sandbox=None,
|
|
)
|
|
tool = _NetworkTool()
|
|
result = await engine._execute_tool("net_tool", {}, [tool])
|
|
assert result.get("error_code") != "sandbox_network_blocked"
|