fischer-agentkit/tests/unit/server/test_portal_routes.py

443 lines
16 KiB
Python

"""Tests for Portal API routes"""
from __future__ import annotations
from unittest.mock import AsyncMock
# Note: AGENTKIT_WS_TIMEOUT=0 is set in tests/conftest.py (before portal import)
import pytest
from fastapi.testclient import TestClient
from agentkit.llm.gateway import LLMGateway
from agentkit.llm.protocol import LLMResponse, TokenUsage
from agentkit.server.app import create_app
from agentkit.chat.sqlite_conversation_store import SqliteConversationStore
from agentkit.server.routes.portal import CAPABILITY_CATEGORIES
from agentkit.skills.base import Skill, SkillConfig
from agentkit.skills.registry import SkillRegistry
from agentkit.tools.registry import ToolRegistry
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def mock_llm_gateway():
gateway = LLMGateway()
mock_provider = AsyncMock()
mock_provider.chat.return_value = LLMResponse(
content='{"skill": "chat_skill", "confidence": 0.9}',
model="test-model",
usage=TokenUsage(prompt_tokens=10, completion_tokens=20),
)
gateway.register_provider("test", mock_provider)
return gateway
@pytest.fixture
def skill_registry():
return SkillRegistry()
@pytest.fixture
def tool_registry():
return ToolRegistry()
@pytest.fixture
def app(mock_llm_gateway, skill_registry, tool_registry):
return create_app(
llm_gateway=mock_llm_gateway,
skill_registry=skill_registry,
tool_registry=tool_registry,
)
@pytest.fixture
def client(app):
return TestClient(app)
def _register_skill(registry: SkillRegistry, name: str = "chat_skill", **kwargs):
"""Helper to register a skill with sensible defaults."""
config = SkillConfig(
name=name,
agent_type="chat_type",
task_mode="llm_generate",
prompt={"identity": "Chat Skill", "instructions": "Handle chat"},
intent={"keywords": ["chat", "hello"], "description": "A chat skill"},
**kwargs,
)
skill = Skill(config=config)
registry.register(skill)
return skill
# ---------------------------------------------------------------------------
# ConversationStore unit tests
# ---------------------------------------------------------------------------
class TestConversationStore:
"""Tests for SqliteConversationStore (async, in-memory DB)."""
@pytest.fixture
def store(self, tmp_path):
return SqliteConversationStore(db_path=str(tmp_path / "test.db"))
@pytest.mark.asyncio
async def test_get_or_create_new(self, store):
conv = await store.get_or_create()
assert conv.id is not None
assert conv.messages == []
@pytest.mark.asyncio
async def test_get_or_create_with_id(self, store):
conv = await store.get_or_create("test-id-123")
assert conv.id == "test-id-123"
@pytest.mark.asyncio
async def test_get_or_create_reuse(self, store):
await store.get_or_create("reuse-id")
await store.add_message("reuse-id", "user", "hello")
conv2 = await store.get_or_create("reuse-id")
assert conv2.id == "reuse-id"
assert len(conv2.messages) == 1
@pytest.mark.asyncio
async def test_add_message(self, store):
conv = await store.get_or_create("msg-id")
msg = await store.add_message("msg-id", "user", "hello")
assert msg.role == "user"
assert msg.content == "hello"
assert len(conv.messages) == 1
@pytest.mark.asyncio
async def test_add_message_not_found(self, store):
with pytest.raises(KeyError):
await store.add_message("nonexistent", "user", "hello")
@pytest.mark.asyncio
async def test_get_history(self, store):
await store.get_or_create("hist-id")
await store.add_message("hist-id", "user", "msg1")
await store.add_message("hist-id", "assistant", "msg2")
history = await store.get_history("hist-id")
assert len(history) == 2
assert history[0].role == "user"
assert history[1].role == "assistant"
@pytest.mark.asyncio
async def test_get_history_limit(self, store):
await store.get_or_create("limit-id")
for i in range(10):
await store.add_message("limit-id", "user", f"msg{i}")
history = await store.get_history("limit-id", limit=3)
assert len(history) == 3
assert history[0].content == "msg7"
@pytest.mark.asyncio
async def test_get_history_nonexistent(self, store):
history = await store.get_history("no-such-id")
assert history == []
@pytest.mark.asyncio
async def test_list_conversations(self, store):
await store.get_or_create("conv-a")
await store.get_or_create("conv-b")
convs = await store.list_conversations()
assert len(convs) == 2
@pytest.mark.asyncio
async def test_list_conversations_limit(self, store):
for i in range(5):
await store.get_or_create(f"conv-{i}")
convs = await store.list_conversations(limit=2)
assert len(convs) == 2
@pytest.mark.asyncio
async def test_max_conversations_eviction(self, tmp_path):
store = SqliteConversationStore(
db_path=str(tmp_path / "evict.db"), max_conversations=3
)
for i in range(5):
await store.get_or_create(f"evict-{i}")
assert len(store._cache) <= 3
# ---------------------------------------------------------------------------
# POST /portal/chat
# ---------------------------------------------------------------------------
class TestPortalChat:
def test_chat_with_skill_name(self, client, skill_registry):
_register_skill(skill_registry, "chat_skill")
response = client.post(
"/api/v1/portal/chat",
json={"message": "hello", "skill_name": "chat_skill"},
)
assert response.status_code == 200
data = response.json()
assert data["conversation_id"] is not None
assert data["matched_skill"] == "chat_skill"
assert data["routing_method"] == "skill_prefix"
assert data["confidence"] == 1.0
assert data["status"] == "completed"
def test_chat_with_intent_routing(self, client, skill_registry):
_register_skill(skill_registry, "chat_skill")
response = client.post(
"/api/v1/portal/chat",
json={"message": "hello chat"},
)
assert response.status_code == 200
data = response.json()
assert data["matched_skill"] is not None
assert data["routing_method"] is not None
assert data["conversation_id"] is not None
def test_chat_no_skills_available(self, client):
"""Greeting fast-path works even without skills (DIRECT_CHAT mode)."""
response = client.post(
"/api/v1/portal/chat",
json={"message": "hello"},
)
# Greeting regex fast-path: no skill needed, returns 200
assert response.status_code == 200
def test_chat_skill_not_found(self, client):
response = client.post(
"/api/v1/portal/chat",
json={"message": "hello", "skill_name": "nonexistent_skill"},
)
assert response.status_code == 404
def test_chat_with_conversation_id(self, client, skill_registry):
_register_skill(skill_registry, "chat_skill")
response1 = client.post(
"/api/v1/portal/chat",
json={"message": "hello", "skill_name": "chat_skill"},
)
conv_id = response1.json()["conversation_id"]
response2 = client.post(
"/api/v1/portal/chat",
json={"message": "follow up", "skill_name": "chat_skill", "conversation_id": conv_id},
)
assert response2.status_code == 200
assert response2.json()["conversation_id"] == conv_id
def test_chat_with_sources(self, client, skill_registry):
_register_skill(skill_registry, "chat_skill")
response = client.post(
"/api/v1/portal/chat",
json={
"message": "search docs",
"skill_name": "chat_skill",
"sources": ["wiki", "docs"],
},
)
assert response.status_code == 200
# ---------------------------------------------------------------------------
# GET /portal/capabilities
# ---------------------------------------------------------------------------
class TestPortalCapabilities:
def test_capabilities_returns_list(self, client):
response = client.get("/api/v1/portal/capabilities")
assert response.status_code == 200
data = response.json()
assert "capabilities" in data
caps = data["capabilities"]
assert len(caps) == len(CAPABILITY_CATEGORIES)
def test_capabilities_structure(self, client):
response = client.get("/api/v1/portal/capabilities")
caps = response.json()["capabilities"]
for cap in caps:
assert "name" in cap
assert "display_name" in cap
assert "description" in cap
assert "icon" in cap
assert "enabled" in cap
assert "skill_count" in cap
def test_capabilities_with_skills(self, client, skill_registry):
_register_skill(skill_registry, "chat_skill", capabilities=["chat"])
response = client.get("/api/v1/portal/capabilities")
caps = response.json()["capabilities"]
chat_cap = next(c for c in caps if c["name"] == "chat")
assert chat_cap["skill_count"] >= 1
def test_capabilities_skills_category_always_increments(self, client, skill_registry):
_register_skill(skill_registry, "some_skill")
response = client.get("/api/v1/portal/capabilities")
caps = response.json()["capabilities"]
skills_cap = next(c for c in caps if c["name"] == "skills")
assert skills_cap["skill_count"] >= 1
# ---------------------------------------------------------------------------
# GET /portal/conversations
# ---------------------------------------------------------------------------
class TestPortalConversations:
def test_list_conversations_empty(self, client):
response = client.get("/api/v1/portal/conversations")
assert response.status_code == 200
assert isinstance(response.json(), list)
def test_list_conversations_after_chat(self, client, skill_registry):
_register_skill(skill_registry, "chat_skill")
client.post(
"/api/v1/portal/chat",
json={"message": "hello", "skill_name": "chat_skill"},
)
response = client.get("/api/v1/portal/conversations")
assert response.status_code == 200
data = response.json()
assert len(data) >= 1
assert "id" in data[0]
assert "message_count" in data[0]
def test_list_conversations_limit(self, client, skill_registry):
_register_skill(skill_registry, "chat_skill")
for i in range(3):
client.post(
"/api/v1/portal/chat",
json={"message": f"msg{i}", "skill_name": "chat_skill"},
)
response = client.get("/api/v1/portal/conversations?limit=2")
assert response.status_code == 200
assert len(response.json()) <= 2
# ---------------------------------------------------------------------------
# GET /portal/conversations/{id}
# ---------------------------------------------------------------------------
class TestPortalConversationHistory:
def test_get_conversation_history(self, client, skill_registry):
_register_skill(skill_registry, "chat_skill")
chat_resp = client.post(
"/api/v1/portal/chat",
json={"message": "hello", "skill_name": "chat_skill"},
)
conv_id = chat_resp.json()["conversation_id"]
response = client.get(f"/api/v1/portal/conversations/{conv_id}")
assert response.status_code == 200
data = response.json()
# Response is now an IConversation object, not a bare array
assert "id" in data
assert "messages" in data
assert len(data["messages"]) >= 1
assert data["messages"][0]["role"] in ("user", "assistant")
assert "content" in data["messages"][0]
assert "timestamp" in data["messages"][0]
def test_get_conversation_not_found(self, client):
response = client.get("/api/v1/portal/conversations/nonexistent-id")
assert response.status_code == 404
def test_get_conversation_history_limit(self, client, skill_registry):
_register_skill(skill_registry, "chat_skill")
chat_resp = client.post(
"/api/v1/portal/chat",
json={"message": "hello", "skill_name": "chat_skill"},
)
conv_id = chat_resp.json()["conversation_id"]
response = client.get(f"/api/v1/portal/conversations/{conv_id}?limit=1")
assert response.status_code == 200
data = response.json()
# Response is now an IConversation object
assert len(data["messages"]) <= 1
# ---------------------------------------------------------------------------
# WebSocket /portal/ws
# ---------------------------------------------------------------------------
class TestPortalWebSocket:
# NOTE: Starlette TestClient's sync WS client does not properly trigger
# server-side disconnect when the `with` block exits, causing the server's
# `receive_text()` to hang indefinitely. These tests are skipped until
# we migrate to async WS testing (e.g., httpx-async or pytest-asyncio).
@pytest.mark.skip(reason="Starlette TestClient WS hangs on disconnect")
def test_ws_connect(self, client):
with client.websocket_connect("/api/v1/portal/ws") as ws:
data = ws.receive_json()
assert data["type"] == "connected"
assert "conversation_id" in data
@pytest.mark.skip(reason="Starlette TestClient WS hangs on disconnect")
def test_ws_chat_flow(self, client, skill_registry):
_register_skill(skill_registry, "chat_skill")
with client.websocket_connect("/api/v1/portal/ws") as ws:
# Receive connected message
connected = ws.receive_json()
assert connected["type"] == "connected"
# Send chat message
ws.send_json({"type": "chat", "message": "hello chat"})
# Should receive routing info
routing = ws.receive_json()
assert routing["type"] == "routing"
assert "skill" in routing
# Then receive step/result messages until we get a result
messages = []
while True:
msg = ws.receive_json()
messages.append(msg)
if msg["type"] == "result":
break
if msg["type"] == "error":
break
# At least one message should have been received
assert len(messages) >= 1
@pytest.mark.skip(reason="Starlette TestClient WS hangs on disconnect")
def test_ws_cancel(self, client):
with client.websocket_connect("/api/v1/portal/ws") as ws:
connected = ws.receive_json()
assert connected["type"] == "connected"
ws.send_json({"type": "cancel"})
result = ws.receive_json()
assert result["type"] == "result"
assert result["data"]["status"] == "cancelled"
@pytest.mark.skip(reason="Starlette TestClient WS hangs on disconnect")
def test_ws_no_skills_error(self, client):
with client.websocket_connect("/api/v1/portal/ws") as ws:
connected = ws.receive_json()
assert connected["type"] == "connected"
ws.send_json({"type": "chat", "message": "hello"})
msg = ws.receive_json()
assert msg["type"] == "error"
assert "No skills available" in msg["data"]["message"]