feat(compression): U5 HeadroomRetrieveTool for CCR cache retrieval
Add HeadroomRetrieveTool that allows LLM to retrieve original uncompressed data from CCR cache via Function Calling. Auto-registered when HeadroomCompressor is active and available.
This commit is contained in:
parent
286804792d
commit
9c04362dba
|
|
@ -10,6 +10,12 @@ from agentkit.tools.web_crawl import WebCrawlTool
|
||||||
from agentkit.tools.schema_tools import SchemaExtractTool, SchemaGenerateTool
|
from agentkit.tools.schema_tools import SchemaExtractTool, SchemaGenerateTool
|
||||||
from agentkit.tools.baidu_search import BaiduSearchTool
|
from agentkit.tools.baidu_search import BaiduSearchTool
|
||||||
|
|
||||||
|
# Conditional import: HeadroomRetrieveTool requires HeadroomCompressor
|
||||||
|
try:
|
||||||
|
from agentkit.tools.headroom_retrieve import HeadroomRetrieveTool
|
||||||
|
except ImportError:
|
||||||
|
HeadroomRetrieveTool = None # type: ignore[misc,assignment]
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Tool",
|
"Tool",
|
||||||
"FunctionTool",
|
"FunctionTool",
|
||||||
|
|
@ -23,4 +29,5 @@ __all__ = [
|
||||||
"SchemaExtractTool",
|
"SchemaExtractTool",
|
||||||
"SchemaGenerateTool",
|
"SchemaGenerateTool",
|
||||||
"BaiduSearchTool",
|
"BaiduSearchTool",
|
||||||
|
"HeadroomRetrieveTool",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,70 @@
|
||||||
|
"""HeadroomRetrieveTool — CCR 可逆压缩检索工具
|
||||||
|
|
||||||
|
当 HeadroomCompressor 启用时,LLM 可通过此工具从 CCR 缓存中
|
||||||
|
取回被压缩的原始数据。工具输出中的 <!-- CCR:hash=xxx --> 标记
|
||||||
|
指示可检索的内容。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from agentkit.tools.base import Tool
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class HeadroomRetrieveTool(Tool):
|
||||||
|
"""从 CCR 缓存检索原始未压缩数据
|
||||||
|
|
||||||
|
当 Headroom 压缩工具输出后,LLM 可通过此工具取回原始数据。
|
||||||
|
压缩内容中包含 <!-- CCR:hash=xxx --> 标记,LLM 可使用该哈希值检索。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, compressor: Any):
|
||||||
|
super().__init__(
|
||||||
|
name="headroom_retrieve",
|
||||||
|
description=(
|
||||||
|
"Retrieve original uncompressed data from the CCR (Compress-Cache-Retrieve) cache. "
|
||||||
|
"Use this tool when you see a <!-- CCR:hash=xxx --> marker in compressed content "
|
||||||
|
"and need the full original data. Pass the hash value or a search query."
|
||||||
|
),
|
||||||
|
input_schema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"ccr_hash": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The CCR hash from a <!-- CCR:hash=xxx --> marker. Use this for direct lookup.",
|
||||||
|
},
|
||||||
|
"query": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Search query to find matching cached content. Used when hash is not available.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"anyOf": [
|
||||||
|
{"required": ["ccr_hash"]},
|
||||||
|
{"required": ["query"]},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self._compressor = compressor
|
||||||
|
|
||||||
|
async def execute(self, **kwargs) -> dict:
|
||||||
|
"""从 CCR 缓存检索原始数据"""
|
||||||
|
ccr_hash = kwargs.get("ccr_hash")
|
||||||
|
query = kwargs.get("query")
|
||||||
|
|
||||||
|
if not ccr_hash and not query:
|
||||||
|
return {
|
||||||
|
"error": "Either ccr_hash or query must be provided",
|
||||||
|
"success": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = self._compressor.retrieve(ccr_hash=ccr_hash, query=query)
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"CCR retrieval failed: {e}")
|
||||||
|
return {
|
||||||
|
"error": f"CCR retrieval failed: {e}",
|
||||||
|
"success": False,
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,195 @@
|
||||||
|
"""U5 测试: HeadroomRetrieveTool - CCR 可逆压缩检索工具
|
||||||
|
|
||||||
|
测试 headroom_retrieve 工具的构造、执行和注册逻辑。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.tools.headroom_retrieve import HeadroomRetrieveTool
|
||||||
|
from agentkit.tools.registry import ToolRegistry
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestHeadroomRetrieveToolConstruction ────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestHeadroomRetrieveToolConstruction:
|
||||||
|
"""HeadroomRetrieveTool 构造测试"""
|
||||||
|
|
||||||
|
def test_name_and_description(self):
|
||||||
|
"""工具名称为 headroom_retrieve,描述包含 CCR"""
|
||||||
|
compressor = MagicMock()
|
||||||
|
tool = HeadroomRetrieveTool(compressor=compressor)
|
||||||
|
|
||||||
|
assert tool.name == "headroom_retrieve"
|
||||||
|
assert "CCR" in tool.description
|
||||||
|
|
||||||
|
def test_input_schema_has_ccr_hash(self):
|
||||||
|
"""input_schema 包含 ccr_hash 属性"""
|
||||||
|
compressor = MagicMock()
|
||||||
|
tool = HeadroomRetrieveTool(compressor=compressor)
|
||||||
|
|
||||||
|
assert "ccr_hash" in tool.input_schema["properties"]
|
||||||
|
assert tool.input_schema["properties"]["ccr_hash"]["type"] == "string"
|
||||||
|
|
||||||
|
def test_input_schema_has_query(self):
|
||||||
|
"""input_schema 包含 query 属性"""
|
||||||
|
compressor = MagicMock()
|
||||||
|
tool = HeadroomRetrieveTool(compressor=compressor)
|
||||||
|
|
||||||
|
assert "query" in tool.input_schema["properties"]
|
||||||
|
assert tool.input_schema["properties"]["query"]["type"] == "string"
|
||||||
|
|
||||||
|
def test_input_schema_requires_at_least_one(self):
|
||||||
|
"""input_schema 使用 anyOf 要求至少提供 ccr_hash 或 query"""
|
||||||
|
compressor = MagicMock()
|
||||||
|
tool = HeadroomRetrieveTool(compressor=compressor)
|
||||||
|
|
||||||
|
assert "anyOf" in tool.input_schema
|
||||||
|
any_of = tool.input_schema["anyOf"]
|
||||||
|
# One entry requires ccr_hash, the other requires query
|
||||||
|
required_sets = [item["required"] for item in any_of]
|
||||||
|
assert ["ccr_hash"] in required_sets
|
||||||
|
assert ["query"] in required_sets
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestHeadroomRetrieveToolExecute ────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestHeadroomRetrieveToolExecute:
|
||||||
|
"""HeadroomRetrieveTool 执行测试"""
|
||||||
|
|
||||||
|
async def test_retrieve_by_hash(self):
|
||||||
|
"""通过 ccr_hash 检索,调用 compressor.retrieve"""
|
||||||
|
compressor = MagicMock()
|
||||||
|
compressor.retrieve.return_value = {
|
||||||
|
"content": "original data",
|
||||||
|
"ccr_hash": "abc123",
|
||||||
|
"success": True,
|
||||||
|
}
|
||||||
|
tool = HeadroomRetrieveTool(compressor=compressor)
|
||||||
|
|
||||||
|
result = await tool.execute(ccr_hash="abc123")
|
||||||
|
|
||||||
|
compressor.retrieve.assert_called_once_with(ccr_hash="abc123", query=None)
|
||||||
|
assert result["success"] is True
|
||||||
|
assert result["content"] == "original data"
|
||||||
|
|
||||||
|
async def test_retrieve_by_query(self):
|
||||||
|
"""通过 query 检索,调用 compressor.retrieve"""
|
||||||
|
compressor = MagicMock()
|
||||||
|
compressor.retrieve.return_value = {
|
||||||
|
"results": [{"ccr_hash": "h1", "content": "matched data"}],
|
||||||
|
"success": True,
|
||||||
|
}
|
||||||
|
tool = HeadroomRetrieveTool(compressor=compressor)
|
||||||
|
|
||||||
|
result = await tool.execute(query="search term")
|
||||||
|
|
||||||
|
compressor.retrieve.assert_called_once_with(ccr_hash=None, query="search term")
|
||||||
|
assert result["success"] is True
|
||||||
|
assert len(result["results"]) == 1
|
||||||
|
|
||||||
|
async def test_retrieve_both(self):
|
||||||
|
"""同时提供 ccr_hash 和 query,两个参数都传递给 compressor.retrieve"""
|
||||||
|
compressor = MagicMock()
|
||||||
|
compressor.retrieve.return_value = {
|
||||||
|
"content": "original data",
|
||||||
|
"ccr_hash": "abc123",
|
||||||
|
"success": True,
|
||||||
|
}
|
||||||
|
tool = HeadroomRetrieveTool(compressor=compressor)
|
||||||
|
|
||||||
|
result = await tool.execute(ccr_hash="abc123", query="search term")
|
||||||
|
|
||||||
|
compressor.retrieve.assert_called_once_with(ccr_hash="abc123", query="search term")
|
||||||
|
assert result["success"] is True
|
||||||
|
|
||||||
|
async def test_missing_both_params(self):
|
||||||
|
"""既没有 ccr_hash 也没有 query,返回错误"""
|
||||||
|
compressor = MagicMock()
|
||||||
|
tool = HeadroomRetrieveTool(compressor=compressor)
|
||||||
|
|
||||||
|
result = await tool.execute()
|
||||||
|
|
||||||
|
assert result["success"] is False
|
||||||
|
assert "error" in result
|
||||||
|
assert "ccr_hash" in result["error"] or "query" in result["error"]
|
||||||
|
|
||||||
|
async def test_retrieve_failure(self):
|
||||||
|
"""compressor.retrieve 抛出异常时返回错误结果"""
|
||||||
|
compressor = MagicMock()
|
||||||
|
compressor.retrieve.side_effect = RuntimeError("cache corrupted")
|
||||||
|
tool = HeadroomRetrieveTool(compressor=compressor)
|
||||||
|
|
||||||
|
result = await tool.execute(ccr_hash="abc123")
|
||||||
|
|
||||||
|
assert result["success"] is False
|
||||||
|
assert "error" in result
|
||||||
|
assert "cache corrupted" in result["error"]
|
||||||
|
|
||||||
|
async def test_successful_retrieval(self):
|
||||||
|
"""成功检索返回 content 和 success=True"""
|
||||||
|
compressor = MagicMock()
|
||||||
|
compressor.retrieve.return_value = {
|
||||||
|
"content": "This is the original uncompressed data that was cached",
|
||||||
|
"ccr_hash": "deadbeef1234",
|
||||||
|
"success": True,
|
||||||
|
}
|
||||||
|
tool = HeadroomRetrieveTool(compressor=compressor)
|
||||||
|
|
||||||
|
result = await tool.execute(ccr_hash="deadbeef1234")
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert result["content"] == "This is the original uncompressed data that was cached"
|
||||||
|
assert result["ccr_hash"] == "deadbeef1234"
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestHeadroomRetrieveToolRegistration ────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestHeadroomRetrieveToolRegistration:
|
||||||
|
"""HeadroomRetrieveTool 注册测试"""
|
||||||
|
|
||||||
|
def test_not_registered_when_no_compressor(self):
|
||||||
|
"""没有 compressor 时工具不注册"""
|
||||||
|
registry = ToolRegistry()
|
||||||
|
|
||||||
|
# Simulate: compressor is None → no registration
|
||||||
|
# (no tool should be registered)
|
||||||
|
assert not registry.has_tool("headroom_retrieve")
|
||||||
|
|
||||||
|
def test_not_registered_when_context_compressor(self):
|
||||||
|
"""ContextCompressor(非 HeadroomCompressor)时不注册"""
|
||||||
|
from agentkit.core.compressor import ContextCompressor
|
||||||
|
|
||||||
|
registry = ToolRegistry()
|
||||||
|
# Create a ContextCompressor (not HeadroomCompressor)
|
||||||
|
compressor = ContextCompressor()
|
||||||
|
|
||||||
|
# Simulate the app.py logic: only register if HeadroomCompressor + is_available
|
||||||
|
from agentkit.core.headroom_compressor import HeadroomCompressor
|
||||||
|
if isinstance(compressor, HeadroomCompressor) and compressor.is_available():
|
||||||
|
tool = HeadroomRetrieveTool(compressor=compressor)
|
||||||
|
registry.register(tool)
|
||||||
|
|
||||||
|
assert not registry.has_tool("headroom_retrieve")
|
||||||
|
|
||||||
|
def test_registered_when_headroom_compressor(self):
|
||||||
|
"""HeadroomCompressor 且 is_available() 为 True 时注册"""
|
||||||
|
from agentkit.core.headroom_compressor import HeadroomCompressor
|
||||||
|
|
||||||
|
registry = ToolRegistry()
|
||||||
|
|
||||||
|
# Create a real HeadroomCompressor instance but mock is_available
|
||||||
|
with patch.object(HeadroomCompressor, "is_available", return_value=True):
|
||||||
|
compressor = HeadroomCompressor(config={})
|
||||||
|
# Simulate the app.py logic
|
||||||
|
if isinstance(compressor, HeadroomCompressor) and compressor.is_available():
|
||||||
|
tool = HeadroomRetrieveTool(compressor=compressor)
|
||||||
|
registry.register(tool)
|
||||||
|
|
||||||
|
assert registry.has_tool("headroom_retrieve")
|
||||||
|
registered_tool = registry.get("headroom_retrieve")
|
||||||
|
assert isinstance(registered_tool, HeadroomRetrieveTool)
|
||||||
Loading…
Reference in New Issue