fischer-agentkit/tests/unit/test_headroom_retrieve_tool.py

196 lines
7.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""U5 测试: HeadroomRetrieveTool - CCR 可逆压缩检索工具
测试 headroom_retrieve 工具的构造、执行和注册逻辑。
"""
from unittest.mock import MagicMock, patch
import pytest
from agentkit.tools.headroom_retrieve import HeadroomRetrieveTool
from agentkit.tools.registry import ToolRegistry
# ── TestHeadroomRetrieveToolConstruction ────────────────────
class TestHeadroomRetrieveToolConstruction:
"""HeadroomRetrieveTool 构造测试"""
def test_name_and_description(self):
"""工具名称为 headroom_retrieve描述包含 CCR"""
compressor = MagicMock()
tool = HeadroomRetrieveTool(compressor=compressor)
assert tool.name == "headroom_retrieve"
assert "CCR" in tool.description
def test_input_schema_has_ccr_hash(self):
"""input_schema 包含 ccr_hash 属性"""
compressor = MagicMock()
tool = HeadroomRetrieveTool(compressor=compressor)
assert "ccr_hash" in tool.input_schema["properties"]
assert tool.input_schema["properties"]["ccr_hash"]["type"] == "string"
def test_input_schema_has_query(self):
"""input_schema 包含 query 属性"""
compressor = MagicMock()
tool = HeadroomRetrieveTool(compressor=compressor)
assert "query" in tool.input_schema["properties"]
assert tool.input_schema["properties"]["query"]["type"] == "string"
def test_input_schema_requires_at_least_one(self):
"""input_schema 使用 anyOf 要求至少提供 ccr_hash 或 query"""
compressor = MagicMock()
tool = HeadroomRetrieveTool(compressor=compressor)
assert "anyOf" in tool.input_schema
any_of = tool.input_schema["anyOf"]
# One entry requires ccr_hash, the other requires query
required_sets = [item["required"] for item in any_of]
assert ["ccr_hash"] in required_sets
assert ["query"] in required_sets
# ── TestHeadroomRetrieveToolExecute ────────────────────────
class TestHeadroomRetrieveToolExecute:
"""HeadroomRetrieveTool 执行测试"""
async def test_retrieve_by_hash(self):
"""通过 ccr_hash 检索,调用 compressor.retrieve"""
compressor = MagicMock()
compressor.retrieve.return_value = {
"content": "original data",
"ccr_hash": "abc123",
"success": True,
}
tool = HeadroomRetrieveTool(compressor=compressor)
result = await tool.execute(ccr_hash="abc123")
compressor.retrieve.assert_called_once_with(ccr_hash="abc123", query=None)
assert result["success"] is True
assert result["content"] == "original data"
async def test_retrieve_by_query(self):
"""通过 query 检索,调用 compressor.retrieve"""
compressor = MagicMock()
compressor.retrieve.return_value = {
"results": [{"ccr_hash": "h1", "content": "matched data"}],
"success": True,
}
tool = HeadroomRetrieveTool(compressor=compressor)
result = await tool.execute(query="search term")
compressor.retrieve.assert_called_once_with(ccr_hash=None, query="search term")
assert result["success"] is True
assert len(result["results"]) == 1
async def test_retrieve_both(self):
"""同时提供 ccr_hash 和 query两个参数都传递给 compressor.retrieve"""
compressor = MagicMock()
compressor.retrieve.return_value = {
"content": "original data",
"ccr_hash": "abc123",
"success": True,
}
tool = HeadroomRetrieveTool(compressor=compressor)
result = await tool.execute(ccr_hash="abc123", query="search term")
compressor.retrieve.assert_called_once_with(ccr_hash="abc123", query="search term")
assert result["success"] is True
async def test_missing_both_params(self):
"""既没有 ccr_hash 也没有 query返回错误"""
compressor = MagicMock()
tool = HeadroomRetrieveTool(compressor=compressor)
result = await tool.execute()
assert result["success"] is False
assert "error" in result
assert "ccr_hash" in result["error"] or "query" in result["error"]
async def test_retrieve_failure(self):
"""compressor.retrieve 抛出异常时返回错误结果"""
compressor = MagicMock()
compressor.retrieve.side_effect = RuntimeError("cache corrupted")
tool = HeadroomRetrieveTool(compressor=compressor)
result = await tool.execute(ccr_hash="abc123")
assert result["success"] is False
assert "error" in result
assert "cache corrupted" in result["error"]
async def test_successful_retrieval(self):
"""成功检索返回 content 和 success=True"""
compressor = MagicMock()
compressor.retrieve.return_value = {
"content": "This is the original uncompressed data that was cached",
"ccr_hash": "deadbeef1234",
"success": True,
}
tool = HeadroomRetrieveTool(compressor=compressor)
result = await tool.execute(ccr_hash="deadbeef1234")
assert result["success"] is True
assert result["content"] == "This is the original uncompressed data that was cached"
assert result["ccr_hash"] == "deadbeef1234"
# ── TestHeadroomRetrieveToolRegistration ────────────────────
class TestHeadroomRetrieveToolRegistration:
"""HeadroomRetrieveTool 注册测试"""
def test_not_registered_when_no_compressor(self):
"""没有 compressor 时工具不注册"""
registry = ToolRegistry()
# Simulate: compressor is None → no registration
# (no tool should be registered)
assert not registry.has_tool("headroom_retrieve")
def test_not_registered_when_context_compressor(self):
"""ContextCompressor非 HeadroomCompressor时不注册"""
from agentkit.core.compressor import ContextCompressor
registry = ToolRegistry()
# Create a ContextCompressor (not HeadroomCompressor)
compressor = ContextCompressor()
# Simulate the app.py logic: only register if HeadroomCompressor + is_available
from agentkit.core.headroom_compressor import HeadroomCompressor
if isinstance(compressor, HeadroomCompressor) and compressor.is_available():
tool = HeadroomRetrieveTool(compressor=compressor)
registry.register(tool)
assert not registry.has_tool("headroom_retrieve")
def test_registered_when_headroom_compressor(self):
"""HeadroomCompressor 且 is_available() 为 True 时注册"""
from agentkit.core.headroom_compressor import HeadroomCompressor
registry = ToolRegistry()
# Create a real HeadroomCompressor instance but mock is_available
with patch.object(HeadroomCompressor, "is_available", return_value=True):
compressor = HeadroomCompressor(config={})
# Simulate the app.py logic
if isinstance(compressor, HeadroomCompressor) and compressor.is_available():
tool = HeadroomRetrieveTool(compressor=compressor)
registry.register(tool)
assert registry.has_tool("headroom_retrieve")
registered_tool = registry.get("headroom_retrieve")
assert isinstance(registered_tool, HeadroomRetrieveTool)