443 lines
16 KiB
Python
443 lines
16 KiB
Python
"""Tests for Portal API routes"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from unittest.mock import AsyncMock
|
|
|
|
# Note: AGENTKIT_WS_TIMEOUT=0 is set in tests/conftest.py (before portal import)
|
|
|
|
import pytest
|
|
from fastapi.testclient import TestClient
|
|
|
|
from agentkit.llm.gateway import LLMGateway
|
|
from agentkit.llm.protocol import LLMResponse, TokenUsage
|
|
from agentkit.server.app import create_app
|
|
from agentkit.chat.sqlite_conversation_store import SqliteConversationStore
|
|
from agentkit.server.routes.portal import CAPABILITY_CATEGORIES
|
|
from agentkit.skills.base import Skill, SkillConfig
|
|
from agentkit.skills.registry import SkillRegistry
|
|
from agentkit.tools.registry import ToolRegistry
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fixtures
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_llm_gateway():
|
|
gateway = LLMGateway()
|
|
mock_provider = AsyncMock()
|
|
mock_provider.chat.return_value = LLMResponse(
|
|
content='{"skill": "chat_skill", "confidence": 0.9}',
|
|
model="test-model",
|
|
usage=TokenUsage(prompt_tokens=10, completion_tokens=20),
|
|
)
|
|
gateway.register_provider("test", mock_provider)
|
|
return gateway
|
|
|
|
|
|
@pytest.fixture
|
|
def skill_registry():
|
|
return SkillRegistry()
|
|
|
|
|
|
@pytest.fixture
|
|
def tool_registry():
|
|
return ToolRegistry()
|
|
|
|
|
|
@pytest.fixture
|
|
def app(mock_llm_gateway, skill_registry, tool_registry):
|
|
return create_app(
|
|
llm_gateway=mock_llm_gateway,
|
|
skill_registry=skill_registry,
|
|
tool_registry=tool_registry,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def client(app):
|
|
return TestClient(app)
|
|
|
|
|
|
def _register_skill(registry: SkillRegistry, name: str = "chat_skill", **kwargs):
|
|
"""Helper to register a skill with sensible defaults."""
|
|
config = SkillConfig(
|
|
name=name,
|
|
agent_type="chat_type",
|
|
task_mode="llm_generate",
|
|
prompt={"identity": "Chat Skill", "instructions": "Handle chat"},
|
|
intent={"keywords": ["chat", "hello"], "description": "A chat skill"},
|
|
**kwargs,
|
|
)
|
|
skill = Skill(config=config)
|
|
registry.register(skill)
|
|
return skill
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# ConversationStore unit tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestConversationStore:
|
|
"""Tests for SqliteConversationStore (async, in-memory DB)."""
|
|
|
|
@pytest.fixture
|
|
def store(self, tmp_path):
|
|
return SqliteConversationStore(db_path=str(tmp_path / "test.db"))
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_or_create_new(self, store):
|
|
conv = await store.get_or_create()
|
|
assert conv.id is not None
|
|
assert conv.messages == []
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_or_create_with_id(self, store):
|
|
conv = await store.get_or_create("test-id-123")
|
|
assert conv.id == "test-id-123"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_or_create_reuse(self, store):
|
|
await store.get_or_create("reuse-id")
|
|
await store.add_message("reuse-id", "user", "hello")
|
|
conv2 = await store.get_or_create("reuse-id")
|
|
assert conv2.id == "reuse-id"
|
|
assert len(conv2.messages) == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_add_message(self, store):
|
|
conv = await store.get_or_create("msg-id")
|
|
msg = await store.add_message("msg-id", "user", "hello")
|
|
assert msg.role == "user"
|
|
assert msg.content == "hello"
|
|
assert len(conv.messages) == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_add_message_not_found(self, store):
|
|
with pytest.raises(KeyError):
|
|
await store.add_message("nonexistent", "user", "hello")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_history(self, store):
|
|
await store.get_or_create("hist-id")
|
|
await store.add_message("hist-id", "user", "msg1")
|
|
await store.add_message("hist-id", "assistant", "msg2")
|
|
history = await store.get_history("hist-id")
|
|
assert len(history) == 2
|
|
assert history[0].role == "user"
|
|
assert history[1].role == "assistant"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_history_limit(self, store):
|
|
await store.get_or_create("limit-id")
|
|
for i in range(10):
|
|
await store.add_message("limit-id", "user", f"msg{i}")
|
|
history = await store.get_history("limit-id", limit=3)
|
|
assert len(history) == 3
|
|
assert history[0].content == "msg7"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_history_nonexistent(self, store):
|
|
history = await store.get_history("no-such-id")
|
|
assert history == []
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_conversations(self, store):
|
|
await store.get_or_create("conv-a")
|
|
await store.get_or_create("conv-b")
|
|
convs = await store.list_conversations()
|
|
assert len(convs) == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_conversations_limit(self, store):
|
|
for i in range(5):
|
|
await store.get_or_create(f"conv-{i}")
|
|
convs = await store.list_conversations(limit=2)
|
|
assert len(convs) == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_max_conversations_eviction(self, tmp_path):
|
|
store = SqliteConversationStore(
|
|
db_path=str(tmp_path / "evict.db"), max_conversations=3
|
|
)
|
|
for i in range(5):
|
|
await store.get_or_create(f"evict-{i}")
|
|
assert len(store._cache) <= 3
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# POST /portal/chat
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestPortalChat:
|
|
def test_chat_with_skill_name(self, client, skill_registry):
|
|
_register_skill(skill_registry, "chat_skill")
|
|
|
|
response = client.post(
|
|
"/api/v1/portal/chat",
|
|
json={"message": "hello", "skill_name": "chat_skill"},
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["conversation_id"] is not None
|
|
assert data["matched_skill"] == "chat_skill"
|
|
assert data["routing_method"] == "skill_prefix"
|
|
assert data["confidence"] == 1.0
|
|
assert data["status"] == "completed"
|
|
|
|
def test_chat_with_intent_routing(self, client, skill_registry):
|
|
_register_skill(skill_registry, "chat_skill")
|
|
|
|
response = client.post(
|
|
"/api/v1/portal/chat",
|
|
json={"message": "hello chat"},
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["matched_skill"] is not None
|
|
assert data["routing_method"] is not None
|
|
assert data["conversation_id"] is not None
|
|
|
|
def test_chat_no_skills_available(self, client):
|
|
"""Greeting fast-path works even without skills (DIRECT_CHAT mode)."""
|
|
response = client.post(
|
|
"/api/v1/portal/chat",
|
|
json={"message": "hello"},
|
|
)
|
|
# Greeting regex fast-path: no skill needed, returns 200
|
|
assert response.status_code == 200
|
|
|
|
def test_chat_skill_not_found(self, client):
|
|
response = client.post(
|
|
"/api/v1/portal/chat",
|
|
json={"message": "hello", "skill_name": "nonexistent_skill"},
|
|
)
|
|
assert response.status_code == 404
|
|
|
|
def test_chat_with_conversation_id(self, client, skill_registry):
|
|
_register_skill(skill_registry, "chat_skill")
|
|
|
|
response1 = client.post(
|
|
"/api/v1/portal/chat",
|
|
json={"message": "hello", "skill_name": "chat_skill"},
|
|
)
|
|
conv_id = response1.json()["conversation_id"]
|
|
|
|
response2 = client.post(
|
|
"/api/v1/portal/chat",
|
|
json={"message": "follow up", "skill_name": "chat_skill", "conversation_id": conv_id},
|
|
)
|
|
assert response2.status_code == 200
|
|
assert response2.json()["conversation_id"] == conv_id
|
|
|
|
def test_chat_with_sources(self, client, skill_registry):
|
|
_register_skill(skill_registry, "chat_skill")
|
|
|
|
response = client.post(
|
|
"/api/v1/portal/chat",
|
|
json={
|
|
"message": "search docs",
|
|
"skill_name": "chat_skill",
|
|
"sources": ["wiki", "docs"],
|
|
},
|
|
)
|
|
assert response.status_code == 200
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# GET /portal/capabilities
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestPortalCapabilities:
|
|
def test_capabilities_returns_list(self, client):
|
|
response = client.get("/api/v1/portal/capabilities")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "capabilities" in data
|
|
caps = data["capabilities"]
|
|
assert len(caps) == len(CAPABILITY_CATEGORIES)
|
|
|
|
def test_capabilities_structure(self, client):
|
|
response = client.get("/api/v1/portal/capabilities")
|
|
caps = response.json()["capabilities"]
|
|
for cap in caps:
|
|
assert "name" in cap
|
|
assert "display_name" in cap
|
|
assert "description" in cap
|
|
assert "icon" in cap
|
|
assert "enabled" in cap
|
|
assert "skill_count" in cap
|
|
|
|
def test_capabilities_with_skills(self, client, skill_registry):
|
|
_register_skill(skill_registry, "chat_skill", capabilities=["chat"])
|
|
|
|
response = client.get("/api/v1/portal/capabilities")
|
|
caps = response.json()["capabilities"]
|
|
chat_cap = next(c for c in caps if c["name"] == "chat")
|
|
assert chat_cap["skill_count"] >= 1
|
|
|
|
def test_capabilities_skills_category_always_increments(self, client, skill_registry):
|
|
_register_skill(skill_registry, "some_skill")
|
|
|
|
response = client.get("/api/v1/portal/capabilities")
|
|
caps = response.json()["capabilities"]
|
|
skills_cap = next(c for c in caps if c["name"] == "skills")
|
|
assert skills_cap["skill_count"] >= 1
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# GET /portal/conversations
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestPortalConversations:
|
|
def test_list_conversations_empty(self, client):
|
|
response = client.get("/api/v1/portal/conversations")
|
|
assert response.status_code == 200
|
|
assert isinstance(response.json(), list)
|
|
|
|
def test_list_conversations_after_chat(self, client, skill_registry):
|
|
_register_skill(skill_registry, "chat_skill")
|
|
client.post(
|
|
"/api/v1/portal/chat",
|
|
json={"message": "hello", "skill_name": "chat_skill"},
|
|
)
|
|
|
|
response = client.get("/api/v1/portal/conversations")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert len(data) >= 1
|
|
assert "id" in data[0]
|
|
assert "message_count" in data[0]
|
|
|
|
def test_list_conversations_limit(self, client, skill_registry):
|
|
_register_skill(skill_registry, "chat_skill")
|
|
for i in range(3):
|
|
client.post(
|
|
"/api/v1/portal/chat",
|
|
json={"message": f"msg{i}", "skill_name": "chat_skill"},
|
|
)
|
|
|
|
response = client.get("/api/v1/portal/conversations?limit=2")
|
|
assert response.status_code == 200
|
|
assert len(response.json()) <= 2
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# GET /portal/conversations/{id}
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestPortalConversationHistory:
|
|
def test_get_conversation_history(self, client, skill_registry):
|
|
_register_skill(skill_registry, "chat_skill")
|
|
chat_resp = client.post(
|
|
"/api/v1/portal/chat",
|
|
json={"message": "hello", "skill_name": "chat_skill"},
|
|
)
|
|
conv_id = chat_resp.json()["conversation_id"]
|
|
|
|
response = client.get(f"/api/v1/portal/conversations/{conv_id}")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
# Response is now an IConversation object, not a bare array
|
|
assert "id" in data
|
|
assert "messages" in data
|
|
assert len(data["messages"]) >= 1
|
|
assert data["messages"][0]["role"] in ("user", "assistant")
|
|
assert "content" in data["messages"][0]
|
|
assert "timestamp" in data["messages"][0]
|
|
|
|
def test_get_conversation_not_found(self, client):
|
|
response = client.get("/api/v1/portal/conversations/nonexistent-id")
|
|
assert response.status_code == 404
|
|
|
|
def test_get_conversation_history_limit(self, client, skill_registry):
|
|
_register_skill(skill_registry, "chat_skill")
|
|
chat_resp = client.post(
|
|
"/api/v1/portal/chat",
|
|
json={"message": "hello", "skill_name": "chat_skill"},
|
|
)
|
|
conv_id = chat_resp.json()["conversation_id"]
|
|
|
|
response = client.get(f"/api/v1/portal/conversations/{conv_id}?limit=1")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
# Response is now an IConversation object
|
|
assert len(data["messages"]) <= 1
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# WebSocket /portal/ws
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestPortalWebSocket:
|
|
# NOTE: Starlette TestClient's sync WS client does not properly trigger
|
|
# server-side disconnect when the `with` block exits, causing the server's
|
|
# `receive_text()` to hang indefinitely. These tests are skipped until
|
|
# we migrate to async WS testing (e.g., httpx-async or pytest-asyncio).
|
|
@pytest.mark.skip(reason="Starlette TestClient WS hangs on disconnect")
|
|
def test_ws_connect(self, client):
|
|
with client.websocket_connect("/api/v1/portal/ws") as ws:
|
|
data = ws.receive_json()
|
|
assert data["type"] == "connected"
|
|
assert "conversation_id" in data
|
|
|
|
@pytest.mark.skip(reason="Starlette TestClient WS hangs on disconnect")
|
|
def test_ws_chat_flow(self, client, skill_registry):
|
|
_register_skill(skill_registry, "chat_skill")
|
|
|
|
with client.websocket_connect("/api/v1/portal/ws") as ws:
|
|
# Receive connected message
|
|
connected = ws.receive_json()
|
|
assert connected["type"] == "connected"
|
|
|
|
# Send chat message
|
|
ws.send_json({"type": "chat", "message": "hello chat"})
|
|
|
|
# Should receive routing info
|
|
routing = ws.receive_json()
|
|
assert routing["type"] == "routing"
|
|
assert "skill" in routing
|
|
|
|
# Then receive step/result messages until we get a result
|
|
messages = []
|
|
while True:
|
|
msg = ws.receive_json()
|
|
messages.append(msg)
|
|
if msg["type"] == "result":
|
|
break
|
|
if msg["type"] == "error":
|
|
break
|
|
|
|
# At least one message should have been received
|
|
assert len(messages) >= 1
|
|
|
|
@pytest.mark.skip(reason="Starlette TestClient WS hangs on disconnect")
|
|
def test_ws_cancel(self, client):
|
|
with client.websocket_connect("/api/v1/portal/ws") as ws:
|
|
connected = ws.receive_json()
|
|
assert connected["type"] == "connected"
|
|
|
|
ws.send_json({"type": "cancel"})
|
|
result = ws.receive_json()
|
|
assert result["type"] == "result"
|
|
assert result["data"]["status"] == "cancelled"
|
|
|
|
@pytest.mark.skip(reason="Starlette TestClient WS hangs on disconnect")
|
|
def test_ws_no_skills_error(self, client):
|
|
with client.websocket_connect("/api/v1/portal/ws") as ws:
|
|
connected = ws.receive_json()
|
|
assert connected["type"] == "connected"
|
|
|
|
ws.send_json({"type": "chat", "message": "hello"})
|
|
msg = ws.receive_json()
|
|
assert msg["type"] == "error"
|
|
assert "No skills available" in msg["data"]["message"]
|