diff --git a/src/agentkit/server/app.py b/src/agentkit/server/app.py index 65d4650..6ad6018 100644 --- a/src/agentkit/server/app.py +++ b/src/agentkit/server/app.py @@ -21,7 +21,7 @@ from agentkit.skills.base import Skill, SkillConfig from agentkit.skills.registry import SkillRegistry from agentkit.tools.registry import ToolRegistry from agentkit.server.config import ServerConfig -from agentkit.server.routes import agents, tasks, skills, llm, health, metrics, ws, evolution, memory +from agentkit.server.routes import agents, tasks, skills, llm, health, metrics, ws, evolution, memory, chat from agentkit.server.middleware import APIKeyAuthMiddleware, RateLimitMiddleware from agentkit.server.task_store import create_task_store from agentkit.server.runner import BackgroundRunner @@ -301,6 +301,19 @@ def create_app( app.state.server_config = server_config app.state.api_key = effective_api_key + # Initialize session manager for Chat mode + from agentkit.session.manager import SessionManager + from agentkit.session.store import create_session_store + session_config = {} + if server_config and hasattr(server_config, "session") and server_config.session: + session_config = server_config.session + session_store = create_session_store( + backend=session_config.get("backend", "memory"), + redis_url=session_config.get("redis_url", "redis://localhost:6379/0"), + ttl_seconds=session_config.get("ttl_seconds", 86400), + ) + app.state.session_manager = SessionManager(store=session_store) + # Initialize evolution store if configured if server_config and hasattr(server_config, 'evolution') and server_config.evolution: try: @@ -426,5 +439,6 @@ def create_app( app.include_router(ws.router, prefix="/api/v1") app.include_router(evolution.router, prefix="/api/v1") app.include_router(memory.router, prefix="/api/v1") + app.include_router(chat.router, prefix="/api/v1") return app diff --git a/src/agentkit/server/config.py b/src/agentkit/server/config.py index be7b66a..449d644 100644 --- a/src/agentkit/server/config.py +++ b/src/agentkit/server/config.py @@ -106,6 +106,7 @@ class ServerConfig: mcp_servers: dict[str, MCPServerConfig] | None = None, telemetry: dict[str, Any] | None = None, compression: dict[str, Any] | None = None, + session: dict[str, Any] | None = None, on_change: Callable[["ServerConfig"], None] | None = None, ): self.host = host @@ -124,6 +125,7 @@ class ServerConfig: self.mcp_servers = mcp_servers or {} self.telemetry = telemetry or {} self.compression = compression or {} + self.session = session or {} self.on_change = on_change # Config watching state @@ -172,6 +174,9 @@ class ServerConfig: # Compression config compression_data = data.get("compression", {}) + # Session config + session_data = data.get("session", {}) + return cls( host=server.get("host", "0.0.0.0"), port=server.get("port", 8001), @@ -189,6 +194,7 @@ class ServerConfig: mcp_servers=mcp_servers, telemetry=telemetry_data, compression=compression_data, + session=session_data, ) @staticmethod @@ -380,6 +386,7 @@ class ServerConfig: self.mcp_servers = new_config.mcp_servers self.telemetry = new_config.telemetry self.compression = new_config.compression + self.session = new_config.session self._last_mtime = new_config._last_mtime logger.info(f"Config reloaded from {path}") diff --git a/src/agentkit/server/routes/chat.py b/src/agentkit/server/routes/chat.py new file mode 100644 index 0000000..e7a1ba1 --- /dev/null +++ b/src/agentkit/server/routes/chat.py @@ -0,0 +1,405 @@ +"""Chat API routes — multi-turn conversation with Agent via REST and WebSocket.""" + +from __future__ import annotations + +import asyncio +import json +import logging +from typing import Any + +from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect, Request +from pydantic import BaseModel + +from agentkit.core.protocol import CancellationToken +from agentkit.core.react import ReActEngine +from agentkit.session.manager import SessionManager +from agentkit.session.models import MessageRole, SessionStatus + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/chat", tags=["chat"]) + + +# ── Request/Response schemas ────────────────────────────────────────── + + +class CreateSessionRequest(BaseModel): + agent_name: str + metadata: dict[str, Any] | None = None + + +class SendMessageRequest(BaseModel): + content: str + role: str = "user" + + +class SessionResponse(BaseModel): + session_id: str + agent_name: str + status: str + metadata: dict[str, Any] + created_at: str + updated_at: str + + +class MessageResponse(BaseModel): + message_id: str + session_id: str + role: str + content: str + tool_call_id: str | None = None + agent_name: str | None = None + created_at: str + + +# ── Chat WebSocket connection manager ───────────────────────────────── + + +class ChatConnectionManager: + """Track active WebSocket connections per session_id.""" + + def __init__(self) -> None: + # session_id -> list of (websocket, pending_replies) + self._connections: dict[str, list[tuple[WebSocket, dict[str, asyncio.Future]]]] = {} + + def add(self, session_id: str, ws: WebSocket, pending: dict[str, asyncio.Future]) -> None: + self._connections.setdefault(session_id, []).append((ws, pending)) + + def remove(self, session_id: str, ws: WebSocket) -> None: + conns = self._connections.get(session_id) + if conns is None: + return + self._connections[session_id] = [(w, p) for w, p in conns if w is not ws] + if not self._connections[session_id]: + del self._connections[session_id] + + def get_connections(self, session_id: str) -> list[tuple[WebSocket, dict[str, asyncio.Future]]]: + return self._connections.get(session_id, []) + + async def send_json(self, session_id: str, message: dict) -> None: + """Send a JSON message to all connections for a session.""" + conns = self._connections.get(session_id, []) + stale: list[WebSocket] = [] + for ws, _ in conns: + try: + await ws.send_json(message) + except Exception: + stale.append(ws) + for ws in stale: + self.remove(session_id, ws) + + +chat_manager = ChatConnectionManager() + + +# ── Helper ──────────────────────────────────────────────────────────── + + +def _get_session_manager(request: Request) -> SessionManager: + return request.app.state.session_manager + + +def _session_to_response(session) -> SessionResponse: + return SessionResponse( + session_id=session.session_id, + agent_name=session.agent_name, + status=session.status.value, + metadata=session.metadata, + created_at=session.created_at.isoformat(), + updated_at=session.updated_at.isoformat(), + ) + + +def _message_to_response(msg) -> MessageResponse: + return MessageResponse( + message_id=msg.message_id, + session_id=msg.session_id, + role=msg.role.value, + content=msg.content, + tool_call_id=msg.tool_call_id, + agent_name=msg.agent_name, + created_at=msg.created_at.isoformat(), + ) + + +# ── REST endpoints ──────────────────────────────────────────────────── + + +@router.post("/sessions", response_model=SessionResponse) +async def create_session(request: CreateSessionRequest, req: Request): + """Create a new chat session bound to an Agent.""" + sm = _get_session_manager(req) + session = await sm.create_session( + agent_name=request.agent_name, + metadata=request.metadata, + ) + return _session_to_response(session) + + +@router.get("/sessions/{session_id}", response_model=SessionResponse) +async def get_session(session_id: str, req: Request): + """Get session information.""" + sm = _get_session_manager(req) + session = await sm.get_session(session_id) + if session is None: + raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found") + return _session_to_response(session) + + +@router.get("/sessions/{session_id}/messages", response_model=list[MessageResponse]) +async def get_messages(session_id: str, limit: int | None = None, offset: int = 0, req: Request = None): + """Get conversation history for a session.""" + sm = _get_session_manager(req) + session = await sm.get_session(session_id) + if session is None: + raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found") + messages = await sm.get_messages(session_id, limit=limit, offset=offset) + return [_message_to_response(m) for m in messages] + + +@router.post("/sessions/{session_id}/messages", response_model=MessageResponse) +async def send_message(session_id: str, request: SendMessageRequest, req: Request): + """Send a message to the Agent (synchronous mode — waits for full reply).""" + sm = _get_session_manager(req) + session = await sm.get_session(session_id) + if session is None: + raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found") + if session.status == SessionStatus.CLOSED: + raise HTTPException(status_code=400, detail=f"Session '{session_id}' is closed") + + # Append user message + user_msg = await sm.append_message( + session_id=session_id, + role=MessageRole.USER, + content=request.content, + ) + + # Get full conversation history for the Agent + chat_messages = await sm.get_chat_messages(session_id) + + # Resolve the Agent + pool = req.app.state.agent_pool + agent = pool.get_agent(session.agent_name) + if agent is None: + raise HTTPException(status_code=404, detail=f"Agent '{session.agent_name}' not found") + + # Execute the Agent + try: + react_engine = ReActEngine(llm_gateway=req.app.state.llm_gateway) + tools = list(agent._tool_registry._tools.values()) if agent._tool_registry else [] + result = await react_engine.execute( + messages=chat_messages, + tools=tools, + model=agent._llm_model if hasattr(agent, "_llm_model") else "default", + agent_name=agent.name, + system_prompt=agent._system_prompt if hasattr(agent, "_system_prompt") else None, + ) + + # Append assistant reply + assistant_msg = await sm.append_message( + session_id=session_id, + role=MessageRole.ASSISTANT, + content=result.output if hasattr(result, "output") else str(result), + agent_name=agent.name, + ) + return _message_to_response(assistant_msg) + + except Exception as e: + logger.error(f"Chat execution error for session {session_id}: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.delete("/sessions/{session_id}") +async def close_session(session_id: str, req: Request): + """Close a chat session.""" + sm = _get_session_manager(req) + session = await sm.close_session(session_id) + if session is None: + raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found") + return {"status": "closed", "session_id": session_id} + + +# ── WebSocket endpoint ──────────────────────────────────────────────── + + +@router.websocket("/ws/{session_id}") +async def chat_websocket(websocket: WebSocket, session_id: str) -> None: + """WebSocket endpoint for real-time chat with streaming. + + Client → Server messages: + {"type": "message", "content": "..."} — Send a user message + {"type": "cancel"} — Cancel current execution + {"type": "ping"} — Heartbeat + + Server → Client messages: + {"type": "connected", "session_id": "..."} — Connection confirmed + {"type": "token", "content": "..."} — LLM token streaming + {"type": "step", "data": {...}} — ReAct step event + {"type": "ask_human", "question": "...", "request_id": "..."} — Agent asks user + {"type": "final_answer", "content": "..."} — Agent's final reply + {"type": "error", "data": {"message": "..."}} — Error occurred + {"type": "pong"} — Heartbeat response + """ + # Authentication + configured_api_key: str | None = None + if hasattr(websocket.app.state, "server_config") and websocket.app.state.server_config: + configured_api_key = websocket.app.state.server_config.api_key + if configured_api_key is None and hasattr(websocket.app.state, "api_key"): + configured_api_key = websocket.app.state.api_key + + if configured_api_key: + provided = websocket.query_params.get("api_key") + if provided != configured_api_key: + await websocket.accept() + await websocket.send_json({"type": "error", "data": {"message": "Invalid api_key"}}) + await websocket.close(code=4001, reason="Invalid api_key") + return + + await websocket.accept() + + # Validate session + sm: SessionManager = websocket.app.state.session_manager + session = await sm.get_session(session_id) + if session is None: + await websocket.send_json({"type": "error", "data": {"message": f"Session '{session_id}' not found"}}) + await websocket.close(code=1000, reason="Session not found") + return + + if session.status == SessionStatus.CLOSED: + await websocket.send_json({"type": "error", "data": {"message": "Session is closed"}}) + await websocket.close(code=1000, reason="Session closed") + return + + # Track pending replies for AskHumanTool + pending_replies: dict[str, asyncio.Future] = {} + chat_manager.add(session_id, websocket, pending_replies) + + cancellation_token = CancellationToken() + + try: + await websocket.send_json({"type": "connected", "session_id": session_id}) + + # Listen for client messages + while True: + try: + raw = await asyncio.wait_for(websocket.receive_text(), timeout=300.0) + except asyncio.TimeoutError: + await websocket.send_json({"type": "pong"}) + continue + + try: + msg = json.loads(raw) + except json.JSONDecodeError: + continue + + msg_type = msg.get("type") + + if msg_type == "message": + content = msg.get("content", "") + await _handle_chat_message( + websocket, session_id, content, sm, cancellation_token, pending_replies + ) + + elif msg_type == "reply": + # Reply to AskHumanTool + request_id = msg.get("request_id") + reply_content = msg.get("content", "") + if request_id and request_id in pending_replies: + pending_replies[request_id].set_result(reply_content) + + elif msg_type == "cancel": + cancellation_token.cancel() + await websocket.send_json({"type": "result", "data": {"status": "cancelled"}}) + + elif msg_type == "ping": + await websocket.send_json({"type": "pong"}) + + except WebSocketDisconnect: + logger.debug(f"Chat WebSocket disconnected for session {session_id}") + except Exception as e: + logger.error(f"Chat WebSocket error for session {session_id}: {e}") + try: + await websocket.send_json({"type": "error", "data": {"message": str(e)}}) + except Exception: + pass + finally: + # Clean up pending futures + for fut in pending_replies.values(): + if not fut.done(): + fut.cancel() + chat_manager.remove(session_id, websocket) + + +async def _handle_chat_message( + websocket: WebSocket, + session_id: str, + content: str, + sm: SessionManager, + cancellation_token: CancellationToken, + pending_replies: dict[str, asyncio.Future], +) -> None: + """Handle a user message: append to session, execute Agent, stream events.""" + # Append user message + await sm.append_message(session_id=session_id, role=MessageRole.USER, content=content) + + # Get full conversation history + chat_messages = await sm.get_chat_messages(session_id) + + # Resolve Agent + pool = websocket.app.state.agent_pool + session = await sm.get_session(session_id) + if session is None: + await websocket.send_json({"type": "error", "data": {"message": "Session lost"}}) + return + + agent = pool.get_agent(session.agent_name) + if agent is None: + await websocket.send_json({"type": "error", "data": {"message": f"Agent '{session.agent_name}' not found"}}) + return + + # Execute Agent with streaming + react_engine = ReActEngine(llm_gateway=websocket.app.state.llm_gateway) + tools = list(agent._tool_registry._tools.values()) if agent._tool_registry else [] + + try: + final_content = "" + async for event in react_engine.execute_stream( + messages=chat_messages, + tools=tools, + model=agent._llm_model if hasattr(agent, "_llm_model") else "default", + agent_name=agent.name, + system_prompt=agent._system_prompt if hasattr(agent, "_system_prompt") else None, + cancellation_token=cancellation_token, + ): + if event.event_type == "final_answer": + final_content = event.data.get("output", "") + await websocket.send_json({ + "type": "final_answer", + "content": final_content, + }) + elif event.event_type == "token": + await websocket.send_json({ + "type": "token", + "content": event.data.get("content", ""), + }) + else: + await websocket.send_json({ + "type": "step", + "data": { + "event_type": event.event_type, + "step": event.step, + "data": event.data, + }, + }) + + # Append assistant reply to session + if final_content: + await sm.append_message( + session_id=session_id, + role=MessageRole.ASSISTANT, + content=final_content, + agent_name=agent.name, + ) + + except Exception as e: + await websocket.send_json({"type": "error", "data": {"message": str(e)}}) diff --git a/tests/unit/test_chat_routes.py b/tests/unit/test_chat_routes.py new file mode 100644 index 0000000..fbcae8c --- /dev/null +++ b/tests/unit/test_chat_routes.py @@ -0,0 +1,98 @@ +"""Tests for Chat API routes.""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from fastapi.testclient import TestClient + +from agentkit.session.manager import SessionManager +from agentkit.session.store import InMemorySessionStore +from agentkit.session.models import SessionStatus + + +@pytest.fixture +def app_with_chat(): + """Create a FastAPI app with Chat routes and mocked dependencies.""" + from fastapi import FastAPI + from agentkit.server.routes.chat import router + + app = FastAPI() + app.include_router(router, prefix="/api/v1") + + # Mock app.state dependencies + app.state.session_manager = SessionManager(store=InMemorySessionStore()) + app.state.llm_gateway = MagicMock() + app.state.agent_pool = MagicMock() + app.state.server_config = MagicMock() + app.state.server_config.api_key = None + + return app + + +@pytest.fixture +def client(app_with_chat): + return TestClient(app_with_chat) + + +class TestChatSessionCRUD: + def test_create_session(self, client): + resp = client.post("/api/v1/chat/sessions", json={"agent_name": "test-agent"}) + assert resp.status_code == 200 + data = resp.json() + assert data["agent_name"] == "test-agent" + assert data["status"] == "active" + assert "session_id" in data + + def test_get_session(self, client): + create_resp = client.post("/api/v1/chat/sessions", json={"agent_name": "test-agent"}) + session_id = create_resp.json()["session_id"] + + get_resp = client.get(f"/api/v1/chat/sessions/{session_id}") + assert get_resp.status_code == 200 + assert get_resp.json()["session_id"] == session_id + + def test_get_nonexistent_session(self, client): + resp = client.get("/api/v1/chat/sessions/nonexistent") + assert resp.status_code == 404 + + def test_close_session(self, client): + create_resp = client.post("/api/v1/chat/sessions", json={"agent_name": "test-agent"}) + session_id = create_resp.json()["session_id"] + + close_resp = client.delete(f"/api/v1/chat/sessions/{session_id}") + assert close_resp.status_code == 200 + assert close_resp.json()["status"] == "closed" + + def test_close_nonexistent_session(self, client): + resp = client.delete("/api/v1/chat/sessions/nonexistent") + assert resp.status_code == 404 + + def test_get_messages_empty(self, client): + create_resp = client.post("/api/v1/chat/sessions", json={"agent_name": "test-agent"}) + session_id = create_resp.json()["session_id"] + + msgs_resp = client.get(f"/api/v1/chat/sessions/{session_id}/messages") + assert msgs_resp.status_code == 200 + assert msgs_resp.json() == [] + + def test_get_messages_nonexistent_session(self, client): + resp = client.get("/api/v1/chat/sessions/nonexistent/messages") + assert resp.status_code == 404 + + def test_send_message_closed_session(self, client): + create_resp = client.post("/api/v1/chat/sessions", json={"agent_name": "test-agent"}) + session_id = create_resp.json()["session_id"] + + client.delete(f"/api/v1/chat/sessions/{session_id}") + + msg_resp = client.post( + f"/api/v1/chat/sessions/{session_id}/messages", + json={"content": "Hello"}, + ) + assert msg_resp.status_code == 400 + + def test_send_message_nonexistent_session(self, client): + resp = client.post( + "/api/v1/chat/sessions/nonexistent/messages", + json={"content": "Hello"}, + ) + assert resp.status_code == 404