463 lines
17 KiB
Python
463 lines
17 KiB
Python
"""Chat API routes — multi-turn conversation with Agent via REST and WebSocket."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
from typing import Any
|
|
|
|
from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect, Request
|
|
from pydantic import BaseModel
|
|
|
|
from agentkit.core.protocol import CancellationToken
|
|
from agentkit.core.react import ReActEngine
|
|
from agentkit.session.manager import SessionManager
|
|
from agentkit.session.models import MessageRole, SessionStatus
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(prefix="/chat", tags=["chat"])
|
|
|
|
|
|
# ── Request/Response schemas ──────────────────────────────────────────
|
|
|
|
|
|
class CreateSessionRequest(BaseModel):
|
|
agent_name: str
|
|
metadata: dict[str, Any] | None = None
|
|
|
|
|
|
class SendMessageRequest(BaseModel):
|
|
content: str
|
|
role: str = "user"
|
|
|
|
|
|
class SessionResponse(BaseModel):
|
|
session_id: str
|
|
agent_name: str
|
|
status: str
|
|
metadata: dict[str, Any]
|
|
created_at: str
|
|
updated_at: str
|
|
|
|
|
|
class MessageResponse(BaseModel):
|
|
message_id: str
|
|
session_id: str
|
|
role: str
|
|
content: str
|
|
tool_call_id: str | None = None
|
|
agent_name: str | None = None
|
|
created_at: str
|
|
|
|
|
|
# ── Chat WebSocket connection manager ─────────────────────────────────
|
|
|
|
|
|
class ChatConnectionManager:
|
|
"""Track active WebSocket connections per session_id."""
|
|
|
|
def __init__(self) -> None:
|
|
# session_id -> list of (websocket, pending_replies)
|
|
self._connections: dict[str, list[tuple[WebSocket, dict[str, asyncio.Future]]]] = {}
|
|
|
|
def add(self, session_id: str, ws: WebSocket, pending: dict[str, asyncio.Future]) -> None:
|
|
self._connections.setdefault(session_id, []).append((ws, pending))
|
|
|
|
def remove(self, session_id: str, ws: WebSocket) -> None:
|
|
conns = self._connections.get(session_id)
|
|
if conns is None:
|
|
return
|
|
self._connections[session_id] = [(w, p) for w, p in conns if w is not ws]
|
|
if not self._connections[session_id]:
|
|
del self._connections[session_id]
|
|
|
|
def get_connections(self, session_id: str) -> list[tuple[WebSocket, dict[str, asyncio.Future]]]:
|
|
return self._connections.get(session_id, [])
|
|
|
|
async def send_json(self, session_id: str, message: dict) -> None:
|
|
"""Send a JSON message to all connections for a session."""
|
|
conns = self._connections.get(session_id, [])
|
|
stale: list[WebSocket] = []
|
|
for ws, _ in conns:
|
|
try:
|
|
await ws.send_json(message)
|
|
except Exception:
|
|
stale.append(ws)
|
|
for ws in stale:
|
|
self.remove(session_id, ws)
|
|
|
|
|
|
chat_manager = ChatConnectionManager()
|
|
|
|
|
|
# ── Helper ────────────────────────────────────────────────────────────
|
|
|
|
|
|
def _get_session_manager(request: Request) -> SessionManager:
|
|
return request.app.state.session_manager
|
|
|
|
|
|
def _session_to_response(session) -> SessionResponse:
|
|
return SessionResponse(
|
|
session_id=session.session_id,
|
|
agent_name=session.agent_name,
|
|
status=session.status.value,
|
|
metadata=session.metadata,
|
|
created_at=session.created_at.isoformat(),
|
|
updated_at=session.updated_at.isoformat(),
|
|
)
|
|
|
|
|
|
def _message_to_response(msg) -> MessageResponse:
|
|
return MessageResponse(
|
|
message_id=msg.message_id,
|
|
session_id=msg.session_id,
|
|
role=msg.role.value,
|
|
content=msg.content,
|
|
tool_call_id=msg.tool_call_id,
|
|
agent_name=msg.agent_name,
|
|
created_at=msg.created_at.isoformat(),
|
|
)
|
|
|
|
|
|
# ── REST endpoints ────────────────────────────────────────────────────
|
|
|
|
|
|
@router.get("/sessions", response_model=list[SessionResponse])
|
|
async def list_sessions(req: Request):
|
|
"""List all chat sessions."""
|
|
sm = _get_session_manager(req)
|
|
sessions = await sm.list_sessions()
|
|
return [_session_to_response(s) for s in sessions]
|
|
|
|
|
|
@router.post("/sessions", response_model=SessionResponse)
|
|
async def create_session(request: CreateSessionRequest, req: Request):
|
|
"""Create a new chat session bound to an Agent."""
|
|
sm = _get_session_manager(req)
|
|
session = await sm.create_session(
|
|
agent_name=request.agent_name,
|
|
metadata=request.metadata,
|
|
)
|
|
return _session_to_response(session)
|
|
|
|
|
|
@router.get("/sessions/{session_id}", response_model=SessionResponse)
|
|
async def get_session(session_id: str, req: Request):
|
|
"""Get session information."""
|
|
sm = _get_session_manager(req)
|
|
session = await sm.get_session(session_id)
|
|
if session is None:
|
|
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
|
|
return _session_to_response(session)
|
|
|
|
|
|
@router.get("/sessions/{session_id}/messages", response_model=list[MessageResponse])
|
|
async def get_messages(session_id: str, req: Request, limit: int | None = None, offset: int = 0):
|
|
"""Get conversation history for a session."""
|
|
sm = _get_session_manager(req)
|
|
session = await sm.get_session(session_id)
|
|
if session is None:
|
|
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
|
|
messages = await sm.get_messages(session_id, limit=limit, offset=offset)
|
|
return [_message_to_response(m) for m in messages]
|
|
|
|
|
|
@router.post("/sessions/{session_id}/messages", response_model=MessageResponse)
|
|
async def send_message(session_id: str, request: SendMessageRequest, req: Request):
|
|
"""Send a message to the Agent (synchronous mode — waits for full reply)."""
|
|
sm = _get_session_manager(req)
|
|
session = await sm.get_session(session_id)
|
|
if session is None:
|
|
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
|
|
if session.status == SessionStatus.CLOSED:
|
|
raise HTTPException(status_code=400, detail=f"Session '{session_id}' is closed")
|
|
|
|
# Append user message
|
|
user_msg = await sm.append_message(
|
|
session_id=session_id,
|
|
role=MessageRole.USER,
|
|
content=request.content,
|
|
)
|
|
|
|
# Get full conversation history for the Agent
|
|
chat_messages = await sm.get_chat_messages(session_id)
|
|
|
|
# Resolve the Agent
|
|
pool = req.app.state.agent_pool
|
|
agent = pool.get_agent(session.agent_name)
|
|
if agent is None:
|
|
raise HTTPException(status_code=404, detail=f"Agent '{session.agent_name}' not found")
|
|
|
|
# Execute the Agent
|
|
try:
|
|
react_engine = ReActEngine(llm_gateway=req.app.state.llm_gateway)
|
|
tools = agent._tool_registry.list_tools() if agent._tool_registry else []
|
|
system_prompt = getattr(agent, "_system_prompt", None) or (agent.get_system_prompt() if hasattr(agent, "get_system_prompt") else None)
|
|
result = await react_engine.execute(
|
|
messages=chat_messages,
|
|
tools=tools,
|
|
model=agent.get_model() if hasattr(agent, "get_model") else getattr(agent, "_llm_model", "default"),
|
|
agent_name=agent.name,
|
|
system_prompt=system_prompt,
|
|
)
|
|
|
|
# Append assistant reply
|
|
assistant_msg = await sm.append_message(
|
|
session_id=session_id,
|
|
role=MessageRole.ASSISTANT,
|
|
content=result.output if hasattr(result, "output") else str(result),
|
|
agent_name=agent.name,
|
|
)
|
|
return _message_to_response(assistant_msg)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Chat execution error for session {session_id}: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.delete("/sessions/{session_id}")
|
|
async def close_session(session_id: str, req: Request):
|
|
"""Close a chat session."""
|
|
sm = _get_session_manager(req)
|
|
session = await sm.close_session(session_id)
|
|
if session is None:
|
|
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
|
|
return {"status": "closed", "session_id": session_id}
|
|
|
|
|
|
# ── WebSocket endpoint ────────────────────────────────────────────────
|
|
|
|
|
|
@router.websocket("/ws/{session_id}")
|
|
async def chat_websocket(websocket: WebSocket, session_id: str) -> None:
|
|
"""WebSocket endpoint for real-time chat with streaming.
|
|
|
|
Client → Server messages:
|
|
{"type": "message", "content": "..."} — Send a user message
|
|
{"type": "cancel"} — Cancel current execution
|
|
{"type": "ping"} — Heartbeat
|
|
|
|
Server → Client messages:
|
|
{"type": "connected", "session_id": "..."} — Connection confirmed
|
|
{"type": "token", "content": "..."} — LLM token streaming
|
|
{"type": "step", "data": {...}} — ReAct step event
|
|
{"type": "ask_human", "question": "...", "request_id": "..."} — Agent asks user
|
|
{"type": "final_answer", "content": "..."} — Agent's final reply
|
|
{"type": "error", "data": {"message": "..."}} — Error occurred
|
|
{"type": "pong"} — Heartbeat response
|
|
"""
|
|
# Authentication
|
|
configured_api_key: str | None = None
|
|
if hasattr(websocket.app.state, "server_config") and websocket.app.state.server_config:
|
|
configured_api_key = websocket.app.state.server_config.api_key
|
|
if configured_api_key is None and hasattr(websocket.app.state, "api_key"):
|
|
configured_api_key = websocket.app.state.api_key
|
|
|
|
if configured_api_key:
|
|
provided = websocket.query_params.get("api_key")
|
|
if provided != configured_api_key:
|
|
await websocket.accept()
|
|
await websocket.send_json({"type": "error", "data": {"message": "Invalid api_key"}})
|
|
await websocket.close(code=4001, reason="Invalid api_key")
|
|
return
|
|
|
|
await websocket.accept()
|
|
|
|
# Validate session
|
|
sm: SessionManager = websocket.app.state.session_manager
|
|
session = await sm.get_session(session_id)
|
|
if session is None:
|
|
await websocket.send_json({"type": "error", "data": {"message": f"Session '{session_id}' not found"}})
|
|
await websocket.close(code=1000, reason="Session not found")
|
|
return
|
|
|
|
if session.status == SessionStatus.CLOSED:
|
|
await websocket.send_json({"type": "error", "data": {"message": "Session is closed"}})
|
|
await websocket.close(code=1000, reason="Session closed")
|
|
return
|
|
|
|
# Track pending replies for AskHumanTool
|
|
pending_replies: dict[str, asyncio.Future] = {}
|
|
chat_manager.add(session_id, websocket, pending_replies)
|
|
|
|
cancellation_token = CancellationToken()
|
|
|
|
try:
|
|
await websocket.send_json({"type": "connected", "session_id": session_id})
|
|
|
|
# Listen for client messages
|
|
while True:
|
|
try:
|
|
raw = await asyncio.wait_for(websocket.receive_text(), timeout=300.0)
|
|
except asyncio.TimeoutError:
|
|
await websocket.send_json({"type": "pong"})
|
|
continue
|
|
|
|
try:
|
|
msg = json.loads(raw)
|
|
except json.JSONDecodeError:
|
|
continue
|
|
|
|
msg_type = msg.get("type")
|
|
|
|
if msg_type == "message":
|
|
content = msg.get("content", "")
|
|
# Create a fresh CancellationToken for each message
|
|
message_token = CancellationToken()
|
|
await _handle_chat_message(
|
|
websocket, session_id, content, sm, message_token, pending_replies
|
|
)
|
|
|
|
elif msg_type == "reply":
|
|
# Reply to AskHumanTool
|
|
request_id = msg.get("request_id")
|
|
reply_content = msg.get("content", "")
|
|
if request_id and request_id in pending_replies:
|
|
pending_replies[request_id].set_result(reply_content)
|
|
|
|
elif msg_type == "cancel":
|
|
cancellation_token.cancel()
|
|
await websocket.send_json({"type": "result", "data": {"status": "cancelled"}})
|
|
|
|
elif msg_type == "ping":
|
|
await websocket.send_json({"type": "pong"})
|
|
|
|
except WebSocketDisconnect:
|
|
logger.debug(f"Chat WebSocket disconnected for session {session_id}")
|
|
except Exception as e:
|
|
logger.error(f"Chat WebSocket error for session {session_id}: {e}")
|
|
try:
|
|
await websocket.send_json({"type": "error", "data": {"message": str(e)}})
|
|
except Exception:
|
|
pass
|
|
finally:
|
|
# Clean up pending futures
|
|
for fut in pending_replies.values():
|
|
if not fut.done():
|
|
fut.cancel()
|
|
chat_manager.remove(session_id, websocket)
|
|
|
|
|
|
async def _handle_chat_message(
|
|
websocket: WebSocket,
|
|
session_id: str,
|
|
content: str,
|
|
sm: SessionManager,
|
|
cancellation_token: CancellationToken,
|
|
pending_replies: dict[str, asyncio.Future],
|
|
) -> None:
|
|
"""Handle a user message: append to session, execute Agent, stream events.
|
|
|
|
When skills are registered, attempts to route the user's message to a
|
|
matching skill via IntentRouter. If a skill is matched, the skill's
|
|
prompt, tools, and execution_mode are used instead of the default agent's.
|
|
"""
|
|
from agentkit.chat.skill_routing import resolve_skill_routing
|
|
|
|
# Resolve Agent first (needed for default tools/prompt)
|
|
pool = websocket.app.state.agent_pool
|
|
session = await sm.get_session(session_id)
|
|
if session is None:
|
|
await websocket.send_json({"type": "error", "data": {"message": "Session lost"}})
|
|
return
|
|
|
|
agent = pool.get_agent(session.agent_name)
|
|
if agent is None:
|
|
await websocket.send_json({"type": "error", "data": {"message": f"Agent '{session.agent_name}' not found"}})
|
|
return
|
|
|
|
# Default execution parameters from agent
|
|
default_tools = agent._tool_registry.list_tools() if agent._tool_registry else []
|
|
default_system_prompt = getattr(agent, "_system_prompt", None) or (agent.get_system_prompt() if hasattr(agent, "get_system_prompt") else None)
|
|
default_model = agent.get_model() if hasattr(agent, "get_model") else getattr(agent, "_llm_model", "default")
|
|
|
|
# Resolve skill routing using shared module
|
|
skill_registry = getattr(websocket.app.state, "skill_registry", None)
|
|
intent_router = getattr(websocket.app.state, "intent_router", None)
|
|
|
|
routing = await resolve_skill_routing(
|
|
content=content,
|
|
skill_registry=skill_registry,
|
|
intent_router=intent_router,
|
|
default_tools=default_tools,
|
|
default_system_prompt=default_system_prompt,
|
|
default_model=default_model,
|
|
default_agent_name=agent.name,
|
|
agent_tool_registry=agent._tool_registry if agent._tool_registry else None,
|
|
session_id=session_id,
|
|
)
|
|
|
|
# Notify frontend about skill match
|
|
if routing.matched:
|
|
await websocket.send_json({
|
|
"type": "skill_match",
|
|
"data": {
|
|
"skill": routing.skill_name,
|
|
"method": routing.match_method,
|
|
"confidence": routing.match_confidence,
|
|
},
|
|
})
|
|
|
|
# Append user message (use clean_content if @skill: prefix was stripped)
|
|
await sm.append_message(session_id=session_id, role=MessageRole.USER, content=routing.clean_content)
|
|
|
|
# Get full conversation history
|
|
chat_messages = await sm.get_chat_messages(session_id)
|
|
|
|
# Execute Agent with streaming
|
|
react_engine = ReActEngine(llm_gateway=websocket.app.state.llm_gateway)
|
|
|
|
logger.info(f"Chat session {session_id}: executing with {len(routing.tools)} tools, model={routing.model}, skill={routing.skill_name}")
|
|
|
|
try:
|
|
final_content = ""
|
|
async for event in react_engine.execute_stream(
|
|
messages=chat_messages,
|
|
tools=routing.tools,
|
|
model=routing.model,
|
|
agent_name=routing.agent_name,
|
|
system_prompt=routing.system_prompt,
|
|
cancellation_token=cancellation_token,
|
|
):
|
|
if event.event_type == "final_answer":
|
|
final_content = event.data.get("output", "")
|
|
await websocket.send_json({
|
|
"type": "final_answer",
|
|
"content": final_content,
|
|
})
|
|
elif event.event_type == "token":
|
|
await websocket.send_json({
|
|
"type": "token",
|
|
"content": event.data.get("content", ""),
|
|
})
|
|
else:
|
|
await websocket.send_json({
|
|
"type": "step",
|
|
"data": {
|
|
"event_type": event.event_type,
|
|
"step": event.step,
|
|
"data": event.data,
|
|
},
|
|
})
|
|
|
|
# Append assistant reply to session
|
|
if final_content:
|
|
await sm.append_message(
|
|
session_id=session_id,
|
|
role=MessageRole.ASSISTANT,
|
|
content=final_content,
|
|
agent_name=agent.name,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Chat execution error for session {session_id}: {e}")
|
|
# Show meaningful error to user, but avoid leaking full stack traces
|
|
error_msg = str(e)
|
|
# Truncate very long error messages
|
|
if len(error_msg) > 200:
|
|
error_msg = error_msg[:200] + "..."
|
|
await websocket.send_json({"type": "error", "data": {"message": error_msg}})
|