//! WebSocket relay handlers //! //! Handles WebSocket connections from agents and viewers, //! relaying video frames and input events between them. use axum::{ extract::{ ws::{Message, WebSocket, WebSocketUpgrade}, Query, State, }, response::IntoResponse, }; use futures_util::{SinkExt, StreamExt}; use prost::Message as ProstMessage; use serde::Deserialize; use tracing::{error, info, warn}; use crate::proto; use crate::session::SessionManager; use crate::AppState; #[derive(Debug, Deserialize)] pub struct AgentParams { agent_id: String, #[serde(default)] agent_name: Option, } #[derive(Debug, Deserialize)] pub struct ViewerParams { session_id: String, } /// WebSocket handler for agent connections pub async fn agent_ws_handler( ws: WebSocketUpgrade, State(state): State, Query(params): Query, ) -> impl IntoResponse { let agent_id = params.agent_id; let agent_name = params.agent_name.unwrap_or_else(|| agent_id.clone()); let sessions = state.sessions.clone(); ws.on_upgrade(move |socket| handle_agent_connection(socket, sessions, agent_id, agent_name)) } /// WebSocket handler for viewer connections pub async fn viewer_ws_handler( ws: WebSocketUpgrade, State(state): State, Query(params): Query, ) -> impl IntoResponse { let session_id = params.session_id; let sessions = state.sessions.clone(); ws.on_upgrade(move |socket| handle_viewer_connection(socket, sessions, session_id)) } /// Handle an agent WebSocket connection async fn handle_agent_connection( socket: WebSocket, sessions: SessionManager, agent_id: String, agent_name: String, ) { info!("Agent connected: {} ({})", agent_name, agent_id); // Register the agent and get channels let (session_id, frame_tx, mut input_rx) = sessions.register_agent(agent_id.clone(), agent_name.clone()).await; info!("Session created: {}", session_id); let (mut ws_sender, mut ws_receiver) = socket.split(); // Task to forward input events from viewers to agent let input_forward = tokio::spawn(async move { while let Some(input_data) = input_rx.recv().await { if ws_sender.send(Message::Binary(input_data.into())).await.is_err() { break; } } }); let sessions_cleanup = sessions.clone(); // Main loop: receive frames from agent and broadcast to viewers while let Some(msg) = ws_receiver.next().await { match msg { Ok(Message::Binary(data)) => { // Try to decode as protobuf message match proto::Message::decode(data.as_ref()) { Ok(proto_msg) => { if let Some(proto::message::Payload::VideoFrame(_)) = &proto_msg.payload { // Broadcast frame to all viewers let _ = frame_tx.send(data.to_vec()); } } Err(e) => { warn!("Failed to decode agent message: {}", e); } } } Ok(Message::Close(_)) => { info!("Agent disconnected: {}", agent_id); break; } Ok(Message::Ping(data)) => { // Pong is handled automatically by axum let _ = data; } Ok(_) => {} Err(e) => { error!("WebSocket error from agent {}: {}", agent_id, e); break; } } } // Cleanup input_forward.abort(); sessions_cleanup.remove_session(session_id).await; info!("Session {} ended", session_id); } /// Handle a viewer WebSocket connection async fn handle_viewer_connection( socket: WebSocket, sessions: SessionManager, session_id_str: String, ) { // Parse session ID let session_id = match uuid::Uuid::parse_str(&session_id_str) { Ok(id) => id, Err(_) => { warn!("Invalid session ID: {}", session_id_str); return; } }; // Join the session let (mut frame_rx, input_tx) = match sessions.join_session(session_id).await { Some(channels) => channels, None => { warn!("Session not found: {}", session_id); return; } }; info!("Viewer joined session: {}", session_id); let (mut ws_sender, mut ws_receiver) = socket.split(); // Task to forward frames from agent to this viewer let frame_forward = tokio::spawn(async move { while let Ok(frame_data) = frame_rx.recv().await { if ws_sender.send(Message::Binary(frame_data.into())).await.is_err() { break; } } }); let sessions_cleanup = sessions.clone(); // Main loop: receive input from viewer and forward to agent while let Some(msg) = ws_receiver.next().await { match msg { Ok(Message::Binary(data)) => { // Try to decode as protobuf message match proto::Message::decode(data.as_ref()) { Ok(proto_msg) => { match &proto_msg.payload { Some(proto::message::Payload::MouseEvent(_)) | Some(proto::message::Payload::KeyEvent(_)) => { // Forward input to agent let _ = input_tx.send(data.to_vec()).await; } _ => {} } } Err(e) => { warn!("Failed to decode viewer message: {}", e); } } } Ok(Message::Close(_)) => { info!("Viewer disconnected from session: {}", session_id); break; } Ok(_) => {} Err(e) => { error!("WebSocket error from viewer: {}", e); break; } } } // Cleanup frame_forward.abort(); sessions_cleanup.leave_session(session_id).await; info!("Viewer left session: {}", session_id); }