feat(compression): U5 HeadroomRetrieveTool for CCR cache retrieval
Add HeadroomRetrieveTool that allows LLM to retrieve original uncompressed data from CCR cache via Function Calling. Auto-registered when HeadroomCompressor is active and available.
This commit is contained in:
parent
286804792d
commit
9c04362dba
|
|
@ -10,6 +10,12 @@ from agentkit.tools.web_crawl import WebCrawlTool
|
|||
from agentkit.tools.schema_tools import SchemaExtractTool, SchemaGenerateTool
|
||||
from agentkit.tools.baidu_search import BaiduSearchTool
|
||||
|
||||
# Conditional import: HeadroomRetrieveTool requires HeadroomCompressor
|
||||
try:
|
||||
from agentkit.tools.headroom_retrieve import HeadroomRetrieveTool
|
||||
except ImportError:
|
||||
HeadroomRetrieveTool = None # type: ignore[misc,assignment]
|
||||
|
||||
__all__ = [
|
||||
"Tool",
|
||||
"FunctionTool",
|
||||
|
|
@ -23,4 +29,5 @@ __all__ = [
|
|||
"SchemaExtractTool",
|
||||
"SchemaGenerateTool",
|
||||
"BaiduSearchTool",
|
||||
"HeadroomRetrieveTool",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,70 @@
|
|||
"""HeadroomRetrieveTool — CCR 可逆压缩检索工具
|
||||
|
||||
当 HeadroomCompressor 启用时,LLM 可通过此工具从 CCR 缓存中
|
||||
取回被压缩的原始数据。工具输出中的 <!-- CCR:hash=xxx --> 标记
|
||||
指示可检索的内容。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from agentkit.tools.base import Tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HeadroomRetrieveTool(Tool):
|
||||
"""从 CCR 缓存检索原始未压缩数据
|
||||
|
||||
当 Headroom 压缩工具输出后,LLM 可通过此工具取回原始数据。
|
||||
压缩内容中包含 <!-- CCR:hash=xxx --> 标记,LLM 可使用该哈希值检索。
|
||||
"""
|
||||
|
||||
def __init__(self, compressor: Any):
|
||||
super().__init__(
|
||||
name="headroom_retrieve",
|
||||
description=(
|
||||
"Retrieve original uncompressed data from the CCR (Compress-Cache-Retrieve) cache. "
|
||||
"Use this tool when you see a <!-- CCR:hash=xxx --> marker in compressed content "
|
||||
"and need the full original data. Pass the hash value or a search query."
|
||||
),
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"ccr_hash": {
|
||||
"type": "string",
|
||||
"description": "The CCR hash from a <!-- CCR:hash=xxx --> marker. Use this for direct lookup.",
|
||||
},
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query to find matching cached content. Used when hash is not available.",
|
||||
},
|
||||
},
|
||||
"anyOf": [
|
||||
{"required": ["ccr_hash"]},
|
||||
{"required": ["query"]},
|
||||
],
|
||||
},
|
||||
)
|
||||
self._compressor = compressor
|
||||
|
||||
async def execute(self, **kwargs) -> dict:
|
||||
"""从 CCR 缓存检索原始数据"""
|
||||
ccr_hash = kwargs.get("ccr_hash")
|
||||
query = kwargs.get("query")
|
||||
|
||||
if not ccr_hash and not query:
|
||||
return {
|
||||
"error": "Either ccr_hash or query must be provided",
|
||||
"success": False,
|
||||
}
|
||||
|
||||
try:
|
||||
result = self._compressor.retrieve(ccr_hash=ccr_hash, query=query)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"CCR retrieval failed: {e}")
|
||||
return {
|
||||
"error": f"CCR retrieval failed: {e}",
|
||||
"success": False,
|
||||
}
|
||||
|
|
@ -0,0 +1,195 @@
|
|||
"""U5 测试: HeadroomRetrieveTool - CCR 可逆压缩检索工具
|
||||
|
||||
测试 headroom_retrieve 工具的构造、执行和注册逻辑。
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.tools.headroom_retrieve import HeadroomRetrieveTool
|
||||
from agentkit.tools.registry import ToolRegistry
|
||||
|
||||
|
||||
# ── TestHeadroomRetrieveToolConstruction ────────────────────
|
||||
|
||||
|
||||
class TestHeadroomRetrieveToolConstruction:
|
||||
"""HeadroomRetrieveTool 构造测试"""
|
||||
|
||||
def test_name_and_description(self):
|
||||
"""工具名称为 headroom_retrieve,描述包含 CCR"""
|
||||
compressor = MagicMock()
|
||||
tool = HeadroomRetrieveTool(compressor=compressor)
|
||||
|
||||
assert tool.name == "headroom_retrieve"
|
||||
assert "CCR" in tool.description
|
||||
|
||||
def test_input_schema_has_ccr_hash(self):
|
||||
"""input_schema 包含 ccr_hash 属性"""
|
||||
compressor = MagicMock()
|
||||
tool = HeadroomRetrieveTool(compressor=compressor)
|
||||
|
||||
assert "ccr_hash" in tool.input_schema["properties"]
|
||||
assert tool.input_schema["properties"]["ccr_hash"]["type"] == "string"
|
||||
|
||||
def test_input_schema_has_query(self):
|
||||
"""input_schema 包含 query 属性"""
|
||||
compressor = MagicMock()
|
||||
tool = HeadroomRetrieveTool(compressor=compressor)
|
||||
|
||||
assert "query" in tool.input_schema["properties"]
|
||||
assert tool.input_schema["properties"]["query"]["type"] == "string"
|
||||
|
||||
def test_input_schema_requires_at_least_one(self):
|
||||
"""input_schema 使用 anyOf 要求至少提供 ccr_hash 或 query"""
|
||||
compressor = MagicMock()
|
||||
tool = HeadroomRetrieveTool(compressor=compressor)
|
||||
|
||||
assert "anyOf" in tool.input_schema
|
||||
any_of = tool.input_schema["anyOf"]
|
||||
# One entry requires ccr_hash, the other requires query
|
||||
required_sets = [item["required"] for item in any_of]
|
||||
assert ["ccr_hash"] in required_sets
|
||||
assert ["query"] in required_sets
|
||||
|
||||
|
||||
# ── TestHeadroomRetrieveToolExecute ────────────────────────
|
||||
|
||||
|
||||
class TestHeadroomRetrieveToolExecute:
|
||||
"""HeadroomRetrieveTool 执行测试"""
|
||||
|
||||
async def test_retrieve_by_hash(self):
|
||||
"""通过 ccr_hash 检索,调用 compressor.retrieve"""
|
||||
compressor = MagicMock()
|
||||
compressor.retrieve.return_value = {
|
||||
"content": "original data",
|
||||
"ccr_hash": "abc123",
|
||||
"success": True,
|
||||
}
|
||||
tool = HeadroomRetrieveTool(compressor=compressor)
|
||||
|
||||
result = await tool.execute(ccr_hash="abc123")
|
||||
|
||||
compressor.retrieve.assert_called_once_with(ccr_hash="abc123", query=None)
|
||||
assert result["success"] is True
|
||||
assert result["content"] == "original data"
|
||||
|
||||
async def test_retrieve_by_query(self):
|
||||
"""通过 query 检索,调用 compressor.retrieve"""
|
||||
compressor = MagicMock()
|
||||
compressor.retrieve.return_value = {
|
||||
"results": [{"ccr_hash": "h1", "content": "matched data"}],
|
||||
"success": True,
|
||||
}
|
||||
tool = HeadroomRetrieveTool(compressor=compressor)
|
||||
|
||||
result = await tool.execute(query="search term")
|
||||
|
||||
compressor.retrieve.assert_called_once_with(ccr_hash=None, query="search term")
|
||||
assert result["success"] is True
|
||||
assert len(result["results"]) == 1
|
||||
|
||||
async def test_retrieve_both(self):
|
||||
"""同时提供 ccr_hash 和 query,两个参数都传递给 compressor.retrieve"""
|
||||
compressor = MagicMock()
|
||||
compressor.retrieve.return_value = {
|
||||
"content": "original data",
|
||||
"ccr_hash": "abc123",
|
||||
"success": True,
|
||||
}
|
||||
tool = HeadroomRetrieveTool(compressor=compressor)
|
||||
|
||||
result = await tool.execute(ccr_hash="abc123", query="search term")
|
||||
|
||||
compressor.retrieve.assert_called_once_with(ccr_hash="abc123", query="search term")
|
||||
assert result["success"] is True
|
||||
|
||||
async def test_missing_both_params(self):
|
||||
"""既没有 ccr_hash 也没有 query,返回错误"""
|
||||
compressor = MagicMock()
|
||||
tool = HeadroomRetrieveTool(compressor=compressor)
|
||||
|
||||
result = await tool.execute()
|
||||
|
||||
assert result["success"] is False
|
||||
assert "error" in result
|
||||
assert "ccr_hash" in result["error"] or "query" in result["error"]
|
||||
|
||||
async def test_retrieve_failure(self):
|
||||
"""compressor.retrieve 抛出异常时返回错误结果"""
|
||||
compressor = MagicMock()
|
||||
compressor.retrieve.side_effect = RuntimeError("cache corrupted")
|
||||
tool = HeadroomRetrieveTool(compressor=compressor)
|
||||
|
||||
result = await tool.execute(ccr_hash="abc123")
|
||||
|
||||
assert result["success"] is False
|
||||
assert "error" in result
|
||||
assert "cache corrupted" in result["error"]
|
||||
|
||||
async def test_successful_retrieval(self):
|
||||
"""成功检索返回 content 和 success=True"""
|
||||
compressor = MagicMock()
|
||||
compressor.retrieve.return_value = {
|
||||
"content": "This is the original uncompressed data that was cached",
|
||||
"ccr_hash": "deadbeef1234",
|
||||
"success": True,
|
||||
}
|
||||
tool = HeadroomRetrieveTool(compressor=compressor)
|
||||
|
||||
result = await tool.execute(ccr_hash="deadbeef1234")
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["content"] == "This is the original uncompressed data that was cached"
|
||||
assert result["ccr_hash"] == "deadbeef1234"
|
||||
|
||||
|
||||
# ── TestHeadroomRetrieveToolRegistration ────────────────────
|
||||
|
||||
|
||||
class TestHeadroomRetrieveToolRegistration:
|
||||
"""HeadroomRetrieveTool 注册测试"""
|
||||
|
||||
def test_not_registered_when_no_compressor(self):
|
||||
"""没有 compressor 时工具不注册"""
|
||||
registry = ToolRegistry()
|
||||
|
||||
# Simulate: compressor is None → no registration
|
||||
# (no tool should be registered)
|
||||
assert not registry.has_tool("headroom_retrieve")
|
||||
|
||||
def test_not_registered_when_context_compressor(self):
|
||||
"""ContextCompressor(非 HeadroomCompressor)时不注册"""
|
||||
from agentkit.core.compressor import ContextCompressor
|
||||
|
||||
registry = ToolRegistry()
|
||||
# Create a ContextCompressor (not HeadroomCompressor)
|
||||
compressor = ContextCompressor()
|
||||
|
||||
# Simulate the app.py logic: only register if HeadroomCompressor + is_available
|
||||
from agentkit.core.headroom_compressor import HeadroomCompressor
|
||||
if isinstance(compressor, HeadroomCompressor) and compressor.is_available():
|
||||
tool = HeadroomRetrieveTool(compressor=compressor)
|
||||
registry.register(tool)
|
||||
|
||||
assert not registry.has_tool("headroom_retrieve")
|
||||
|
||||
def test_registered_when_headroom_compressor(self):
|
||||
"""HeadroomCompressor 且 is_available() 为 True 时注册"""
|
||||
from agentkit.core.headroom_compressor import HeadroomCompressor
|
||||
|
||||
registry = ToolRegistry()
|
||||
|
||||
# Create a real HeadroomCompressor instance but mock is_available
|
||||
with patch.object(HeadroomCompressor, "is_available", return_value=True):
|
||||
compressor = HeadroomCompressor(config={})
|
||||
# Simulate the app.py logic
|
||||
if isinstance(compressor, HeadroomCompressor) and compressor.is_available():
|
||||
tool = HeadroomRetrieveTool(compressor=compressor)
|
||||
registry.register(tool)
|
||||
|
||||
assert registry.has_tool("headroom_retrieve")
|
||||
registered_tool = registry.get("headroom_retrieve")
|
||||
assert isinstance(registered_tool, HeadroomRetrieveTool)
|
||||
Loading…
Reference in New Issue