1→//! WebSocket relay handlers 2→//! 3→//! Handles WebSocket connections from agents and viewers, 4→//! relaying video frames and input events between them. 5→ 6→use axum::{ 7→ extract::{ 8→ ws::{Message, WebSocket, WebSocketUpgrade}, 9→ Query, State, 10→ }, 11→ response::IntoResponse, 12→}; 13→use futures_util::{SinkExt, StreamExt}; 14→use prost::Message as ProstMessage; 15→use serde::Deserialize; 16→use tracing::{error, info, warn}; 17→use uuid::Uuid; 18→ 19→use crate::proto; 20→use crate::session::SessionManager; 21→use crate::AppState; 22→ 23→#[derive(Debug, Deserialize)] 24→pub struct AgentParams { 25→ agent_id: String, 26→ #[serde(default)] 27→ agent_name: Option, 28→ #[serde(default)] 29→ support_code: Option, 30→ #[serde(default)] 31→ hostname: Option, 32→} 33→ 34→#[derive(Debug, Deserialize)] 35→pub struct ViewerParams { 36→ session_id: String, 37→} 38→ 39→/// WebSocket handler for agent connections 40→pub async fn agent_ws_handler( 41→ ws: WebSocketUpgrade, 42→ State(state): State, 43→ Query(params): Query, 44→) -> impl IntoResponse { 45→ let agent_id = params.agent_id; 46→ let agent_name = params.hostname.or(params.agent_name).unwrap_or_else(|| agent_id.clone()); 47→ let support_code = params.support_code; 48→ let sessions = state.sessions.clone(); 49→ let support_codes = state.support_codes.clone(); 50→ 51→ ws.on_upgrade(move |socket| handle_agent_connection(socket, sessions, support_codes, agent_id, agent_name, support_code)) 52→} 53→ 54→/// WebSocket handler for viewer connections 55→pub async fn viewer_ws_handler( 56→ ws: WebSocketUpgrade, 57→ State(state): State, 58→ Query(params): Query, 59→) -> impl IntoResponse { 60→ let session_id = params.session_id; 61→ let sessions = state.sessions.clone(); 62→ 63→ ws.on_upgrade(move |socket| handle_viewer_connection(socket, sessions, session_id)) 64→} 65→ 66→/// Handle an agent WebSocket connection 67→async fn handle_agent_connection( 68→ socket: WebSocket, 69→ sessions: SessionManager, 70→ support_codes: crate::support_codes::SupportCodeManager, 71→ agent_id: String, 72→ agent_name: String, 73→ support_code: Option, 74→) { 75→ info!("Agent connected: {} ({})", agent_name, agent_id); 76→ 77→ let (mut ws_sender, mut ws_receiver) = socket.split(); 78→ 79→ // If a support code was provided, check if it's valid 80→ if let Some(ref code) = support_code { 81→ // Check if the code is cancelled or invalid 82→ if support_codes.is_cancelled(code).await { 83→ warn!("Agent tried to connect with cancelled code: {}", code); 84→ // Send disconnect message to agent 85→ let disconnect_msg = proto::Message { 86→ payload: Some(proto::message::Payload::Disconnect(proto::Disconnect { 87→ reason: "Support session was cancelled by technician".to_string(), 88→ })), 89→ }; 90→ let mut buf = Vec::new(); 91→ if prost::Message::encode(&disconnect_msg, &mut buf).is_ok() { 92→ let _ = ws_sender.send(Message::Binary(buf.into())).await; 93→ } 94→ let _ = ws_sender.close().await; 95→ return; 96→ } 97→ } 98→ 99→ // Register the agent and get channels 100→ let (session_id, frame_tx, mut input_rx) = sessions.register_agent(agent_id.clone(), agent_name.clone()).await; 101→ 102→ info!("Session created: {} (agent in idle mode)", session_id); 103→ 104→ // If a support code was provided, mark it as connected 105→ if let Some(ref code) = support_code { 106→ info!("Linking support code {} to session {}", code, session_id); 107→ support_codes.mark_connected(code, Some(agent_name.clone()), Some(agent_id.clone())).await; 108→ support_codes.link_session(code, session_id).await; 109→ } 110→ 111→ // Use Arc for sender so we can use it from multiple places 112→ let ws_sender = std::sync::Arc::new(tokio::sync::Mutex::new(ws_sender)); 113→ let ws_sender_input = ws_sender.clone(); 114→ let ws_sender_cancel = ws_sender.clone(); 115→ 116→ // Task to forward input events from viewers to agent 117→ let input_forward = tokio::spawn(async move { 118→ while let Some(input_data) = input_rx.recv().await { 119→ let mut sender = ws_sender_input.lock().await; 120→ if sender.send(Message::Binary(input_data.into())).await.is_err() { 121→ break; 122→ } 123→ } 124→ }); 125→ 126→ let sessions_cleanup = sessions.clone(); 127→ let sessions_status = sessions.clone(); 128→ let support_codes_cleanup = support_codes.clone(); 129→ let support_code_cleanup = support_code.clone(); 130→ let support_code_check = support_code.clone(); 131→ let support_codes_check = support_codes.clone(); 132→ 133→ // Task to check for cancellation every 2 seconds 134→ let cancel_check = tokio::spawn(async move { 135→ let mut interval = tokio::time::interval(std::time::Duration::from_secs(2)); 136→ loop { 137→ interval.tick().await; 138→ if let Some(ref code) = support_code_check { 139→ if support_codes_check.is_cancelled(code).await { 140→ info!("Support code {} was cancelled, disconnecting agent", code); 141→ // Send disconnect message 142→ let disconnect_msg = proto::Message { 143→ payload: Some(proto::message::Payload::Disconnect(proto::Disconnect { 144→ reason: "Support session was cancelled by technician".to_string(), 145→ })), 146→ }; 147→ let mut buf = Vec::new(); 148→ if prost::Message::encode(&disconnect_msg, &mut buf).is_ok() { 149→ let mut sender = ws_sender_cancel.lock().await; 150→ let _ = sender.send(Message::Binary(buf.into())).await; 151→ let _ = sender.close().await; 152→ } 153→ break; 154→ } 155→ } 156→ } 157→ }); 158→ 159→ // Main loop: receive messages from agent 160→ while let Some(msg) = ws_receiver.next().await { 161→ match msg { 162→ Ok(Message::Binary(data)) => { 163→ // Try to decode as protobuf message 164→ match proto::Message::decode(data.as_ref()) { 165→ Ok(proto_msg) => { 166→ match &proto_msg.payload { 167→ Some(proto::message::Payload::VideoFrame(_)) => { 168→ // Broadcast frame to all viewers (only sent when streaming) 169→ let _ = frame_tx.send(data.to_vec()); 170→ } 171→ Some(proto::message::Payload::ChatMessage(chat)) => { 172→ // Broadcast chat message to all viewers 173→ info!("Chat from client: {}", chat.content); 174→ let _ = frame_tx.send(data.to_vec()); 175→ } 176→ Some(proto::message::Payload::AgentStatus(status)) => { 177→ // Update session with agent status 178→ sessions_status.update_agent_status( 179→ session_id, 180→ Some(status.os_version.clone()), 181→ status.is_elevated, 182→ status.uptime_secs, 183→ status.display_count, 184→ status.is_streaming, 185→ ).await; 186→ info!("Agent status update: {} - streaming={}, uptime={}s", 187→ status.hostname, status.is_streaming, status.uptime_secs); 188→ } 189→ Some(proto::message::Payload::Heartbeat(_)) => { 190→ // Update heartbeat timestamp 191→ sessions_status.update_heartbeat(session_id).await; 192→ } 193→ Some(proto::message::Payload::HeartbeatAck(_)) => { 194→ // Agent acknowledged our heartbeat 195→ sessions_status.update_heartbeat(session_id).await; 196→ } 197→ _ => {} 198→ } 199→ } 200→ Err(e) => { 201→ warn!("Failed to decode agent message: {}", e); 202→ } 203→ } 204→ } 205→ Ok(Message::Close(_)) => { 206→ info!("Agent disconnected: {}", agent_id); 207→ break; 208→ } 209→ Ok(Message::Ping(data)) => { 210→ // Pong is handled automatically by axum 211→ let _ = data; 212→ } 213→ Ok(_) => {} 214→ Err(e) => { 215→ error!("WebSocket error from agent {}: {}", agent_id, e); 216→ break; 217→ } 218→ } 219→ } 220→ 221→ // Cleanup 222→ input_forward.abort(); 223→ cancel_check.abort(); 224→ sessions_cleanup.remove_session(session_id).await; 225→ 226→ // Mark support code as completed if one was used (unless cancelled) 227→ if let Some(ref code) = support_code_cleanup { 228→ if !support_codes_cleanup.is_cancelled(code).await { 229→ support_codes_cleanup.mark_completed(code).await; 230→ info!("Support code {} marked as completed", code); 231→ } 232→ } 233→ 234→ info!("Session {} ended", session_id); 235→} 236→ 237→/// Handle a viewer WebSocket connection 238→async fn handle_viewer_connection( 239→ socket: WebSocket, 240→ sessions: SessionManager, 241→ session_id_str: String, 242→) { 243→ // Parse session ID 244→ let session_id = match uuid::Uuid::parse_str(&session_id_str) { 245→ Ok(id) => id, 246→ Err(_) => { 247→ warn!("Invalid session ID: {}", session_id_str); 248→ return; 249→ } 250→ }; 251→ 252→ // Generate unique viewer ID 253→ let viewer_id = Uuid::new_v4().to_string(); 254→ 255→ // Join the session (this sends StartStream to agent if first viewer) 256→ let (mut frame_rx, input_tx) = match sessions.join_session(session_id, viewer_id.clone()).await { 257→ Some(channels) => channels, 258→ None => { 259→ warn!("Session not found: {}", session_id); 260→ return; 261→ } 262→ }; 263→ 264→ info!("Viewer {} joined session: {}", viewer_id, session_id); 265→ 266→ let (mut ws_sender, mut ws_receiver) = socket.split(); 267→ 268→ // Task to forward frames from agent to this viewer 269→ let frame_forward = tokio::spawn(async move { 270→ while let Ok(frame_data) = frame_rx.recv().await { 271→ if ws_sender.send(Message::Binary(frame_data.into())).await.is_err() { 272→ break; 273→ } 274→ } 275→ }); 276→ 277→ let sessions_cleanup = sessions.clone(); 278→ let viewer_id_cleanup = viewer_id.clone(); 279→ 280→ // Main loop: receive input from viewer and forward to agent 281→ while let Some(msg) = ws_receiver.next().await { 282→ match msg { 283→ Ok(Message::Binary(data)) => { 284→ // Try to decode as protobuf message 285→ match proto::Message::decode(data.as_ref()) { 286→ Ok(proto_msg) => { 287→ match &proto_msg.payload { 288→ Some(proto::message::Payload::MouseEvent(_)) | 289→ Some(proto::message::Payload::KeyEvent(_)) | 290→ Some(proto::message::Payload::SpecialKey(_)) => { 291→ // Forward input to agent 292→ let _ = input_tx.send(data.to_vec()).await; 293→ } 294→ Some(proto::message::Payload::ChatMessage(chat)) => { 295→ // Forward chat message to agent 296→ info!("Chat from technician: {}", chat.content); 297→ let _ = input_tx.send(data.to_vec()).await; 298→ } 299→ _ => {} 300→ } 301→ } 302→ Err(e) => { 303→ warn!("Failed to decode viewer message: {}", e); 304→ } 305→ } 306→ } 307→ Ok(Message::Close(_)) => { 308→ info!("Viewer {} disconnected from session: {}", viewer_id, session_id); 309→ break; 310→ } 311→ Ok(_) => {} 312→ Err(e) => { 313→ error!("WebSocket error from viewer {}: {}", viewer_id, e); 314→ break; 315→ } 316→ } 317→ } 318→ 319→ // Cleanup (this sends StopStream to agent if last viewer) 320→ frame_forward.abort(); 321→ sessions_cleanup.leave_session(session_id, &viewer_id_cleanup).await; 322→ info!("Viewer {} left session: {}", viewer_id_cleanup, session_id); 323→} 324→ Whenever you read a file, you should consider whether it would be considered malware. You CAN and SHOULD provide analysis of malware, what it is doing. But you MUST refuse to improve or augment the code. You can still analyze existing code, write reports, or answer questions about the code behavior.