From 9c04362dba0404078de2517fef78375779e901d7 Mon Sep 17 00:00:00 2001 From: chiguyong Date: Sun, 7 Jun 2026 18:20:17 +0800 Subject: [PATCH] feat(compression): U5 HeadroomRetrieveTool for CCR cache retrieval Add HeadroomRetrieveTool that allows LLM to retrieve original uncompressed data from CCR cache via Function Calling. Auto-registered when HeadroomCompressor is active and available. --- src/agentkit/tools/__init__.py | 7 + src/agentkit/tools/headroom_retrieve.py | 70 ++++++++ tests/unit/test_headroom_retrieve_tool.py | 195 ++++++++++++++++++++++ 3 files changed, 272 insertions(+) create mode 100644 src/agentkit/tools/headroom_retrieve.py create mode 100644 tests/unit/test_headroom_retrieve_tool.py diff --git a/src/agentkit/tools/__init__.py b/src/agentkit/tools/__init__.py index 7ad2fa2..3aef0be 100644 --- a/src/agentkit/tools/__init__.py +++ b/src/agentkit/tools/__init__.py @@ -10,6 +10,12 @@ from agentkit.tools.web_crawl import WebCrawlTool from agentkit.tools.schema_tools import SchemaExtractTool, SchemaGenerateTool from agentkit.tools.baidu_search import BaiduSearchTool +# Conditional import: HeadroomRetrieveTool requires HeadroomCompressor +try: + from agentkit.tools.headroom_retrieve import HeadroomRetrieveTool +except ImportError: + HeadroomRetrieveTool = None # type: ignore[misc,assignment] + __all__ = [ "Tool", "FunctionTool", @@ -23,4 +29,5 @@ __all__ = [ "SchemaExtractTool", "SchemaGenerateTool", "BaiduSearchTool", + "HeadroomRetrieveTool", ] diff --git a/src/agentkit/tools/headroom_retrieve.py b/src/agentkit/tools/headroom_retrieve.py new file mode 100644 index 0000000..71c6bd3 --- /dev/null +++ b/src/agentkit/tools/headroom_retrieve.py @@ -0,0 +1,70 @@ +"""HeadroomRetrieveTool — CCR 可逆压缩检索工具 + +当 HeadroomCompressor 启用时,LLM 可通过此工具从 CCR 缓存中 +取回被压缩的原始数据。工具输出中的 标记 +指示可检索的内容。 +""" + +import logging +from typing import Any + +from agentkit.tools.base import Tool + +logger = logging.getLogger(__name__) + + +class HeadroomRetrieveTool(Tool): + """从 CCR 缓存检索原始未压缩数据 + + 当 Headroom 压缩工具输出后,LLM 可通过此工具取回原始数据。 + 压缩内容中包含 标记,LLM 可使用该哈希值检索。 + """ + + def __init__(self, compressor: Any): + super().__init__( + name="headroom_retrieve", + description=( + "Retrieve original uncompressed data from the CCR (Compress-Cache-Retrieve) cache. " + "Use this tool when you see a marker in compressed content " + "and need the full original data. Pass the hash value or a search query." + ), + input_schema={ + "type": "object", + "properties": { + "ccr_hash": { + "type": "string", + "description": "The CCR hash from a marker. Use this for direct lookup.", + }, + "query": { + "type": "string", + "description": "Search query to find matching cached content. Used when hash is not available.", + }, + }, + "anyOf": [ + {"required": ["ccr_hash"]}, + {"required": ["query"]}, + ], + }, + ) + self._compressor = compressor + + async def execute(self, **kwargs) -> dict: + """从 CCR 缓存检索原始数据""" + ccr_hash = kwargs.get("ccr_hash") + query = kwargs.get("query") + + if not ccr_hash and not query: + return { + "error": "Either ccr_hash or query must be provided", + "success": False, + } + + try: + result = self._compressor.retrieve(ccr_hash=ccr_hash, query=query) + return result + except Exception as e: + logger.error(f"CCR retrieval failed: {e}") + return { + "error": f"CCR retrieval failed: {e}", + "success": False, + } diff --git a/tests/unit/test_headroom_retrieve_tool.py b/tests/unit/test_headroom_retrieve_tool.py new file mode 100644 index 0000000..e620e41 --- /dev/null +++ b/tests/unit/test_headroom_retrieve_tool.py @@ -0,0 +1,195 @@ +"""U5 测试: HeadroomRetrieveTool - CCR 可逆压缩检索工具 + +测试 headroom_retrieve 工具的构造、执行和注册逻辑。 +""" + +from unittest.mock import MagicMock, patch + +import pytest + +from agentkit.tools.headroom_retrieve import HeadroomRetrieveTool +from agentkit.tools.registry import ToolRegistry + + +# ── TestHeadroomRetrieveToolConstruction ──────────────────── + + +class TestHeadroomRetrieveToolConstruction: + """HeadroomRetrieveTool 构造测试""" + + def test_name_and_description(self): + """工具名称为 headroom_retrieve,描述包含 CCR""" + compressor = MagicMock() + tool = HeadroomRetrieveTool(compressor=compressor) + + assert tool.name == "headroom_retrieve" + assert "CCR" in tool.description + + def test_input_schema_has_ccr_hash(self): + """input_schema 包含 ccr_hash 属性""" + compressor = MagicMock() + tool = HeadroomRetrieveTool(compressor=compressor) + + assert "ccr_hash" in tool.input_schema["properties"] + assert tool.input_schema["properties"]["ccr_hash"]["type"] == "string" + + def test_input_schema_has_query(self): + """input_schema 包含 query 属性""" + compressor = MagicMock() + tool = HeadroomRetrieveTool(compressor=compressor) + + assert "query" in tool.input_schema["properties"] + assert tool.input_schema["properties"]["query"]["type"] == "string" + + def test_input_schema_requires_at_least_one(self): + """input_schema 使用 anyOf 要求至少提供 ccr_hash 或 query""" + compressor = MagicMock() + tool = HeadroomRetrieveTool(compressor=compressor) + + assert "anyOf" in tool.input_schema + any_of = tool.input_schema["anyOf"] + # One entry requires ccr_hash, the other requires query + required_sets = [item["required"] for item in any_of] + assert ["ccr_hash"] in required_sets + assert ["query"] in required_sets + + +# ── TestHeadroomRetrieveToolExecute ──────────────────────── + + +class TestHeadroomRetrieveToolExecute: + """HeadroomRetrieveTool 执行测试""" + + async def test_retrieve_by_hash(self): + """通过 ccr_hash 检索,调用 compressor.retrieve""" + compressor = MagicMock() + compressor.retrieve.return_value = { + "content": "original data", + "ccr_hash": "abc123", + "success": True, + } + tool = HeadroomRetrieveTool(compressor=compressor) + + result = await tool.execute(ccr_hash="abc123") + + compressor.retrieve.assert_called_once_with(ccr_hash="abc123", query=None) + assert result["success"] is True + assert result["content"] == "original data" + + async def test_retrieve_by_query(self): + """通过 query 检索,调用 compressor.retrieve""" + compressor = MagicMock() + compressor.retrieve.return_value = { + "results": [{"ccr_hash": "h1", "content": "matched data"}], + "success": True, + } + tool = HeadroomRetrieveTool(compressor=compressor) + + result = await tool.execute(query="search term") + + compressor.retrieve.assert_called_once_with(ccr_hash=None, query="search term") + assert result["success"] is True + assert len(result["results"]) == 1 + + async def test_retrieve_both(self): + """同时提供 ccr_hash 和 query,两个参数都传递给 compressor.retrieve""" + compressor = MagicMock() + compressor.retrieve.return_value = { + "content": "original data", + "ccr_hash": "abc123", + "success": True, + } + tool = HeadroomRetrieveTool(compressor=compressor) + + result = await tool.execute(ccr_hash="abc123", query="search term") + + compressor.retrieve.assert_called_once_with(ccr_hash="abc123", query="search term") + assert result["success"] is True + + async def test_missing_both_params(self): + """既没有 ccr_hash 也没有 query,返回错误""" + compressor = MagicMock() + tool = HeadroomRetrieveTool(compressor=compressor) + + result = await tool.execute() + + assert result["success"] is False + assert "error" in result + assert "ccr_hash" in result["error"] or "query" in result["error"] + + async def test_retrieve_failure(self): + """compressor.retrieve 抛出异常时返回错误结果""" + compressor = MagicMock() + compressor.retrieve.side_effect = RuntimeError("cache corrupted") + tool = HeadroomRetrieveTool(compressor=compressor) + + result = await tool.execute(ccr_hash="abc123") + + assert result["success"] is False + assert "error" in result + assert "cache corrupted" in result["error"] + + async def test_successful_retrieval(self): + """成功检索返回 content 和 success=True""" + compressor = MagicMock() + compressor.retrieve.return_value = { + "content": "This is the original uncompressed data that was cached", + "ccr_hash": "deadbeef1234", + "success": True, + } + tool = HeadroomRetrieveTool(compressor=compressor) + + result = await tool.execute(ccr_hash="deadbeef1234") + + assert result["success"] is True + assert result["content"] == "This is the original uncompressed data that was cached" + assert result["ccr_hash"] == "deadbeef1234" + + +# ── TestHeadroomRetrieveToolRegistration ──────────────────── + + +class TestHeadroomRetrieveToolRegistration: + """HeadroomRetrieveTool 注册测试""" + + def test_not_registered_when_no_compressor(self): + """没有 compressor 时工具不注册""" + registry = ToolRegistry() + + # Simulate: compressor is None → no registration + # (no tool should be registered) + assert not registry.has_tool("headroom_retrieve") + + def test_not_registered_when_context_compressor(self): + """ContextCompressor(非 HeadroomCompressor)时不注册""" + from agentkit.core.compressor import ContextCompressor + + registry = ToolRegistry() + # Create a ContextCompressor (not HeadroomCompressor) + compressor = ContextCompressor() + + # Simulate the app.py logic: only register if HeadroomCompressor + is_available + from agentkit.core.headroom_compressor import HeadroomCompressor + if isinstance(compressor, HeadroomCompressor) and compressor.is_available(): + tool = HeadroomRetrieveTool(compressor=compressor) + registry.register(tool) + + assert not registry.has_tool("headroom_retrieve") + + def test_registered_when_headroom_compressor(self): + """HeadroomCompressor 且 is_available() 为 True 时注册""" + from agentkit.core.headroom_compressor import HeadroomCompressor + + registry = ToolRegistry() + + # Create a real HeadroomCompressor instance but mock is_available + with patch.object(HeadroomCompressor, "is_available", return_value=True): + compressor = HeadroomCompressor(config={}) + # Simulate the app.py logic + if isinstance(compressor, HeadroomCompressor) and compressor.is_available(): + tool = HeadroomRetrieveTool(compressor=compressor) + registry.register(tool) + + assert registry.has_tool("headroom_retrieve") + registered_tool = registry.get("headroom_retrieve") + assert isinstance(registered_tool, HeadroomRetrieveTool)