242 lines
8.4 KiB
Python
242 lines
8.4 KiB
Python
"""Unit tests for Memory API routes"""
|
|
|
|
import pytest
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
from fastapi.testclient import TestClient
|
|
|
|
from agentkit.llm.gateway import LLMGateway
|
|
from agentkit.memory.retriever import MemoryRetriever
|
|
from agentkit.memory.base import MemoryItem
|
|
from agentkit.skills.registry import SkillRegistry
|
|
from agentkit.tools.registry import ToolRegistry
|
|
from agentkit.server.app import create_app
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_llm_gateway():
|
|
gateway = LLMGateway()
|
|
mock_provider = AsyncMock()
|
|
from agentkit.llm.protocol import LLMResponse, TokenUsage
|
|
mock_provider.chat.return_value = LLMResponse(
|
|
content='{"result": "mocked"}',
|
|
model="test-model",
|
|
usage=TokenUsage(prompt_tokens=10, completion_tokens=20),
|
|
)
|
|
gateway.register_provider("test", mock_provider)
|
|
return gateway
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_episodic():
|
|
episodic = AsyncMock()
|
|
return episodic
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_semantic():
|
|
semantic = AsyncMock()
|
|
return semantic
|
|
|
|
|
|
@pytest.fixture
|
|
def memory_retriever(mock_episodic, mock_semantic):
|
|
return MemoryRetriever(
|
|
episodic_memory=mock_episodic,
|
|
semantic_memory=mock_semantic,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def app(mock_llm_gateway, memory_retriever):
|
|
app = create_app(
|
|
llm_gateway=mock_llm_gateway,
|
|
skill_registry=SkillRegistry(),
|
|
tool_registry=ToolRegistry(),
|
|
)
|
|
app.state.memory_retriever = memory_retriever
|
|
return app
|
|
|
|
|
|
@pytest.fixture
|
|
def client(app):
|
|
return TestClient(app)
|
|
|
|
|
|
class TestSearchEpisodicMemory:
|
|
"""GET /api/v1/memory/episodic"""
|
|
|
|
def test_search_returns_results(self, client, mock_episodic):
|
|
mock_episodic.search.return_value = [
|
|
MemoryItem(
|
|
key="ep-1",
|
|
value={"input_summary": "test input", "output_summary": "test output"},
|
|
score=0.85,
|
|
metadata={"source": "episodic", "agent_name": "test_agent"},
|
|
),
|
|
]
|
|
|
|
response = client.get("/api/v1/memory/episodic?query=test")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["query"] == "test"
|
|
assert data["total"] == 1
|
|
assert data["results"][0]["key"] == "ep-1"
|
|
assert data["results"][0]["score"] == 0.85
|
|
|
|
def test_search_with_agent_name_filter(self, client, mock_episodic):
|
|
mock_episodic.search.return_value = []
|
|
|
|
response = client.get("/api/v1/memory/episodic?query=test&agent_name=my_agent")
|
|
assert response.status_code == 200
|
|
mock_episodic.search.assert_called_once()
|
|
call_kwargs = mock_episodic.search.call_args
|
|
assert call_kwargs[1]["filters"] == {"agent_name": "my_agent"} or (
|
|
call_kwargs[0] and len(call_kwargs[0]) > 2 and call_kwargs[0][2] == {"agent_name": "my_agent"}
|
|
)
|
|
|
|
def test_search_with_top_k(self, client, mock_episodic):
|
|
mock_episodic.search.return_value = []
|
|
|
|
response = client.get("/api/v1/memory/episodic?query=test&top_k=10")
|
|
assert response.status_code == 200
|
|
mock_episodic.search.assert_called_once()
|
|
|
|
def test_search_returns_empty_results(self, client, mock_episodic):
|
|
mock_episodic.search.return_value = []
|
|
|
|
response = client.get("/api/v1/memory/episodic?query=nonexistent")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["total"] == 0
|
|
assert data["results"] == []
|
|
|
|
def test_returns_503_when_retriever_not_configured(self, mock_llm_gateway):
|
|
app = create_app(
|
|
llm_gateway=mock_llm_gateway,
|
|
skill_registry=SkillRegistry(),
|
|
tool_registry=ToolRegistry(),
|
|
)
|
|
app.state.memory_retriever = None
|
|
client = TestClient(app)
|
|
response = client.get("/api/v1/memory/episodic?query=test")
|
|
assert response.status_code == 503
|
|
|
|
def test_returns_503_when_episodic_not_configured(self, mock_llm_gateway):
|
|
retriever = MemoryRetriever(episodic_memory=None, semantic_memory=None)
|
|
app = create_app(
|
|
llm_gateway=mock_llm_gateway,
|
|
skill_registry=SkillRegistry(),
|
|
tool_registry=ToolRegistry(),
|
|
)
|
|
app.state.memory_retriever = retriever
|
|
client = TestClient(app)
|
|
response = client.get("/api/v1/memory/episodic?query=test")
|
|
assert response.status_code == 503
|
|
|
|
|
|
class TestSearchSemanticMemory:
|
|
"""GET /api/v1/memory/semantic/search"""
|
|
|
|
def test_search_returns_results(self, client, mock_semantic):
|
|
mock_semantic.search.return_value = [
|
|
MemoryItem(
|
|
key="doc-1",
|
|
value="Relevant document content",
|
|
score=0.92,
|
|
metadata={"source": "rag", "knowledge_base_id": "kb-1"},
|
|
),
|
|
]
|
|
|
|
response = client.get("/api/v1/memory/semantic/search?query=hello")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["query"] == "hello"
|
|
assert data["total"] == 1
|
|
assert data["results"][0]["key"] == "doc-1"
|
|
|
|
def test_search_with_knowledge_base_ids(self, client, mock_semantic):
|
|
mock_semantic.search.return_value = []
|
|
|
|
response = client.get("/api/v1/memory/semantic/search?query=test&knowledge_base_ids=kb1,kb2")
|
|
assert response.status_code == 200
|
|
mock_semantic.search.assert_called_once()
|
|
call_args = mock_semantic.search.call_args
|
|
# filters is passed as keyword arg
|
|
filters = call_args.kwargs.get("filters") or call_args[1].get("filters")
|
|
assert filters is not None
|
|
assert "knowledge_base_ids" in filters
|
|
assert filters["knowledge_base_ids"] == ["kb1", "kb2"]
|
|
|
|
def test_search_returns_empty_results(self, client, mock_semantic):
|
|
mock_semantic.search.return_value = []
|
|
|
|
response = client.get("/api/v1/memory/semantic/search?query=nonexistent")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["total"] == 0
|
|
|
|
def test_returns_503_when_retriever_not_configured(self, mock_llm_gateway):
|
|
app = create_app(
|
|
llm_gateway=mock_llm_gateway,
|
|
skill_registry=SkillRegistry(),
|
|
tool_registry=ToolRegistry(),
|
|
)
|
|
app.state.memory_retriever = None
|
|
client = TestClient(app)
|
|
response = client.get("/api/v1/memory/semantic/search?query=test")
|
|
assert response.status_code == 503
|
|
|
|
def test_returns_503_when_semantic_not_configured(self, mock_llm_gateway):
|
|
retriever = MemoryRetriever(episodic_memory=None, semantic_memory=None)
|
|
app = create_app(
|
|
llm_gateway=mock_llm_gateway,
|
|
skill_registry=SkillRegistry(),
|
|
tool_registry=ToolRegistry(),
|
|
)
|
|
app.state.memory_retriever = retriever
|
|
client = TestClient(app)
|
|
response = client.get("/api/v1/memory/semantic/search?query=test")
|
|
assert response.status_code == 503
|
|
|
|
|
|
class TestDeleteEpisodicMemory:
|
|
"""DELETE /api/v1/memory/episodic/{key}"""
|
|
|
|
def test_delete_succeeds(self, client, mock_episodic):
|
|
mock_episodic.delete.return_value = True
|
|
|
|
response = client.delete("/api/v1/memory/episodic/ep-123")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["key"] == "ep-123"
|
|
assert data["deleted"] is True
|
|
|
|
def test_delete_returns_404_when_not_found(self, client, mock_episodic):
|
|
mock_episodic.delete.return_value = False
|
|
|
|
response = client.delete("/api/v1/memory/episodic/nonexistent")
|
|
assert response.status_code == 404
|
|
|
|
def test_returns_503_when_retriever_not_configured(self, mock_llm_gateway):
|
|
app = create_app(
|
|
llm_gateway=mock_llm_gateway,
|
|
skill_registry=SkillRegistry(),
|
|
tool_registry=ToolRegistry(),
|
|
)
|
|
app.state.memory_retriever = None
|
|
client = TestClient(app)
|
|
response = client.delete("/api/v1/memory/episodic/ep-1")
|
|
assert response.status_code == 503
|
|
|
|
def test_returns_503_when_episodic_not_configured(self, mock_llm_gateway):
|
|
retriever = MemoryRetriever(episodic_memory=None, semantic_memory=None)
|
|
app = create_app(
|
|
llm_gateway=mock_llm_gateway,
|
|
skill_registry=SkillRegistry(),
|
|
tool_registry=ToolRegistry(),
|
|
)
|
|
app.state.memory_retriever = retriever
|
|
client = TestClient(app)
|
|
response = client.delete("/api/v1/memory/episodic/ep-1")
|
|
assert response.status_code == 503
|