401 lines
15 KiB
Python
401 lines
15 KiB
Python
"""G4/U1 — Auxiliary LLM routing in ContextCompressor.
|
|
|
|
Verifies:
|
|
- auxiliary_model routes _summarize through the cheaper model first
|
|
- empty content (Finding 4 anti-pattern) triggers fallback to main model
|
|
- auxiliary exception triggers fallback to main model
|
|
- both auxiliary and main failing falls through to _simple_summary
|
|
- auxiliary_model=None preserves existing single-model behavior (characterization)
|
|
- config wiring (LLMConfig.from_dict, ServerConfig._build_llm_config)
|
|
"""
|
|
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
from agentkit.core.compressor import ContextCompressor
|
|
from agentkit.llm.config import LLMConfig
|
|
from agentkit.llm.protocol import LLMResponse, TokenUsage
|
|
|
|
|
|
# ── Helpers ──────────────────────────────────────────
|
|
|
|
|
|
def make_gateway_with_response(content: str, model: str = "test") -> MagicMock:
|
|
"""Mock LLMGateway returning a fixed response."""
|
|
from agentkit.llm.gateway import LLMGateway
|
|
|
|
gateway = MagicMock(spec=LLMGateway)
|
|
gateway.chat = AsyncMock(
|
|
return_value=LLMResponse(
|
|
content=content,
|
|
model=model,
|
|
usage=TokenUsage(prompt_tokens=10, completion_tokens=10),
|
|
)
|
|
)
|
|
return gateway
|
|
|
|
|
|
def make_gateway_side_effect(responses_by_model: dict[str, LLMResponse | Exception]) -> MagicMock:
|
|
"""Mock LLMGateway returning different responses (or raising) keyed by model name.
|
|
|
|
Each call to gateway.chat(model=X) pops the next response for X from a queue,
|
|
so repeated calls to the same model can return different values.
|
|
"""
|
|
from agentkit.llm.gateway import LLMGateway
|
|
|
|
gateway = MagicMock(spec=LLMGateway)
|
|
queues = {m: list(rs) for m, rs in responses_by_model.items()}
|
|
|
|
async def chat_side_effect(*, messages, model, **kwargs):
|
|
queue = queues.get(model)
|
|
if queue is None:
|
|
raise ValueError(f"unexpected model={model}")
|
|
if not queue:
|
|
raise ValueError(f"queue for model={model} exhausted")
|
|
item = queue.pop(0)
|
|
if isinstance(item, Exception):
|
|
raise item
|
|
return item
|
|
|
|
gateway.chat = AsyncMock(side_effect=chat_side_effect)
|
|
return gateway
|
|
|
|
|
|
def make_long_messages(count: int = 4, content_length: int = 2000) -> list[dict]:
|
|
"""Generate long messages that exceed token budget (triggers compression)."""
|
|
messages = [{"role": "system", "content": "You are a helpful assistant."}]
|
|
for i in range(count):
|
|
messages.append({"role": "user", "content": "x" * content_length + f" m{i}"})
|
|
messages.append({"role": "assistant", "content": "y" * content_length + f" r{i}"})
|
|
messages.append({"role": "user", "content": "recent question"})
|
|
messages.append({"role": "assistant", "content": "recent answer"})
|
|
return messages
|
|
|
|
|
|
# ── Characterization: auxiliary_model=None preserves existing behavior ──
|
|
|
|
|
|
class TestAuxiliaryNoneCharacterization:
|
|
"""auxiliary_model=None (default) — single model call, existing behavior."""
|
|
|
|
async def test_no_auxiliary_calls_main_once(self):
|
|
gateway = make_gateway_with_response("main summary")
|
|
compressor = ContextCompressor(
|
|
llm_gateway=gateway,
|
|
max_tokens=100,
|
|
keep_recent=2,
|
|
model="main",
|
|
# auxiliary_model omitted → None
|
|
)
|
|
result = await compressor.compress(make_long_messages())
|
|
|
|
gateway.chat.assert_awaited_once()
|
|
# The call used the main model
|
|
assert gateway.chat.await_args.kwargs.get("model") == "main"
|
|
# Summary surfaced in result
|
|
summary_msgs = [
|
|
m
|
|
for m in result
|
|
if m.get("role") == "system" and "Conversation Summary" in m.get("content", "")
|
|
]
|
|
assert any("main summary" in m["content"] for m in summary_msgs)
|
|
|
|
async def test_main_failure_falls_to_simple_summary(self):
|
|
gateway = MagicMock()
|
|
gateway.chat = AsyncMock(side_effect=Exception("main LLM error"))
|
|
compressor = ContextCompressor(
|
|
llm_gateway=gateway,
|
|
max_tokens=100,
|
|
keep_recent=2,
|
|
model="main",
|
|
)
|
|
result = await compressor.compress(make_long_messages())
|
|
|
|
# _simple_summary produces truncated messages with "..."
|
|
summary_msgs = [
|
|
m
|
|
for m in result
|
|
if m.get("role") == "system" and "Conversation Summary" in m.get("content", "")
|
|
]
|
|
assert len(summary_msgs) == 1
|
|
assert "..." in summary_msgs[0]["content"]
|
|
|
|
|
|
# ── New behavior: auxiliary routing ──────────────────
|
|
|
|
|
|
class TestAuxiliaryRouting:
|
|
"""auxiliary_model set and differs from main → auxiliary tried first."""
|
|
|
|
async def test_auxiliary_success_returns_auxiliary_content(self):
|
|
gateway = make_gateway_side_effect(
|
|
{
|
|
"fast": [
|
|
LLMResponse(
|
|
content="aux summary",
|
|
model="fast",
|
|
usage=TokenUsage(prompt_tokens=1, completion_tokens=1),
|
|
)
|
|
],
|
|
"main": [
|
|
LLMResponse(
|
|
content="MAIN SHOULD NOT BE USED",
|
|
model="main",
|
|
usage=TokenUsage(prompt_tokens=1, completion_tokens=1),
|
|
)
|
|
],
|
|
}
|
|
)
|
|
compressor = ContextCompressor(
|
|
llm_gateway=gateway,
|
|
max_tokens=100,
|
|
keep_recent=2,
|
|
model="main",
|
|
auxiliary_model="fast",
|
|
)
|
|
result = await compressor.compress(make_long_messages())
|
|
|
|
# Auxiliary called; main NOT called
|
|
aux_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "fast"]
|
|
main_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "main"]
|
|
assert len(aux_calls) == 1
|
|
assert len(main_calls) == 0
|
|
# Result contains auxiliary summary
|
|
summary_msgs = [
|
|
m
|
|
for m in result
|
|
if m.get("role") == "system" and "Conversation Summary" in m.get("content", "")
|
|
]
|
|
assert any("aux summary" in m["content"] for m in summary_msgs)
|
|
|
|
async def test_empty_content_triggers_main_fallback(self):
|
|
"""Finding 4 anti-pattern: empty content is a failure, not a success."""
|
|
gateway = make_gateway_side_effect(
|
|
{
|
|
"fast": [
|
|
LLMResponse(
|
|
content="",
|
|
model="fast",
|
|
usage=TokenUsage(prompt_tokens=1, completion_tokens=0),
|
|
)
|
|
],
|
|
"main": [
|
|
LLMResponse(
|
|
content="main summary",
|
|
model="main",
|
|
usage=TokenUsage(prompt_tokens=1, completion_tokens=1),
|
|
)
|
|
],
|
|
}
|
|
)
|
|
compressor = ContextCompressor(
|
|
llm_gateway=gateway,
|
|
max_tokens=100,
|
|
keep_recent=2,
|
|
model="main",
|
|
auxiliary_model="fast",
|
|
)
|
|
result = await compressor.compress(make_long_messages())
|
|
|
|
# Auxiliary called once (returned empty)
|
|
# Main called once (fallback)
|
|
aux_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "fast"]
|
|
main_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "main"]
|
|
assert len(aux_calls) == 1
|
|
assert len(main_calls) == 1
|
|
# Result contains main summary (not the empty auxiliary)
|
|
summary_msgs = [
|
|
m
|
|
for m in result
|
|
if m.get("role") == "system" and "Conversation Summary" in m.get("content", "")
|
|
]
|
|
assert any("main summary" in m["content"] for m in summary_msgs)
|
|
|
|
async def test_whitespace_content_triggers_main_fallback(self):
|
|
"""Whitespace-only content also counts as empty (Finding 4)."""
|
|
gateway = make_gateway_side_effect(
|
|
{
|
|
"fast": [
|
|
LLMResponse(
|
|
content=" \n ",
|
|
model="fast",
|
|
usage=TokenUsage(prompt_tokens=1, completion_tokens=0),
|
|
)
|
|
],
|
|
"main": [
|
|
LLMResponse(
|
|
content="main summary",
|
|
model="main",
|
|
usage=TokenUsage(prompt_tokens=1, completion_tokens=1),
|
|
)
|
|
],
|
|
}
|
|
)
|
|
compressor = ContextCompressor(
|
|
llm_gateway=gateway,
|
|
max_tokens=100,
|
|
keep_recent=2,
|
|
model="main",
|
|
auxiliary_model="fast",
|
|
)
|
|
await compressor.compress(make_long_messages())
|
|
|
|
# Both auxiliary and main called
|
|
aux_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "fast"]
|
|
main_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "main"]
|
|
assert len(aux_calls) == 1
|
|
assert len(main_calls) == 1
|
|
|
|
async def test_auxiliary_exception_triggers_main_fallback(self):
|
|
from agentkit.core.exceptions import LLMProviderError
|
|
|
|
gateway = make_gateway_side_effect(
|
|
{
|
|
"fast": [LLMProviderError("aux", "provider down")],
|
|
"main": [
|
|
LLMResponse(
|
|
content="main summary",
|
|
model="main",
|
|
usage=TokenUsage(prompt_tokens=1, completion_tokens=1),
|
|
)
|
|
],
|
|
}
|
|
)
|
|
compressor = ContextCompressor(
|
|
llm_gateway=gateway,
|
|
max_tokens=100,
|
|
keep_recent=2,
|
|
model="main",
|
|
auxiliary_model="fast",
|
|
)
|
|
result = await compressor.compress(make_long_messages())
|
|
|
|
# Both called; main succeeded
|
|
aux_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "fast"]
|
|
main_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "main"]
|
|
assert len(aux_calls) == 1
|
|
assert len(main_calls) == 1
|
|
summary_msgs = [
|
|
m
|
|
for m in result
|
|
if m.get("role") == "system" and "Conversation Summary" in m.get("content", "")
|
|
]
|
|
assert any("main summary" in m["content"] for m in summary_msgs)
|
|
|
|
async def test_both_fail_falls_to_simple_summary(self):
|
|
"""Auxiliary raises, main raises → existing _simple_summary degradation."""
|
|
# Note: aggressive compression path may invoke _summarize multiple times.
|
|
# Queue provides enough responses to handle that without raising queue-exhausted.
|
|
gateway = make_gateway_side_effect(
|
|
{
|
|
"fast": [Exception("aux boom")] * 5,
|
|
"main": [Exception("main boom")] * 5,
|
|
}
|
|
)
|
|
compressor = ContextCompressor(
|
|
llm_gateway=gateway,
|
|
max_tokens=100,
|
|
keep_recent=2,
|
|
model="main",
|
|
auxiliary_model="fast",
|
|
)
|
|
result = await compressor.compress(make_long_messages())
|
|
|
|
# Both called at least once
|
|
aux_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "fast"]
|
|
main_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "main"]
|
|
assert len(aux_calls) >= 1
|
|
assert len(main_calls) >= 1
|
|
# _simple_summary output has "..." truncation markers
|
|
summary_msgs = [
|
|
m
|
|
for m in result
|
|
if m.get("role") == "system" and "Conversation Summary" in m.get("content", "")
|
|
]
|
|
assert len(summary_msgs) == 1
|
|
assert "..." in summary_msgs[0]["content"]
|
|
|
|
async def test_auxiliary_equal_to_main_skipped(self):
|
|
"""auxiliary_model == model → no auxiliary routing (single call to main)."""
|
|
gateway = make_gateway_with_response("main summary")
|
|
compressor = ContextCompressor(
|
|
llm_gateway=gateway,
|
|
max_tokens=100,
|
|
keep_recent=2,
|
|
model="main",
|
|
auxiliary_model="main", # same as main
|
|
)
|
|
await compressor.compress(make_long_messages())
|
|
|
|
# Only one call (to main); auxiliary block skipped
|
|
assert gateway.chat.await_count == 1
|
|
assert gateway.chat.await_args.kwargs.get("model") == "main"
|
|
|
|
async def test_audit_fields_preserved(self):
|
|
"""Auxiliary call uses agent_name='compressor', task_type='summarization'."""
|
|
gateway = make_gateway_with_response("aux summary")
|
|
compressor = ContextCompressor(
|
|
llm_gateway=gateway,
|
|
max_tokens=100,
|
|
keep_recent=2,
|
|
model="main",
|
|
auxiliary_model="fast",
|
|
)
|
|
# Override the mock to use a single-response gateway where auxiliary succeeds
|
|
# (the make_gateway_with_response mock returns same response regardless of model)
|
|
await compressor.compress(make_long_messages())
|
|
|
|
# Single call (auxiliary succeeded) — verify audit fields
|
|
call_kwargs = gateway.chat.await_args.kwargs
|
|
assert call_kwargs.get("agent_name") == "compressor"
|
|
assert call_kwargs.get("task_type") == "summarization"
|
|
|
|
|
|
# ── Config wiring ────────────────────────────────────
|
|
|
|
|
|
class TestConfigWiring:
|
|
"""LLMConfig + ServerConfig read auxiliary_model from dict."""
|
|
|
|
def test_llm_config_from_dict_reads_auxiliary_model(self):
|
|
cfg = LLMConfig.from_dict(
|
|
{
|
|
"providers": {},
|
|
"model_aliases": {"fast": "p/m"},
|
|
"auxiliary_model": "fast",
|
|
}
|
|
)
|
|
assert cfg.auxiliary_model == "fast"
|
|
|
|
def test_llm_config_from_dict_auxiliary_none_when_absent(self):
|
|
cfg = LLMConfig.from_dict({"providers": {}})
|
|
assert cfg.auxiliary_model is None
|
|
|
|
def test_llm_config_default_auxiliary_none(self):
|
|
cfg = LLMConfig()
|
|
assert cfg.auxiliary_model is None
|
|
|
|
def test_server_config_build_llm_config_reads_auxiliary_model(self):
|
|
from agentkit.server.config import ServerConfig
|
|
|
|
llm_data = {
|
|
"providers": {
|
|
"p": {
|
|
"type": "openai",
|
|
"api_key": "k",
|
|
"base_url": "http://x",
|
|
"models": {"m": {"alias": "fast"}},
|
|
}
|
|
},
|
|
"auxiliary_model": "fast",
|
|
}
|
|
llm_config = ServerConfig._build_llm_config(llm_data)
|
|
assert llm_config.auxiliary_model == "fast"
|
|
# Also verify model_aliases still built correctly
|
|
assert llm_config.model_aliases.get("fast") == "p/m"
|
|
|
|
def test_server_config_build_llm_config_auxiliary_none_when_absent(self):
|
|
from agentkit.server.config import ServerConfig
|
|
|
|
llm_config = ServerConfig._build_llm_config({"providers": {}})
|
|
assert llm_config.auxiliary_model is None
|