1287 lines
49 KiB
Python
1287 lines
49 KiB
Python
"""Chat API routes — multi-turn conversation with Agent via REST and WebSocket."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import hmac
|
|
import json
|
|
import logging
|
|
from typing import Any
|
|
|
|
import os
|
|
import uuid
|
|
from pathlib import Path
|
|
|
|
from fastapi import (
|
|
APIRouter,
|
|
File,
|
|
HTTPException,
|
|
Request,
|
|
UploadFile,
|
|
WebSocket,
|
|
WebSocketDisconnect,
|
|
)
|
|
from fastapi.responses import FileResponse
|
|
from pydantic import BaseModel
|
|
|
|
from agentkit.chat.skill_routing import ExecutionMode
|
|
from agentkit.core.protocol import CancellationToken
|
|
from agentkit.core.react import ReActEngine
|
|
from agentkit.session.manager import SessionManager
|
|
from agentkit.session.models import MessageRole, SessionStatus
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(prefix="/chat", tags=["chat"])
|
|
|
|
|
|
# ── Request/Response schemas ──────────────────────────────────────────
|
|
|
|
|
|
class CreateSessionRequest(BaseModel):
|
|
agent_name: str
|
|
metadata: dict[str, Any] | None = None
|
|
|
|
|
|
class SendMessageRequest(BaseModel):
|
|
content: str
|
|
role: str = "user"
|
|
|
|
|
|
class SessionResponse(BaseModel):
|
|
session_id: str
|
|
agent_name: str
|
|
status: str
|
|
metadata: dict[str, Any]
|
|
created_at: str
|
|
updated_at: str
|
|
|
|
|
|
class MessageResponse(BaseModel):
|
|
message_id: str
|
|
session_id: str
|
|
role: str
|
|
content: str
|
|
tool_call_id: str | None = None
|
|
agent_name: str | None = None
|
|
created_at: str
|
|
|
|
|
|
# ── Chat WebSocket connection manager ─────────────────────────────────
|
|
|
|
|
|
class ChatConnectionManager:
|
|
"""Track active WebSocket connections per session_id."""
|
|
|
|
def __init__(self) -> None:
|
|
# session_id -> list of (websocket, pending_replies)
|
|
self._connections: dict[str, list[tuple[WebSocket, dict[str, asyncio.Future]]]] = {}
|
|
|
|
def add(self, session_id: str, ws: WebSocket, pending: dict[str, asyncio.Future]) -> None:
|
|
self._connections.setdefault(session_id, []).append((ws, pending))
|
|
|
|
def remove(self, session_id: str, ws: WebSocket) -> None:
|
|
conns = self._connections.get(session_id)
|
|
if conns is None:
|
|
return
|
|
self._connections[session_id] = [(w, p) for w, p in conns if w is not ws]
|
|
if not self._connections[session_id]:
|
|
del self._connections[session_id]
|
|
|
|
def get_connections(self, session_id: str) -> list[tuple[WebSocket, dict[str, asyncio.Future]]]:
|
|
return self._connections.get(session_id, [])
|
|
|
|
async def send_json(self, session_id: str, message: dict) -> None:
|
|
"""Send a JSON message to all connections for a session."""
|
|
conns = self._connections.get(session_id, [])
|
|
stale: list[WebSocket] = []
|
|
for ws, _ in conns:
|
|
try:
|
|
await ws.send_json(message)
|
|
except Exception:
|
|
stale.append(ws)
|
|
for ws in stale:
|
|
self.remove(session_id, ws)
|
|
|
|
|
|
chat_manager = ChatConnectionManager()
|
|
|
|
|
|
# U4: Active team sessions — maps session_id to the ExpertTeam currently executing.
|
|
# When a message arrives during team execution, it is routed as an intervention
|
|
# instead of starting a new chat task. Populated by _execute_team_collab.
|
|
_active_teams: dict[str, "object"] = {}
|
|
|
|
|
|
def _register_active_team(session_id: str, team: "object") -> None:
|
|
"""Register an active team for a session (intervention routing)."""
|
|
_active_teams[session_id] = team
|
|
|
|
|
|
def _unregister_active_team(session_id: str) -> None:
|
|
"""Unregister the active team for a session."""
|
|
_active_teams.pop(session_id, None)
|
|
|
|
|
|
def _get_active_team(session_id: str) -> "object | None":
|
|
"""Get the active team for a session, if any."""
|
|
return _active_teams.get(session_id)
|
|
|
|
|
|
# ── Helper ────────────────────────────────────────────────────────────
|
|
|
|
|
|
_VALID_TEAM_EVENT_TYPES = frozenset(
|
|
{
|
|
"team_formed",
|
|
"expert_step",
|
|
"expert_result",
|
|
"plan_update",
|
|
"team_synthesis",
|
|
"team_dissolved",
|
|
"plan_step",
|
|
"phase_started",
|
|
"phase_completed",
|
|
"phase_failed",
|
|
"replanning",
|
|
# PM Collaboration 模式事件 (U1-U4)
|
|
"collaboration_contract_defined",
|
|
"collaboration_notice",
|
|
"review_result",
|
|
"risk_flagged",
|
|
# Board Meeting 模式事件
|
|
"board_started",
|
|
"expert_speech",
|
|
"round_summary",
|
|
"user_intervention",
|
|
"board_concluded",
|
|
}
|
|
)
|
|
|
|
|
|
async def emit_team_event(websocket: WebSocket, event_type: str, data: dict) -> None:
|
|
"""Emit a team-related WebSocket event.
|
|
|
|
Supported event types:
|
|
team_formed — Team assembled with expert list and plan.
|
|
data: {team_id, status, experts, plan_phases, lead_expert}
|
|
expert_step — An expert is executing a step.
|
|
data: {expert_id, expert_name, expert_color, content, step}
|
|
expert_result — An expert produced a result.
|
|
data: {expert_id, expert_name, expert_color, content}
|
|
plan_update — Team plan phases updated.
|
|
data: {plan_phases}
|
|
team_synthesis — Final synthesis from the lead expert.
|
|
data: {content}
|
|
team_dissolved — Team dissolved after completion.
|
|
data: {team_id}
|
|
"""
|
|
if event_type not in _VALID_TEAM_EVENT_TYPES:
|
|
logger.warning(f"emit_team_event: invalid event_type '{event_type}'")
|
|
return
|
|
await websocket.send_json(
|
|
{
|
|
"type": event_type,
|
|
"data": data,
|
|
}
|
|
)
|
|
|
|
|
|
def _get_session_manager(request: Request) -> SessionManager:
|
|
return request.app.state.session_manager
|
|
|
|
|
|
async def _execute_board_meeting(
|
|
websocket: WebSocket,
|
|
session_id: str,
|
|
content: str,
|
|
sm: SessionManager,
|
|
) -> bool:
|
|
"""Intercept @board prefix and execute a board meeting discussion.
|
|
|
|
Returns True if the input was handled as a board meeting (caller should return),
|
|
False if the input should continue through the normal chat pipeline.
|
|
|
|
Flow:
|
|
1. Resolve @board routing via BoardRouter
|
|
2. Create BoardTeam with expert configs
|
|
3. Register handoff_transport handler to relay events to WebSocket
|
|
4. Execute BoardOrchestrator
|
|
5. Send final conclusion as final_answer
|
|
6. Persist user topic + final summary to session history
|
|
"""
|
|
from agentkit.experts.board_router import BoardRouter
|
|
from agentkit.experts.board import BoardTeam
|
|
from agentkit.experts.board_orchestrator import BoardOrchestrator
|
|
|
|
app_state = websocket.app.state
|
|
|
|
# Resolve ExpertTemplateRegistry from app.state (loaded at startup)
|
|
template_registry = getattr(app_state, "expert_template_registry", None)
|
|
if template_registry is None:
|
|
from agentkit.experts.registry import ExpertTemplateRegistry
|
|
|
|
template_registry = ExpertTemplateRegistry()
|
|
|
|
board_router = BoardRouter(template_registry=template_registry)
|
|
routing_result = board_router.resolve(content)
|
|
|
|
if not routing_result.matched:
|
|
return False # Not a @board input, continue normal pipeline
|
|
|
|
if not routing_result.topic:
|
|
await websocket.send_json(
|
|
{
|
|
"type": "error",
|
|
"data": {"message": "私董会需要一个讨论主题,例如:@board 如何看待 AI 未来"},
|
|
}
|
|
)
|
|
return True
|
|
|
|
# Resolve expert configs from specified experts or default template
|
|
expert_configs = board_router.resolve_expert_configs(routing_result.specified_experts)
|
|
if not expert_configs:
|
|
await websocket.send_json(
|
|
{"type": "error", "data": {"message": "无法解析私董会成员,请检查专家名称或模板配置"}}
|
|
)
|
|
return True
|
|
|
|
# Read board config from server_config if available; user-specified rounds take precedence
|
|
max_rounds = routing_result.max_rounds
|
|
if max_rounds is None:
|
|
max_rounds = 5
|
|
server_config = getattr(app_state, "server_config", None)
|
|
if server_config is not None:
|
|
board_cfg = getattr(server_config, "board", None) or {}
|
|
if isinstance(board_cfg, dict):
|
|
max_rounds = int(board_cfg.get("max_rounds", 5))
|
|
|
|
# Create BoardTeam
|
|
team = BoardTeam(
|
|
pool=app_state.agent_pool,
|
|
template_registry=template_registry,
|
|
max_rounds=max_rounds,
|
|
)
|
|
|
|
# Register handoff_transport handler to relay board events to WebSocket
|
|
async def _relay_board_event(message: dict) -> None:
|
|
msg_type = message.get("type")
|
|
if not msg_type:
|
|
return
|
|
# Strip internal fields, keep only event data
|
|
event_data = {k: v for k, v in message.items() if k != "type"}
|
|
await emit_team_event(websocket, msg_type, event_data)
|
|
|
|
team.handoff_transport.register_handler(team.team_channel, _relay_board_event)
|
|
|
|
# Append user topic to session history
|
|
await sm.append_message(
|
|
session_id=session_id,
|
|
role=MessageRole.USER,
|
|
content=content,
|
|
)
|
|
|
|
try:
|
|
await team.create_board(topic=routing_result.topic, expert_configs=expert_configs)
|
|
orchestrator = BoardOrchestrator(team=team)
|
|
result = await orchestrator.execute(routing_result.topic)
|
|
except Exception as e:
|
|
logger.error(f"Board meeting failed for session {session_id}: {e}", exc_info=True)
|
|
await websocket.send_json(
|
|
{"type": "error", "data": {"message": f"私董会执行失败: {str(e)[:200]}"}}
|
|
)
|
|
try:
|
|
await team.dissolve()
|
|
except Exception:
|
|
pass
|
|
return True
|
|
finally:
|
|
# dissolve() already clears handlers via handoff_transport.close()
|
|
pass
|
|
|
|
# Build final answer text from conclusion
|
|
summary = result.get("summary", "")
|
|
decision_advice = result.get("decision_advice", "")
|
|
consensus_points = result.get("consensus_points", []) or []
|
|
dissent_points = result.get("dissent_points", []) or []
|
|
total_rounds = result.get("total_rounds", 0)
|
|
|
|
final_parts: list[str] = []
|
|
if summary:
|
|
final_parts.append(f"## 讨论总结\n\n{summary}")
|
|
if decision_advice:
|
|
final_parts.append(f"## 决策建议\n\n{decision_advice}")
|
|
if consensus_points:
|
|
final_parts.append("## 共识点\n\n" + "\n".join(f"- {p}" for p in consensus_points))
|
|
if dissent_points:
|
|
final_parts.append("## 分歧点\n\n" + "\n".join(f"- {p}" for p in dissent_points))
|
|
final_parts.append(f"\n\n_共进行 {total_rounds} 轮讨论_")
|
|
|
|
final_content = "\n\n".join(final_parts)
|
|
|
|
await websocket.send_json(
|
|
{
|
|
"type": "final_answer",
|
|
"content": final_content,
|
|
"is_final": True,
|
|
}
|
|
)
|
|
|
|
# Persist final summary as assistant message
|
|
await sm.append_message(
|
|
session_id=session_id,
|
|
role=MessageRole.ASSISTANT,
|
|
content=final_content,
|
|
agent_name="board_meeting",
|
|
)
|
|
|
|
# Dissolve the team to release expert agents
|
|
try:
|
|
await team.dissolve()
|
|
except Exception as e:
|
|
logger.warning(f"Board team dissolve failed: {e}")
|
|
|
|
return True
|
|
|
|
|
|
async def _execute_team_collab(
|
|
websocket: WebSocket,
|
|
session_id: str,
|
|
content: str,
|
|
sm: SessionManager,
|
|
) -> bool:
|
|
"""Intercept @team prefix and execute a pipeline team collaboration.
|
|
|
|
Returns True if the input was handled as a team collaboration (caller should return),
|
|
False if the input should continue through the normal chat pipeline.
|
|
|
|
Flow:
|
|
1. Resolve @team routing via ExpertTeamRouter
|
|
2. Create ExpertTeam with lead + member configs
|
|
3. Register handoff_transport handler to relay events to WebSocket
|
|
4. Execute TeamOrchestrator (pipeline mode)
|
|
5. Send final synthesis as final_answer
|
|
6. Persist user task + final result to session history
|
|
"""
|
|
from agentkit.experts.router import ExpertTeamRouter
|
|
from agentkit.experts.team import ExpertTeam
|
|
from agentkit.experts.orchestrator import TeamOrchestrator
|
|
|
|
app_state = websocket.app.state
|
|
|
|
# Resolve ExpertTemplateRegistry from app.state (loaded at startup)
|
|
template_registry = getattr(app_state, "expert_template_registry", None)
|
|
if template_registry is None:
|
|
from agentkit.experts.registry import ExpertTemplateRegistry
|
|
|
|
template_registry = ExpertTemplateRegistry()
|
|
|
|
team_router = ExpertTeamRouter(template_registry=template_registry)
|
|
routing_result = team_router.resolve(content)
|
|
|
|
if not routing_result.matched:
|
|
return False # Not a @team input, continue normal pipeline
|
|
|
|
if not routing_result.task_content:
|
|
await websocket.send_json(
|
|
{
|
|
"type": "error",
|
|
"data": {"message": "团队任务需要一个描述,例如:@team 开发用户登录功能"},
|
|
}
|
|
)
|
|
return True
|
|
|
|
# Resolve expert configs from specified experts or default dev_team template
|
|
expert_configs = team_router.resolve_expert_configs(routing_result.specified_experts)
|
|
if not expert_configs:
|
|
await websocket.send_json(
|
|
{"type": "error", "data": {"message": "无法解析团队成员,请检查专家名称或模板配置"}}
|
|
)
|
|
return True
|
|
|
|
# Split configs: first is lead, rest are members (V2 verification)
|
|
lead_config = expert_configs[0]
|
|
member_configs = expert_configs[1:] if len(expert_configs) > 1 else []
|
|
|
|
# Create ExpertTeam
|
|
team = ExpertTeam(
|
|
pool=app_state.agent_pool,
|
|
template_registry=template_registry,
|
|
)
|
|
|
|
# Register handoff_transport handler to relay team events to WebSocket
|
|
async def _relay_team_event(message: dict) -> None:
|
|
msg_type = message.get("type")
|
|
if not msg_type:
|
|
return
|
|
# Strip internal fields, keep only event data
|
|
event_data = {k: v for k, v in message.items() if k != "type"}
|
|
await emit_team_event(websocket, msg_type, event_data)
|
|
|
|
team.handoff_transport.register_handler(team.team_channel, _relay_team_event)
|
|
|
|
try:
|
|
# Append user task to session history (inside try so we can compensate on error)
|
|
await sm.append_message(
|
|
session_id=session_id,
|
|
role=MessageRole.USER,
|
|
content=content,
|
|
)
|
|
|
|
await team.create_team(lead_config=lead_config, member_configs=member_configs)
|
|
orchestrator = TeamOrchestrator(team=team)
|
|
# U4: Register active team so WS messages during execution route as interventions
|
|
_register_active_team(session_id, team)
|
|
result = await orchestrator.execute(routing_result.task_content)
|
|
except asyncio.CancelledError:
|
|
logger.info(f"Team collaboration cancelled for session {session_id}")
|
|
await websocket.send_json({"type": "error", "data": {"message": "团队协作已取消"}})
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Team collaboration failed for session {session_id}: {e}", exc_info=True)
|
|
await websocket.send_json(
|
|
{"type": "error", "data": {"message": f"团队协作执行失败: {str(e)[:200]}"}}
|
|
)
|
|
return True
|
|
finally:
|
|
# U4: Always unregister the active team first so subsequent messages
|
|
# don't route to a dissolving team.
|
|
_unregister_active_team(session_id)
|
|
# Always dissolve the team and remove handler to avoid leaks
|
|
try:
|
|
await team.dissolve()
|
|
except Exception as e:
|
|
logger.warning(f"Team dissolve failed: {e}")
|
|
# dissolve() already clears handlers via handoff_transport.close()
|
|
|
|
# Build final answer text from synthesis result
|
|
final_result = result.get("result") or {}
|
|
final_content = (
|
|
final_result.get("content", "") if isinstance(final_result, dict) else str(final_result)
|
|
)
|
|
|
|
if not final_content:
|
|
# Fallback: use phase results if synthesis is empty
|
|
phase_results = result.get("phase_results") or {}
|
|
if phase_results:
|
|
parts = []
|
|
# Build a phase_id -> phase_name lookup from the plan
|
|
phase_names = {}
|
|
plan_obj = result.get("plan")
|
|
if plan_obj:
|
|
for ph in plan_obj.phases:
|
|
phase_names[ph.id] = ph.name
|
|
for phase_id, pr in phase_results.items():
|
|
if isinstance(pr, dict) and "content" in pr:
|
|
parts.append(
|
|
f"### {phase_names.get(phase_id, pr.get('phase_name', phase_id))}\n\n{pr['content']}"
|
|
)
|
|
final_content = "\n\n".join(parts) if parts else "团队执行完成,但未生成最终结果。"
|
|
else:
|
|
final_content = "团队执行完成,但未生成最终结果。"
|
|
|
|
await websocket.send_json(
|
|
{
|
|
"type": "final_answer",
|
|
"content": final_content,
|
|
"is_final": True,
|
|
}
|
|
)
|
|
|
|
# Persist final synthesis as assistant message
|
|
await sm.append_message(
|
|
session_id=session_id,
|
|
role=MessageRole.ASSISTANT,
|
|
content=final_content,
|
|
agent_name="team_collab",
|
|
)
|
|
|
|
return True
|
|
|
|
|
|
def _session_to_response(session) -> SessionResponse:
|
|
return SessionResponse(
|
|
session_id=session.session_id,
|
|
agent_name=session.agent_name,
|
|
status=session.status.value,
|
|
metadata=session.metadata,
|
|
created_at=session.created_at.isoformat(),
|
|
updated_at=session.updated_at.isoformat(),
|
|
)
|
|
|
|
|
|
def _message_to_response(msg) -> MessageResponse:
|
|
return MessageResponse(
|
|
message_id=msg.message_id,
|
|
session_id=msg.session_id,
|
|
role=msg.role.value,
|
|
content=msg.content,
|
|
tool_call_id=msg.tool_call_id,
|
|
agent_name=msg.agent_name,
|
|
created_at=msg.created_at.isoformat(),
|
|
)
|
|
|
|
|
|
# ── REST endpoints ────────────────────────────────────────────────────
|
|
|
|
|
|
@router.get("/sessions", response_model=list[SessionResponse])
|
|
async def list_sessions(req: Request):
|
|
"""List all chat sessions."""
|
|
sm = _get_session_manager(req)
|
|
sessions = await sm.list_sessions()
|
|
return [_session_to_response(s) for s in sessions]
|
|
|
|
|
|
@router.post("/sessions", response_model=SessionResponse)
|
|
async def create_session(request: CreateSessionRequest, req: Request):
|
|
"""Create a new chat session bound to an Agent."""
|
|
sm = _get_session_manager(req)
|
|
session = await sm.create_session(
|
|
agent_name=request.agent_name,
|
|
metadata=request.metadata,
|
|
)
|
|
return _session_to_response(session)
|
|
|
|
|
|
@router.get("/sessions/{session_id}", response_model=SessionResponse)
|
|
async def get_session(session_id: str, req: Request):
|
|
"""Get session information."""
|
|
sm = _get_session_manager(req)
|
|
session = await sm.get_session(session_id)
|
|
if session is None:
|
|
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
|
|
return _session_to_response(session)
|
|
|
|
|
|
@router.get("/sessions/{session_id}/messages", response_model=list[MessageResponse])
|
|
async def get_messages(session_id: str, req: Request, limit: int | None = None, offset: int = 0):
|
|
"""Get conversation history for a session."""
|
|
sm = _get_session_manager(req)
|
|
session = await sm.get_session(session_id)
|
|
if session is None:
|
|
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
|
|
messages = await sm.get_messages(session_id, limit=limit, offset=offset)
|
|
return [_message_to_response(m) for m in messages]
|
|
|
|
|
|
@router.post("/sessions/{session_id}/messages", response_model=MessageResponse)
|
|
async def send_message(session_id: str, request: SendMessageRequest, req: Request):
|
|
"""Send a message to the Agent (synchronous mode — waits for full reply)."""
|
|
sm = _get_session_manager(req)
|
|
session = await sm.get_session(session_id)
|
|
if session is None:
|
|
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
|
|
if session.status == SessionStatus.CLOSED:
|
|
raise HTTPException(status_code=400, detail=f"Session '{session_id}' is closed")
|
|
|
|
# Append user message
|
|
await sm.append_message(
|
|
session_id=session_id,
|
|
role=MessageRole.USER,
|
|
content=request.content,
|
|
)
|
|
|
|
# Get full conversation history for the Agent
|
|
chat_messages = await sm.get_chat_messages(session_id)
|
|
|
|
# Resolve the Agent
|
|
pool = req.app.state.agent_pool
|
|
agent = pool.get_agent(session.agent_name)
|
|
if agent is None:
|
|
raise HTTPException(status_code=404, detail=f"Agent '{session.agent_name}' not found")
|
|
|
|
# Execute the Agent
|
|
try:
|
|
# Reuse Agent's ReActEngine if available (U2: Chat pipeline optimization)
|
|
react_engine = getattr(agent, "_react_engine", None)
|
|
if react_engine is None:
|
|
react_engine = ReActEngine(llm_gateway=req.app.state.llm_gateway)
|
|
else:
|
|
react_engine.reset()
|
|
tools = agent._tool_registry.list_tools() if agent._tool_registry else []
|
|
system_prompt = getattr(agent, "_system_prompt", None) or (
|
|
agent.get_system_prompt() if hasattr(agent, "get_system_prompt") else None
|
|
)
|
|
result = await react_engine.execute(
|
|
messages=chat_messages,
|
|
tools=tools,
|
|
model=agent.get_model()
|
|
if hasattr(agent, "get_model")
|
|
else getattr(agent, "_llm_model", "default"),
|
|
agent_name=agent.name,
|
|
system_prompt=system_prompt,
|
|
)
|
|
|
|
# Append assistant reply
|
|
assistant_msg = await sm.append_message(
|
|
session_id=session_id,
|
|
role=MessageRole.ASSISTANT,
|
|
content=result.output if hasattr(result, "output") else str(result),
|
|
agent_name=agent.name,
|
|
)
|
|
return _message_to_response(assistant_msg)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Chat execution error for session {session_id}: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.delete("/sessions/{session_id}")
|
|
async def close_session(session_id: str, req: Request):
|
|
"""Close a chat session."""
|
|
sm = _get_session_manager(req)
|
|
session = await sm.close_session(session_id)
|
|
if session is None:
|
|
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
|
|
return {"status": "closed", "session_id": session_id}
|
|
|
|
|
|
# ── WebSocket endpoint ────────────────────────────────────────────────
|
|
|
|
|
|
async def _resolve_ws_dept_context(
|
|
websocket: WebSocket,
|
|
) -> tuple[str | None, list[str], Path | None]:
|
|
"""Resolve user_id, department_ids, and db_path for quota enforcement.
|
|
|
|
Reads ``websocket.state.current_user`` (set by :class:`AuthMiddleware`
|
|
when JWT auth via ``?token=`` query param succeeds). If no user context
|
|
is available (API key auth), returns ``(None, [], None)`` — quota is
|
|
not enforced for API key clients.
|
|
|
|
Returns ``(user_id, department_ids, db_path)``:
|
|
- ``user_id``: the authenticated user's id, or ``None``.
|
|
- ``department_ids``: the user's active department ids (empty for
|
|
admins and API key clients).
|
|
- ``db_path``: the auth DB path (needed for quota checks), or ``None``
|
|
if no user context is available.
|
|
"""
|
|
current_user: dict[str, Any] | None = getattr(websocket.state, "current_user", None)
|
|
if current_user is None:
|
|
return None, [], None
|
|
|
|
user_id = current_user.get("user_id")
|
|
role = current_user.get("role")
|
|
|
|
# Admins and API-key clients (user_id=None) — no quota enforcement.
|
|
if role == "admin" or not user_id:
|
|
return user_id, [], None
|
|
|
|
# Regular user: look up their active department ids.
|
|
from agentkit.server.admin.context import _fetch_user_department_ids
|
|
|
|
db_path = getattr(websocket.app.state, "auth_db_path", None)
|
|
if db_path is None:
|
|
db_path_resolved = None
|
|
else:
|
|
db_path_resolved = Path(db_path)
|
|
try:
|
|
department_ids = await _fetch_user_department_ids(db_path_resolved, user_id)
|
|
except Exception:
|
|
logger.exception(
|
|
"Failed to fetch department ids for WebSocket user %s — fail-closed",
|
|
user_id,
|
|
)
|
|
# Fail-closed: return empty list with db_path=None so gateway
|
|
# skips quota enforcement but we log the failure. The user
|
|
# gets a degraded experience rather than a security bypass.
|
|
return user_id, [], None
|
|
return user_id, department_ids, db_path_resolved
|
|
|
|
return user_id, [], None
|
|
|
|
|
|
@router.websocket("/ws/{session_id}")
|
|
async def chat_websocket(websocket: WebSocket, session_id: str) -> None:
|
|
"""WebSocket endpoint for real-time chat with streaming.
|
|
|
|
Client → Server messages:
|
|
{"type": "message", "content": "..."} — Send a user message
|
|
{"type": "cancel"} — Cancel current execution
|
|
{"type": "ping"} — Heartbeat
|
|
|
|
Server → Client messages:
|
|
{"type": "connected", "session_id": "..."} — Connection confirmed
|
|
{"type": "token", "content": "..."} — LLM token streaming
|
|
{"type": "step", "data": {...}} — ReAct step event
|
|
{"type": "ask_human", "question": "...", "request_id": "..."} — Agent asks user
|
|
{"type": "final_answer", "content": "..."} — Agent's final reply
|
|
{"type": "error", "data": {"message": "..."}} — Error occurred
|
|
{"type": "pong"} — Heartbeat response
|
|
|
|
Expert Team events (Server → Client):
|
|
{"type": "team_formed", "data": {...}} — Team assembled
|
|
{"type": "expert_step", "data": {...}} — Expert executing step
|
|
{"type": "expert_result", "data": {...}} — Expert produced result
|
|
{"type": "plan_update", "data": {...}} — Plan phases updated
|
|
{"type": "team_synthesis", "data": {...}} — Final synthesis
|
|
{"type": "team_dissolved", "data": {...}} — Team dissolved
|
|
"""
|
|
# Authentication
|
|
configured_api_key: str | None = None
|
|
if hasattr(websocket.app.state, "server_config") and websocket.app.state.server_config:
|
|
configured_api_key = websocket.app.state.server_config.api_key
|
|
if configured_api_key is None and hasattr(websocket.app.state, "api_key"):
|
|
configured_api_key = websocket.app.state.api_key
|
|
|
|
if configured_api_key:
|
|
provided = websocket.query_params.get("api_key")
|
|
if not provided or not hmac.compare_digest(provided, configured_api_key):
|
|
await websocket.accept()
|
|
await websocket.send_json({"type": "error", "data": {"message": "Invalid api_key"}})
|
|
await websocket.close(code=4001, reason="Invalid api_key")
|
|
return
|
|
|
|
await websocket.accept()
|
|
|
|
# Validate session
|
|
sm: SessionManager = websocket.app.state.session_manager
|
|
session = await sm.get_session(session_id)
|
|
if session is None:
|
|
await websocket.send_json(
|
|
{"type": "error", "data": {"message": f"Session '{session_id}' not found"}}
|
|
)
|
|
await websocket.close(code=1000, reason="Session not found")
|
|
return
|
|
|
|
if session.status == SessionStatus.CLOSED:
|
|
await websocket.send_json({"type": "error", "data": {"message": "Session is closed"}})
|
|
await websocket.close(code=1000, reason="Session closed")
|
|
return
|
|
|
|
# Track pending replies for AskHumanTool and confirmations
|
|
pending_replies: dict[str, asyncio.Future] = {}
|
|
pending_confirmations: dict[str, asyncio.Future] = {}
|
|
chat_manager.add(session_id, websocket, pending_replies)
|
|
|
|
cancellation_token = CancellationToken()
|
|
|
|
# Per-session concurrency guard to prevent unlimited task creation (DoS mitigation)
|
|
_MAX_CONCURRENT_TASKS = 4
|
|
active_tasks: set[asyncio.Task] = set()
|
|
|
|
try:
|
|
await websocket.send_json({"type": "connected", "session_id": session_id})
|
|
|
|
# Listen for client messages
|
|
while True:
|
|
try:
|
|
raw = await asyncio.wait_for(websocket.receive_text(), timeout=300.0)
|
|
except asyncio.TimeoutError:
|
|
await websocket.send_json({"type": "pong"})
|
|
continue
|
|
|
|
try:
|
|
msg = json.loads(raw)
|
|
except json.JSONDecodeError:
|
|
continue
|
|
|
|
msg_type = msg.get("type")
|
|
|
|
if msg_type == "message":
|
|
content = msg.get("content", "")
|
|
model = msg.get("model") # Optional model override from frontend
|
|
|
|
# U4: If a team is currently executing for this session, route
|
|
# the message as an intervention instead of a new chat task.
|
|
active_team = _get_active_team(session_id)
|
|
if active_team is not None:
|
|
try:
|
|
await active_team.add_user_intervention(content)
|
|
await websocket.send_json(
|
|
{
|
|
"type": "team_intervention_ack",
|
|
"data": {"content": content},
|
|
}
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to enqueue intervention: {e}")
|
|
await websocket.send_json(
|
|
{
|
|
"type": "error",
|
|
"data": {"message": f"干预消息入队失败: {e}"},
|
|
}
|
|
)
|
|
continue
|
|
|
|
# Create a fresh CancellationToken for each message
|
|
message_token = CancellationToken()
|
|
|
|
# Guard against unlimited concurrent tasks
|
|
# Clean up completed tasks first
|
|
active_tasks.difference_update(t for t in active_tasks if t.done())
|
|
if len(active_tasks) >= _MAX_CONCURRENT_TASKS:
|
|
await websocket.send_json(
|
|
{
|
|
"type": "error",
|
|
"data": {
|
|
"message": "Too many concurrent requests. Please wait for the current task to complete."
|
|
},
|
|
}
|
|
)
|
|
continue
|
|
|
|
# Run in background task so the WebSocket receive loop stays free
|
|
# to process confirmation_reply / reply messages while the agent
|
|
# is waiting for user confirmation (otherwise deadlock).
|
|
task = asyncio.create_task(
|
|
_handle_chat_message(
|
|
websocket,
|
|
session_id,
|
|
content,
|
|
sm,
|
|
message_token,
|
|
pending_replies,
|
|
pending_confirmations,
|
|
model_override=model,
|
|
)
|
|
)
|
|
active_tasks.add(task)
|
|
task.add_done_callback(active_tasks.discard)
|
|
|
|
elif msg_type == "reply":
|
|
# Reply to AskHumanTool
|
|
request_id = msg.get("request_id")
|
|
reply_content = msg.get("content", "")
|
|
if request_id and request_id in pending_replies:
|
|
pending_replies[request_id].set_result(reply_content)
|
|
|
|
elif msg_type == "confirmation_reply":
|
|
# Reply to confirmation request
|
|
confirmation_id = msg.get("confirmation_id")
|
|
approved = msg.get("approved", False)
|
|
logger.info(
|
|
f"Received confirmation_reply: id={confirmation_id!r}, approved={approved}"
|
|
)
|
|
if confirmation_id and confirmation_id in pending_confirmations:
|
|
pending_confirmations[confirmation_id].set_result(approved)
|
|
logger.info(f"Confirmation {confirmation_id} set_result({approved})")
|
|
else:
|
|
logger.warning(
|
|
f"Confirmation {confirmation_id!r} not found in pending_confirmations"
|
|
)
|
|
|
|
elif msg_type == "cancel":
|
|
cancellation_token.cancel()
|
|
await websocket.send_json({"type": "result", "data": {"status": "cancelled"}})
|
|
|
|
elif msg_type == "ping":
|
|
await websocket.send_json({"type": "pong"})
|
|
|
|
except WebSocketDisconnect:
|
|
logger.debug(f"Chat WebSocket disconnected for session {session_id}")
|
|
except Exception as e:
|
|
logger.error(f"Chat WebSocket error for session {session_id}: {e}")
|
|
try:
|
|
await websocket.send_json({"type": "error", "data": {"message": str(e)}})
|
|
except Exception:
|
|
pass
|
|
finally:
|
|
# Clean up pending futures
|
|
for fut in pending_replies.values():
|
|
if not fut.done():
|
|
fut.cancel()
|
|
for fut in pending_confirmations.values():
|
|
if not fut.done():
|
|
fut.cancel()
|
|
chat_manager.remove(session_id, websocket)
|
|
|
|
|
|
async def _handle_chat_message(
|
|
websocket: WebSocket,
|
|
session_id: str,
|
|
content: str,
|
|
sm: SessionManager,
|
|
cancellation_token: CancellationToken,
|
|
pending_replies: dict[str, asyncio.Future],
|
|
pending_confirmations: dict[str, asyncio.Future] | None = None,
|
|
model_override: str | None = None,
|
|
) -> None:
|
|
"""Handle a user message: append to session, execute Agent, stream events.
|
|
|
|
Uses RequestPreprocessor for minimal preprocessing: @skill prefix + greeting regex + REACT.
|
|
|
|
Board Meeting mode: @board prefix is intercepted before RequestPreprocessor
|
|
and routed to BoardOrchestrator for multi-round group discussion.
|
|
|
|
Team Collaboration mode: @team prefix is intercepted before RequestPreprocessor
|
|
and routed to TeamOrchestrator for pipeline-based expert collaboration.
|
|
"""
|
|
from agentkit.chat.request_preprocessor import RequestPreprocessor
|
|
|
|
# Board Meeting mode: intercept @board prefix before any other preprocessing
|
|
if await _execute_board_meeting(websocket, session_id, content, sm):
|
|
return
|
|
|
|
# Team Collaboration mode: intercept @team prefix before any other preprocessing
|
|
if await _execute_team_collab(websocket, session_id, content, sm):
|
|
return
|
|
|
|
# Resolve Agent first (needed for default tools/prompt)
|
|
pool = websocket.app.state.agent_pool
|
|
session = await sm.get_session(session_id)
|
|
if session is None:
|
|
await websocket.send_json({"type": "error", "data": {"message": "Session lost"}})
|
|
return
|
|
|
|
agent = pool.get_agent(session.agent_name)
|
|
if agent is None:
|
|
await websocket.send_json(
|
|
{"type": "error", "data": {"message": f"Agent '{session.agent_name}' not found"}}
|
|
)
|
|
return
|
|
|
|
# Default execution parameters from agent
|
|
default_tools = agent._tool_registry.list_tools() if agent._tool_registry else []
|
|
default_system_prompt = getattr(agent, "_system_prompt", None) or (
|
|
agent.get_system_prompt() if hasattr(agent, "get_system_prompt") else None
|
|
)
|
|
default_model = (
|
|
agent.get_model()
|
|
if hasattr(agent, "get_model")
|
|
else getattr(agent, "_llm_model", "default")
|
|
)
|
|
|
|
# Resolve skill routing using RequestPreprocessor
|
|
skill_registry = getattr(websocket.app.state, "skill_registry", None)
|
|
request_preprocessor: RequestPreprocessor = websocket.app.state.request_preprocessor
|
|
|
|
routing = await request_preprocessor.preprocess(
|
|
content=content,
|
|
skill_registry=skill_registry,
|
|
default_tools=default_tools,
|
|
default_system_prompt=default_system_prompt,
|
|
default_model=default_model,
|
|
default_agent_name=agent.name,
|
|
)
|
|
|
|
# Debug: log tools that will be passed to ReActEngine
|
|
tool_names = [t.name for t in routing.tools]
|
|
logger.info(
|
|
f"Chat {session_id}: resolved {len(routing.tools)} tools: {tool_names}, model={routing.model}, skill={routing.skill_name}"
|
|
)
|
|
|
|
# Apply model override from frontend selector
|
|
if model_override:
|
|
routing.model = model_override
|
|
|
|
# Notify frontend about skill match
|
|
if routing.matched:
|
|
await websocket.send_json(
|
|
{
|
|
"type": "skill_match",
|
|
"data": {
|
|
"skill": routing.skill_name,
|
|
"method": routing.match_method,
|
|
"confidence": routing.match_confidence,
|
|
},
|
|
}
|
|
)
|
|
|
|
# Append user message (use clean_content if @skill: prefix was stripped)
|
|
await sm.append_message(
|
|
session_id=session_id, role=MessageRole.USER, content=routing.clean_content
|
|
)
|
|
|
|
# Get full conversation history
|
|
chat_messages = await sm.get_chat_messages(session_id)
|
|
|
|
# Handle DIRECT_CHAT: direct LLM call, no ReAct loop
|
|
if routing.execution_mode == ExecutionMode.DIRECT_CHAT:
|
|
direct_messages = []
|
|
if routing.system_prompt:
|
|
direct_messages.append({"role": "system", "content": routing.system_prompt})
|
|
direct_messages.extend(chat_messages)
|
|
# Resolve department context for quota enforcement (U4/KTD-3).
|
|
ws_user_id, ws_dept_ids, ws_db_path = await _resolve_ws_dept_context(websocket)
|
|
try:
|
|
response = await websocket.app.state.llm_gateway.chat(
|
|
messages=direct_messages,
|
|
model=routing.model or "default",
|
|
agent_name=agent.name,
|
|
task_type="chat",
|
|
user_id=ws_user_id,
|
|
department_ids=ws_dept_ids if ws_dept_ids else None,
|
|
db_path=ws_db_path,
|
|
)
|
|
final_content = response.content or ""
|
|
if not final_content or not final_content.strip():
|
|
from agentkit.core.fallback import EMPTY_LLM_RESPONSE
|
|
|
|
final_content = EMPTY_LLM_RESPONSE
|
|
await websocket.send_json(
|
|
{
|
|
"type": "final_answer",
|
|
"content": final_content,
|
|
"is_final": True,
|
|
}
|
|
)
|
|
await sm.append_message(
|
|
session_id=session_id,
|
|
role=MessageRole.ASSISTANT,
|
|
content=final_content,
|
|
agent_name=agent.name,
|
|
)
|
|
except Exception as e:
|
|
# Check if this is a QuotaExceededError (U4: WebSocket quota).
|
|
from agentkit.llm.gateway import QuotaExceededError
|
|
|
|
if isinstance(e, QuotaExceededError):
|
|
logger.warning(
|
|
"WebSocket DIRECT_CHAT quota exceeded for session %s: %s",
|
|
session_id,
|
|
e,
|
|
)
|
|
await websocket.send_json(
|
|
{
|
|
"type": "error",
|
|
"data": {
|
|
"message": "quota exceeded",
|
|
"quota_info": {
|
|
"department_id": e.department_id,
|
|
"quota_type": e.quota_type,
|
|
"period": e.period,
|
|
"limit": e.limit,
|
|
"current": e.current,
|
|
},
|
|
},
|
|
}
|
|
)
|
|
else:
|
|
logger.error(f"Chat DIRECT_CHAT error for session {session_id}: {e}")
|
|
await websocket.send_json({"type": "error", "data": {"message": str(e)[:200]}})
|
|
return
|
|
|
|
# Handle advanced execution modes: REWOO/REFLEXION/PLAN_EXEC/TEAM_COLLAB
|
|
# currently fall back to REACT with a warning.
|
|
if routing.execution_mode not in (ExecutionMode.REACT, ExecutionMode.SKILL_REACT):
|
|
logger.warning(
|
|
f"Execution mode {routing.execution_mode.value} not yet supported "
|
|
f"in chat WebSocket, falling back to REACT"
|
|
)
|
|
|
|
# Execute Agent with streaming
|
|
# Reuse Agent's ReActEngine if available (U2: Chat pipeline optimization)
|
|
react_engine = getattr(agent, "_react_engine", None)
|
|
if react_engine is None:
|
|
react_engine = ReActEngine(llm_gateway=websocket.app.state.llm_gateway)
|
|
else:
|
|
react_engine.reset()
|
|
|
|
# Create confirmation handler that sends request to frontend and waits for reply
|
|
# Use the same dict object — do NOT use `or {}` because an empty dict is falsy
|
|
# and would create a new dict, breaking the shared state with the WS loop.
|
|
_pending_confirmations = pending_confirmations if pending_confirmations is not None else {}
|
|
|
|
async def _confirmation_handler(confirmation_id: str, command: str, reason: str) -> bool:
|
|
"""Send confirmation request to frontend via WebSocket and wait for user reply."""
|
|
# Send confirmation request to frontend
|
|
await websocket.send_json(
|
|
{
|
|
"type": "confirmation_request",
|
|
"data": {
|
|
"confirmation_id": confirmation_id,
|
|
"command": command,
|
|
"reason": reason,
|
|
},
|
|
}
|
|
)
|
|
|
|
# Create a Future and wait for the user's reply
|
|
loop = asyncio.get_running_loop()
|
|
future: asyncio.Future[bool] = loop.create_future()
|
|
_pending_confirmations[confirmation_id] = future
|
|
logger.info(f"Confirmation request {confirmation_id} sent, waiting for reply")
|
|
|
|
try:
|
|
# Wait up to 5 minutes for user confirmation
|
|
result = await asyncio.wait_for(future, timeout=300.0)
|
|
logger.info(f"Confirmation request {confirmation_id} resolved: {result}")
|
|
# Immediately notify frontend of the result so the card updates
|
|
# without waiting for the tool to re-execute
|
|
await websocket.send_json(
|
|
{
|
|
"type": "confirmation_result",
|
|
"data": {"confirmation_id": confirmation_id, "approved": result},
|
|
}
|
|
)
|
|
return result
|
|
except asyncio.TimeoutError:
|
|
logger.warning(f"Confirmation request {confirmation_id} timed out")
|
|
return False
|
|
except asyncio.CancelledError:
|
|
logger.warning(f"Confirmation request {confirmation_id} cancelled")
|
|
return False
|
|
finally:
|
|
_pending_confirmations.pop(confirmation_id, None)
|
|
|
|
logger.info(
|
|
f"Chat session {session_id}: executing with {len(routing.tools)} tools, model={routing.model}, skill={routing.skill_name}"
|
|
)
|
|
|
|
try:
|
|
final_content = ""
|
|
token_buffer: list[str] = []
|
|
async for event in react_engine.execute_stream(
|
|
messages=chat_messages,
|
|
tools=routing.tools,
|
|
model=routing.model,
|
|
agent_name=routing.agent_name,
|
|
system_prompt=routing.system_prompt,
|
|
cancellation_token=cancellation_token,
|
|
confirmation_handler=_confirmation_handler,
|
|
):
|
|
if event.event_type == "final_answer":
|
|
# Flush any buffered tokens as a single write
|
|
if token_buffer:
|
|
await websocket.send_json({"type": "token", "content": "".join(token_buffer)})
|
|
token_buffer.clear()
|
|
# Then send final answer
|
|
final_content = event.data.get("output", "")
|
|
if not final_content or not final_content.strip():
|
|
from agentkit.core.fallback import EMPTY_LLM_RESPONSE
|
|
|
|
final_content = EMPTY_LLM_RESPONSE
|
|
await websocket.send_json(
|
|
{
|
|
"type": "final_answer",
|
|
"content": final_content,
|
|
"is_final": True,
|
|
}
|
|
)
|
|
elif event.event_type == "token":
|
|
# Buffer tokens instead of sending immediately
|
|
token_buffer.append(event.data.get("content", ""))
|
|
elif event.event_type == "thinking":
|
|
# If we have buffered tokens, convert them to a thinking event
|
|
if token_buffer:
|
|
buffered_text = "".join(token_buffer)
|
|
token_buffer.clear()
|
|
await websocket.send_json({"type": "thinking", "content": buffered_text})
|
|
# Also send the thinking event content
|
|
thinking_msg = event.data.get("message", "")
|
|
if thinking_msg:
|
|
await websocket.send_json({"type": "thinking", "content": thinking_msg})
|
|
elif event.event_type == "tool_call":
|
|
# Convert buffered tokens to thinking (they were "thinking" text before tool call)
|
|
if token_buffer:
|
|
buffered_text = "".join(token_buffer)
|
|
token_buffer.clear()
|
|
await websocket.send_json({"type": "thinking", "content": buffered_text})
|
|
await websocket.send_json(
|
|
{
|
|
"type": "step",
|
|
"data": {
|
|
"event_type": event.event_type,
|
|
"step": event.step,
|
|
"data": event.data,
|
|
},
|
|
}
|
|
)
|
|
elif event.event_type == "confirmation_request":
|
|
pass
|
|
elif event.event_type == "confirmation_result":
|
|
await websocket.send_json(
|
|
{
|
|
"type": "confirmation_result",
|
|
"data": event.data,
|
|
}
|
|
)
|
|
else:
|
|
await websocket.send_json(
|
|
{
|
|
"type": "step",
|
|
"data": {
|
|
"event_type": event.event_type,
|
|
"step": event.step,
|
|
"data": event.data,
|
|
},
|
|
}
|
|
)
|
|
|
|
# Append assistant reply to session
|
|
if final_content:
|
|
await sm.append_message(
|
|
session_id=session_id,
|
|
role=MessageRole.ASSISTANT,
|
|
content=final_content,
|
|
agent_name=agent.name,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Chat execution error for session {session_id}: {e}")
|
|
# Show meaningful error to user, but avoid leaking full stack traces
|
|
error_msg = str(e)
|
|
# Truncate very long error messages
|
|
if len(error_msg) > 200:
|
|
error_msg = error_msg[:200] + "..."
|
|
await websocket.send_json({"type": "error", "data": {"message": error_msg}})
|
|
|
|
|
|
# ── File upload ───────────────────────────────────────────────────────
|
|
|
|
|
|
MAX_UPLOAD_SIZE = 10 * 1024 * 1024 # 10 MB
|
|
UPLOAD_DIR = Path(os.environ.get("AGENTKIT_UPLOAD_DIR", "data/uploads"))
|
|
|
|
|
|
def _ensure_upload_dir() -> Path:
|
|
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
|
return UPLOAD_DIR
|
|
|
|
|
|
def _sanitize_filename(name: str) -> str:
|
|
"""Remove path separators and keep only safe characters."""
|
|
name = name.replace("\\", "_").replace("/", "_")
|
|
return "".join(c for c in name if c.isalnum() or c in "._-").strip(".")
|
|
|
|
|
|
@router.post("/upload")
|
|
async def upload_chat_file(file: UploadFile = File(...)) -> dict[str, Any]:
|
|
"""Upload a file to be referenced in chat messages.
|
|
|
|
Returns metadata including a public download URL.
|
|
"""
|
|
if file.size is not None and file.size > MAX_UPLOAD_SIZE:
|
|
raise HTTPException(status_code=413, detail="File exceeds 10 MB limit")
|
|
|
|
original_name = file.filename or "unnamed"
|
|
safe_name = _sanitize_filename(original_name) or "unnamed"
|
|
ext = Path(safe_name).suffix
|
|
stored_name = f"{uuid.uuid4().hex}{ext}"
|
|
upload_dir = _ensure_upload_dir()
|
|
file_path = upload_dir / stored_name
|
|
|
|
try:
|
|
contents = await file.read()
|
|
if len(contents) > MAX_UPLOAD_SIZE:
|
|
raise HTTPException(status_code=413, detail="File exceeds 10 MB limit")
|
|
file_path.write_bytes(contents)
|
|
except HTTPException:
|
|
raise
|
|
except Exception as exc:
|
|
logger.error(f"Failed to save uploaded file: {exc}")
|
|
raise HTTPException(status_code=500, detail="Failed to save file") from exc
|
|
finally:
|
|
await file.close()
|
|
|
|
return {
|
|
"filename": original_name,
|
|
"stored_name": stored_name,
|
|
"content_type": file.content_type or "application/octet-stream",
|
|
"size": file_path.stat().st_size,
|
|
"download_url": f"/api/v1/chat/uploads/{stored_name}",
|
|
}
|
|
|
|
|
|
@router.get("/uploads/{filename}")
|
|
async def download_chat_file(filename: str) -> FileResponse:
|
|
"""Download an uploaded chat file by its stored filename."""
|
|
upload_dir = _ensure_upload_dir()
|
|
safe_filename = _sanitize_filename(filename)
|
|
file_path = upload_dir / safe_filename
|
|
if not file_path.exists() or not file_path.is_file():
|
|
raise HTTPException(status_code=404, detail="File not found")
|
|
return FileResponse(file_path, filename=safe_filename)
|