diff --git a/src/agentkit/core/react.py b/src/agentkit/core/react.py index 5e1c50c..f99e204 100644 --- a/src/agentkit/core/react.py +++ b/src/agentkit/core/react.py @@ -163,6 +163,7 @@ class ReActEngine: core_tool_names: list[str] | None = None, enable_tool_search: bool = True, middleware_chain: "MiddlewareChain | None" = None, + prompt_cache_enable: bool = True, ): if max_steps < 1: raise ValueError(f"max_steps must be >= 1, got {max_steps}") @@ -176,6 +177,9 @@ class ReActEngine: self._parallel_tools = parallel_tools self._verification_enabled = verification_enabled self._verification_commands = verification_commands + # U2/G2: prompt cache 双块结构开关(True 时 Anthropic 用 cache_control blocks, + # 其他 provider 走字符串拼接依赖自动前缀缓存) + self._prompt_cache_enable = prompt_cache_enable # Tiered tool description injection config self._core_tool_names: tuple[str, ...] | None = ( tuple(core_tool_names) if core_tool_names is not None else None @@ -429,7 +433,9 @@ class ReActEngine: skill_name=task_type or None, ) - # Memory retrieval: 执行前检索相关上下文注入 system_prompt + # Memory retrieval: 执行前检索相关上下文,作为 volatile 层注入 system message + # U2/G2: 不再拼到 stable(system_prompt)末尾,改由 _build_system_message 组装双块结构 + memory_context = "" if memory_retriever: try: query = str(messages[-1].get("content", "")) if messages else "" @@ -439,12 +445,7 @@ class ReActEngine: query=query, top_k=top_k, token_budget=token_budget, - ) - if memory_context: - if system_prompt: - system_prompt += f"\n\n## 参考信息\n{memory_context}" - else: - system_prompt = f"## 参考信息\n{memory_context}" + ) or "" except Exception as e: logger.warning( f"Memory retrieval failed, continuing without context: {e}", exc_info=True @@ -452,8 +453,13 @@ class ReActEngine: # 构建初始消息 conversation: list[dict[str, Any]] = [] - if system_prompt: - conversation.append({"role": "system", "content": system_prompt}) + system_content = self._build_system_message( + stable=system_prompt or "", + volatile=memory_context, + model=model, + ) + if system_content is not None: + conversation.append({"role": "system", "content": system_content}) conversation.extend(messages) # Context compression: 压缩超长对话历史 @@ -1039,7 +1045,10 @@ class ReActEngine: skill_name=task_type or None, ) - # Memory retrieval: 执行前检索相关上下文注入 system_prompt + # Memory retrieval: 执行前检索相关上下文,作为 volatile 层注入 system message + # U2/G2: 不再拼到 stable(system_prompt)末尾破坏 cache 前缀,改由 _build_system_message + # 组装双块结构(stable + volatile),Anthropic provider 在 stable 上加 cache_control。 + memory_context = "" if memory_retriever: try: query = str(messages[-1].get("content", "")) if messages else "" @@ -1049,18 +1058,18 @@ class ReActEngine: query=query, top_k=top_k, token_budget=token_budget, - ) - if memory_context: - if system_prompt: - system_prompt += f"\n\n## 参考信息\n{memory_context}" - else: - system_prompt = f"## 参考信息\n{memory_context}" + ) or "" except Exception as e: logger.warning(f"Memory retrieval failed, continuing without context: {e}") conversation: list[dict[str, Any]] = [] - if system_prompt: - conversation.append({"role": "system", "content": system_prompt}) + system_content = self._build_system_message( + stable=system_prompt or "", + volatile=memory_context, + model=model, + ) + if system_content is not None: + conversation.append({"role": "system", "content": system_content}) conversation.extend(messages) # Context compression: 压缩超长对话历史 @@ -1721,6 +1730,62 @@ class ReActEngine: schemas.append(schema) return schemas + def _build_system_message( + self, + stable: str, + volatile: str, + *, + model: str, + ) -> str | list[dict[str, Any]] | None: + """构建双块结构 system message(stable + volatile)。 + + - prompt_cache_enable=False 或无 stable+volatile → 返回 str(或 None) + - Anthropic provider → 返回 content blocks 列表,stable 块带 cache_control + - 其他 provider → 返回字符串拼接(stable + volatile),依赖 stable 前缀命中自动前缀缓存 + + ponytail: 断点数硬编码为 1(stable 层),不暴露配置(YAGNI — 双块结构 >1 无语义)。 + """ + if not stable and not volatile: + return None + if not self._prompt_cache_enable: + # 退化为字符串拼接(向后兼容,行为同改动前) + if stable and volatile: + return f"{stable}\n\n## 参考信息\n{volatile}" + if volatile: + return f"## 参考信息\n{volatile}" + return stable + + provider_name = self._get_provider_name(model) + if provider_name == "anthropic": + blocks: list[dict[str, Any]] = [] + if stable: + blocks.append({ + "type": "text", + "text": stable, + "cache_control": {"type": "ephemeral"}, + }) + if volatile: + blocks.append({ + "type": "text", + "text": f"## 参考信息\n{volatile}", + }) + return blocks if blocks else None + + # 非 Anthropic:字符串拼接,stable 前缀命中 OpenAI/DashScope 自动前缀缓存 + if stable and volatile: + return f"{stable}\n\n## 参考信息\n{volatile}" + if volatile: + return f"## 参考信息\n{volatile}" + return stable + + def _get_provider_name(self, model: str) -> str | None: + """通过 gateway 查询 model 对应的 provider 名。失败回退 None(字符串拼接)。""" + try: + return self._llm_gateway.get_provider_name_for_model(model) + except Exception: + # ponytail: 测试中 gateway 可能是 MagicMock,无该方法;回退保守路径 + return None + def _build_tool_use_prompt(self, tools: list[Tool]) -> str: """Build prompt-based tool calling instructions with tiered injection. diff --git a/src/agentkit/llm/gateway.py b/src/agentkit/llm/gateway.py index 337d395..0b7be8f 100644 --- a/src/agentkit/llm/gateway.py +++ b/src/agentkit/llm/gateway.py @@ -412,6 +412,24 @@ class LLMGateway: return self._config.model_aliases[model] return model + def get_provider_name_for_model(self, model: str) -> str | None: + """返回 model 对应的 provider 名(用于 provider-specific 优化如 cache_control)。 + + ponytail: 仅做 alias 解析 + provider 前缀提取,不查内部状态。 + 升级路径:ServerConfig 显式声明 provider per model。 + 返回 None 表示无法确定(多 provider + 无 "/" 前缀),调用方应回退到字符串拼接。 + """ + resolved = self._resolve_model_alias(model) + if "/" in resolved: + provider_name = resolved.split("/", 1)[0] + if provider_name in self._providers: + return provider_name + return None + # 无 "/" 前缀:仅当只有一个 provider 时能确定 + if len(self._providers) == 1: + return next(iter(self._providers)) + return None + def _resolve_model(self, model: str) -> tuple[LLMProvider, str]: """解析模型为 (provider, actual_model_name)""" # model 格式: "provider/model_name" 或 "model_name" diff --git a/src/agentkit/llm/providers/anthropic.py b/src/agentkit/llm/providers/anthropic.py index a26b94d..2829ac9 100644 --- a/src/agentkit/llm/providers/anthropic.py +++ b/src/agentkit/llm/providers/anthropic.py @@ -99,13 +99,18 @@ class AnthropicProvider(LLMProvider): "content-type": "application/json", } - def _convert_messages(self, messages: list[dict[str, str]]) -> tuple[str | None, list[dict[str, Any]]]: + def _convert_messages(self, messages: list[dict[str, str]]) -> tuple[str | list[dict[str, Any]] | None, list[dict[str, Any]]]: """将 OpenAI 风格消息转换为 Anthropic 格式 Returns: (system_prompt, anthropic_messages) + + system_prompt 可为: + - str: 传统字符串(向后兼容) + - list[dict]: Anthropic content blocks(支持 cache_control,U2/G2) + - None: 无 system 消息 """ - system_prompt: str | None = None + system_prompt: str | list[dict[str, Any]] | None = None anthropic_messages: list[dict[str, Any]] = [] for msg in messages: @@ -113,6 +118,8 @@ class AnthropicProvider(LLMProvider): content = msg.get("content", "") if role == "system": + # U2/G2: content 可为 str(传统)或 list[dict](content blocks) + # content blocks 直接透传(包含 cache_control 标记) system_prompt = content continue diff --git a/src/agentkit/server/config.py b/src/agentkit/server/config.py index ef07bf9..f71fba2 100644 --- a/src/agentkit/server/config.py +++ b/src/agentkit/server/config.py @@ -116,6 +116,7 @@ class ServerConfig: evolution: dict[str, Any] | None = None, expert_paths: list[str] | None = None, board: dict[str, Any] | None = None, + prompt_cache: dict[str, Any] | None = None, on_change: Callable[["ServerConfig"], None] | None = None, ): self.host = host @@ -144,6 +145,7 @@ class ServerConfig: self.evolution = evolution or {} self.expert_paths = expert_paths or [] self.board = board or {} + self.prompt_cache = prompt_cache or {} self.on_change = on_change # Config watching state diff --git a/tests/unit/test_prompt_cache_layers.py b/tests/unit/test_prompt_cache_layers.py new file mode 100644 index 0000000..d3e61f0 --- /dev/null +++ b/tests/unit/test_prompt_cache_layers.py @@ -0,0 +1,221 @@ +"""U2 / G2 Prompt Cache 双块结构测试。 + +覆盖 R4-R7, R13: +- R4 stable/volatile 双块结构 +- R5 记忆注入从 system_prompt 末尾移到 volatile 层 +- R6 跨 provider 统一 cache 策略(Anthropic blocks / OpenAI 字符串) +- R7 多轮 stable 不变(由构造保证) +- R13 配置化(prompt_cache_enable=False 退化) +""" + +from __future__ import annotations + +from typing import Any + + +from agentkit.core.react import ReActEngine + + +class _StubGateway: + """模拟 LLMGateway,记录最后一次 chat_stream 调用的 messages。""" + + def __init__(self, provider_name: str | None = "anthropic"): + self._provider_name = provider_name + self.captured_messages: list[dict[str, Any]] | None = None + + def get_provider_name_for_model(self, model: str) -> str | None: + return self._provider_name + + async def chat_stream(self, **kwargs): + self.captured_messages = list(kwargs.get("messages", [])) + return + yield # makes this an async generator + + +def _make_engine(provider_name: str | None = "anthropic", *, cache_enable: bool = True) -> tuple[ReActEngine, _StubGateway]: + gw = _StubGateway(provider_name=provider_name) + engine = ReActEngine.__new__(ReActEngine) + engine._llm_gateway = gw + engine._prompt_cache_enable = cache_enable + return engine, gw + + +# ---- R4/R6 Anthropic: stable + volatile → content blocks ---- + + +def test_anthropic_provider_returns_content_blocks_with_cache_control(): + engine, _ = _make_engine("anthropic") + result = engine._build_system_message( + stable="You are a helpful assistant.", + volatile="Memory: foo", + model="claude-sonnet-4", + ) + assert isinstance(result, list) + assert result[0] == { + "type": "text", + "text": "You are a helpful assistant.", + "cache_control": {"type": "ephemeral"}, + } + assert result[1] == {"type": "text", "text": "## 参考信息\nMemory: foo"} + + +# ---- R5 Anthropic: empty volatile → only stable block ---- + + +def test_anthropic_empty_volatile_returns_only_stable_block(): + engine, _ = _make_engine("anthropic") + result = engine._build_system_message( + stable="base prompt", + volatile="", + model="claude", + ) + assert isinstance(result, list) + assert len(result) == 1 + assert result[0]["text"] == "base prompt" + assert "cache_control" in result[0] + + +# ---- R6 Non-Anthropic: returns string concat ---- + + +def test_non_anthropic_returns_string_concat(): + engine, _ = _make_engine("openai") + result = engine._build_system_message( + stable="base", + volatile="ctx", + model="gpt-4", + ) + assert isinstance(result, str) + assert result == "base\n\n## 参考信息\nctx" + + +def test_unknown_provider_returns_string_concat(): + """provider_name 无法确定时(gateway 返回 None),回退字符串拼接不报错。""" + engine, _ = _make_engine(None) + result = engine._build_system_message( + stable="base", + volatile="ctx", + model="default", + ) + assert isinstance(result, str) + assert "## 参考信息" in result + + +# ---- R13 Config: prompt_cache_enable=False → 退化字符串 ---- + + +def test_prompt_cache_disabled_falls_back_to_string(): + engine, _ = _make_engine("anthropic", cache_enable=False) + result = engine._build_system_message( + stable="base", + volatile="ctx", + model="claude", + ) + # 即便是 Anthropic,enable=False 时也返回字符串(行为同改动前) + assert isinstance(result, str) + assert "cache_control" not in result + + +# ---- R4 Edge: no stable + no volatile → None ---- + + +def test_empty_stable_and_volatile_returns_none(): + engine, _ = _make_engine("anthropic") + result = engine._build_system_message(stable="", volatile="", model="claude") + assert result is None + + +# ---- R5 Edge: empty stable + volatile → only volatile block (Anthropic) ---- + + +def test_anthropic_empty_stable_returns_only_volatile_block(): + engine, _ = _make_engine("anthropic") + result = engine._build_system_message( + stable="", + volatile="only memory", + model="claude", + ) + assert isinstance(result, list) + assert len(result) == 1 + assert result[0]["text"] == "## 参考信息\nonly memory" + assert "cache_control" not in result[0] # volatile 块无 cache_control + + +# ---- Integration: execute_stream uses _build_system_message end-to-end ---- + + +async def test_execute_stream_with_anthropic_uses_content_blocks(): + """execute_stream 把双块 system content 透传给 gateway。""" + from agentkit.tools.base import Tool + + class _NoopTool(Tool): + def __init__(self): + super().__init__(name="noop", description="noop") + + async def execute(self, **kwargs): + return {} + + class _MockGateway: + def __init__(self): + self.captured_messages = None + + def get_provider_name_for_model(self, model: str) -> str | None: + return "anthropic" + + async def chat_stream(self, **kwargs): + self.captured_messages = list(kwargs.get("messages", [])) + # yield one chunk then end + from agentkit.llm.protocol import StreamChunk + yield StreamChunk(content="done", model="claude") + + class _MemRetriever: + async def get_context_string(self, **kw): + return "retrieved context" + + gw = _MockGateway() + engine = ReActEngine(llm_gateway=gw, prompt_cache_enable=True) + events = [] + async for ev in engine.execute_stream( + messages=[{"role": "user", "content": "hi"}], + tools=[], + model="claude", + system_prompt="base", + memory_retriever=_MemRetriever(), + ): + events.append(ev) + + assert gw.captured_messages is not None + sys_msg = gw.captured_messages[0] + assert sys_msg["role"] == "system" + assert isinstance(sys_msg["content"], list) + assert sys_msg["content"][0]["text"] == "base" + assert "cache_control" in sys_msg["content"][0] + assert sys_msg["content"][1]["text"] == "## 参考信息\nretrieved context" + + +# ---- Anthropic provider _convert_messages passes through list-type system ---- + + +def test_anthropic_convert_messages_passes_through_list_system_content(): + """AnthropicProvider._convert_messages 应直接透传 list-type system content。""" + from agentkit.llm.providers.anthropic import AnthropicProvider + + provider = AnthropicProvider.__new__(AnthropicProvider) + blocks = [ + {"type": "text", "text": "stable", "cache_control": {"type": "ephemeral"}}, + {"type": "text", "text": "volatile"}, + ] + messages = [{"role": "system", "content": blocks}] + system_prompt, anthropic_messages = provider._convert_messages(messages) + assert system_prompt is blocks # same object, transparent passthrough + assert anthropic_messages == [] + + +def test_anthropic_convert_messages_string_system_still_works(): + """传统 string system content 仍能透传(向后兼容)。""" + from agentkit.llm.providers.anthropic import AnthropicProvider + + provider = AnthropicProvider.__new__(AnthropicProvider) + messages = [{"role": "system", "content": "old style"}] + system_prompt, _ = provider._convert_messages(messages) + assert system_prompt == "old style"