269 lines
9.3 KiB
Python
269 lines
9.3 KiB
Python
"""Tests for compression config integration (U4)
|
|
|
|
Covers:
|
|
1. ServerConfig.compression field
|
|
2. create_app compression setup
|
|
3. ConfigDrivenAgent compressor passthrough
|
|
"""
|
|
|
|
import os
|
|
import tempfile
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from agentkit.server.config import ServerConfig
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# 1. ServerConfig.compression
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestServerConfigCompression:
|
|
"""Test compression field on ServerConfig"""
|
|
|
|
def test_default_compression_is_empty_dict(self):
|
|
config = ServerConfig()
|
|
assert config.compression == {}
|
|
|
|
def test_compression_from_dict(self):
|
|
data = {
|
|
"compression": {
|
|
"enabled": True,
|
|
"provider": "headroom",
|
|
"compressors": ["smart_crusher"],
|
|
}
|
|
}
|
|
config = ServerConfig.from_dict(data)
|
|
assert config.compression["enabled"] is True
|
|
assert config.compression["provider"] == "headroom"
|
|
assert config.compression["compressors"] == ["smart_crusher"]
|
|
|
|
def test_compression_none_when_not_in_yaml(self):
|
|
data = {"server": {"host": "0.0.0.0"}}
|
|
config = ServerConfig.from_dict(data)
|
|
assert config.compression == {}
|
|
|
|
def test_compression_hot_reload(self):
|
|
"""_try_reload_config should update compression"""
|
|
with tempfile.NamedTemporaryFile(
|
|
mode="w", suffix=".yaml", delete=False
|
|
) as f:
|
|
f.write(
|
|
"server:\n host: 0.0.0.0\n port: 8001\n"
|
|
"compression:\n enabled: false\n"
|
|
)
|
|
f.flush()
|
|
config = ServerConfig.from_yaml(f.name)
|
|
assert config.compression == {"enabled": False}
|
|
|
|
# Write new content
|
|
f2 = open(f.name, "w")
|
|
f2.write(
|
|
"server:\n host: 0.0.0.0\n port: 8001\n"
|
|
"compression:\n enabled: true\n provider: headroom\n"
|
|
)
|
|
f2.close()
|
|
|
|
config._try_reload_config(f.name)
|
|
assert config.compression["enabled"] is True
|
|
assert config.compression["provider"] == "headroom"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# 2. create_app compression setup
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestCreateAppCompression:
|
|
"""Test compression setup in create_app"""
|
|
|
|
def test_compressor_created_when_enabled(self):
|
|
from agentkit.server.app import create_app
|
|
|
|
with patch("agentkit.core.compressor.create_compressor") as mock_create:
|
|
mock_compressor = MagicMock()
|
|
mock_create.return_value = mock_compressor
|
|
|
|
server_config = ServerConfig(
|
|
compression={"enabled": True, "provider": "summary"}
|
|
)
|
|
app = create_app(server_config=server_config)
|
|
|
|
mock_create.assert_called_once_with({"enabled": True, "provider": "summary"})
|
|
assert app.state.compressor is mock_compressor
|
|
|
|
def test_compressor_none_when_disabled(self):
|
|
from agentkit.server.app import create_app
|
|
|
|
with patch("agentkit.core.compressor.create_compressor") as mock_create:
|
|
# create_compressor returns None when disabled
|
|
mock_create.return_value = None
|
|
|
|
server_config = ServerConfig(
|
|
compression={"enabled": False}
|
|
)
|
|
app = create_app(server_config=server_config)
|
|
|
|
mock_create.assert_called_once_with({"enabled": False})
|
|
assert app.state.compressor is None
|
|
|
|
def test_compressor_none_when_no_config(self):
|
|
from agentkit.server.app import create_app
|
|
|
|
with patch("agentkit.core.compressor.create_compressor") as mock_create:
|
|
mock_create.return_value = None
|
|
|
|
# No server_config at all — also prevent auto-discovery of agentkit.yaml
|
|
with patch.dict(os.environ, {"AGENTKIT_CONFIG_PATH": ""}, clear=False):
|
|
# Prevent CWD agentkit.yaml from being auto-loaded
|
|
with patch("os.path.exists", return_value=False):
|
|
app = create_app()
|
|
|
|
# create_compressor should not be called (no server_config)
|
|
mock_create.assert_not_called()
|
|
assert app.state.compressor is None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# 3. ConfigDrivenAgent compressor passthrough
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestConfigDrivenAgentCompression:
|
|
"""Test compressor passthrough from ConfigDrivenAgent to ReActEngine"""
|
|
|
|
@pytest.fixture
|
|
def agent_config(self):
|
|
from agentkit.core.config_driven import AgentConfig
|
|
|
|
return AgentConfig(
|
|
name="test_agent",
|
|
agent_type="test",
|
|
task_mode="llm_generate",
|
|
prompt={"identity": "test", "instructions": "test"},
|
|
)
|
|
|
|
@pytest.fixture
|
|
def skill_config(self):
|
|
from agentkit.skills.base import SkillConfig
|
|
|
|
return SkillConfig(
|
|
name="test_skill",
|
|
agent_type="test",
|
|
description="test",
|
|
prompt={"identity": "test", "instructions": "test instructions"},
|
|
execution_mode="react",
|
|
)
|
|
|
|
def test_compressor_stored_on_agent(self, agent_config):
|
|
from agentkit.core.config_driven import ConfigDrivenAgent
|
|
|
|
mock_compressor = MagicMock()
|
|
agent = ConfigDrivenAgent(
|
|
config=agent_config,
|
|
compressor=mock_compressor,
|
|
)
|
|
assert agent._compressor is mock_compressor
|
|
|
|
def test_no_compressor_backward_compatible(self, agent_config):
|
|
from agentkit.core.config_driven import ConfigDrivenAgent
|
|
|
|
agent = ConfigDrivenAgent(config=agent_config)
|
|
assert agent._compressor is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_compressor_passed_to_react_engine(self, skill_config):
|
|
from agentkit.core.config_driven import ConfigDrivenAgent
|
|
|
|
mock_compressor = MagicMock()
|
|
|
|
with patch.object(
|
|
ConfigDrivenAgent, "__init__", return_value=None
|
|
) as mock_init:
|
|
# We need to test _handle_react directly, so set up the agent manually
|
|
agent = object.__new__(ConfigDrivenAgent)
|
|
agent._config = skill_config
|
|
agent._skill_config = skill_config
|
|
agent._prompt_template = None
|
|
agent._tools = []
|
|
agent._tool_registry = None
|
|
agent._memory_retriever = None
|
|
agent._compressor = mock_compressor
|
|
agent._evolution_enabled = False
|
|
agent._current_module = None
|
|
agent._active_tokens = {}
|
|
agent.name = "test_agent"
|
|
|
|
# Mock the ReActEngine
|
|
mock_engine = MagicMock()
|
|
mock_result = MagicMock()
|
|
mock_result.output = '{"result": "ok"}'
|
|
mock_engine.execute = AsyncMock(return_value=mock_result)
|
|
agent._react_engine = mock_engine
|
|
|
|
from agentkit.core.protocol import TaskMessage
|
|
from datetime import datetime, timezone
|
|
|
|
task = TaskMessage(
|
|
task_id="t1",
|
|
agent_name="test_agent",
|
|
task_type="test",
|
|
input_data={"query": "hello"},
|
|
priority=1,
|
|
created_at=datetime.now(timezone.utc),
|
|
callback_url=None,
|
|
)
|
|
|
|
await agent._handle_react(task)
|
|
|
|
# Verify compressor was passed to execute
|
|
mock_engine.execute.assert_called_once()
|
|
call_kwargs = mock_engine.execute.call_args.kwargs
|
|
assert call_kwargs["compressor"] is mock_compressor
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_no_compressor_backward_compatible_react(self, skill_config):
|
|
from agentkit.core.config_driven import ConfigDrivenAgent
|
|
|
|
with patch.object(
|
|
ConfigDrivenAgent, "__init__", return_value=None
|
|
):
|
|
agent = object.__new__(ConfigDrivenAgent)
|
|
agent._config = skill_config
|
|
agent._skill_config = skill_config
|
|
agent._prompt_template = None
|
|
agent._tools = []
|
|
agent._tool_registry = None
|
|
agent._memory_retriever = None
|
|
agent._compressor = None
|
|
agent._evolution_enabled = False
|
|
agent._current_module = None
|
|
agent._active_tokens = {}
|
|
agent.name = "test_agent"
|
|
|
|
mock_engine = MagicMock()
|
|
mock_result = MagicMock()
|
|
mock_result.output = '{"result": "ok"}'
|
|
mock_engine.execute = AsyncMock(return_value=mock_result)
|
|
agent._react_engine = mock_engine
|
|
|
|
from agentkit.core.protocol import TaskMessage
|
|
from datetime import datetime, timezone
|
|
|
|
task = TaskMessage(
|
|
task_id="t2",
|
|
agent_name="test_agent",
|
|
task_type="test",
|
|
input_data={"query": "hello"},
|
|
priority=1,
|
|
created_at=datetime.now(timezone.utc),
|
|
callback_url=None,
|
|
)
|
|
|
|
await agent._handle_react(task)
|
|
|
|
call_kwargs = mock_engine.execute.call_args.kwargs
|
|
assert call_kwargs["compressor"] is None
|