//! WebSocket relay handlers //! //! Handles WebSocket connections from agents and viewers, //! relaying video frames and input events between them. use axum::{ extract::{ ws::{Message, WebSocket, WebSocketUpgrade}, Query, State, }, response::IntoResponse, }; use futures_util::{SinkExt, StreamExt}; use prost::Message as ProstMessage; use serde::Deserialize; use tracing::{error, info, warn}; use uuid::Uuid; use crate::proto; use crate::session::SessionManager; use crate::AppState; #[derive(Debug, Deserialize)] pub struct AgentParams { agent_id: String, #[serde(default)] agent_name: Option, #[serde(default)] support_code: Option, #[serde(default)] hostname: Option, } #[derive(Debug, Deserialize)] pub struct ViewerParams { session_id: String, #[serde(default = "default_viewer_name")] viewer_name: String, } fn default_viewer_name() -> String { "Technician".to_string() } /// WebSocket handler for agent connections pub async fn agent_ws_handler( ws: WebSocketUpgrade, State(state): State, Query(params): Query, ) -> impl IntoResponse { let agent_id = params.agent_id; let agent_name = params.hostname.or(params.agent_name).unwrap_or_else(|| agent_id.clone()); let support_code = params.support_code; let sessions = state.sessions.clone(); let support_codes = state.support_codes.clone(); ws.on_upgrade(move |socket| handle_agent_connection(socket, sessions, support_codes, agent_id, agent_name, support_code)) } /// WebSocket handler for viewer connections pub async fn viewer_ws_handler( ws: WebSocketUpgrade, State(state): State, Query(params): Query, ) -> impl IntoResponse { let session_id = params.session_id; let viewer_name = params.viewer_name; let sessions = state.sessions.clone(); ws.on_upgrade(move |socket| handle_viewer_connection(socket, sessions, session_id, viewer_name)) } /// Handle an agent WebSocket connection async fn handle_agent_connection( socket: WebSocket, sessions: SessionManager, support_codes: crate::support_codes::SupportCodeManager, agent_id: String, agent_name: String, support_code: Option, ) { info!("Agent connected: {} ({})", agent_name, agent_id); let (mut ws_sender, mut ws_receiver) = socket.split(); // If a support code was provided, check if it's valid if let Some(ref code) = support_code { // Check if the code is cancelled or invalid if support_codes.is_cancelled(code).await { warn!("Agent tried to connect with cancelled code: {}", code); // Send disconnect message to agent let disconnect_msg = proto::Message { payload: Some(proto::message::Payload::Disconnect(proto::Disconnect { reason: "Support session was cancelled by technician".to_string(), })), }; let mut buf = Vec::new(); if prost::Message::encode(&disconnect_msg, &mut buf).is_ok() { let _ = ws_sender.send(Message::Binary(buf.into())).await; } let _ = ws_sender.close().await; return; } } // Register the agent and get channels // Persistent agents (no support code) keep their session when disconnected let is_persistent = support_code.is_none(); let (session_id, frame_tx, mut input_rx) = sessions.register_agent(agent_id.clone(), agent_name.clone(), is_persistent).await; info!("Session created: {} (agent in idle mode)", session_id); // If a support code was provided, mark it as connected if let Some(ref code) = support_code { info!("Linking support code {} to session {}", code, session_id); support_codes.mark_connected(code, Some(agent_name.clone()), Some(agent_id.clone())).await; support_codes.link_session(code, session_id).await; } // Use Arc for sender so we can use it from multiple places let ws_sender = std::sync::Arc::new(tokio::sync::Mutex::new(ws_sender)); let ws_sender_input = ws_sender.clone(); let ws_sender_cancel = ws_sender.clone(); // Task to forward input events from viewers to agent let input_forward = tokio::spawn(async move { while let Some(input_data) = input_rx.recv().await { let mut sender = ws_sender_input.lock().await; if sender.send(Message::Binary(input_data.into())).await.is_err() { break; } } }); let sessions_cleanup = sessions.clone(); let sessions_status = sessions.clone(); let support_codes_cleanup = support_codes.clone(); let support_code_cleanup = support_code.clone(); let support_code_check = support_code.clone(); let support_codes_check = support_codes.clone(); // Task to check for cancellation every 2 seconds let cancel_check = tokio::spawn(async move { let mut interval = tokio::time::interval(std::time::Duration::from_secs(2)); loop { interval.tick().await; if let Some(ref code) = support_code_check { if support_codes_check.is_cancelled(code).await { info!("Support code {} was cancelled, disconnecting agent", code); // Send disconnect message let disconnect_msg = proto::Message { payload: Some(proto::message::Payload::Disconnect(proto::Disconnect { reason: "Support session was cancelled by technician".to_string(), })), }; let mut buf = Vec::new(); if prost::Message::encode(&disconnect_msg, &mut buf).is_ok() { let mut sender = ws_sender_cancel.lock().await; let _ = sender.send(Message::Binary(buf.into())).await; let _ = sender.close().await; } break; } } } }); // Main loop: receive messages from agent while let Some(msg) = ws_receiver.next().await { match msg { Ok(Message::Binary(data)) => { // Try to decode as protobuf message match proto::Message::decode(data.as_ref()) { Ok(proto_msg) => { match &proto_msg.payload { Some(proto::message::Payload::VideoFrame(_)) => { // Broadcast frame to all viewers (only sent when streaming) let _ = frame_tx.send(data.to_vec()); } Some(proto::message::Payload::ChatMessage(chat)) => { // Broadcast chat message to all viewers info!("Chat from client: {}", chat.content); let _ = frame_tx.send(data.to_vec()); } Some(proto::message::Payload::AgentStatus(status)) => { // Update session with agent status sessions_status.update_agent_status( session_id, Some(status.os_version.clone()), status.is_elevated, status.uptime_secs, status.display_count, status.is_streaming, ).await; info!("Agent status update: {} - streaming={}, uptime={}s", status.hostname, status.is_streaming, status.uptime_secs); } Some(proto::message::Payload::Heartbeat(_)) => { // Update heartbeat timestamp sessions_status.update_heartbeat(session_id).await; } Some(proto::message::Payload::HeartbeatAck(_)) => { // Agent acknowledged our heartbeat sessions_status.update_heartbeat(session_id).await; } _ => {} } } Err(e) => { warn!("Failed to decode agent message: {}", e); } } } Ok(Message::Close(_)) => { info!("Agent disconnected: {}", agent_id); break; } Ok(Message::Ping(data)) => { // Pong is handled automatically by axum let _ = data; } Ok(_) => {} Err(e) => { error!("WebSocket error from agent {}: {}", agent_id, e); break; } } } // Cleanup input_forward.abort(); cancel_check.abort(); // Mark agent as disconnected (persistent agents stay in list as offline) sessions_cleanup.mark_agent_disconnected(session_id).await; // Mark support code as completed if one was used (unless cancelled) if let Some(ref code) = support_code_cleanup { if !support_codes_cleanup.is_cancelled(code).await { support_codes_cleanup.mark_completed(code).await; info!("Support code {} marked as completed", code); } } info!("Session {} ended", session_id); } /// Handle a viewer WebSocket connection async fn handle_viewer_connection( socket: WebSocket, sessions: SessionManager, session_id_str: String, viewer_name: String, ) { // Parse session ID let session_id = match uuid::Uuid::parse_str(&session_id_str) { Ok(id) => id, Err(_) => { warn!("Invalid session ID: {}", session_id_str); return; } }; // Generate unique viewer ID let viewer_id = Uuid::new_v4().to_string(); // Join the session (this sends StartStream to agent if first viewer) let (mut frame_rx, input_tx) = match sessions.join_session(session_id, viewer_id.clone(), viewer_name.clone()).await { Some(channels) => channels, None => { warn!("Session not found: {}", session_id); return; } }; info!("Viewer {} ({}) joined session: {}", viewer_name, viewer_id, session_id); let (mut ws_sender, mut ws_receiver) = socket.split(); // Task to forward frames from agent to this viewer let frame_forward = tokio::spawn(async move { while let Ok(frame_data) = frame_rx.recv().await { if ws_sender.send(Message::Binary(frame_data.into())).await.is_err() { break; } } }); let sessions_cleanup = sessions.clone(); let viewer_id_cleanup = viewer_id.clone(); // Main loop: receive input from viewer and forward to agent while let Some(msg) = ws_receiver.next().await { match msg { Ok(Message::Binary(data)) => { // Try to decode as protobuf message match proto::Message::decode(data.as_ref()) { Ok(proto_msg) => { match &proto_msg.payload { Some(proto::message::Payload::MouseEvent(_)) | Some(proto::message::Payload::KeyEvent(_)) | Some(proto::message::Payload::SpecialKey(_)) => { // Forward input to agent let _ = input_tx.send(data.to_vec()).await; } Some(proto::message::Payload::ChatMessage(chat)) => { // Forward chat message to agent info!("Chat from technician: {}", chat.content); let _ = input_tx.send(data.to_vec()).await; } _ => {} } } Err(e) => { warn!("Failed to decode viewer message: {}", e); } } } Ok(Message::Close(_)) => { info!("Viewer {} disconnected from session: {}", viewer_id, session_id); break; } Ok(_) => {} Err(e) => { error!("WebSocket error from viewer {}: {}", viewer_id, e); break; } } } // Cleanup (this sends StopStream to agent if last viewer) frame_forward.abort(); sessions_cleanup.leave_session(session_id, &viewer_id_cleanup).await; info!("Viewer {} left session: {}", viewer_id_cleanup, session_id); }