fischer-agentkit/src/agentkit/server/routes/chat.py

463 lines
17 KiB
Python

"""Chat API routes — multi-turn conversation with Agent via REST and WebSocket."""
from __future__ import annotations
import asyncio
import json
import logging
from typing import Any
from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect, Request
from pydantic import BaseModel
from agentkit.core.protocol import CancellationToken
from agentkit.core.react import ReActEngine
from agentkit.session.manager import SessionManager
from agentkit.session.models import MessageRole, SessionStatus
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/chat", tags=["chat"])
# ── Request/Response schemas ──────────────────────────────────────────
class CreateSessionRequest(BaseModel):
agent_name: str
metadata: dict[str, Any] | None = None
class SendMessageRequest(BaseModel):
content: str
role: str = "user"
class SessionResponse(BaseModel):
session_id: str
agent_name: str
status: str
metadata: dict[str, Any]
created_at: str
updated_at: str
class MessageResponse(BaseModel):
message_id: str
session_id: str
role: str
content: str
tool_call_id: str | None = None
agent_name: str | None = None
created_at: str
# ── Chat WebSocket connection manager ─────────────────────────────────
class ChatConnectionManager:
"""Track active WebSocket connections per session_id."""
def __init__(self) -> None:
# session_id -> list of (websocket, pending_replies)
self._connections: dict[str, list[tuple[WebSocket, dict[str, asyncio.Future]]]] = {}
def add(self, session_id: str, ws: WebSocket, pending: dict[str, asyncio.Future]) -> None:
self._connections.setdefault(session_id, []).append((ws, pending))
def remove(self, session_id: str, ws: WebSocket) -> None:
conns = self._connections.get(session_id)
if conns is None:
return
self._connections[session_id] = [(w, p) for w, p in conns if w is not ws]
if not self._connections[session_id]:
del self._connections[session_id]
def get_connections(self, session_id: str) -> list[tuple[WebSocket, dict[str, asyncio.Future]]]:
return self._connections.get(session_id, [])
async def send_json(self, session_id: str, message: dict) -> None:
"""Send a JSON message to all connections for a session."""
conns = self._connections.get(session_id, [])
stale: list[WebSocket] = []
for ws, _ in conns:
try:
await ws.send_json(message)
except Exception:
stale.append(ws)
for ws in stale:
self.remove(session_id, ws)
chat_manager = ChatConnectionManager()
# ── Helper ────────────────────────────────────────────────────────────
def _get_session_manager(request: Request) -> SessionManager:
return request.app.state.session_manager
def _session_to_response(session) -> SessionResponse:
return SessionResponse(
session_id=session.session_id,
agent_name=session.agent_name,
status=session.status.value,
metadata=session.metadata,
created_at=session.created_at.isoformat(),
updated_at=session.updated_at.isoformat(),
)
def _message_to_response(msg) -> MessageResponse:
return MessageResponse(
message_id=msg.message_id,
session_id=msg.session_id,
role=msg.role.value,
content=msg.content,
tool_call_id=msg.tool_call_id,
agent_name=msg.agent_name,
created_at=msg.created_at.isoformat(),
)
# ── REST endpoints ────────────────────────────────────────────────────
@router.get("/sessions", response_model=list[SessionResponse])
async def list_sessions(req: Request):
"""List all chat sessions."""
sm = _get_session_manager(req)
sessions = await sm.list_sessions()
return [_session_to_response(s) for s in sessions]
@router.post("/sessions", response_model=SessionResponse)
async def create_session(request: CreateSessionRequest, req: Request):
"""Create a new chat session bound to an Agent."""
sm = _get_session_manager(req)
session = await sm.create_session(
agent_name=request.agent_name,
metadata=request.metadata,
)
return _session_to_response(session)
@router.get("/sessions/{session_id}", response_model=SessionResponse)
async def get_session(session_id: str, req: Request):
"""Get session information."""
sm = _get_session_manager(req)
session = await sm.get_session(session_id)
if session is None:
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
return _session_to_response(session)
@router.get("/sessions/{session_id}/messages", response_model=list[MessageResponse])
async def get_messages(session_id: str, req: Request, limit: int | None = None, offset: int = 0):
"""Get conversation history for a session."""
sm = _get_session_manager(req)
session = await sm.get_session(session_id)
if session is None:
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
messages = await sm.get_messages(session_id, limit=limit, offset=offset)
return [_message_to_response(m) for m in messages]
@router.post("/sessions/{session_id}/messages", response_model=MessageResponse)
async def send_message(session_id: str, request: SendMessageRequest, req: Request):
"""Send a message to the Agent (synchronous mode — waits for full reply)."""
sm = _get_session_manager(req)
session = await sm.get_session(session_id)
if session is None:
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
if session.status == SessionStatus.CLOSED:
raise HTTPException(status_code=400, detail=f"Session '{session_id}' is closed")
# Append user message
user_msg = await sm.append_message(
session_id=session_id,
role=MessageRole.USER,
content=request.content,
)
# Get full conversation history for the Agent
chat_messages = await sm.get_chat_messages(session_id)
# Resolve the Agent
pool = req.app.state.agent_pool
agent = pool.get_agent(session.agent_name)
if agent is None:
raise HTTPException(status_code=404, detail=f"Agent '{session.agent_name}' not found")
# Execute the Agent
try:
react_engine = ReActEngine(llm_gateway=req.app.state.llm_gateway)
tools = agent._tool_registry.list_tools() if agent._tool_registry else []
system_prompt = getattr(agent, "_system_prompt", None) or (agent.get_system_prompt() if hasattr(agent, "get_system_prompt") else None)
result = await react_engine.execute(
messages=chat_messages,
tools=tools,
model=agent.get_model() if hasattr(agent, "get_model") else getattr(agent, "_llm_model", "default"),
agent_name=agent.name,
system_prompt=system_prompt,
)
# Append assistant reply
assistant_msg = await sm.append_message(
session_id=session_id,
role=MessageRole.ASSISTANT,
content=result.output if hasattr(result, "output") else str(result),
agent_name=agent.name,
)
return _message_to_response(assistant_msg)
except Exception as e:
logger.error(f"Chat execution error for session {session_id}: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.delete("/sessions/{session_id}")
async def close_session(session_id: str, req: Request):
"""Close a chat session."""
sm = _get_session_manager(req)
session = await sm.close_session(session_id)
if session is None:
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
return {"status": "closed", "session_id": session_id}
# ── WebSocket endpoint ────────────────────────────────────────────────
@router.websocket("/ws/{session_id}")
async def chat_websocket(websocket: WebSocket, session_id: str) -> None:
"""WebSocket endpoint for real-time chat with streaming.
Client → Server messages:
{"type": "message", "content": "..."} — Send a user message
{"type": "cancel"} — Cancel current execution
{"type": "ping"} — Heartbeat
Server → Client messages:
{"type": "connected", "session_id": "..."} — Connection confirmed
{"type": "token", "content": "..."} — LLM token streaming
{"type": "step", "data": {...}} — ReAct step event
{"type": "ask_human", "question": "...", "request_id": "..."} — Agent asks user
{"type": "final_answer", "content": "..."} — Agent's final reply
{"type": "error", "data": {"message": "..."}} — Error occurred
{"type": "pong"} — Heartbeat response
"""
# Authentication
configured_api_key: str | None = None
if hasattr(websocket.app.state, "server_config") and websocket.app.state.server_config:
configured_api_key = websocket.app.state.server_config.api_key
if configured_api_key is None and hasattr(websocket.app.state, "api_key"):
configured_api_key = websocket.app.state.api_key
if configured_api_key:
provided = websocket.query_params.get("api_key")
if provided != configured_api_key:
await websocket.accept()
await websocket.send_json({"type": "error", "data": {"message": "Invalid api_key"}})
await websocket.close(code=4001, reason="Invalid api_key")
return
await websocket.accept()
# Validate session
sm: SessionManager = websocket.app.state.session_manager
session = await sm.get_session(session_id)
if session is None:
await websocket.send_json({"type": "error", "data": {"message": f"Session '{session_id}' not found"}})
await websocket.close(code=1000, reason="Session not found")
return
if session.status == SessionStatus.CLOSED:
await websocket.send_json({"type": "error", "data": {"message": "Session is closed"}})
await websocket.close(code=1000, reason="Session closed")
return
# Track pending replies for AskHumanTool
pending_replies: dict[str, asyncio.Future] = {}
chat_manager.add(session_id, websocket, pending_replies)
cancellation_token = CancellationToken()
try:
await websocket.send_json({"type": "connected", "session_id": session_id})
# Listen for client messages
while True:
try:
raw = await asyncio.wait_for(websocket.receive_text(), timeout=300.0)
except asyncio.TimeoutError:
await websocket.send_json({"type": "pong"})
continue
try:
msg = json.loads(raw)
except json.JSONDecodeError:
continue
msg_type = msg.get("type")
if msg_type == "message":
content = msg.get("content", "")
# Create a fresh CancellationToken for each message
message_token = CancellationToken()
await _handle_chat_message(
websocket, session_id, content, sm, message_token, pending_replies
)
elif msg_type == "reply":
# Reply to AskHumanTool
request_id = msg.get("request_id")
reply_content = msg.get("content", "")
if request_id and request_id in pending_replies:
pending_replies[request_id].set_result(reply_content)
elif msg_type == "cancel":
cancellation_token.cancel()
await websocket.send_json({"type": "result", "data": {"status": "cancelled"}})
elif msg_type == "ping":
await websocket.send_json({"type": "pong"})
except WebSocketDisconnect:
logger.debug(f"Chat WebSocket disconnected for session {session_id}")
except Exception as e:
logger.error(f"Chat WebSocket error for session {session_id}: {e}")
try:
await websocket.send_json({"type": "error", "data": {"message": str(e)}})
except Exception:
pass
finally:
# Clean up pending futures
for fut in pending_replies.values():
if not fut.done():
fut.cancel()
chat_manager.remove(session_id, websocket)
async def _handle_chat_message(
websocket: WebSocket,
session_id: str,
content: str,
sm: SessionManager,
cancellation_token: CancellationToken,
pending_replies: dict[str, asyncio.Future],
) -> None:
"""Handle a user message: append to session, execute Agent, stream events.
When skills are registered, attempts to route the user's message to a
matching skill via IntentRouter. If a skill is matched, the skill's
prompt, tools, and execution_mode are used instead of the default agent's.
"""
from agentkit.chat.skill_routing import resolve_skill_routing
# Resolve Agent first (needed for default tools/prompt)
pool = websocket.app.state.agent_pool
session = await sm.get_session(session_id)
if session is None:
await websocket.send_json({"type": "error", "data": {"message": "Session lost"}})
return
agent = pool.get_agent(session.agent_name)
if agent is None:
await websocket.send_json({"type": "error", "data": {"message": f"Agent '{session.agent_name}' not found"}})
return
# Default execution parameters from agent
default_tools = agent._tool_registry.list_tools() if agent._tool_registry else []
default_system_prompt = getattr(agent, "_system_prompt", None) or (agent.get_system_prompt() if hasattr(agent, "get_system_prompt") else None)
default_model = agent.get_model() if hasattr(agent, "get_model") else getattr(agent, "_llm_model", "default")
# Resolve skill routing using shared module
skill_registry = getattr(websocket.app.state, "skill_registry", None)
intent_router = getattr(websocket.app.state, "intent_router", None)
routing = await resolve_skill_routing(
content=content,
skill_registry=skill_registry,
intent_router=intent_router,
default_tools=default_tools,
default_system_prompt=default_system_prompt,
default_model=default_model,
default_agent_name=agent.name,
agent_tool_registry=agent._tool_registry if agent._tool_registry else None,
session_id=session_id,
)
# Notify frontend about skill match
if routing.matched:
await websocket.send_json({
"type": "skill_match",
"data": {
"skill": routing.skill_name,
"method": routing.match_method,
"confidence": routing.match_confidence,
},
})
# Append user message (use clean_content if @skill: prefix was stripped)
await sm.append_message(session_id=session_id, role=MessageRole.USER, content=routing.clean_content)
# Get full conversation history
chat_messages = await sm.get_chat_messages(session_id)
# Execute Agent with streaming
react_engine = ReActEngine(llm_gateway=websocket.app.state.llm_gateway)
logger.info(f"Chat session {session_id}: executing with {len(routing.tools)} tools, model={routing.model}, skill={routing.skill_name}")
try:
final_content = ""
async for event in react_engine.execute_stream(
messages=chat_messages,
tools=routing.tools,
model=routing.model,
agent_name=routing.agent_name,
system_prompt=routing.system_prompt,
cancellation_token=cancellation_token,
):
if event.event_type == "final_answer":
final_content = event.data.get("output", "")
await websocket.send_json({
"type": "final_answer",
"content": final_content,
})
elif event.event_type == "token":
await websocket.send_json({
"type": "token",
"content": event.data.get("content", ""),
})
else:
await websocket.send_json({
"type": "step",
"data": {
"event_type": event.event_type,
"step": event.step,
"data": event.data,
},
})
# Append assistant reply to session
if final_content:
await sm.append_message(
session_id=session_id,
role=MessageRole.ASSISTANT,
content=final_content,
agent_name=agent.name,
)
except Exception as e:
logger.error(f"Chat execution error for session {session_id}: {e}")
# Show meaningful error to user, but avoid leaking full stack traces
error_msg = str(e)
# Truncate very long error messages
if len(error_msg) > 200:
error_msg = error_msg[:200] + "..."
await websocket.send_json({"type": "error", "data": {"message": error_msg}})