feat(chat): add Chat API routes with REST + WebSocket bidirectional communication

This commit is contained in:
chiguyong 2026-06-07 22:49:26 +08:00
parent 493187782c
commit 6013d5189b
4 changed files with 525 additions and 1 deletions

View File

@ -21,7 +21,7 @@ from agentkit.skills.base import Skill, SkillConfig
from agentkit.skills.registry import SkillRegistry
from agentkit.tools.registry import ToolRegistry
from agentkit.server.config import ServerConfig
from agentkit.server.routes import agents, tasks, skills, llm, health, metrics, ws, evolution, memory
from agentkit.server.routes import agents, tasks, skills, llm, health, metrics, ws, evolution, memory, chat
from agentkit.server.middleware import APIKeyAuthMiddleware, RateLimitMiddleware
from agentkit.server.task_store import create_task_store
from agentkit.server.runner import BackgroundRunner
@ -301,6 +301,19 @@ def create_app(
app.state.server_config = server_config
app.state.api_key = effective_api_key
# Initialize session manager for Chat mode
from agentkit.session.manager import SessionManager
from agentkit.session.store import create_session_store
session_config = {}
if server_config and hasattr(server_config, "session") and server_config.session:
session_config = server_config.session
session_store = create_session_store(
backend=session_config.get("backend", "memory"),
redis_url=session_config.get("redis_url", "redis://localhost:6379/0"),
ttl_seconds=session_config.get("ttl_seconds", 86400),
)
app.state.session_manager = SessionManager(store=session_store)
# Initialize evolution store if configured
if server_config and hasattr(server_config, 'evolution') and server_config.evolution:
try:
@ -426,5 +439,6 @@ def create_app(
app.include_router(ws.router, prefix="/api/v1")
app.include_router(evolution.router, prefix="/api/v1")
app.include_router(memory.router, prefix="/api/v1")
app.include_router(chat.router, prefix="/api/v1")
return app

View File

@ -106,6 +106,7 @@ class ServerConfig:
mcp_servers: dict[str, MCPServerConfig] | None = None,
telemetry: dict[str, Any] | None = None,
compression: dict[str, Any] | None = None,
session: dict[str, Any] | None = None,
on_change: Callable[["ServerConfig"], None] | None = None,
):
self.host = host
@ -124,6 +125,7 @@ class ServerConfig:
self.mcp_servers = mcp_servers or {}
self.telemetry = telemetry or {}
self.compression = compression or {}
self.session = session or {}
self.on_change = on_change
# Config watching state
@ -172,6 +174,9 @@ class ServerConfig:
# Compression config
compression_data = data.get("compression", {})
# Session config
session_data = data.get("session", {})
return cls(
host=server.get("host", "0.0.0.0"),
port=server.get("port", 8001),
@ -189,6 +194,7 @@ class ServerConfig:
mcp_servers=mcp_servers,
telemetry=telemetry_data,
compression=compression_data,
session=session_data,
)
@staticmethod
@ -380,6 +386,7 @@ class ServerConfig:
self.mcp_servers = new_config.mcp_servers
self.telemetry = new_config.telemetry
self.compression = new_config.compression
self.session = new_config.session
self._last_mtime = new_config._last_mtime
logger.info(f"Config reloaded from {path}")

View File

@ -0,0 +1,405 @@
"""Chat API routes — multi-turn conversation with Agent via REST and WebSocket."""
from __future__ import annotations
import asyncio
import json
import logging
from typing import Any
from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect, Request
from pydantic import BaseModel
from agentkit.core.protocol import CancellationToken
from agentkit.core.react import ReActEngine
from agentkit.session.manager import SessionManager
from agentkit.session.models import MessageRole, SessionStatus
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/chat", tags=["chat"])
# ── Request/Response schemas ──────────────────────────────────────────
class CreateSessionRequest(BaseModel):
agent_name: str
metadata: dict[str, Any] | None = None
class SendMessageRequest(BaseModel):
content: str
role: str = "user"
class SessionResponse(BaseModel):
session_id: str
agent_name: str
status: str
metadata: dict[str, Any]
created_at: str
updated_at: str
class MessageResponse(BaseModel):
message_id: str
session_id: str
role: str
content: str
tool_call_id: str | None = None
agent_name: str | None = None
created_at: str
# ── Chat WebSocket connection manager ─────────────────────────────────
class ChatConnectionManager:
"""Track active WebSocket connections per session_id."""
def __init__(self) -> None:
# session_id -> list of (websocket, pending_replies)
self._connections: dict[str, list[tuple[WebSocket, dict[str, asyncio.Future]]]] = {}
def add(self, session_id: str, ws: WebSocket, pending: dict[str, asyncio.Future]) -> None:
self._connections.setdefault(session_id, []).append((ws, pending))
def remove(self, session_id: str, ws: WebSocket) -> None:
conns = self._connections.get(session_id)
if conns is None:
return
self._connections[session_id] = [(w, p) for w, p in conns if w is not ws]
if not self._connections[session_id]:
del self._connections[session_id]
def get_connections(self, session_id: str) -> list[tuple[WebSocket, dict[str, asyncio.Future]]]:
return self._connections.get(session_id, [])
async def send_json(self, session_id: str, message: dict) -> None:
"""Send a JSON message to all connections for a session."""
conns = self._connections.get(session_id, [])
stale: list[WebSocket] = []
for ws, _ in conns:
try:
await ws.send_json(message)
except Exception:
stale.append(ws)
for ws in stale:
self.remove(session_id, ws)
chat_manager = ChatConnectionManager()
# ── Helper ────────────────────────────────────────────────────────────
def _get_session_manager(request: Request) -> SessionManager:
return request.app.state.session_manager
def _session_to_response(session) -> SessionResponse:
return SessionResponse(
session_id=session.session_id,
agent_name=session.agent_name,
status=session.status.value,
metadata=session.metadata,
created_at=session.created_at.isoformat(),
updated_at=session.updated_at.isoformat(),
)
def _message_to_response(msg) -> MessageResponse:
return MessageResponse(
message_id=msg.message_id,
session_id=msg.session_id,
role=msg.role.value,
content=msg.content,
tool_call_id=msg.tool_call_id,
agent_name=msg.agent_name,
created_at=msg.created_at.isoformat(),
)
# ── REST endpoints ────────────────────────────────────────────────────
@router.post("/sessions", response_model=SessionResponse)
async def create_session(request: CreateSessionRequest, req: Request):
"""Create a new chat session bound to an Agent."""
sm = _get_session_manager(req)
session = await sm.create_session(
agent_name=request.agent_name,
metadata=request.metadata,
)
return _session_to_response(session)
@router.get("/sessions/{session_id}", response_model=SessionResponse)
async def get_session(session_id: str, req: Request):
"""Get session information."""
sm = _get_session_manager(req)
session = await sm.get_session(session_id)
if session is None:
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
return _session_to_response(session)
@router.get("/sessions/{session_id}/messages", response_model=list[MessageResponse])
async def get_messages(session_id: str, limit: int | None = None, offset: int = 0, req: Request = None):
"""Get conversation history for a session."""
sm = _get_session_manager(req)
session = await sm.get_session(session_id)
if session is None:
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
messages = await sm.get_messages(session_id, limit=limit, offset=offset)
return [_message_to_response(m) for m in messages]
@router.post("/sessions/{session_id}/messages", response_model=MessageResponse)
async def send_message(session_id: str, request: SendMessageRequest, req: Request):
"""Send a message to the Agent (synchronous mode — waits for full reply)."""
sm = _get_session_manager(req)
session = await sm.get_session(session_id)
if session is None:
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
if session.status == SessionStatus.CLOSED:
raise HTTPException(status_code=400, detail=f"Session '{session_id}' is closed")
# Append user message
user_msg = await sm.append_message(
session_id=session_id,
role=MessageRole.USER,
content=request.content,
)
# Get full conversation history for the Agent
chat_messages = await sm.get_chat_messages(session_id)
# Resolve the Agent
pool = req.app.state.agent_pool
agent = pool.get_agent(session.agent_name)
if agent is None:
raise HTTPException(status_code=404, detail=f"Agent '{session.agent_name}' not found")
# Execute the Agent
try:
react_engine = ReActEngine(llm_gateway=req.app.state.llm_gateway)
tools = list(agent._tool_registry._tools.values()) if agent._tool_registry else []
result = await react_engine.execute(
messages=chat_messages,
tools=tools,
model=agent._llm_model if hasattr(agent, "_llm_model") else "default",
agent_name=agent.name,
system_prompt=agent._system_prompt if hasattr(agent, "_system_prompt") else None,
)
# Append assistant reply
assistant_msg = await sm.append_message(
session_id=session_id,
role=MessageRole.ASSISTANT,
content=result.output if hasattr(result, "output") else str(result),
agent_name=agent.name,
)
return _message_to_response(assistant_msg)
except Exception as e:
logger.error(f"Chat execution error for session {session_id}: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.delete("/sessions/{session_id}")
async def close_session(session_id: str, req: Request):
"""Close a chat session."""
sm = _get_session_manager(req)
session = await sm.close_session(session_id)
if session is None:
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
return {"status": "closed", "session_id": session_id}
# ── WebSocket endpoint ────────────────────────────────────────────────
@router.websocket("/ws/{session_id}")
async def chat_websocket(websocket: WebSocket, session_id: str) -> None:
"""WebSocket endpoint for real-time chat with streaming.
Client Server messages:
{"type": "message", "content": "..."} Send a user message
{"type": "cancel"} Cancel current execution
{"type": "ping"} Heartbeat
Server Client messages:
{"type": "connected", "session_id": "..."} Connection confirmed
{"type": "token", "content": "..."} LLM token streaming
{"type": "step", "data": {...}} ReAct step event
{"type": "ask_human", "question": "...", "request_id": "..."} Agent asks user
{"type": "final_answer", "content": "..."} Agent's final reply
{"type": "error", "data": {"message": "..."}} Error occurred
{"type": "pong"} Heartbeat response
"""
# Authentication
configured_api_key: str | None = None
if hasattr(websocket.app.state, "server_config") and websocket.app.state.server_config:
configured_api_key = websocket.app.state.server_config.api_key
if configured_api_key is None and hasattr(websocket.app.state, "api_key"):
configured_api_key = websocket.app.state.api_key
if configured_api_key:
provided = websocket.query_params.get("api_key")
if provided != configured_api_key:
await websocket.accept()
await websocket.send_json({"type": "error", "data": {"message": "Invalid api_key"}})
await websocket.close(code=4001, reason="Invalid api_key")
return
await websocket.accept()
# Validate session
sm: SessionManager = websocket.app.state.session_manager
session = await sm.get_session(session_id)
if session is None:
await websocket.send_json({"type": "error", "data": {"message": f"Session '{session_id}' not found"}})
await websocket.close(code=1000, reason="Session not found")
return
if session.status == SessionStatus.CLOSED:
await websocket.send_json({"type": "error", "data": {"message": "Session is closed"}})
await websocket.close(code=1000, reason="Session closed")
return
# Track pending replies for AskHumanTool
pending_replies: dict[str, asyncio.Future] = {}
chat_manager.add(session_id, websocket, pending_replies)
cancellation_token = CancellationToken()
try:
await websocket.send_json({"type": "connected", "session_id": session_id})
# Listen for client messages
while True:
try:
raw = await asyncio.wait_for(websocket.receive_text(), timeout=300.0)
except asyncio.TimeoutError:
await websocket.send_json({"type": "pong"})
continue
try:
msg = json.loads(raw)
except json.JSONDecodeError:
continue
msg_type = msg.get("type")
if msg_type == "message":
content = msg.get("content", "")
await _handle_chat_message(
websocket, session_id, content, sm, cancellation_token, pending_replies
)
elif msg_type == "reply":
# Reply to AskHumanTool
request_id = msg.get("request_id")
reply_content = msg.get("content", "")
if request_id and request_id in pending_replies:
pending_replies[request_id].set_result(reply_content)
elif msg_type == "cancel":
cancellation_token.cancel()
await websocket.send_json({"type": "result", "data": {"status": "cancelled"}})
elif msg_type == "ping":
await websocket.send_json({"type": "pong"})
except WebSocketDisconnect:
logger.debug(f"Chat WebSocket disconnected for session {session_id}")
except Exception as e:
logger.error(f"Chat WebSocket error for session {session_id}: {e}")
try:
await websocket.send_json({"type": "error", "data": {"message": str(e)}})
except Exception:
pass
finally:
# Clean up pending futures
for fut in pending_replies.values():
if not fut.done():
fut.cancel()
chat_manager.remove(session_id, websocket)
async def _handle_chat_message(
websocket: WebSocket,
session_id: str,
content: str,
sm: SessionManager,
cancellation_token: CancellationToken,
pending_replies: dict[str, asyncio.Future],
) -> None:
"""Handle a user message: append to session, execute Agent, stream events."""
# Append user message
await sm.append_message(session_id=session_id, role=MessageRole.USER, content=content)
# Get full conversation history
chat_messages = await sm.get_chat_messages(session_id)
# Resolve Agent
pool = websocket.app.state.agent_pool
session = await sm.get_session(session_id)
if session is None:
await websocket.send_json({"type": "error", "data": {"message": "Session lost"}})
return
agent = pool.get_agent(session.agent_name)
if agent is None:
await websocket.send_json({"type": "error", "data": {"message": f"Agent '{session.agent_name}' not found"}})
return
# Execute Agent with streaming
react_engine = ReActEngine(llm_gateway=websocket.app.state.llm_gateway)
tools = list(agent._tool_registry._tools.values()) if agent._tool_registry else []
try:
final_content = ""
async for event in react_engine.execute_stream(
messages=chat_messages,
tools=tools,
model=agent._llm_model if hasattr(agent, "_llm_model") else "default",
agent_name=agent.name,
system_prompt=agent._system_prompt if hasattr(agent, "_system_prompt") else None,
cancellation_token=cancellation_token,
):
if event.event_type == "final_answer":
final_content = event.data.get("output", "")
await websocket.send_json({
"type": "final_answer",
"content": final_content,
})
elif event.event_type == "token":
await websocket.send_json({
"type": "token",
"content": event.data.get("content", ""),
})
else:
await websocket.send_json({
"type": "step",
"data": {
"event_type": event.event_type,
"step": event.step,
"data": event.data,
},
})
# Append assistant reply to session
if final_content:
await sm.append_message(
session_id=session_id,
role=MessageRole.ASSISTANT,
content=final_content,
agent_name=agent.name,
)
except Exception as e:
await websocket.send_json({"type": "error", "data": {"message": str(e)}})

View File

@ -0,0 +1,98 @@
"""Tests for Chat API routes."""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from fastapi.testclient import TestClient
from agentkit.session.manager import SessionManager
from agentkit.session.store import InMemorySessionStore
from agentkit.session.models import SessionStatus
@pytest.fixture
def app_with_chat():
"""Create a FastAPI app with Chat routes and mocked dependencies."""
from fastapi import FastAPI
from agentkit.server.routes.chat import router
app = FastAPI()
app.include_router(router, prefix="/api/v1")
# Mock app.state dependencies
app.state.session_manager = SessionManager(store=InMemorySessionStore())
app.state.llm_gateway = MagicMock()
app.state.agent_pool = MagicMock()
app.state.server_config = MagicMock()
app.state.server_config.api_key = None
return app
@pytest.fixture
def client(app_with_chat):
return TestClient(app_with_chat)
class TestChatSessionCRUD:
def test_create_session(self, client):
resp = client.post("/api/v1/chat/sessions", json={"agent_name": "test-agent"})
assert resp.status_code == 200
data = resp.json()
assert data["agent_name"] == "test-agent"
assert data["status"] == "active"
assert "session_id" in data
def test_get_session(self, client):
create_resp = client.post("/api/v1/chat/sessions", json={"agent_name": "test-agent"})
session_id = create_resp.json()["session_id"]
get_resp = client.get(f"/api/v1/chat/sessions/{session_id}")
assert get_resp.status_code == 200
assert get_resp.json()["session_id"] == session_id
def test_get_nonexistent_session(self, client):
resp = client.get("/api/v1/chat/sessions/nonexistent")
assert resp.status_code == 404
def test_close_session(self, client):
create_resp = client.post("/api/v1/chat/sessions", json={"agent_name": "test-agent"})
session_id = create_resp.json()["session_id"]
close_resp = client.delete(f"/api/v1/chat/sessions/{session_id}")
assert close_resp.status_code == 200
assert close_resp.json()["status"] == "closed"
def test_close_nonexistent_session(self, client):
resp = client.delete("/api/v1/chat/sessions/nonexistent")
assert resp.status_code == 404
def test_get_messages_empty(self, client):
create_resp = client.post("/api/v1/chat/sessions", json={"agent_name": "test-agent"})
session_id = create_resp.json()["session_id"]
msgs_resp = client.get(f"/api/v1/chat/sessions/{session_id}/messages")
assert msgs_resp.status_code == 200
assert msgs_resp.json() == []
def test_get_messages_nonexistent_session(self, client):
resp = client.get("/api/v1/chat/sessions/nonexistent/messages")
assert resp.status_code == 404
def test_send_message_closed_session(self, client):
create_resp = client.post("/api/v1/chat/sessions", json={"agent_name": "test-agent"})
session_id = create_resp.json()["session_id"]
client.delete(f"/api/v1/chat/sessions/{session_id}")
msg_resp = client.post(
f"/api/v1/chat/sessions/{session_id}/messages",
json={"content": "Hello"},
)
assert msg_resp.status_code == 400
def test_send_message_nonexistent_session(self, client):
resp = client.post(
"/api/v1/chat/sessions/nonexistent/messages",
json={"content": "Hello"},
)
assert resp.status_code == 404