196 lines
7.4 KiB
Python
196 lines
7.4 KiB
Python
"""U5 测试: HeadroomRetrieveTool - CCR 可逆压缩检索工具
|
||
|
||
测试 headroom_retrieve 工具的构造、执行和注册逻辑。
|
||
"""
|
||
|
||
from unittest.mock import MagicMock, patch
|
||
|
||
import pytest
|
||
|
||
from agentkit.tools.headroom_retrieve import HeadroomRetrieveTool
|
||
from agentkit.tools.registry import ToolRegistry
|
||
|
||
|
||
# ── TestHeadroomRetrieveToolConstruction ────────────────────
|
||
|
||
|
||
class TestHeadroomRetrieveToolConstruction:
|
||
"""HeadroomRetrieveTool 构造测试"""
|
||
|
||
def test_name_and_description(self):
|
||
"""工具名称为 headroom_retrieve,描述包含 CCR"""
|
||
compressor = MagicMock()
|
||
tool = HeadroomRetrieveTool(compressor=compressor)
|
||
|
||
assert tool.name == "headroom_retrieve"
|
||
assert "CCR" in tool.description
|
||
|
||
def test_input_schema_has_ccr_hash(self):
|
||
"""input_schema 包含 ccr_hash 属性"""
|
||
compressor = MagicMock()
|
||
tool = HeadroomRetrieveTool(compressor=compressor)
|
||
|
||
assert "ccr_hash" in tool.input_schema["properties"]
|
||
assert tool.input_schema["properties"]["ccr_hash"]["type"] == "string"
|
||
|
||
def test_input_schema_has_query(self):
|
||
"""input_schema 包含 query 属性"""
|
||
compressor = MagicMock()
|
||
tool = HeadroomRetrieveTool(compressor=compressor)
|
||
|
||
assert "query" in tool.input_schema["properties"]
|
||
assert tool.input_schema["properties"]["query"]["type"] == "string"
|
||
|
||
def test_input_schema_requires_at_least_one(self):
|
||
"""input_schema 使用 anyOf 要求至少提供 ccr_hash 或 query"""
|
||
compressor = MagicMock()
|
||
tool = HeadroomRetrieveTool(compressor=compressor)
|
||
|
||
assert "anyOf" in tool.input_schema
|
||
any_of = tool.input_schema["anyOf"]
|
||
# One entry requires ccr_hash, the other requires query
|
||
required_sets = [item["required"] for item in any_of]
|
||
assert ["ccr_hash"] in required_sets
|
||
assert ["query"] in required_sets
|
||
|
||
|
||
# ── TestHeadroomRetrieveToolExecute ────────────────────────
|
||
|
||
|
||
class TestHeadroomRetrieveToolExecute:
|
||
"""HeadroomRetrieveTool 执行测试"""
|
||
|
||
async def test_retrieve_by_hash(self):
|
||
"""通过 ccr_hash 检索,调用 compressor.retrieve"""
|
||
compressor = MagicMock()
|
||
compressor.retrieve.return_value = {
|
||
"content": "original data",
|
||
"ccr_hash": "abc123",
|
||
"success": True,
|
||
}
|
||
tool = HeadroomRetrieveTool(compressor=compressor)
|
||
|
||
result = await tool.execute(ccr_hash="abc123")
|
||
|
||
compressor.retrieve.assert_called_once_with(ccr_hash="abc123", query=None)
|
||
assert result["success"] is True
|
||
assert result["content"] == "original data"
|
||
|
||
async def test_retrieve_by_query(self):
|
||
"""通过 query 检索,调用 compressor.retrieve"""
|
||
compressor = MagicMock()
|
||
compressor.retrieve.return_value = {
|
||
"results": [{"ccr_hash": "h1", "content": "matched data"}],
|
||
"success": True,
|
||
}
|
||
tool = HeadroomRetrieveTool(compressor=compressor)
|
||
|
||
result = await tool.execute(query="search term")
|
||
|
||
compressor.retrieve.assert_called_once_with(ccr_hash=None, query="search term")
|
||
assert result["success"] is True
|
||
assert len(result["results"]) == 1
|
||
|
||
async def test_retrieve_both(self):
|
||
"""同时提供 ccr_hash 和 query,两个参数都传递给 compressor.retrieve"""
|
||
compressor = MagicMock()
|
||
compressor.retrieve.return_value = {
|
||
"content": "original data",
|
||
"ccr_hash": "abc123",
|
||
"success": True,
|
||
}
|
||
tool = HeadroomRetrieveTool(compressor=compressor)
|
||
|
||
result = await tool.execute(ccr_hash="abc123", query="search term")
|
||
|
||
compressor.retrieve.assert_called_once_with(ccr_hash="abc123", query="search term")
|
||
assert result["success"] is True
|
||
|
||
async def test_missing_both_params(self):
|
||
"""既没有 ccr_hash 也没有 query,返回错误"""
|
||
compressor = MagicMock()
|
||
tool = HeadroomRetrieveTool(compressor=compressor)
|
||
|
||
result = await tool.execute()
|
||
|
||
assert result["success"] is False
|
||
assert "error" in result
|
||
assert "ccr_hash" in result["error"] or "query" in result["error"]
|
||
|
||
async def test_retrieve_failure(self):
|
||
"""compressor.retrieve 抛出异常时返回错误结果"""
|
||
compressor = MagicMock()
|
||
compressor.retrieve.side_effect = RuntimeError("cache corrupted")
|
||
tool = HeadroomRetrieveTool(compressor=compressor)
|
||
|
||
result = await tool.execute(ccr_hash="abc123")
|
||
|
||
assert result["success"] is False
|
||
assert "error" in result
|
||
assert "cache corrupted" in result["error"]
|
||
|
||
async def test_successful_retrieval(self):
|
||
"""成功检索返回 content 和 success=True"""
|
||
compressor = MagicMock()
|
||
compressor.retrieve.return_value = {
|
||
"content": "This is the original uncompressed data that was cached",
|
||
"ccr_hash": "deadbeef1234",
|
||
"success": True,
|
||
}
|
||
tool = HeadroomRetrieveTool(compressor=compressor)
|
||
|
||
result = await tool.execute(ccr_hash="deadbeef1234")
|
||
|
||
assert result["success"] is True
|
||
assert result["content"] == "This is the original uncompressed data that was cached"
|
||
assert result["ccr_hash"] == "deadbeef1234"
|
||
|
||
|
||
# ── TestHeadroomRetrieveToolRegistration ────────────────────
|
||
|
||
|
||
class TestHeadroomRetrieveToolRegistration:
|
||
"""HeadroomRetrieveTool 注册测试"""
|
||
|
||
def test_not_registered_when_no_compressor(self):
|
||
"""没有 compressor 时工具不注册"""
|
||
registry = ToolRegistry()
|
||
|
||
# Simulate: compressor is None → no registration
|
||
# (no tool should be registered)
|
||
assert not registry.has_tool("headroom_retrieve")
|
||
|
||
def test_not_registered_when_context_compressor(self):
|
||
"""ContextCompressor(非 HeadroomCompressor)时不注册"""
|
||
from agentkit.core.compressor import ContextCompressor
|
||
|
||
registry = ToolRegistry()
|
||
# Create a ContextCompressor (not HeadroomCompressor)
|
||
compressor = ContextCompressor()
|
||
|
||
# Simulate the app.py logic: only register if HeadroomCompressor + is_available
|
||
from agentkit.core.headroom_compressor import HeadroomCompressor
|
||
if isinstance(compressor, HeadroomCompressor) and compressor.is_available():
|
||
tool = HeadroomRetrieveTool(compressor=compressor)
|
||
registry.register(tool)
|
||
|
||
assert not registry.has_tool("headroom_retrieve")
|
||
|
||
def test_registered_when_headroom_compressor(self):
|
||
"""HeadroomCompressor 且 is_available() 为 True 时注册"""
|
||
from agentkit.core.headroom_compressor import HeadroomCompressor
|
||
|
||
registry = ToolRegistry()
|
||
|
||
# Create a real HeadroomCompressor instance but mock is_available
|
||
with patch.object(HeadroomCompressor, "is_available", return_value=True):
|
||
compressor = HeadroomCompressor(config={})
|
||
# Simulate the app.py logic
|
||
if isinstance(compressor, HeadroomCompressor) and compressor.is_available():
|
||
tool = HeadroomRetrieveTool(compressor=compressor)
|
||
registry.register(tool)
|
||
|
||
assert registry.has_tool("headroom_retrieve")
|
||
registered_tool = registry.get("headroom_retrieve")
|
||
assert isinstance(registered_tool, HeadroomRetrieveTool)
|