fischer-agentkit/tests/unit/test_memory_api.py

242 lines
8.4 KiB
Python

"""Unit tests for Memory API routes"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from fastapi.testclient import TestClient
from agentkit.llm.gateway import LLMGateway
from agentkit.memory.retriever import MemoryRetriever
from agentkit.memory.base import MemoryItem
from agentkit.skills.registry import SkillRegistry
from agentkit.tools.registry import ToolRegistry
from agentkit.server.app import create_app
@pytest.fixture
def mock_llm_gateway():
gateway = LLMGateway()
mock_provider = AsyncMock()
from agentkit.llm.protocol import LLMResponse, TokenUsage
mock_provider.chat.return_value = LLMResponse(
content='{"result": "mocked"}',
model="test-model",
usage=TokenUsage(prompt_tokens=10, completion_tokens=20),
)
gateway.register_provider("test", mock_provider)
return gateway
@pytest.fixture
def mock_episodic():
episodic = AsyncMock()
return episodic
@pytest.fixture
def mock_semantic():
semantic = AsyncMock()
return semantic
@pytest.fixture
def memory_retriever(mock_episodic, mock_semantic):
return MemoryRetriever(
episodic_memory=mock_episodic,
semantic_memory=mock_semantic,
)
@pytest.fixture
def app(mock_llm_gateway, memory_retriever):
app = create_app(
llm_gateway=mock_llm_gateway,
skill_registry=SkillRegistry(),
tool_registry=ToolRegistry(),
)
app.state.memory_retriever = memory_retriever
return app
@pytest.fixture
def client(app):
return TestClient(app)
class TestSearchEpisodicMemory:
"""GET /api/v1/memory/episodic"""
def test_search_returns_results(self, client, mock_episodic):
mock_episodic.search.return_value = [
MemoryItem(
key="ep-1",
value={"input_summary": "test input", "output_summary": "test output"},
score=0.85,
metadata={"source": "episodic", "agent_name": "test_agent"},
),
]
response = client.get("/api/v1/memory/episodic?query=test")
assert response.status_code == 200
data = response.json()
assert data["query"] == "test"
assert data["total"] == 1
assert data["results"][0]["key"] == "ep-1"
assert data["results"][0]["score"] == 0.85
def test_search_with_agent_name_filter(self, client, mock_episodic):
mock_episodic.search.return_value = []
response = client.get("/api/v1/memory/episodic?query=test&agent_name=my_agent")
assert response.status_code == 200
mock_episodic.search.assert_called_once()
call_kwargs = mock_episodic.search.call_args
assert call_kwargs[1]["filters"] == {"agent_name": "my_agent"} or (
call_kwargs[0] and len(call_kwargs[0]) > 2 and call_kwargs[0][2] == {"agent_name": "my_agent"}
)
def test_search_with_top_k(self, client, mock_episodic):
mock_episodic.search.return_value = []
response = client.get("/api/v1/memory/episodic?query=test&top_k=10")
assert response.status_code == 200
mock_episodic.search.assert_called_once()
def test_search_returns_empty_results(self, client, mock_episodic):
mock_episodic.search.return_value = []
response = client.get("/api/v1/memory/episodic?query=nonexistent")
assert response.status_code == 200
data = response.json()
assert data["total"] == 0
assert data["results"] == []
def test_returns_503_when_retriever_not_configured(self, mock_llm_gateway):
app = create_app(
llm_gateway=mock_llm_gateway,
skill_registry=SkillRegistry(),
tool_registry=ToolRegistry(),
)
app.state.memory_retriever = None
client = TestClient(app)
response = client.get("/api/v1/memory/episodic?query=test")
assert response.status_code == 503
def test_returns_503_when_episodic_not_configured(self, mock_llm_gateway):
retriever = MemoryRetriever(episodic_memory=None, semantic_memory=None)
app = create_app(
llm_gateway=mock_llm_gateway,
skill_registry=SkillRegistry(),
tool_registry=ToolRegistry(),
)
app.state.memory_retriever = retriever
client = TestClient(app)
response = client.get("/api/v1/memory/episodic?query=test")
assert response.status_code == 503
class TestSearchSemanticMemory:
"""GET /api/v1/memory/semantic/search"""
def test_search_returns_results(self, client, mock_semantic):
mock_semantic.search.return_value = [
MemoryItem(
key="doc-1",
value="Relevant document content",
score=0.92,
metadata={"source": "rag", "knowledge_base_id": "kb-1"},
),
]
response = client.get("/api/v1/memory/semantic/search?query=hello")
assert response.status_code == 200
data = response.json()
assert data["query"] == "hello"
assert data["total"] == 1
assert data["results"][0]["key"] == "doc-1"
def test_search_with_knowledge_base_ids(self, client, mock_semantic):
mock_semantic.search.return_value = []
response = client.get("/api/v1/memory/semantic/search?query=test&knowledge_base_ids=kb1,kb2")
assert response.status_code == 200
mock_semantic.search.assert_called_once()
call_args = mock_semantic.search.call_args
# filters is passed as keyword arg
filters = call_args.kwargs.get("filters") or call_args[1].get("filters")
assert filters is not None
assert "knowledge_base_ids" in filters
assert filters["knowledge_base_ids"] == ["kb1", "kb2"]
def test_search_returns_empty_results(self, client, mock_semantic):
mock_semantic.search.return_value = []
response = client.get("/api/v1/memory/semantic/search?query=nonexistent")
assert response.status_code == 200
data = response.json()
assert data["total"] == 0
def test_returns_503_when_retriever_not_configured(self, mock_llm_gateway):
app = create_app(
llm_gateway=mock_llm_gateway,
skill_registry=SkillRegistry(),
tool_registry=ToolRegistry(),
)
app.state.memory_retriever = None
client = TestClient(app)
response = client.get("/api/v1/memory/semantic/search?query=test")
assert response.status_code == 503
def test_returns_503_when_semantic_not_configured(self, mock_llm_gateway):
retriever = MemoryRetriever(episodic_memory=None, semantic_memory=None)
app = create_app(
llm_gateway=mock_llm_gateway,
skill_registry=SkillRegistry(),
tool_registry=ToolRegistry(),
)
app.state.memory_retriever = retriever
client = TestClient(app)
response = client.get("/api/v1/memory/semantic/search?query=test")
assert response.status_code == 503
class TestDeleteEpisodicMemory:
"""DELETE /api/v1/memory/episodic/{key}"""
def test_delete_succeeds(self, client, mock_episodic):
mock_episodic.delete.return_value = True
response = client.delete("/api/v1/memory/episodic/ep-123")
assert response.status_code == 200
data = response.json()
assert data["key"] == "ep-123"
assert data["deleted"] is True
def test_delete_returns_404_when_not_found(self, client, mock_episodic):
mock_episodic.delete.return_value = False
response = client.delete("/api/v1/memory/episodic/nonexistent")
assert response.status_code == 404
def test_returns_503_when_retriever_not_configured(self, mock_llm_gateway):
app = create_app(
llm_gateway=mock_llm_gateway,
skill_registry=SkillRegistry(),
tool_registry=ToolRegistry(),
)
app.state.memory_retriever = None
client = TestClient(app)
response = client.delete("/api/v1/memory/episodic/ep-1")
assert response.status_code == 503
def test_returns_503_when_episodic_not_configured(self, mock_llm_gateway):
retriever = MemoryRetriever(episodic_memory=None, semantic_memory=None)
app = create_app(
llm_gateway=mock_llm_gateway,
skill_registry=SkillRegistry(),
tool_registry=ToolRegistry(),
)
app.state.memory_retriever = retriever
client = TestClient(app)
response = client.delete("/api/v1/memory/episodic/ep-1")
assert response.status_code == 503