222 lines
7.1 KiB
Python
222 lines
7.1 KiB
Python
"""U2 / G2 Prompt Cache 双块结构测试。
|
|
|
|
覆盖 R4-R7, R13:
|
|
- R4 stable/volatile 双块结构
|
|
- R5 记忆注入从 system_prompt 末尾移到 volatile 层
|
|
- R6 跨 provider 统一 cache 策略(Anthropic blocks / OpenAI 字符串)
|
|
- R7 多轮 stable 不变(由构造保证)
|
|
- R13 配置化(prompt_cache_enable=False 退化)
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Any
|
|
|
|
|
|
from agentkit.core.react import ReActEngine
|
|
|
|
|
|
class _StubGateway:
|
|
"""模拟 LLMGateway,记录最后一次 chat_stream 调用的 messages。"""
|
|
|
|
def __init__(self, provider_name: str | None = "anthropic"):
|
|
self._provider_name = provider_name
|
|
self.captured_messages: list[dict[str, Any]] | None = None
|
|
|
|
def get_provider_name_for_model(self, model: str) -> str | None:
|
|
return self._provider_name
|
|
|
|
async def chat_stream(self, **kwargs):
|
|
self.captured_messages = list(kwargs.get("messages", []))
|
|
return
|
|
yield # makes this an async generator
|
|
|
|
|
|
def _make_engine(provider_name: str | None = "anthropic", *, cache_enable: bool = True) -> tuple[ReActEngine, _StubGateway]:
|
|
gw = _StubGateway(provider_name=provider_name)
|
|
engine = ReActEngine.__new__(ReActEngine)
|
|
engine._llm_gateway = gw
|
|
engine._prompt_cache_enable = cache_enable
|
|
return engine, gw
|
|
|
|
|
|
# ---- R4/R6 Anthropic: stable + volatile → content blocks ----
|
|
|
|
|
|
def test_anthropic_provider_returns_content_blocks_with_cache_control():
|
|
engine, _ = _make_engine("anthropic")
|
|
result = engine._build_system_message(
|
|
stable="You are a helpful assistant.",
|
|
volatile="Memory: foo",
|
|
model="claude-sonnet-4",
|
|
)
|
|
assert isinstance(result, list)
|
|
assert result[0] == {
|
|
"type": "text",
|
|
"text": "You are a helpful assistant.",
|
|
"cache_control": {"type": "ephemeral"},
|
|
}
|
|
assert result[1] == {"type": "text", "text": "## 参考信息\nMemory: foo"}
|
|
|
|
|
|
# ---- R5 Anthropic: empty volatile → only stable block ----
|
|
|
|
|
|
def test_anthropic_empty_volatile_returns_only_stable_block():
|
|
engine, _ = _make_engine("anthropic")
|
|
result = engine._build_system_message(
|
|
stable="base prompt",
|
|
volatile="",
|
|
model="claude",
|
|
)
|
|
assert isinstance(result, list)
|
|
assert len(result) == 1
|
|
assert result[0]["text"] == "base prompt"
|
|
assert "cache_control" in result[0]
|
|
|
|
|
|
# ---- R6 Non-Anthropic: returns string concat ----
|
|
|
|
|
|
def test_non_anthropic_returns_string_concat():
|
|
engine, _ = _make_engine("openai")
|
|
result = engine._build_system_message(
|
|
stable="base",
|
|
volatile="ctx",
|
|
model="gpt-4",
|
|
)
|
|
assert isinstance(result, str)
|
|
assert result == "base\n\n## 参考信息\nctx"
|
|
|
|
|
|
def test_unknown_provider_returns_string_concat():
|
|
"""provider_name 无法确定时(gateway 返回 None),回退字符串拼接不报错。"""
|
|
engine, _ = _make_engine(None)
|
|
result = engine._build_system_message(
|
|
stable="base",
|
|
volatile="ctx",
|
|
model="default",
|
|
)
|
|
assert isinstance(result, str)
|
|
assert "## 参考信息" in result
|
|
|
|
|
|
# ---- R13 Config: prompt_cache_enable=False → 退化字符串 ----
|
|
|
|
|
|
def test_prompt_cache_disabled_falls_back_to_string():
|
|
engine, _ = _make_engine("anthropic", cache_enable=False)
|
|
result = engine._build_system_message(
|
|
stable="base",
|
|
volatile="ctx",
|
|
model="claude",
|
|
)
|
|
# 即便是 Anthropic,enable=False 时也返回字符串(行为同改动前)
|
|
assert isinstance(result, str)
|
|
assert "cache_control" not in result
|
|
|
|
|
|
# ---- R4 Edge: no stable + no volatile → None ----
|
|
|
|
|
|
def test_empty_stable_and_volatile_returns_none():
|
|
engine, _ = _make_engine("anthropic")
|
|
result = engine._build_system_message(stable="", volatile="", model="claude")
|
|
assert result is None
|
|
|
|
|
|
# ---- R5 Edge: empty stable + volatile → only volatile block (Anthropic) ----
|
|
|
|
|
|
def test_anthropic_empty_stable_returns_only_volatile_block():
|
|
engine, _ = _make_engine("anthropic")
|
|
result = engine._build_system_message(
|
|
stable="",
|
|
volatile="only memory",
|
|
model="claude",
|
|
)
|
|
assert isinstance(result, list)
|
|
assert len(result) == 1
|
|
assert result[0]["text"] == "## 参考信息\nonly memory"
|
|
assert "cache_control" not in result[0] # volatile 块无 cache_control
|
|
|
|
|
|
# ---- Integration: execute_stream uses _build_system_message end-to-end ----
|
|
|
|
|
|
async def test_execute_stream_with_anthropic_uses_content_blocks():
|
|
"""execute_stream 把双块 system content 透传给 gateway。"""
|
|
from agentkit.tools.base import Tool
|
|
|
|
class _NoopTool(Tool):
|
|
def __init__(self):
|
|
super().__init__(name="noop", description="noop")
|
|
|
|
async def execute(self, **kwargs):
|
|
return {}
|
|
|
|
class _MockGateway:
|
|
def __init__(self):
|
|
self.captured_messages = None
|
|
|
|
def get_provider_name_for_model(self, model: str) -> str | None:
|
|
return "anthropic"
|
|
|
|
async def chat_stream(self, **kwargs):
|
|
self.captured_messages = list(kwargs.get("messages", []))
|
|
# yield one chunk then end
|
|
from agentkit.llm.protocol import StreamChunk
|
|
yield StreamChunk(content="done", model="claude")
|
|
|
|
class _MemRetriever:
|
|
async def get_context_string(self, **kw):
|
|
return "retrieved context"
|
|
|
|
gw = _MockGateway()
|
|
engine = ReActEngine(llm_gateway=gw, prompt_cache_enable=True)
|
|
events = []
|
|
async for ev in engine.execute_stream(
|
|
messages=[{"role": "user", "content": "hi"}],
|
|
tools=[],
|
|
model="claude",
|
|
system_prompt="base",
|
|
memory_retriever=_MemRetriever(),
|
|
):
|
|
events.append(ev)
|
|
|
|
assert gw.captured_messages is not None
|
|
sys_msg = gw.captured_messages[0]
|
|
assert sys_msg["role"] == "system"
|
|
assert isinstance(sys_msg["content"], list)
|
|
assert sys_msg["content"][0]["text"] == "base"
|
|
assert "cache_control" in sys_msg["content"][0]
|
|
assert sys_msg["content"][1]["text"] == "## 参考信息\nretrieved context"
|
|
|
|
|
|
# ---- Anthropic provider _convert_messages passes through list-type system ----
|
|
|
|
|
|
def test_anthropic_convert_messages_passes_through_list_system_content():
|
|
"""AnthropicProvider._convert_messages 应直接透传 list-type system content。"""
|
|
from agentkit.llm.providers.anthropic import AnthropicProvider
|
|
|
|
provider = AnthropicProvider.__new__(AnthropicProvider)
|
|
blocks = [
|
|
{"type": "text", "text": "stable", "cache_control": {"type": "ephemeral"}},
|
|
{"type": "text", "text": "volatile"},
|
|
]
|
|
messages = [{"role": "system", "content": blocks}]
|
|
system_prompt, anthropic_messages = provider._convert_messages(messages)
|
|
assert system_prompt is blocks # same object, transparent passthrough
|
|
assert anthropic_messages == []
|
|
|
|
|
|
def test_anthropic_convert_messages_string_system_still_works():
|
|
"""传统 string system content 仍能透传(向后兼容)。"""
|
|
from agentkit.llm.providers.anthropic import AnthropicProvider
|
|
|
|
provider = AnthropicProvider.__new__(AnthropicProvider)
|
|
messages = [{"role": "system", "content": "old style"}]
|
|
system_prompt, _ = provider._convert_messages(messages)
|
|
assert system_prompt == "old style"
|