"""Chat API routes — multi-turn conversation with Agent via REST and WebSocket.""" from __future__ import annotations import asyncio import hmac import json import logging from typing import Any, TYPE_CHECKING import os import uuid from pathlib import Path from fastapi import ( APIRouter, File, HTTPException, Request, UploadFile, WebSocket, WebSocketDisconnect, ) from fastapi.responses import FileResponse from pydantic import BaseModel from agentkit.chat.skill_routing import ExecutionMode from agentkit.core.phase import default_policy, policy_from_config from agentkit.core.protocol import CancellationToken from agentkit.core.react import ReActEngine from agentkit.server._fallback_chain import execute_with_fallback_chain from agentkit.session.manager import SessionManager from agentkit.session.models import MessageRole, SessionStatus from agentkit.tools.advance_phase import AdvancePhaseTool if TYPE_CHECKING: from agentkit.llm.gateway import LLMGateway from agentkit.server.config import ServerConfig from agentkit.tools.base import Tool logger = logging.getLogger(__name__) router = APIRouter(prefix="/chat", tags=["chat"]) # ── Request/Response schemas ────────────────────────────────────────── class CreateSessionRequest(BaseModel): agent_name: str metadata: dict[str, Any] | None = None class SendMessageRequest(BaseModel): content: str role: str = "user" # Optional execution mode override. "plan_exec" → 501 (KTD4: WebSocket only). execution_mode: str | None = None class SessionResponse(BaseModel): session_id: str agent_name: str status: str metadata: dict[str, Any] created_at: str updated_at: str class MessageResponse(BaseModel): message_id: str session_id: str role: str content: str tool_call_id: str | None = None agent_name: str | None = None created_at: str # ── Chat WebSocket connection manager ───────────────────────────────── class ChatConnectionManager: """Track active WebSocket connections per session_id.""" def __init__(self) -> None: # session_id -> list of (websocket, pending_replies) self._connections: dict[str, list[tuple[WebSocket, dict[str, asyncio.Future]]]] = {} def add(self, session_id: str, ws: WebSocket, pending: dict[str, asyncio.Future]) -> None: self._connections.setdefault(session_id, []).append((ws, pending)) def remove(self, session_id: str, ws: WebSocket) -> None: conns = self._connections.get(session_id) if conns is None: return self._connections[session_id] = [(w, p) for w, p in conns if w is not ws] if not self._connections[session_id]: del self._connections[session_id] def get_connections(self, session_id: str) -> list[tuple[WebSocket, dict[str, asyncio.Future]]]: return self._connections.get(session_id, []) async def send_json(self, session_id: str, message: dict) -> None: """Send a JSON message to all connections for a session.""" conns = self._connections.get(session_id, []) stale: list[WebSocket] = [] for ws, _ in conns: try: await ws.send_json(message) except (ConnectionError, RuntimeError, asyncio.TimeoutError): stale.append(ws) for ws in stale: self.remove(session_id, ws) chat_manager = ChatConnectionManager() # U4: Active team sessions — maps session_id to the ExpertTeam currently executing. # When a message arrives during team execution, it is routed as an intervention # instead of starting a new chat task. Populated by _execute_team_collab. _active_teams: dict[str, "object"] = {} def _register_active_team(session_id: str, team: "object") -> None: """Register an active team for a session (intervention routing).""" _active_teams[session_id] = team def _unregister_active_team(session_id: str) -> None: """Unregister the active team for a session.""" _active_teams.pop(session_id, None) def _get_active_team(session_id: str) -> "object | None": """Get the active team for a session, if any.""" return _active_teams.get(session_id) # ── Helper ──────────────────────────────────────────────────────────── _VALID_TEAM_EVENT_TYPES = frozenset( { "team_formed", "expert_step", "expert_result", "expert_result_chunk", "expert_result_chunk_reset", "plan_update", "team_synthesis", "team_synthesis_chunk", "team_dissolved", "plan_step", "phase_started", "phase_completed", "phase_failed", "replanning", # PM Collaboration 模式事件 (U1-U4) "collaboration_contract_defined", "collaboration_notice", "review_result", "risk_flagged", # Board Meeting 模式事件 "board_started", "expert_speech", "expert_speech_chunk", "round_summary", "user_intervention", "board_concluded", } ) async def emit_team_event(websocket: WebSocket, event_type: str, data: dict) -> None: """Emit a team-related WebSocket event. Supported event types: team_formed — Team assembled with expert list and plan. data: {team_id, status, experts, plan_phases, lead_expert} expert_step — An expert is executing a step. data: {expert_id, expert_name, expert_color, content, step} expert_result — An expert produced a result. data: {expert_id, expert_name, expert_color, content} plan_update — Team plan phases updated. data: {plan_phases} team_synthesis — Final synthesis from the lead expert. data: {content} team_dissolved — Team dissolved after completion. data: {team_id} """ if event_type not in _VALID_TEAM_EVENT_TYPES: logger.warning(f"emit_team_event: invalid event_type '{event_type}'") return await websocket.send_json( { "type": event_type, "data": data, } ) def _get_session_manager(request: Request) -> SessionManager: return request.app.state.session_manager async def _execute_board_meeting( websocket: WebSocket, session_id: str, content: str, sm: SessionManager, ) -> bool: """Intercept @board prefix and execute a board meeting discussion. Returns True if the input was handled as a board meeting (caller should return), False if the input should continue through the normal chat pipeline. Flow: 1. Resolve @board routing via BoardRouter 2. Create BoardTeam with expert configs 3. Register handoff_transport handler to relay events to WebSocket 4. Execute BoardOrchestrator 5. Send final conclusion as final_answer 6. Persist user topic + final summary to session history """ from agentkit.experts.board_router import BoardRouter from agentkit.experts.board import BoardTeam from agentkit.experts.board_orchestrator import BoardOrchestrator app_state = websocket.app.state # Resolve ExpertTemplateRegistry from app.state (loaded at startup) template_registry = getattr(app_state, "expert_template_registry", None) if template_registry is None: from agentkit.experts.registry import ExpertTemplateRegistry template_registry = ExpertTemplateRegistry() board_router = BoardRouter(template_registry=template_registry) routing_result = board_router.resolve(content) if not routing_result.matched: return False # Not a @board input, continue normal pipeline if not routing_result.topic: await websocket.send_json( { "type": "error", "data": {"message": "私董会需要一个讨论主题,例如:@board 如何看待 AI 未来"}, } ) return True # Resolve expert configs from specified experts or default template expert_configs = board_router.resolve_expert_configs(routing_result.specified_experts) if not expert_configs: await websocket.send_json( {"type": "error", "data": {"message": "无法解析私董会成员,请检查专家名称或模板配置"}} ) return True # Read board config from server_config if available; user-specified rounds take precedence max_rounds = routing_result.max_rounds if max_rounds is None: max_rounds = 5 server_config = getattr(app_state, "server_config", None) if server_config is not None: board_cfg = getattr(server_config, "board", None) or {} if isinstance(board_cfg, dict): max_rounds = int(board_cfg.get("max_rounds", 5)) # Create BoardTeam team = BoardTeam( pool=app_state.agent_pool, template_registry=template_registry, max_rounds=max_rounds, ) # Register handoff_transport handler to relay board events to WebSocket async def _relay_board_event(message: dict) -> None: msg_type = message.get("type") if not msg_type: return # Strip internal fields, keep only event data event_data = {k: v for k, v in message.items() if k != "type"} await emit_team_event(websocket, msg_type, event_data) # Persist board events so a page reload can reconstruct the # discussion (otherwise the user only sees the final conclusion # and loses every expert speech). We persist on the terminal # "board_started" / "expert_speech" / "round_summary" / # "board_concluded" events — chunks are intermediate and would # explode the history size. The "board_started" payload carries # the expert list (name/avatar/color/is_moderator/persona) so a # reload can rebuild the boardState and the sidebar can show the # "私董会" badge — without it, every restored conversation # appears as a plain chat with 0 experts. # # We also stash the rendering hint (message_type + expert identity) # in Message.metadata so a reload can still show the speech as a # proper board_speech card instead of a plain assistant bubble. experts_data = event_data.get("experts") board_started_text = ( f"私董会开始:{event_data.get('topic', '')}" if event_data.get("topic") else "私董会开始" ) persistable: dict[str, tuple[str, str, dict[str, object] | None]] = { "board_started": ( "assistant", board_started_text, { "message_type": "board_started", "board_started": { "team_id": event_data.get("team_id"), "topic": event_data.get("topic"), "experts": experts_data, "max_rounds": event_data.get("max_rounds"), }, }, ), "expert_speech": ( "assistant", event_data.get("content", ""), { "message_type": "board_speech", "expert_name": event_data.get("expert_name"), "expert_avatar": event_data.get("expert_avatar"), "expert_color": event_data.get("expert_color"), "board_round": event_data.get("round"), "board_role": event_data.get("role"), }, ), "round_summary": ( "assistant", event_data.get("content", ""), { "message_type": "board_summary", "expert_name": event_data.get("moderator_name"), "expert_avatar": event_data.get("moderator_avatar"), "expert_color": event_data.get("moderator_color"), "board_round": event_data.get("round"), "board_role": "summary", }, ), "board_concluded": ( "assistant", event_data.get("summary") or "私董会已结束", { "message_type": "board_conclusion", "board_round": event_data.get("total_rounds"), "board_conclusion": { "summary": event_data.get("summary", ""), "decision_advice": event_data.get("decision_advice", ""), "total_rounds": event_data.get("total_rounds", 0), "consensus_points": event_data.get("consensus_points", []), "dissent_points": event_data.get("dissent_points", []), }, }, ), } if msg_type in persistable: role, content, meta = persistable[msg_type] if content: try: await sm.append_message( session_id=session_id, role=MessageRole(role), content=content, metadata=meta, ) except Exception as persist_err: # noqa: BLE001 # Persistence is best-effort; never let a save error # break the live stream. logger.warning(f"Failed to persist {msg_type} to session store: {persist_err}") team.handoff_transport.register_handler(team.team_channel, _relay_board_event) # Append user topic to session history — store the resolved topic, not # the raw "@board:experts rounds=5 ..." syntax. Frontend renders the # original prefix as a structured bubble; persisting it raw means the # bubble shows the full prefix even on history reload. await sm.append_message( session_id=session_id, role=MessageRole.USER, content=f"@board {routing_result.topic}", ) try: await team.create_board(topic=routing_result.topic, expert_configs=expert_configs) orchestrator = BoardOrchestrator(team=team) result = await orchestrator.execute(routing_result.topic) except asyncio.CancelledError: raise except Exception as e: logger.error(f"Board meeting failed for session {session_id}: {e}", exc_info=True) await websocket.send_json( {"type": "error", "data": {"message": f"私董会执行失败: {str(e)[:200]}"}} ) try: await team.dissolve() except (RuntimeError, asyncio.TimeoutError, ConnectionError): pass return True finally: # dissolve() already clears handlers via handoff_transport.close() pass # Build final answer text from conclusion summary = result.get("summary", "") decision_advice = result.get("decision_advice", "") consensus_points = result.get("consensus_points", []) or [] dissent_points = result.get("dissent_points", []) or [] total_rounds = result.get("total_rounds", 0) final_parts: list[str] = [] if summary: final_parts.append(f"## 讨论总结\n\n{summary}") if decision_advice: final_parts.append(f"## 决策建议\n\n{decision_advice}") if consensus_points: final_parts.append("## 共识点\n\n" + "\n".join(f"- {p}" for p in consensus_points)) if dissent_points: final_parts.append("## 分歧点\n\n" + "\n".join(f"- {p}" for p in dissent_points)) final_parts.append(f"\n\n_共进行 {total_rounds} 轮讨论_") final_content = "\n\n".join(final_parts) await websocket.send_json( { "type": "final_answer", "content": final_content, "is_final": True, } ) # Persist final summary as assistant message await sm.append_message( session_id=session_id, role=MessageRole.ASSISTANT, content=final_content, agent_name="board_meeting", ) # Dissolve the team to release expert agents try: await team.dissolve() except (RuntimeError, asyncio.TimeoutError, ConnectionError) as e: logger.warning(f"Board team dissolve failed: {e}") return True async def _execute_team_collab( websocket: WebSocket, session_id: str, content: str, sm: SessionManager, ) -> bool: """Intercept @team prefix and execute a pipeline team collaboration. Returns True if the input was handled as a team collaboration (caller should return), False if the input should continue through the normal chat pipeline. Flow: 1. Resolve @team routing via ExpertTeamRouter 2. Create ExpertTeam with lead + member configs 3. Register handoff_transport handler to relay events to WebSocket 4. Execute TeamOrchestrator (pipeline mode) 5. Send final synthesis as final_answer 6. Persist user task + final result to session history """ from agentkit.experts.router import ExpertTeamRouter from agentkit.experts.team import ExpertTeam from agentkit.experts.orchestrator import TeamOrchestrator app_state = websocket.app.state # Resolve ExpertTemplateRegistry from app.state (loaded at startup) template_registry = getattr(app_state, "expert_template_registry", None) if template_registry is None: from agentkit.experts.registry import ExpertTemplateRegistry template_registry = ExpertTemplateRegistry() team_router = ExpertTeamRouter(template_registry=template_registry) routing_result = team_router.resolve(content) if not routing_result.matched: return False # Not a @team input, continue normal pipeline if not routing_result.task_content: await websocket.send_json( { "type": "error", "data": {"message": "团队任务需要一个描述,例如:@team 开发用户登录功能"}, } ) return True # Resolve expert configs from specified experts or default dev_team template expert_configs = team_router.resolve_expert_configs(routing_result.specified_experts) if not expert_configs: await websocket.send_json( {"type": "error", "data": {"message": "无法解析团队成员,请检查专家名称或模板配置"}} ) return True # Split configs: first is lead, rest are members (V2 verification) lead_config = expert_configs[0] member_configs = expert_configs[1:] if len(expert_configs) > 1 else [] # Create ExpertTeam team = ExpertTeam( pool=app_state.agent_pool, template_registry=template_registry, redis_client=getattr(app_state, "working_redis_client", None), ) # Register handoff_transport handler to relay team events to WebSocket async def _relay_team_event(message: dict) -> None: msg_type = message.get("type") if not msg_type: return # Strip internal fields, keep only event data event_data = {k: v for k, v in message.items() if k != "type"} await emit_team_event(websocket, msg_type, event_data) team.handoff_transport.register_handler(team.team_channel, _relay_team_event) try: # Append user task to session history (inside try so we can compensate on error) await sm.append_message( session_id=session_id, role=MessageRole.USER, content=content, ) await team.create_team(lead_config=lead_config, member_configs=member_configs) # U7: Create checkpoint manager for crash recovery from agentkit.orchestrator.checkpoint import PipelineCheckpoint checkpoint = PipelineCheckpoint( redis_client=getattr(app_state, "working_redis_client", None) ) orchestrator = TeamOrchestrator(team=team, checkpoint=checkpoint) # U4: Register active team so WS messages during execution route as interventions _register_active_team(session_id, team) result = await orchestrator.execute(routing_result.task_content) except asyncio.CancelledError: logger.info(f"Team collaboration cancelled for session {session_id}") await websocket.send_json({"type": "error", "data": {"message": "团队协作已取消"}}) return True except Exception as e: logger.error(f"Team collaboration failed for session {session_id}: {e}", exc_info=True) await websocket.send_json( {"type": "error", "data": {"message": f"团队协作执行失败: {str(e)[:200]}"}} ) return True finally: # U4: Always unregister the active team first so subsequent messages # don't route to a dissolving team. _unregister_active_team(session_id) # Always dissolve the team and remove handler to avoid leaks try: await team.dissolve() except (RuntimeError, asyncio.TimeoutError, ConnectionError) as e: logger.warning(f"Team dissolve failed: {e}") # dissolve() already clears handlers via handoff_transport.close() # Build final answer text from synthesis result final_result = result.get("result") or {} final_content = ( final_result.get("content", "") if isinstance(final_result, dict) else str(final_result) ) if not final_content: # Fallback: use phase results if synthesis is empty phase_results = result.get("phase_results") or {} if phase_results: parts = [] # Build a phase_id -> phase_name lookup from the plan phase_names = {} plan_obj = result.get("plan") if plan_obj: for ph in plan_obj.phases: phase_names[ph.id] = ph.name for phase_id, pr in phase_results.items(): if isinstance(pr, dict) and "content" in pr: parts.append( f"### {phase_names.get(phase_id, pr.get('phase_name', phase_id))}\n\n{pr['content']}" ) final_content = "\n\n".join(parts) if parts else "团队执行完成,但未生成最终结果。" else: final_content = "团队执行完成,但未生成最终结果。" await websocket.send_json( { "type": "final_answer", "content": final_content, "is_final": True, } ) # Persist final synthesis as assistant message await sm.append_message( session_id=session_id, role=MessageRole.ASSISTANT, content=final_content, agent_name="team_collab", ) return True def _session_to_response(session) -> SessionResponse: return SessionResponse( session_id=session.session_id, agent_name=session.agent_name, status=session.status.value, metadata=session.metadata, created_at=session.created_at.isoformat(), updated_at=session.updated_at.isoformat(), ) def _message_to_response(msg) -> MessageResponse: return MessageResponse( message_id=msg.message_id, session_id=msg.session_id, role=msg.role.value, content=msg.content, tool_call_id=msg.tool_call_id, agent_name=msg.agent_name, created_at=msg.created_at.isoformat(), ) def _build_phase_engine( *, server_config: ServerConfig | None, llm_gateway: LLMGateway, execution_mode: ExecutionMode, base_tools: list[Tool], session_id: str = "", ) -> tuple[ReActEngine | None, list[Tool] | None, str | None]: """Build a PLAN_EXEC engine with PhasePolicy + AdvancePhaseTool. Encapsulates the WS path's phase_policy construction so the REST path can reuse it without duplicating config-lookup + policy_from_config + AdvancePhaseTool registration. KTD5: PLAN_EXEC bypasses the fallback chain — callers must NOT route the returned engine through ``execute_with_fallback_chain``. Args: server_config: ``app.state.server_config`` (or None for tests). llm_gateway: ``app.state.llm_gateway``. execution_mode: routing.execution_mode (WS) or PLAN_EXEC (REST). base_tools: routing.tools (WS) or agent tool list (REST). session_id: included in log lines for traceability only. Returns ``(engine, tools_with_advance_phase, error_message)``: - execution_mode != PLAN_EXEC → ``(None, None, None)`` (fall back to REACT). - plan_exec.enabled=False → ``(None, None, None)`` (fall back to REACT). - phase policy construction failed → ``(None, None, error_message)``. - PLAN_EXEC engaged → ``(engine, tools_with_advance_phase, None)``. """ if execution_mode != ExecutionMode.PLAN_EXEC: return (None, None, None) plan_exec_cfg = getattr(server_config, "plan_exec", None) or {} if plan_exec_cfg.get("enabled", True) is False: logger.info( "PLAN_EXEC disabled by config (plan_exec.enabled=False), " "falling back to REACT for session %s", session_id, ) return (None, None, None) try: phase_policy = policy_from_config(plan_exec_cfg) if phase_policy is None: # Empty config (no `plan_exec:` section) → use KTD5 defaults. phase_policy = default_policy() except (ValueError, TypeError, KeyError) as e: logger.error( "PLAN_EXEC phase policy construction failed for session %s: %s", session_id, e, ) return (None, None, f"phase policy error: {str(e)[:200]}") engine = ReActEngine( llm_gateway=llm_gateway, phase_policy=phase_policy, ) advance_phase_tool = AdvancePhaseTool(engine=engine) tools_with_advance_phase = list(base_tools) + [advance_phase_tool] return (engine, tools_with_advance_phase, None) # ── REST endpoints ──────────────────────────────────────────────────── @router.get("/sessions", response_model=list[SessionResponse]) async def list_sessions(req: Request): """List all chat sessions.""" sm = _get_session_manager(req) sessions = await sm.list_sessions() return [_session_to_response(s) for s in sessions] @router.post("/sessions", response_model=SessionResponse) async def create_session(request: CreateSessionRequest, req: Request): """Create a new chat session bound to an Agent.""" sm = _get_session_manager(req) session = await sm.create_session( agent_name=request.agent_name, metadata=request.metadata, ) return _session_to_response(session) @router.get("/sessions/{session_id}", response_model=SessionResponse) async def get_session(session_id: str, req: Request): """Get session information.""" sm = _get_session_manager(req) session = await sm.get_session(session_id) if session is None: raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found") return _session_to_response(session) @router.get("/sessions/{session_id}/messages", response_model=list[MessageResponse]) async def get_messages(session_id: str, req: Request, limit: int | None = None, offset: int = 0): """Get conversation history for a session.""" sm = _get_session_manager(req) session = await sm.get_session(session_id) if session is None: raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found") messages = await sm.get_messages(session_id, limit=limit, offset=offset) return [_message_to_response(m) for m in messages] @router.post("/sessions/{session_id}/messages", response_model=MessageResponse) async def send_message(session_id: str, request: SendMessageRequest, req: Request): """Send a message to the Agent (synchronous mode — waits for full reply).""" sm = _get_session_manager(req) session = await sm.get_session(session_id) if session is None: raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found") if session.status == SessionStatus.CLOSED: raise HTTPException(status_code=400, detail=f"Session '{session_id}' is closed") # U3: PLAN_EXEC via REST — non-streaming, bypasses the fallback chain # (KTD5: PLAN_EXEC and execute_with_fallback_chain are mutually exclusive). # When plan_exec is disabled by config, falls through to the REACT path below. if request.execution_mode == "plan_exec": # Resolve the Agent early — PLAN_EXEC needs its tool list + system prompt. pool = req.app.state.agent_pool agent = pool.get_agent(session.agent_name) if agent is None: raise HTTPException(status_code=404, detail=f"Agent '{session.agent_name}' not found") plan_exec_engine, plan_exec_tools, plan_exec_error = _build_phase_engine( server_config=getattr(req.app.state, "server_config", None), llm_gateway=req.app.state.llm_gateway, execution_mode=ExecutionMode.PLAN_EXEC, base_tools=agent._tool_registry.list_tools() if agent._tool_registry else [], session_id=session_id, ) if plan_exec_error is not None: raise HTTPException(status_code=500, detail=plan_exec_error) if plan_exec_engine is not None: # PLAN_EXEC engaged — append user msg, execute non-streaming, return. await sm.append_message( session_id=session_id, role=MessageRole.USER, content=request.content, ) chat_messages = await sm.get_chat_messages(session_id) system_prompt = getattr(agent, "_system_prompt", None) or ( agent.get_system_prompt() if hasattr(agent, "get_system_prompt") else None ) try: plan_exec_result = await plan_exec_engine.execute( messages=chat_messages, tools=plan_exec_tools, model=agent.get_model() if hasattr(agent, "get_model") else getattr(agent, "_llm_model", "default"), agent_name=agent.name, system_prompt=system_prompt, ) except asyncio.CancelledError: raise except Exception as e: logger.error(f"PLAN_EXEC execution error for session {session_id}: {e}") raise HTTPException(status_code=500, detail=str(e)) assistant_msg = await sm.append_message( session_id=session_id, role=MessageRole.ASSISTANT, content=plan_exec_result.output, agent_name=agent.name, ) return _message_to_response(assistant_msg) # else: plan_exec.enabled=False → fall through to REACT path below. # Append user message await sm.append_message( session_id=session_id, role=MessageRole.USER, content=request.content, ) # Get full conversation history for the Agent chat_messages = await sm.get_chat_messages(session_id) # Resolve the Agent pool = req.app.state.agent_pool agent = pool.get_agent(session.agent_name) if agent is None: raise HTTPException(status_code=404, detail=f"Agent '{session.agent_name}' not found") # Execute the Agent try: # Reuse Agent's ReActEngine if available (U2: Chat pipeline optimization) react_engine = getattr(agent, "_react_engine", None) if react_engine is None: react_engine = ReActEngine(llm_gateway=req.app.state.llm_gateway) else: react_engine.reset() tools = agent._tool_registry.list_tools() if agent._tool_registry else [] system_prompt = getattr(agent, "_system_prompt", None) or ( agent.get_system_prompt() if hasattr(agent, "get_system_prompt") else None ) # G7/U3: Three-tier fallback chain (main → Recovery → Emergency). # Wired only here (KTD5); CLI / ReWOO / Reflexion internal ReAct bypass. server_config = getattr(req.app.state, "server_config", None) fallback_chain_cfg = ( getattr(server_config, "fallback_chain", None) if server_config else None ) chat_result = await execute_with_fallback_chain( react_engine=react_engine, llm_gateway=req.app.state.llm_gateway, messages=chat_messages, tools=tools, model=agent.get_model() if hasattr(agent, "get_model") else getattr(agent, "_llm_model", "default"), agent_name=agent.name, system_prompt=system_prompt, fallback_chain_config=fallback_chain_cfg, ) # Append assistant reply assistant_msg = await sm.append_message( session_id=session_id, role=MessageRole.ASSISTANT, content=chat_result.output, agent_name=agent.name, ) response = _message_to_response(assistant_msg) # Attach structured error payload when Emergency tier fired. if chat_result.error_struct is not None: response_dict = ( response.model_dump() if hasattr(response, "model_dump") else dict(response) ) response_dict["error_struct"] = chat_result.error_struct response_dict["fallback_status"] = chat_result.status return response_dict return response except asyncio.CancelledError: raise except Exception as e: logger.error(f"Chat execution error for session {session_id}: {e}") raise HTTPException(status_code=500, detail=str(e)) @router.delete("/sessions/{session_id}") async def close_session(session_id: str, req: Request): """Close a chat session.""" sm = _get_session_manager(req) session = await sm.close_session(session_id) if session is None: raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found") return {"status": "closed", "session_id": session_id} # ── WebSocket endpoint ──────────────────────────────────────────────── async def _resolve_ws_dept_context( websocket: WebSocket, ) -> tuple[str | None, list[str], Path | None]: """Resolve user_id, department_ids, and db_path for quota enforcement. Reads ``websocket.state.current_user`` (set by :class:`AuthMiddleware` when JWT auth via ``?token=`` query param succeeds). If no user context is available (API key auth), returns ``(None, [], None)`` — quota is not enforced for API key clients. Returns ``(user_id, department_ids, db_path)``: - ``user_id``: the authenticated user's id, or ``None``. - ``department_ids``: the user's active department ids (empty for admins and API key clients). - ``db_path``: the auth DB path (needed for quota checks), or ``None`` if no user context is available. """ current_user: dict[str, Any] | None = getattr(websocket.state, "current_user", None) if current_user is None: return None, [], None user_id = current_user.get("user_id") role = current_user.get("role") # Admins and API-key clients (user_id=None) — no quota enforcement. if role == "admin" or not user_id: return user_id, [], None # Regular user: look up their active department ids. from agentkit.server.admin.context import _fetch_user_department_ids db_path = getattr(websocket.app.state, "auth_db_path", None) if db_path is None: db_path_resolved = None else: db_path_resolved = Path(db_path) try: department_ids = await _fetch_user_department_ids(db_path_resolved, user_id) except (ConnectionError, OSError, asyncio.TimeoutError, ValueError, KeyError, RuntimeError): logger.exception( "Failed to fetch department ids for WebSocket user %s — fail-closed", user_id, ) # Fail-closed: return empty list with db_path=None so gateway # skips quota enforcement but we log the failure. The user # gets a degraded experience rather than a security bypass. return user_id, [], None return user_id, department_ids, db_path_resolved return user_id, [], None @router.websocket("/ws/{session_id}") async def chat_websocket(websocket: WebSocket, session_id: str) -> None: """WebSocket endpoint for real-time chat with streaming. Client → Server messages: {"type": "message", "content": "..."} — Send a user message {"type": "cancel"} — Cancel current execution {"type": "ping"} — Heartbeat Server → Client messages: {"type": "connected", "session_id": "..."} — Connection confirmed {"type": "token", "content": "..."} — LLM token streaming {"type": "step", "data": {...}} — ReAct step event {"type": "ask_human", "question": "...", "request_id": "..."} — Agent asks user {"type": "final_answer", "content": "..."} — Agent's final reply {"type": "error", "data": {"message": "..."}} — Error occurred {"type": "pong"} — Heartbeat response Expert Team events (Server → Client): {"type": "team_formed", "data": {...}} — Team assembled {"type": "expert_step", "data": {...}} — Expert executing step {"type": "expert_result", "data": {...}} — Expert produced result {"type": "plan_update", "data": {...}} — Plan phases updated {"type": "team_synthesis", "data": {...}} — Final synthesis {"type": "team_dissolved", "data": {...}} — Team dissolved """ # Authentication configured_api_key: str | None = None if hasattr(websocket.app.state, "server_config") and websocket.app.state.server_config: configured_api_key = websocket.app.state.server_config.api_key if configured_api_key is None and hasattr(websocket.app.state, "api_key"): configured_api_key = websocket.app.state.api_key if configured_api_key: provided = websocket.query_params.get("api_key") if not provided or not hmac.compare_digest(provided, configured_api_key): await websocket.accept() await websocket.send_json({"type": "error", "data": {"message": "Invalid api_key"}}) await websocket.close(code=4001, reason="Invalid api_key") return await websocket.accept() # Validate session sm: SessionManager = websocket.app.state.session_manager session = await sm.get_session(session_id) if session is None: await websocket.send_json( {"type": "error", "data": {"message": f"Session '{session_id}' not found"}} ) await websocket.close(code=1000, reason="Session not found") return if session.status == SessionStatus.CLOSED: await websocket.send_json({"type": "error", "data": {"message": "Session is closed"}}) await websocket.close(code=1000, reason="Session closed") return # Track pending replies for AskHumanTool and confirmations pending_replies: dict[str, asyncio.Future] = {} pending_confirmations: dict[str, asyncio.Future] = {} chat_manager.add(session_id, websocket, pending_replies) cancellation_token = CancellationToken() # Per-session concurrency guard to prevent unlimited task creation (DoS mitigation) _MAX_CONCURRENT_TASKS = 4 active_tasks: set[asyncio.Task] = set() try: await websocket.send_json({"type": "connected", "session_id": session_id}) # Listen for client messages while True: try: raw = await asyncio.wait_for(websocket.receive_text(), timeout=300.0) except asyncio.TimeoutError: await websocket.send_json({"type": "pong"}) continue try: msg = json.loads(raw) except json.JSONDecodeError: continue msg_type = msg.get("type") if msg_type == "message": content = msg.get("content", "") model = msg.get("model") # Optional model override from frontend # U4: If a team is currently executing for this session, route # the message as an intervention instead of a new chat task. active_team = _get_active_team(session_id) if active_team is not None: try: await active_team.add_user_intervention(content) await websocket.send_json( { "type": "team_intervention_ack", "data": {"content": content}, } ) except (asyncio.QueueFull, RuntimeError, ConnectionError) as e: logger.warning(f"Failed to enqueue intervention: {e}") await websocket.send_json( { "type": "error", "data": {"message": f"干预消息入队失败: {e}"}, } ) continue # Create a fresh CancellationToken for each message message_token = CancellationToken() # Guard against unlimited concurrent tasks # Clean up completed tasks first active_tasks.difference_update(t for t in active_tasks if t.done()) if len(active_tasks) >= _MAX_CONCURRENT_TASKS: await websocket.send_json( { "type": "error", "data": { "message": "Too many concurrent requests. Please wait for the current task to complete." }, } ) continue # Run in background task so the WebSocket receive loop stays free # to process confirmation_reply / reply messages while the agent # is waiting for user confirmation (otherwise deadlock). task = asyncio.create_task( _handle_chat_message( websocket, session_id, content, sm, message_token, pending_replies, pending_confirmations, model_override=model, ) ) active_tasks.add(task) task.add_done_callback(active_tasks.discard) elif msg_type == "reply": # Reply to AskHumanTool request_id = msg.get("request_id") reply_content = msg.get("content", "") if request_id and request_id in pending_replies: pending_replies[request_id].set_result(reply_content) elif msg_type == "confirmation_reply": # Reply to confirmation request confirmation_id = msg.get("confirmation_id") approved = msg.get("approved", False) logger.info( f"Received confirmation_reply: id={confirmation_id!r}, approved={approved}" ) if confirmation_id and confirmation_id in pending_confirmations: pending_confirmations[confirmation_id].set_result(approved) logger.info(f"Confirmation {confirmation_id} set_result({approved})") else: logger.warning( f"Confirmation {confirmation_id!r} not found in pending_confirmations" ) elif msg_type == "cancel": cancellation_token.cancel() await websocket.send_json({"type": "result", "data": {"status": "cancelled"}}) elif msg_type == "ping": await websocket.send_json({"type": "pong"}) except WebSocketDisconnect: logger.debug(f"Chat WebSocket disconnected for session {session_id}") except asyncio.CancelledError: raise except Exception as e: logger.error(f"Chat WebSocket error for session {session_id}: {e}") try: await websocket.send_json({"type": "error", "data": {"message": str(e)}}) except (ConnectionError, RuntimeError, asyncio.TimeoutError): pass finally: # Clean up pending futures for fut in pending_replies.values(): if not fut.done(): fut.cancel() for fut in pending_confirmations.values(): if not fut.done(): fut.cancel() chat_manager.remove(session_id, websocket) async def _handle_chat_message( websocket: WebSocket, session_id: str, content: str, sm: SessionManager, cancellation_token: CancellationToken, pending_replies: dict[str, asyncio.Future], pending_confirmations: dict[str, asyncio.Future] | None = None, model_override: str | None = None, ) -> None: """Handle a user message: append to session, execute Agent, stream events. Uses RequestPreprocessor for minimal preprocessing: @skill prefix + greeting regex + REACT. Board Meeting mode: @board prefix is intercepted before RequestPreprocessor and routed to BoardOrchestrator for multi-round group discussion. Team Collaboration mode: @team prefix is intercepted before RequestPreprocessor and routed to TeamOrchestrator for pipeline-based expert collaboration. """ from agentkit.chat.request_preprocessor import RequestPreprocessor # Board Meeting mode: intercept @board prefix before any other preprocessing if await _execute_board_meeting(websocket, session_id, content, sm): return # Team Collaboration mode: intercept @team prefix before any other preprocessing if await _execute_team_collab(websocket, session_id, content, sm): return # Resolve Agent first (needed for default tools/prompt) pool = websocket.app.state.agent_pool session = await sm.get_session(session_id) if session is None: await websocket.send_json({"type": "error", "data": {"message": "Session lost"}}) return agent = pool.get_agent(session.agent_name) if agent is None: await websocket.send_json( {"type": "error", "data": {"message": f"Agent '{session.agent_name}' not found"}} ) return # Default execution parameters from agent default_tools = agent._tool_registry.list_tools() if agent._tool_registry else [] default_system_prompt = getattr(agent, "_system_prompt", None) or ( agent.get_system_prompt() if hasattr(agent, "get_system_prompt") else None ) default_model = ( agent.get_model() if hasattr(agent, "get_model") else getattr(agent, "_llm_model", "default") ) # Resolve skill routing using RequestPreprocessor skill_registry = getattr(websocket.app.state, "skill_registry", None) request_preprocessor: RequestPreprocessor = websocket.app.state.request_preprocessor routing = await request_preprocessor.preprocess( content=content, skill_registry=skill_registry, default_tools=default_tools, default_system_prompt=default_system_prompt, default_model=default_model, default_agent_name=agent.name, ) # Debug: log tools that will be passed to ReActEngine tool_names = [t.name for t in routing.tools] logger.info( f"Chat {session_id}: resolved {len(routing.tools)} tools: {tool_names}, model={routing.model}, skill={routing.skill_name}" ) # Apply model override from frontend selector if model_override: routing.model = model_override # Notify frontend about skill match if routing.matched: await websocket.send_json( { "type": "skill_match", "data": { "skill": routing.skill_name, "method": routing.match_method, "confidence": routing.match_confidence, }, } ) # Append user message (use clean_content if @skill: prefix was stripped) await sm.append_message( session_id=session_id, role=MessageRole.USER, content=routing.clean_content ) # Get full conversation history chat_messages = await sm.get_chat_messages(session_id) # Handle DIRECT_CHAT: direct LLM call, no ReAct loop if routing.execution_mode == ExecutionMode.DIRECT_CHAT: direct_messages = [] if routing.system_prompt: direct_messages.append({"role": "system", "content": routing.system_prompt}) direct_messages.extend(chat_messages) # Resolve department context for quota enforcement (U4/KTD-3). ws_user_id, ws_dept_ids, ws_db_path = await _resolve_ws_dept_context(websocket) try: response = await websocket.app.state.llm_gateway.chat( messages=direct_messages, model=routing.model or "default", agent_name=agent.name, task_type="chat", user_id=ws_user_id, department_ids=ws_dept_ids if ws_dept_ids else None, db_path=ws_db_path, ) final_content = response.content or "" if not final_content or not final_content.strip(): from agentkit.core.fallback import EMPTY_LLM_RESPONSE final_content = EMPTY_LLM_RESPONSE await websocket.send_json( { "type": "final_answer", "content": final_content, "is_final": True, } ) await sm.append_message( session_id=session_id, role=MessageRole.ASSISTANT, content=final_content, agent_name=agent.name, ) except asyncio.CancelledError: raise except Exception as e: # Check if this is a QuotaExceededError (U4: WebSocket quota). from agentkit.llm.gateway import QuotaExceededError if isinstance(e, QuotaExceededError): logger.warning( "WebSocket DIRECT_CHAT quota exceeded for session %s: %s", session_id, e, ) await websocket.send_json( { "type": "error", "data": { "message": "quota exceeded", "quota_info": { "department_id": e.department_id, "quota_type": e.quota_type, "period": e.period, "limit": e.limit, "current": e.current, }, }, } ) else: logger.error(f"Chat DIRECT_CHAT error for session {session_id}: {e}") await websocket.send_json({"type": "error", "data": {"message": str(e)[:200]}}) return # U4/G6: PLAN_EXEC — build PhasePolicy from server config. # KTD5 (Wave 2): fallback chain NOT applied to PLAN_EXEC — phase policy and # fallback chain are mutually exclusive. PLAN_EXEC uses its own engine. # U3: logic extracted into _build_phase_engine so REST can reuse it. plan_exec_engine, plan_exec_tools, plan_exec_error = _build_phase_engine( server_config=getattr(websocket.app.state, "server_config", None), llm_gateway=websocket.app.state.llm_gateway, execution_mode=routing.execution_mode, base_tools=routing.tools, session_id=session_id, ) if plan_exec_error is not None: await websocket.send_json( { "type": "error", # Truncate to 200 chars to match nearby error paths and # avoid leaking config internals (see chat.py:1090, 1320). "data": {"message": plan_exec_error}, } ) return # Handle advanced execution modes: REWOO/REFLEXION/TEAM_COLLAB # still fall back to REACT with a warning. PLAN_EXEC is handled above. if routing.execution_mode not in ( ExecutionMode.REACT, ExecutionMode.SKILL_REACT, ExecutionMode.PLAN_EXEC, ): logger.warning( f"Execution mode {routing.execution_mode.value} not implemented " f"in chat WebSocket path, falling back to REACT" ) # Execute Agent with streaming # Reuse Agent's ReActEngine if available (U2: Chat pipeline optimization). # PLAN_EXEC creates a fresh engine with phase_policy set (cannot reuse the # agent's _react_engine — it has no policy). if plan_exec_engine is not None: react_engine = plan_exec_engine routing.tools = plan_exec_tools else: react_engine = getattr(agent, "_react_engine", None) if react_engine is None: react_engine = ReActEngine(llm_gateway=websocket.app.state.llm_gateway) else: react_engine.reset() # Create confirmation handler that sends request to frontend and waits for reply # Use the same dict object — do NOT use `or {}` because an empty dict is falsy # and would create a new dict, breaking the shared state with the WS loop. _pending_confirmations = pending_confirmations if pending_confirmations is not None else {} async def _confirmation_handler(confirmation_id: str, command: str, reason: str) -> bool: """Send confirmation request to frontend via WebSocket and wait for user reply.""" # Send confirmation request to frontend await websocket.send_json( { "type": "confirmation_request", "data": { "confirmation_id": confirmation_id, "command": command, "reason": reason, }, } ) # Create a Future and wait for the user's reply loop = asyncio.get_running_loop() future: asyncio.Future[bool] = loop.create_future() _pending_confirmations[confirmation_id] = future logger.info(f"Confirmation request {confirmation_id} sent, waiting for reply") try: # Wait up to 5 minutes for user confirmation result = await asyncio.wait_for(future, timeout=300.0) logger.info(f"Confirmation request {confirmation_id} resolved: {result}") # Immediately notify frontend of the result so the card updates # without waiting for the tool to re-execute await websocket.send_json( { "type": "confirmation_result", "data": {"confirmation_id": confirmation_id, "approved": result}, } ) return result except asyncio.TimeoutError: logger.warning(f"Confirmation request {confirmation_id} timed out") return False except asyncio.CancelledError: logger.warning(f"Confirmation request {confirmation_id} cancelled") return False finally: _pending_confirmations.pop(confirmation_id, None) logger.info( f"Chat session {session_id}: executing with {len(routing.tools)} tools, model={routing.model}, skill={routing.skill_name}" ) try: final_content = "" token_buffer: list[str] = [] # Track phase transitions for phase_changed events (PLAN_EXEC only). # For non-PLAN_EXEC modes, current_phase is always None → no events. prev_phase = react_engine.current_phase async for event in react_engine.execute_stream( messages=chat_messages, tools=routing.tools, model=routing.model, agent_name=routing.agent_name, system_prompt=routing.system_prompt, cancellation_token=cancellation_token, confirmation_handler=_confirmation_handler, ): if event.event_type == "final_answer": # Flush any buffered tokens as a single write if token_buffer: await websocket.send_json({"type": "token", "content": "".join(token_buffer)}) token_buffer.clear() # Then send final answer final_content = event.data.get("output", "") if not final_content or not final_content.strip(): from agentkit.core.fallback import EMPTY_LLM_RESPONSE final_content = EMPTY_LLM_RESPONSE await websocket.send_json( { "type": "final_answer", "content": final_content, "is_final": True, } ) elif event.event_type == "token": # Buffer tokens instead of sending immediately token_buffer.append(event.data.get("content", "")) elif event.event_type == "thinking": # If we have buffered tokens, convert them to a thinking event if token_buffer: buffered_text = "".join(token_buffer) token_buffer.clear() await websocket.send_json({"type": "thinking", "content": buffered_text}) # Also send the thinking event content thinking_msg = event.data.get("message", "") if thinking_msg: await websocket.send_json({"type": "thinking", "content": thinking_msg}) elif event.event_type == "tool_call": # Convert buffered tokens to thinking (they were "thinking" text before tool call) if token_buffer: buffered_text = "".join(token_buffer) token_buffer.clear() await websocket.send_json({"type": "thinking", "content": buffered_text}) await websocket.send_json( { "type": "step", "data": { "event_type": event.event_type, "step": event.step, "data": event.data, }, } ) elif event.event_type == "confirmation_request": pass elif event.event_type == "confirmation_result": await websocket.send_json( { "type": "confirmation_result", "data": event.data, } ) elif event.event_type == "phase_violation": # Wave 4 U2: forward phase violations to the client so the # frontend can surface them in the PhaseIndicator UI (alongside # the LLM reinjection that already happens via the tool_result # error dict). await websocket.send_json( { "type": "phase_violation", "data": event.data, } ) else: await websocket.send_json( { "type": "step", "data": { "event_type": event.event_type, "step": event.step, "data": event.data, }, } ) # U4/G6: emit phase_changed event when the phase state machine # transitions (PLAN_EXEC only). For non-PLAN_EXEC modes, # current_phase is always None → this branch never fires. curr_phase = react_engine.current_phase if curr_phase != prev_phase: await websocket.send_json( { "type": "phase_changed", "data": { "phase": curr_phase.value if curr_phase else None, "previous": prev_phase.value if prev_phase else None, }, } ) prev_phase = curr_phase # Append assistant reply to session if final_content: await sm.append_message( session_id=session_id, role=MessageRole.ASSISTANT, content=final_content, agent_name=agent.name, ) except asyncio.CancelledError: raise except Exception as e: logger.error(f"Chat execution error for session {session_id}: {e}") # Show meaningful error to user, but avoid leaking full stack traces error_msg = str(e) # Truncate very long error messages if len(error_msg) > 200: error_msg = error_msg[:200] + "..." await websocket.send_json({"type": "error", "data": {"message": error_msg}}) # ── File upload ─────────────────────────────────────────────────────── MAX_UPLOAD_SIZE = 10 * 1024 * 1024 # 10 MB UPLOAD_DIR = Path(os.environ.get("AGENTKIT_UPLOAD_DIR", "data/uploads")) def _ensure_upload_dir() -> Path: UPLOAD_DIR.mkdir(parents=True, exist_ok=True) return UPLOAD_DIR def _sanitize_filename(name: str) -> str: """Remove path separators and keep only safe characters.""" name = name.replace("\\", "_").replace("/", "_") return "".join(c for c in name if c.isalnum() or c in "._-").strip(".") @router.post("/upload") async def upload_chat_file(file: UploadFile = File(...)) -> dict[str, Any]: """Upload a file to be referenced in chat messages. Returns metadata including a public download URL. """ if file.size is not None and file.size > MAX_UPLOAD_SIZE: raise HTTPException(status_code=413, detail="File exceeds 10 MB limit") original_name = file.filename or "unnamed" safe_name = _sanitize_filename(original_name) or "unnamed" ext = Path(safe_name).suffix stored_name = f"{uuid.uuid4().hex}{ext}" upload_dir = _ensure_upload_dir() file_path = upload_dir / stored_name try: contents = await file.read() if len(contents) > MAX_UPLOAD_SIZE: raise HTTPException(status_code=413, detail="File exceeds 10 MB limit") file_path.write_bytes(contents) except HTTPException: raise except asyncio.CancelledError: raise except Exception as exc: logger.error(f"Failed to save uploaded file: {exc}") raise HTTPException(status_code=500, detail="Failed to save file") from exc finally: await file.close() return { "filename": original_name, "stored_name": stored_name, "content_type": file.content_type or "application/octet-stream", "size": file_path.stat().st_size, "download_url": f"/api/v1/chat/uploads/{stored_name}", } @router.get("/uploads/{filename}") async def download_chat_file(filename: str) -> FileResponse: """Download an uploaded chat file by its stored filename.""" upload_dir = _ensure_upload_dir() safe_filename = _sanitize_filename(filename) file_path = upload_dir / safe_filename if not file_path.exists() or not file_path.is_file(): raise HTTPException(status_code=404, detail="File not found") return FileResponse(file_path, filename=safe_filename)