feat(U1): G4 ContextCompressor 辅助 LLM 路由
_summarize 优先尝试 auxiliary_model(成本优化的廉价模型,如 qwen-turbo), 失败或返回空内容(Finding 4 反模式)时回退到主模型,主模型失败仍走 _simple_summary 兜底。auxiliary_model=None 时保持既有单模型调用行为。 - ContextCompressor 新增 auxiliary_model 参数 - LLMConfig 新增 auxiliary_model 字段,ServerConfig._build_llm_config 透传 - agentkit.yaml 文档化 llm.auxiliary_model: fast(注释,保留默认行为) - 测试: 9 场景覆盖成功/空内容/异常/双向失败/aux=main 跳过/审计字段/配置接线
This commit is contained in:
parent
88bfe71d30
commit
8d5ccca604
|
|
@ -28,6 +28,11 @@ llm:
|
||||||
coding: bailian-coding/qwen3-coder-plus
|
coding: bailian-coding/qwen3-coder-plus
|
||||||
chat: deepseek/deepseek-chat
|
chat: deepseek/deepseek-chat
|
||||||
reasoning: deepseek/deepseek-reasoner
|
reasoning: deepseek/deepseek-reasoner
|
||||||
|
# G4/U1: Auxiliary model for cost-sensitive tasks (summarization).
|
||||||
|
# When set, ContextCompressor tries this alias first, falling back to
|
||||||
|
# the main model on failure or empty content. Commented to preserve
|
||||||
|
# default behavior — uncomment to enable.
|
||||||
|
# auxiliary_model: fast
|
||||||
session: {backend: memory}
|
session: {backend: memory}
|
||||||
bus: {backend: memory}
|
bus: {backend: memory}
|
||||||
task_store: {backend: memory}
|
task_store: {backend: memory}
|
||||||
|
|
|
||||||
|
|
@ -41,6 +41,7 @@ class ContextCompressor:
|
||||||
model_context_limit: int = 128_000,
|
model_context_limit: int = 128_000,
|
||||||
headroom_threshold: float = 0.8,
|
headroom_threshold: float = 0.8,
|
||||||
min_tokens: int = 8_000,
|
min_tokens: int = 8_000,
|
||||||
|
auxiliary_model: str | None = None,
|
||||||
):
|
):
|
||||||
self._llm_gateway = llm_gateway
|
self._llm_gateway = llm_gateway
|
||||||
self._max_tokens = max_tokens
|
self._max_tokens = max_tokens
|
||||||
|
|
@ -51,6 +52,11 @@ class ContextCompressor:
|
||||||
self._model_context_limit = model_context_limit
|
self._model_context_limit = model_context_limit
|
||||||
self._headroom_threshold = headroom_threshold
|
self._headroom_threshold = headroom_threshold
|
||||||
self._min_tokens = min_tokens
|
self._min_tokens = min_tokens
|
||||||
|
# G4/U1: Auxiliary model for cost-sensitive summarization (e.g. "fast" alias).
|
||||||
|
# When set and differs from main model, _summarize tries auxiliary first,
|
||||||
|
# falls back to main model on failure OR empty content (Finding 4 anti-pattern).
|
||||||
|
# ponytail: ceiling — auxiliary is best-effort; main model is authoritative fallback.
|
||||||
|
self._auxiliary_model = auxiliary_model
|
||||||
|
|
||||||
def should_compress(self, messages: list[dict]) -> bool:
|
def should_compress(self, messages: list[dict]) -> bool:
|
||||||
"""Check if compression should be triggered based on headroom ratio.
|
"""Check if compression should be triggered based on headroom ratio.
|
||||||
|
|
@ -92,8 +98,8 @@ class ContextCompressor:
|
||||||
if len(non_system) <= self._keep_recent:
|
if len(non_system) <= self._keep_recent:
|
||||||
return messages # Not enough messages to compress
|
return messages # Not enough messages to compress
|
||||||
|
|
||||||
old_msgs = non_system[:-self._keep_recent]
|
old_msgs = non_system[: -self._keep_recent]
|
||||||
recent_msgs = non_system[-self._keep_recent:]
|
recent_msgs = non_system[-self._keep_recent :]
|
||||||
|
|
||||||
# Compress old messages
|
# Compress old messages
|
||||||
summary = await self._summarize(old_msgs)
|
summary = await self._summarize(old_msgs)
|
||||||
|
|
@ -101,10 +107,12 @@ class ContextCompressor:
|
||||||
# Build compressed message list
|
# Build compressed message list
|
||||||
compressed = list(system_msgs)
|
compressed = list(system_msgs)
|
||||||
if summary:
|
if summary:
|
||||||
compressed.append({
|
compressed.append(
|
||||||
"role": "system",
|
{
|
||||||
"content": f"## Conversation Summary\n{summary}",
|
"role": "system",
|
||||||
})
|
"content": f"## Conversation Summary\n{summary}",
|
||||||
|
}
|
||||||
|
)
|
||||||
compressed.extend(recent_msgs)
|
compressed.extend(recent_msgs)
|
||||||
|
|
||||||
# Recursive check: if still over budget, compress again
|
# Recursive check: if still over budget, compress again
|
||||||
|
|
@ -114,22 +122,30 @@ class ContextCompressor:
|
||||||
return self._truncate(compressed)
|
return self._truncate(compressed)
|
||||||
if len(recent_msgs) > 1:
|
if len(recent_msgs) > 1:
|
||||||
# Try keeping fewer recent messages
|
# Try keeping fewer recent messages
|
||||||
return await self._compress_aggressive(messages, _compression_depth=_compression_depth + 1)
|
return await self._compress_aggressive(
|
||||||
|
messages, _compression_depth=_compression_depth + 1
|
||||||
|
)
|
||||||
# Last resort: truncate
|
# Last resort: truncate
|
||||||
return self._truncate(compressed)
|
return self._truncate(compressed)
|
||||||
|
|
||||||
return compressed
|
return compressed
|
||||||
|
|
||||||
async def _summarize(self, messages: list[dict], max_input_tokens: int = 3200) -> str:
|
async def _summarize(self, messages: list[dict], max_input_tokens: int = 3200) -> str:
|
||||||
"""Summarize a list of messages using LLM"""
|
"""Summarize a list of messages using LLM.
|
||||||
|
|
||||||
|
G4/U1: When ``auxiliary_model`` is configured and differs from the main
|
||||||
|
model, try auxiliary first (cost-optimization). On auxiliary failure OR
|
||||||
|
empty content (Finding 4 anti-pattern — "did not throw is not succeeded"),
|
||||||
|
fall back to main model. Existing ``_simple_summary`` degradation
|
||||||
|
preserved as the final tier when main model also fails.
|
||||||
|
"""
|
||||||
if not self._llm_gateway:
|
if not self._llm_gateway:
|
||||||
# No LLM available, do simple truncation
|
# No LLM available, do simple truncation
|
||||||
return self._simple_summary(messages)
|
return self._simple_summary(messages)
|
||||||
|
|
||||||
# Build summary prompt
|
# Build summary prompt
|
||||||
conversation_text = "\n".join(
|
conversation_text = "\n".join(
|
||||||
f"[{m.get('role', 'unknown')}]: {m.get('content', '')}"
|
f"[{m.get('role', 'unknown')}]: {m.get('content', '')}" for m in messages
|
||||||
for m in messages
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Pre-truncate if conversation_text exceeds safe token threshold
|
# Pre-truncate if conversation_text exceeds safe token threshold
|
||||||
|
|
@ -145,6 +161,25 @@ class ContextCompressor:
|
||||||
f"{conversation_text}"
|
f"{conversation_text}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# G4: Try auxiliary model first when configured (cheap route).
|
||||||
|
if self._auxiliary_model and self._auxiliary_model != self._model:
|
||||||
|
try:
|
||||||
|
response = await self._llm_gateway.chat(
|
||||||
|
messages=[{"role": "user", "content": prompt}],
|
||||||
|
model=self._auxiliary_model,
|
||||||
|
agent_name="compressor",
|
||||||
|
task_type="summarization",
|
||||||
|
)
|
||||||
|
# Finding 4: empty content is a failure, not a success.
|
||||||
|
if response.content and response.content.strip():
|
||||||
|
return response.content
|
||||||
|
logger.info("Auxiliary model returned empty content, falling back to main model")
|
||||||
|
except Exception as e:
|
||||||
|
logger.info(
|
||||||
|
f"Auxiliary model summarization failed, falling back to main model: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Main model path (or auxiliary fallback).
|
||||||
try:
|
try:
|
||||||
response = await self._llm_gateway.chat(
|
response = await self._llm_gateway.chat(
|
||||||
messages=[{"role": "user", "content": prompt}],
|
messages=[{"role": "user", "content": prompt}],
|
||||||
|
|
@ -166,7 +201,9 @@ class ContextCompressor:
|
||||||
parts.append(f"[{role}]: {content}...")
|
parts.append(f"[{role}]: {content}...")
|
||||||
return "\n".join(parts)
|
return "\n".join(parts)
|
||||||
|
|
||||||
async def _compress_aggressive(self, messages: list[dict], _compression_depth: int = 0) -> list[dict]:
|
async def _compress_aggressive(
|
||||||
|
self, messages: list[dict], _compression_depth: int = 0
|
||||||
|
) -> list[dict]:
|
||||||
"""More aggressive compression when standard compression isn't enough"""
|
"""More aggressive compression when standard compression isn't enough"""
|
||||||
system_msgs = [m for m in messages if m.get("role") == "system"]
|
system_msgs = [m for m in messages if m.get("role") == "system"]
|
||||||
non_system = [m for m in messages if m.get("role") != "system"]
|
non_system = [m for m in messages if m.get("role") != "system"]
|
||||||
|
|
@ -176,10 +213,12 @@ class ContextCompressor:
|
||||||
summary = await self._summarize(non_system[:-1])
|
summary = await self._summarize(non_system[:-1])
|
||||||
compressed = list(system_msgs)
|
compressed = list(system_msgs)
|
||||||
if summary:
|
if summary:
|
||||||
compressed.append({
|
compressed.append(
|
||||||
"role": "system",
|
{
|
||||||
"content": f"## Conversation Summary\n{summary}",
|
"role": "system",
|
||||||
})
|
"content": f"## Conversation Summary\n{summary}",
|
||||||
|
}
|
||||||
|
)
|
||||||
compressed.append(non_system[-1])
|
compressed.append(non_system[-1])
|
||||||
return compressed
|
return compressed
|
||||||
|
|
||||||
|
|
@ -191,7 +230,7 @@ class ContextCompressor:
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
content = str(msg.get("content", ""))
|
content = str(msg.get("content", ""))
|
||||||
if len(content) > self._max_tokens * 4:
|
if len(content) > self._max_tokens * 4:
|
||||||
msg = {**msg, "content": content[:self._max_tokens * 4] + "...[truncated]"}
|
msg = {**msg, "content": content[: self._max_tokens * 4] + "...[truncated]"}
|
||||||
result.append(msg)
|
result.append(msg)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
@ -226,6 +265,7 @@ def create_compressor(config: dict[str, Any] | None = None) -> CompressionStrate
|
||||||
if provider == "headroom":
|
if provider == "headroom":
|
||||||
try:
|
try:
|
||||||
from agentkit.core.headroom_compressor import HeadroomCompressor
|
from agentkit.core.headroom_compressor import HeadroomCompressor
|
||||||
|
|
||||||
compressor = HeadroomCompressor(config)
|
compressor = HeadroomCompressor(config)
|
||||||
if compressor.is_available():
|
if compressor.is_available():
|
||||||
return compressor
|
return compressor
|
||||||
|
|
@ -235,8 +275,7 @@ def create_compressor(config: dict[str, Any] | None = None) -> CompressionStrate
|
||||||
)
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"HeadroomCompressor module not available. "
|
"HeadroomCompressor module not available. Falling back to ContextCompressor."
|
||||||
"Falling back to ContextCompressor."
|
|
||||||
)
|
)
|
||||||
# Fallback to summary compressor
|
# Fallback to summary compressor
|
||||||
return ContextCompressor(
|
return ContextCompressor(
|
||||||
|
|
@ -253,11 +292,9 @@ def create_compressor(config: dict[str, Any] | None = None) -> CompressionStrate
|
||||||
|
|
||||||
def render_cached(template, variables: dict[str, Any] | None = None) -> list[dict[str, str]]:
|
def render_cached(template, variables: dict[str, Any] | None = None) -> list[dict[str, str]]:
|
||||||
"""Render PromptTemplate with caching - returns cached result for same variables"""
|
"""Render PromptTemplate with caching - returns cached result for same variables"""
|
||||||
cache_key = hashlib.md5(
|
cache_key = hashlib.md5(json.dumps(variables or {}, sort_keys=True).encode()).hexdigest()
|
||||||
json.dumps(variables or {}, sort_keys=True).encode()
|
|
||||||
).hexdigest()
|
|
||||||
|
|
||||||
if not hasattr(template, '_render_cache'):
|
if not hasattr(template, "_render_cache"):
|
||||||
template._render_cache = {}
|
template._render_cache = {}
|
||||||
|
|
||||||
if cache_key in template._render_cache:
|
if cache_key in template._render_cache:
|
||||||
|
|
@ -270,5 +307,5 @@ def render_cached(template, variables: dict[str, Any] | None = None) -> list[dic
|
||||||
|
|
||||||
def clear_cache(template) -> None:
|
def clear_cache(template) -> None:
|
||||||
"""Clear the render cache on a PromptTemplate instance"""
|
"""Clear the render cache on a PromptTemplate instance"""
|
||||||
if hasattr(template, '_render_cache'):
|
if hasattr(template, "_render_cache"):
|
||||||
template._render_cache.clear()
|
template._render_cache.clear()
|
||||||
|
|
|
||||||
|
|
@ -203,6 +203,9 @@ class LLMConfig:
|
||||||
model_aliases: dict[str, str] = field(default_factory=dict)
|
model_aliases: dict[str, str] = field(default_factory=dict)
|
||||||
fallbacks: dict[str, list[str]] = field(default_factory=dict)
|
fallbacks: dict[str, list[str]] = field(default_factory=dict)
|
||||||
cache: CacheConfig | None = None
|
cache: CacheConfig | None = None
|
||||||
|
# G4/U1: Auxiliary model alias for cost-sensitive tasks (e.g. summarization).
|
||||||
|
# Resolved via existing model_aliases mechanism. None = use main model only.
|
||||||
|
auxiliary_model: str | None = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: dict) -> "LLMConfig":
|
def from_dict(cls, data: dict) -> "LLMConfig":
|
||||||
|
|
@ -254,4 +257,5 @@ class LLMConfig:
|
||||||
model_aliases=data.get("model_aliases", {}),
|
model_aliases=data.get("model_aliases", {}),
|
||||||
fallbacks=data.get("fallbacks", {}),
|
fallbacks=data.get("fallbacks", {}),
|
||||||
cache=cache,
|
cache=cache,
|
||||||
|
auxiliary_model=data.get("auxiliary_model"),
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -320,6 +320,9 @@ class ServerConfig:
|
||||||
model_aliases=model_aliases,
|
model_aliases=model_aliases,
|
||||||
fallbacks=data.get("fallbacks", {}),
|
fallbacks=data.get("fallbacks", {}),
|
||||||
cache=cache_config,
|
cache=cache_config,
|
||||||
|
# G4/U1: auxiliary model alias for cost-sensitive summarization.
|
||||||
|
# Resolved via model_aliases; None = no auxiliary routing.
|
||||||
|
auxiliary_model=data.get("auxiliary_model"),
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,400 @@
|
||||||
|
"""G4/U1 — Auxiliary LLM routing in ContextCompressor.
|
||||||
|
|
||||||
|
Verifies:
|
||||||
|
- auxiliary_model routes _summarize through the cheaper model first
|
||||||
|
- empty content (Finding 4 anti-pattern) triggers fallback to main model
|
||||||
|
- auxiliary exception triggers fallback to main model
|
||||||
|
- both auxiliary and main failing falls through to _simple_summary
|
||||||
|
- auxiliary_model=None preserves existing single-model behavior (characterization)
|
||||||
|
- config wiring (LLMConfig.from_dict, ServerConfig._build_llm_config)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
from agentkit.core.compressor import ContextCompressor
|
||||||
|
from agentkit.llm.config import LLMConfig
|
||||||
|
from agentkit.llm.protocol import LLMResponse, TokenUsage
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ──────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def make_gateway_with_response(content: str, model: str = "test") -> MagicMock:
|
||||||
|
"""Mock LLMGateway returning a fixed response."""
|
||||||
|
from agentkit.llm.gateway import LLMGateway
|
||||||
|
|
||||||
|
gateway = MagicMock(spec=LLMGateway)
|
||||||
|
gateway.chat = AsyncMock(
|
||||||
|
return_value=LLMResponse(
|
||||||
|
content=content,
|
||||||
|
model=model,
|
||||||
|
usage=TokenUsage(prompt_tokens=10, completion_tokens=10),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return gateway
|
||||||
|
|
||||||
|
|
||||||
|
def make_gateway_side_effect(responses_by_model: dict[str, LLMResponse | Exception]) -> MagicMock:
|
||||||
|
"""Mock LLMGateway returning different responses (or raising) keyed by model name.
|
||||||
|
|
||||||
|
Each call to gateway.chat(model=X) pops the next response for X from a queue,
|
||||||
|
so repeated calls to the same model can return different values.
|
||||||
|
"""
|
||||||
|
from agentkit.llm.gateway import LLMGateway
|
||||||
|
|
||||||
|
gateway = MagicMock(spec=LLMGateway)
|
||||||
|
queues = {m: list(rs) for m, rs in responses_by_model.items()}
|
||||||
|
|
||||||
|
async def chat_side_effect(*, messages, model, **kwargs):
|
||||||
|
queue = queues.get(model)
|
||||||
|
if queue is None:
|
||||||
|
raise ValueError(f"unexpected model={model}")
|
||||||
|
if not queue:
|
||||||
|
raise ValueError(f"queue for model={model} exhausted")
|
||||||
|
item = queue.pop(0)
|
||||||
|
if isinstance(item, Exception):
|
||||||
|
raise item
|
||||||
|
return item
|
||||||
|
|
||||||
|
gateway.chat = AsyncMock(side_effect=chat_side_effect)
|
||||||
|
return gateway
|
||||||
|
|
||||||
|
|
||||||
|
def make_long_messages(count: int = 4, content_length: int = 2000) -> list[dict]:
|
||||||
|
"""Generate long messages that exceed token budget (triggers compression)."""
|
||||||
|
messages = [{"role": "system", "content": "You are a helpful assistant."}]
|
||||||
|
for i in range(count):
|
||||||
|
messages.append({"role": "user", "content": "x" * content_length + f" m{i}"})
|
||||||
|
messages.append({"role": "assistant", "content": "y" * content_length + f" r{i}"})
|
||||||
|
messages.append({"role": "user", "content": "recent question"})
|
||||||
|
messages.append({"role": "assistant", "content": "recent answer"})
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
# ── Characterization: auxiliary_model=None preserves existing behavior ──
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuxiliaryNoneCharacterization:
|
||||||
|
"""auxiliary_model=None (default) — single model call, existing behavior."""
|
||||||
|
|
||||||
|
async def test_no_auxiliary_calls_main_once(self):
|
||||||
|
gateway = make_gateway_with_response("main summary")
|
||||||
|
compressor = ContextCompressor(
|
||||||
|
llm_gateway=gateway,
|
||||||
|
max_tokens=100,
|
||||||
|
keep_recent=2,
|
||||||
|
model="main",
|
||||||
|
# auxiliary_model omitted → None
|
||||||
|
)
|
||||||
|
result = await compressor.compress(make_long_messages())
|
||||||
|
|
||||||
|
gateway.chat.assert_awaited_once()
|
||||||
|
# The call used the main model
|
||||||
|
assert gateway.chat.await_args.kwargs.get("model") == "main"
|
||||||
|
# Summary surfaced in result
|
||||||
|
summary_msgs = [
|
||||||
|
m
|
||||||
|
for m in result
|
||||||
|
if m.get("role") == "system" and "Conversation Summary" in m.get("content", "")
|
||||||
|
]
|
||||||
|
assert any("main summary" in m["content"] for m in summary_msgs)
|
||||||
|
|
||||||
|
async def test_main_failure_falls_to_simple_summary(self):
|
||||||
|
gateway = MagicMock()
|
||||||
|
gateway.chat = AsyncMock(side_effect=Exception("main LLM error"))
|
||||||
|
compressor = ContextCompressor(
|
||||||
|
llm_gateway=gateway,
|
||||||
|
max_tokens=100,
|
||||||
|
keep_recent=2,
|
||||||
|
model="main",
|
||||||
|
)
|
||||||
|
result = await compressor.compress(make_long_messages())
|
||||||
|
|
||||||
|
# _simple_summary produces truncated messages with "..."
|
||||||
|
summary_msgs = [
|
||||||
|
m
|
||||||
|
for m in result
|
||||||
|
if m.get("role") == "system" and "Conversation Summary" in m.get("content", "")
|
||||||
|
]
|
||||||
|
assert len(summary_msgs) == 1
|
||||||
|
assert "..." in summary_msgs[0]["content"]
|
||||||
|
|
||||||
|
|
||||||
|
# ── New behavior: auxiliary routing ──────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuxiliaryRouting:
|
||||||
|
"""auxiliary_model set and differs from main → auxiliary tried first."""
|
||||||
|
|
||||||
|
async def test_auxiliary_success_returns_auxiliary_content(self):
|
||||||
|
gateway = make_gateway_side_effect(
|
||||||
|
{
|
||||||
|
"fast": [
|
||||||
|
LLMResponse(
|
||||||
|
content="aux summary",
|
||||||
|
model="fast",
|
||||||
|
usage=TokenUsage(prompt_tokens=1, completion_tokens=1),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
"main": [
|
||||||
|
LLMResponse(
|
||||||
|
content="MAIN SHOULD NOT BE USED",
|
||||||
|
model="main",
|
||||||
|
usage=TokenUsage(prompt_tokens=1, completion_tokens=1),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
compressor = ContextCompressor(
|
||||||
|
llm_gateway=gateway,
|
||||||
|
max_tokens=100,
|
||||||
|
keep_recent=2,
|
||||||
|
model="main",
|
||||||
|
auxiliary_model="fast",
|
||||||
|
)
|
||||||
|
result = await compressor.compress(make_long_messages())
|
||||||
|
|
||||||
|
# Auxiliary called; main NOT called
|
||||||
|
aux_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "fast"]
|
||||||
|
main_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "main"]
|
||||||
|
assert len(aux_calls) == 1
|
||||||
|
assert len(main_calls) == 0
|
||||||
|
# Result contains auxiliary summary
|
||||||
|
summary_msgs = [
|
||||||
|
m
|
||||||
|
for m in result
|
||||||
|
if m.get("role") == "system" and "Conversation Summary" in m.get("content", "")
|
||||||
|
]
|
||||||
|
assert any("aux summary" in m["content"] for m in summary_msgs)
|
||||||
|
|
||||||
|
async def test_empty_content_triggers_main_fallback(self):
|
||||||
|
"""Finding 4 anti-pattern: empty content is a failure, not a success."""
|
||||||
|
gateway = make_gateway_side_effect(
|
||||||
|
{
|
||||||
|
"fast": [
|
||||||
|
LLMResponse(
|
||||||
|
content="",
|
||||||
|
model="fast",
|
||||||
|
usage=TokenUsage(prompt_tokens=1, completion_tokens=0),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
"main": [
|
||||||
|
LLMResponse(
|
||||||
|
content="main summary",
|
||||||
|
model="main",
|
||||||
|
usage=TokenUsage(prompt_tokens=1, completion_tokens=1),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
compressor = ContextCompressor(
|
||||||
|
llm_gateway=gateway,
|
||||||
|
max_tokens=100,
|
||||||
|
keep_recent=2,
|
||||||
|
model="main",
|
||||||
|
auxiliary_model="fast",
|
||||||
|
)
|
||||||
|
result = await compressor.compress(make_long_messages())
|
||||||
|
|
||||||
|
# Auxiliary called once (returned empty)
|
||||||
|
# Main called once (fallback)
|
||||||
|
aux_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "fast"]
|
||||||
|
main_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "main"]
|
||||||
|
assert len(aux_calls) == 1
|
||||||
|
assert len(main_calls) == 1
|
||||||
|
# Result contains main summary (not the empty auxiliary)
|
||||||
|
summary_msgs = [
|
||||||
|
m
|
||||||
|
for m in result
|
||||||
|
if m.get("role") == "system" and "Conversation Summary" in m.get("content", "")
|
||||||
|
]
|
||||||
|
assert any("main summary" in m["content"] for m in summary_msgs)
|
||||||
|
|
||||||
|
async def test_whitespace_content_triggers_main_fallback(self):
|
||||||
|
"""Whitespace-only content also counts as empty (Finding 4)."""
|
||||||
|
gateway = make_gateway_side_effect(
|
||||||
|
{
|
||||||
|
"fast": [
|
||||||
|
LLMResponse(
|
||||||
|
content=" \n ",
|
||||||
|
model="fast",
|
||||||
|
usage=TokenUsage(prompt_tokens=1, completion_tokens=0),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
"main": [
|
||||||
|
LLMResponse(
|
||||||
|
content="main summary",
|
||||||
|
model="main",
|
||||||
|
usage=TokenUsage(prompt_tokens=1, completion_tokens=1),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
compressor = ContextCompressor(
|
||||||
|
llm_gateway=gateway,
|
||||||
|
max_tokens=100,
|
||||||
|
keep_recent=2,
|
||||||
|
model="main",
|
||||||
|
auxiliary_model="fast",
|
||||||
|
)
|
||||||
|
await compressor.compress(make_long_messages())
|
||||||
|
|
||||||
|
# Both auxiliary and main called
|
||||||
|
aux_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "fast"]
|
||||||
|
main_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "main"]
|
||||||
|
assert len(aux_calls) == 1
|
||||||
|
assert len(main_calls) == 1
|
||||||
|
|
||||||
|
async def test_auxiliary_exception_triggers_main_fallback(self):
|
||||||
|
from agentkit.core.exceptions import LLMProviderError
|
||||||
|
|
||||||
|
gateway = make_gateway_side_effect(
|
||||||
|
{
|
||||||
|
"fast": [LLMProviderError("aux", "provider down")],
|
||||||
|
"main": [
|
||||||
|
LLMResponse(
|
||||||
|
content="main summary",
|
||||||
|
model="main",
|
||||||
|
usage=TokenUsage(prompt_tokens=1, completion_tokens=1),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
compressor = ContextCompressor(
|
||||||
|
llm_gateway=gateway,
|
||||||
|
max_tokens=100,
|
||||||
|
keep_recent=2,
|
||||||
|
model="main",
|
||||||
|
auxiliary_model="fast",
|
||||||
|
)
|
||||||
|
result = await compressor.compress(make_long_messages())
|
||||||
|
|
||||||
|
# Both called; main succeeded
|
||||||
|
aux_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "fast"]
|
||||||
|
main_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "main"]
|
||||||
|
assert len(aux_calls) == 1
|
||||||
|
assert len(main_calls) == 1
|
||||||
|
summary_msgs = [
|
||||||
|
m
|
||||||
|
for m in result
|
||||||
|
if m.get("role") == "system" and "Conversation Summary" in m.get("content", "")
|
||||||
|
]
|
||||||
|
assert any("main summary" in m["content"] for m in summary_msgs)
|
||||||
|
|
||||||
|
async def test_both_fail_falls_to_simple_summary(self):
|
||||||
|
"""Auxiliary raises, main raises → existing _simple_summary degradation."""
|
||||||
|
# Note: aggressive compression path may invoke _summarize multiple times.
|
||||||
|
# Queue provides enough responses to handle that without raising queue-exhausted.
|
||||||
|
gateway = make_gateway_side_effect(
|
||||||
|
{
|
||||||
|
"fast": [Exception("aux boom")] * 5,
|
||||||
|
"main": [Exception("main boom")] * 5,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
compressor = ContextCompressor(
|
||||||
|
llm_gateway=gateway,
|
||||||
|
max_tokens=100,
|
||||||
|
keep_recent=2,
|
||||||
|
model="main",
|
||||||
|
auxiliary_model="fast",
|
||||||
|
)
|
||||||
|
result = await compressor.compress(make_long_messages())
|
||||||
|
|
||||||
|
# Both called at least once
|
||||||
|
aux_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "fast"]
|
||||||
|
main_calls = [c for c in gateway.chat.await_args_list if c.kwargs.get("model") == "main"]
|
||||||
|
assert len(aux_calls) >= 1
|
||||||
|
assert len(main_calls) >= 1
|
||||||
|
# _simple_summary output has "..." truncation markers
|
||||||
|
summary_msgs = [
|
||||||
|
m
|
||||||
|
for m in result
|
||||||
|
if m.get("role") == "system" and "Conversation Summary" in m.get("content", "")
|
||||||
|
]
|
||||||
|
assert len(summary_msgs) == 1
|
||||||
|
assert "..." in summary_msgs[0]["content"]
|
||||||
|
|
||||||
|
async def test_auxiliary_equal_to_main_skipped(self):
|
||||||
|
"""auxiliary_model == model → no auxiliary routing (single call to main)."""
|
||||||
|
gateway = make_gateway_with_response("main summary")
|
||||||
|
compressor = ContextCompressor(
|
||||||
|
llm_gateway=gateway,
|
||||||
|
max_tokens=100,
|
||||||
|
keep_recent=2,
|
||||||
|
model="main",
|
||||||
|
auxiliary_model="main", # same as main
|
||||||
|
)
|
||||||
|
await compressor.compress(make_long_messages())
|
||||||
|
|
||||||
|
# Only one call (to main); auxiliary block skipped
|
||||||
|
assert gateway.chat.await_count == 1
|
||||||
|
assert gateway.chat.await_args.kwargs.get("model") == "main"
|
||||||
|
|
||||||
|
async def test_audit_fields_preserved(self):
|
||||||
|
"""Auxiliary call uses agent_name='compressor', task_type='summarization'."""
|
||||||
|
gateway = make_gateway_with_response("aux summary")
|
||||||
|
compressor = ContextCompressor(
|
||||||
|
llm_gateway=gateway,
|
||||||
|
max_tokens=100,
|
||||||
|
keep_recent=2,
|
||||||
|
model="main",
|
||||||
|
auxiliary_model="fast",
|
||||||
|
)
|
||||||
|
# Override the mock to use a single-response gateway where auxiliary succeeds
|
||||||
|
# (the make_gateway_with_response mock returns same response regardless of model)
|
||||||
|
await compressor.compress(make_long_messages())
|
||||||
|
|
||||||
|
# Single call (auxiliary succeeded) — verify audit fields
|
||||||
|
call_kwargs = gateway.chat.await_args.kwargs
|
||||||
|
assert call_kwargs.get("agent_name") == "compressor"
|
||||||
|
assert call_kwargs.get("task_type") == "summarization"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Config wiring ────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestConfigWiring:
|
||||||
|
"""LLMConfig + ServerConfig read auxiliary_model from dict."""
|
||||||
|
|
||||||
|
def test_llm_config_from_dict_reads_auxiliary_model(self):
|
||||||
|
cfg = LLMConfig.from_dict(
|
||||||
|
{
|
||||||
|
"providers": {},
|
||||||
|
"model_aliases": {"fast": "p/m"},
|
||||||
|
"auxiliary_model": "fast",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert cfg.auxiliary_model == "fast"
|
||||||
|
|
||||||
|
def test_llm_config_from_dict_auxiliary_none_when_absent(self):
|
||||||
|
cfg = LLMConfig.from_dict({"providers": {}})
|
||||||
|
assert cfg.auxiliary_model is None
|
||||||
|
|
||||||
|
def test_llm_config_default_auxiliary_none(self):
|
||||||
|
cfg = LLMConfig()
|
||||||
|
assert cfg.auxiliary_model is None
|
||||||
|
|
||||||
|
def test_server_config_build_llm_config_reads_auxiliary_model(self):
|
||||||
|
from agentkit.server.config import ServerConfig
|
||||||
|
|
||||||
|
llm_data = {
|
||||||
|
"providers": {
|
||||||
|
"p": {
|
||||||
|
"type": "openai",
|
||||||
|
"api_key": "k",
|
||||||
|
"base_url": "http://x",
|
||||||
|
"models": {"m": {"alias": "fast"}},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"auxiliary_model": "fast",
|
||||||
|
}
|
||||||
|
llm_config = ServerConfig._build_llm_config(llm_data)
|
||||||
|
assert llm_config.auxiliary_model == "fast"
|
||||||
|
# Also verify model_aliases still built correctly
|
||||||
|
assert llm_config.model_aliases.get("fast") == "p/m"
|
||||||
|
|
||||||
|
def test_server_config_build_llm_config_auxiliary_none_when_absent(self):
|
||||||
|
from agentkit.server.config import ServerConfig
|
||||||
|
|
||||||
|
llm_config = ServerConfig._build_llm_config({"providers": {}})
|
||||||
|
assert llm_config.auxiliary_model is None
|
||||||
Loading…
Reference in New Issue