From 4417fdfb6e81075d8e4e6bc4bdcb6997e103d964 Mon Sep 17 00:00:00 2001 From: Mike Swanson Date: Sun, 28 Dec 2025 17:24:51 -0700 Subject: [PATCH] Implement idle/active mode for scalable agent connections MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add StartStream/StopStream/AgentStatus messages to protobuf - Agent now starts in idle mode (heartbeat only, no capture) - Agent enters streaming mode when viewer connects (StartStream) - Agent returns to idle when all viewers disconnect (StopStream) - Server tracks viewer IDs and sends start/stop commands - Heartbeat mechanism with 90 second timeout detection - Session API now includes streaming status and agent info This allows 2000+ agents to connect with minimal bandwidth. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- agent/src/main.rs | 5 +- agent/src/session/mod.rs | 390 +++++++++++++++++++++++--------------- proto/guruconnect.proto | 24 +++ server/src/api/mod.rs | 12 ++ server/src/relay/mod.rs | 52 +++-- server/src/session/mod.rs | 148 ++++++++++++++- 6 files changed, 455 insertions(+), 176 deletions(-) diff --git a/agent/src/main.rs b/agent/src/main.rs index 580dc77..3efaf3b 100644 --- a/agent/src/main.rs +++ b/agent/src/main.rs @@ -204,8 +204,9 @@ fn cleanup_on_exit() { } async fn run_agent(config: config::Config) -> Result<()> { - // Create session manager - let mut session = session::SessionManager::new(config.clone()); + // Create session manager with elevation status + let elevated = is_elevated(); + let mut session = session::SessionManager::new(config.clone(), elevated); let is_support_session = config.support_code.is_some(); let hostname = config.hostname(); diff --git a/agent/src/session/mod.rs b/agent/src/session/mod.rs index 2977fdf..6ffdf37 100644 --- a/agent/src/session/mod.rs +++ b/agent/src/session/mod.rs @@ -2,8 +2,8 @@ //! //! Handles the lifecycle of a remote session including: //! - Connection to server -//! - Authentication -//! - Frame capture and encoding loop +//! - Idle mode (heartbeat only, minimal resources) +//! - Active/streaming mode (capture and send frames) //! - Input event handling #[cfg(windows)] @@ -36,36 +36,58 @@ fn show_debug_console() { fn show_debug_console() { // No-op on non-Windows platforms } -use crate::proto::{Message, message, ChatMessage}; + +use crate::proto::{Message, message, ChatMessage, AgentStatus, Heartbeat, HeartbeatAck}; use crate::transport::WebSocketTransport; use crate::tray::{TrayController, TrayAction}; use anyhow::Result; -use std::sync::Arc; use std::time::{Duration, Instant}; -use tokio::sync::mpsc; + +// Heartbeat interval (30 seconds) +const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(30); +// Status report interval (60 seconds) +const STATUS_INTERVAL: Duration = Duration::from_secs(60); /// Session manager handles the remote control session pub struct SessionManager { config: Config, transport: Option, state: SessionState, + // Lazy-initialized streaming resources + capturer: Option>, + encoder: Option>, + input: Option, + // Streaming state + current_viewer_id: Option, + // System info for status reports + hostname: String, + is_elevated: bool, + start_time: Instant, } #[derive(Debug, Clone, PartialEq)] enum SessionState { Disconnected, Connecting, - Connected, - Active, + Idle, // Connected but not streaming - minimal resource usage + Streaming, // Actively capturing and sending frames } impl SessionManager { /// Create a new session manager - pub fn new(config: Config) -> Self { + pub fn new(config: Config, is_elevated: bool) -> Self { + let hostname = config.hostname(); Self { config, transport: None, state: SessionState::Disconnected, + capturer: None, + encoder: None, + input: None, + current_viewer_id: None, + hostname, + is_elevated, + start_time: Instant::now(), } } @@ -73,104 +95,110 @@ impl SessionManager { pub async fn connect(&mut self) -> Result<()> { self.state = SessionState::Connecting; - let hostname = self.config.hostname(); let transport = WebSocketTransport::connect( &self.config.server_url, &self.config.agent_id, &self.config.api_key, - Some(&hostname), + Some(&self.hostname), self.config.support_code.as_deref(), ).await?; self.transport = Some(transport); - self.state = SessionState::Connected; + self.state = SessionState::Idle; // Start in idle mode + + tracing::info!("Connected to server, entering idle mode"); Ok(()) } - /// Run the session main loop - pub async fn run(&mut self) -> Result<()> { - if self.transport.is_none() { - anyhow::bail!("Not connected"); + /// Initialize streaming resources (capturer, encoder, input) + fn init_streaming(&mut self) -> Result<()> { + if self.capturer.is_some() { + return Ok(()); // Already initialized } - self.state = SessionState::Active; + tracing::info!("Initializing streaming resources..."); // Get primary display let primary_display = capture::primary_display()?; - tracing::info!("Using display: {} ({}x{})", primary_display.name, primary_display.width, primary_display.height); + tracing::info!("Using display: {} ({}x{})", + primary_display.name, primary_display.width, primary_display.height); // Create capturer - let mut capturer = capture::create_capturer( + let capturer = capture::create_capturer( primary_display.clone(), self.config.capture.use_dxgi, self.config.capture.gdi_fallback, )?; + self.capturer = Some(capturer); // Create encoder - let mut encoder = encoder::create_encoder( + let encoder = encoder::create_encoder( &self.config.encoding.codec, self.config.encoding.quality, )?; + self.encoder = Some(encoder); // Create input controller - let mut input = InputController::new()?; + let input = InputController::new()?; + self.input = Some(input); - // Calculate frame interval - let frame_interval = Duration::from_millis(1000 / self.config.capture.fps as u64); - let mut last_frame_time = Instant::now(); + tracing::info!("Streaming resources initialized"); + Ok(()) + } - // Main loop - loop { - // Check for incoming messages (non-blocking) - // Collect messages first, then release borrow before handling - let messages: Vec = { - let transport = self.transport.as_mut().unwrap(); - let mut msgs = Vec::new(); - while let Some(msg) = transport.try_recv()? { - msgs.push(msg); - } - msgs - }; + /// Release streaming resources to save CPU/memory when idle + fn release_streaming(&mut self) { + if self.capturer.is_some() { + tracing::info!("Releasing streaming resources"); + self.capturer = None; + self.encoder = None; + self.input = None; + self.current_viewer_id = None; + } + } - for msg in messages { - self.handle_message(&mut input, msg)?; - } + /// Get display count for status reports + fn get_display_count(&self) -> i32 { + capture::enumerate_displays().map(|d| d.len() as i32).unwrap_or(1) + } - // Capture and send frame if interval elapsed - if last_frame_time.elapsed() >= frame_interval { - last_frame_time = Instant::now(); + /// Send agent status to server + async fn send_status(&mut self) -> Result<()> { + let status = AgentStatus { + hostname: self.hostname.clone(), + os_version: std::env::consts::OS.to_string(), + is_elevated: self.is_elevated, + uptime_secs: self.start_time.elapsed().as_secs() as i64, + display_count: self.get_display_count(), + is_streaming: self.state == SessionState::Streaming, + }; - if let Some(frame) = capturer.capture()? { - let encoded = encoder.encode(&frame)?; + let msg = Message { + payload: Some(message::Payload::AgentStatus(status)), + }; - // Skip empty frames (no changes) - if encoded.size > 0 { - let msg = Message { - payload: Some(message::Payload::VideoFrame(encoded.frame)), - }; - let transport = self.transport.as_mut().unwrap(); - transport.send(msg).await?; - } - } - } - - // Small sleep to prevent busy loop - tokio::time::sleep(Duration::from_millis(1)).await; - - // Check if still connected - if let Some(transport) = self.transport.as_ref() { - if !transport.is_connected() { - tracing::warn!("Connection lost"); - break; - } - } else { - tracing::warn!("Transport is None"); - break; - } + if let Some(transport) = self.transport.as_mut() { + transport.send(msg).await?; + } + + Ok(()) + } + + /// Send heartbeat to server + async fn send_heartbeat(&mut self) -> Result<()> { + let heartbeat = Heartbeat { + timestamp: chrono::Utc::now().timestamp_millis(), + }; + + let msg = Message { + payload: Some(message::Payload::Heartbeat(heartbeat)), + }; + + if let Some(transport) = self.transport.as_mut() { + transport.send(msg).await?; } - self.state = SessionState::Disconnected; Ok(()) } @@ -180,31 +208,14 @@ impl SessionManager { anyhow::bail!("Not connected"); } - self.state = SessionState::Active; + // Send initial status + self.send_status().await?; - // Get primary display - let primary_display = capture::primary_display()?; - tracing::info!("Using display: {} ({}x{})", primary_display.name, primary_display.width, primary_display.height); - - // Create capturer - let mut capturer = capture::create_capturer( - primary_display.clone(), - self.config.capture.use_dxgi, - self.config.capture.gdi_fallback, - )?; - - // Create encoder - let mut encoder = encoder::create_encoder( - &self.config.encoding.codec, - self.config.encoding.quality, - )?; - - // Create input controller - let mut input = InputController::new()?; - - // Calculate frame interval - let frame_interval = Duration::from_millis(1000 / self.config.capture.fps as u64); + // Timing for heartbeat and status + let mut last_heartbeat = Instant::now(); + let mut last_status = Instant::now(); let mut last_frame_time = Instant::now(); + let frame_interval = Duration::from_millis(1000 / self.config.capture.fps as u64); // Main loop loop { @@ -217,7 +228,6 @@ impl SessionManager { return Err(anyhow::anyhow!("USER_EXIT: Session ended by user")); } TrayAction::ShowDetails => { - // TODO: Show a details dialog tracing::info!("User requested details (not yet implemented)"); } TrayAction::ShowDebugWindow => { @@ -226,14 +236,13 @@ impl SessionManager { } } - // Check if exit was requested if t.exit_requested() { tracing::info!("Exit requested via tray"); return Err(anyhow::anyhow!("USER_EXIT: Exit requested by user")); } } - // Check for incoming messages (non-blocking) + // Process incoming messages let messages: Vec = { let transport = self.transport.as_mut().unwrap(); let mut msgs = Vec::new(); @@ -254,12 +263,56 @@ impl SessionManager { timestamp: chat_msg.timestamp, }); } - continue; // Don't pass to handle_message + continue; } - self.handle_message(&mut input, msg)?; + + // Handle control messages that affect state + if let Some(ref payload) = msg.payload { + match payload { + message::Payload::StartStream(start) => { + tracing::info!("StartStream received from viewer: {}", start.viewer_id); + if let Err(e) = self.init_streaming() { + tracing::error!("Failed to init streaming: {}", e); + } else { + self.state = SessionState::Streaming; + self.current_viewer_id = Some(start.viewer_id.clone()); + tracing::info!("Now streaming to viewer {}", start.viewer_id); + } + continue; + } + message::Payload::StopStream(stop) => { + tracing::info!("StopStream received for viewer: {}", stop.viewer_id); + // Only stop if it matches current viewer + if self.current_viewer_id.as_ref() == Some(&stop.viewer_id) { + self.release_streaming(); + self.state = SessionState::Idle; + tracing::info!("Stopped streaming, returning to idle mode"); + } + continue; + } + message::Payload::Heartbeat(hb) => { + // Respond to server heartbeat with ack + let ack = HeartbeatAck { + client_timestamp: hb.timestamp, + server_timestamp: chrono::Utc::now().timestamp_millis(), + }; + let ack_msg = Message { + payload: Some(message::Payload::HeartbeatAck(ack)), + }; + if let Some(transport) = self.transport.as_mut() { + let _ = transport.send(ack_msg).await; + } + continue; + } + _ => {} + } + } + + // Handle other messages (input events, disconnect, etc.) + self.handle_message(msg)?; } - // Check for outgoing chat messages from user + // Check for outgoing chat messages if let Some(c) = chat { if let Some(outgoing) = c.poll_outgoing() { let chat_proto = ChatMessage { @@ -276,27 +329,60 @@ impl SessionManager { } } - // Capture and send frame if interval elapsed - if last_frame_time.elapsed() >= frame_interval { - last_frame_time = Instant::now(); - - if let Some(frame) = capturer.capture()? { - let encoded = encoder.encode(&frame)?; - - // Skip empty frames (no changes) - if encoded.size > 0 { - let msg = Message { - payload: Some(message::Payload::VideoFrame(encoded.frame)), - }; - let transport = self.transport.as_mut().unwrap(); - transport.send(msg).await?; + // State-specific behavior + match self.state { + SessionState::Idle => { + // In idle mode, just send heartbeats and status periodically + if last_heartbeat.elapsed() >= HEARTBEAT_INTERVAL { + last_heartbeat = Instant::now(); + if let Err(e) = self.send_heartbeat().await { + tracing::warn!("Failed to send heartbeat: {}", e); + } } + + if last_status.elapsed() >= STATUS_INTERVAL { + last_status = Instant::now(); + if let Err(e) = self.send_status().await { + tracing::warn!("Failed to send status: {}", e); + } + } + + // Longer sleep in idle mode to reduce CPU usage + tokio::time::sleep(Duration::from_millis(100)).await; + } + SessionState::Streaming => { + // In streaming mode, capture and send frames + if last_frame_time.elapsed() >= frame_interval { + last_frame_time = Instant::now(); + + if let (Some(capturer), Some(encoder)) = + (self.capturer.as_mut(), self.encoder.as_mut()) + { + if let Ok(Some(frame)) = capturer.capture() { + if let Ok(encoded) = encoder.encode(&frame) { + if encoded.size > 0 { + let msg = Message { + payload: Some(message::Payload::VideoFrame(encoded.frame)), + }; + let transport = self.transport.as_mut().unwrap(); + if let Err(e) = transport.send(msg).await { + tracing::warn!("Failed to send frame: {}", e); + } + } + } + } + } + } + + // Short sleep in streaming mode + tokio::time::sleep(Duration::from_millis(1)).await; + } + _ => { + // Disconnected or connecting - shouldn't be in main loop + tokio::time::sleep(Duration::from_millis(100)).await; } } - // Small sleep to prevent busy loop - tokio::time::sleep(Duration::from_millis(1)).await; - // Check if still connected if let Some(transport) = self.transport.as_ref() { if !transport.is_connected() { @@ -309,70 +395,68 @@ impl SessionManager { } } + self.release_streaming(); self.state = SessionState::Disconnected; Ok(()) } /// Handle incoming message from server - fn handle_message(&mut self, input: &mut InputController, msg: Message) -> Result<()> { + fn handle_message(&mut self, msg: Message) -> Result<()> { match msg.payload { Some(message::Payload::MouseEvent(mouse)) => { - // Handle mouse event - use crate::proto::MouseEventType; - use crate::input::MouseButton; + if let Some(input) = self.input.as_mut() { + use crate::proto::MouseEventType; + use crate::input::MouseButton; - match MouseEventType::try_from(mouse.event_type).unwrap_or(MouseEventType::MouseMove) { - MouseEventType::MouseMove => { - input.mouse_move(mouse.x, mouse.y)?; - } - MouseEventType::MouseDown => { - input.mouse_move(mouse.x, mouse.y)?; - if let Some(ref buttons) = mouse.buttons { - if buttons.left { input.mouse_click(MouseButton::Left, true)?; } - if buttons.right { input.mouse_click(MouseButton::Right, true)?; } - if buttons.middle { input.mouse_click(MouseButton::Middle, true)?; } + match MouseEventType::try_from(mouse.event_type).unwrap_or(MouseEventType::MouseMove) { + MouseEventType::MouseMove => { + input.mouse_move(mouse.x, mouse.y)?; } - } - MouseEventType::MouseUp => { - if let Some(ref buttons) = mouse.buttons { - if buttons.left { input.mouse_click(MouseButton::Left, false)?; } - if buttons.right { input.mouse_click(MouseButton::Right, false)?; } - if buttons.middle { input.mouse_click(MouseButton::Middle, false)?; } + MouseEventType::MouseDown => { + input.mouse_move(mouse.x, mouse.y)?; + if let Some(ref buttons) = mouse.buttons { + if buttons.left { input.mouse_click(MouseButton::Left, true)?; } + if buttons.right { input.mouse_click(MouseButton::Right, true)?; } + if buttons.middle { input.mouse_click(MouseButton::Middle, true)?; } + } + } + MouseEventType::MouseUp => { + if let Some(ref buttons) = mouse.buttons { + if buttons.left { input.mouse_click(MouseButton::Left, false)?; } + if buttons.right { input.mouse_click(MouseButton::Right, false)?; } + if buttons.middle { input.mouse_click(MouseButton::Middle, false)?; } + } + } + MouseEventType::MouseWheel => { + input.mouse_scroll(mouse.wheel_delta_x, mouse.wheel_delta_y)?; } - } - MouseEventType::MouseWheel => { - input.mouse_scroll(mouse.wheel_delta_x, mouse.wheel_delta_y)?; } } } Some(message::Payload::KeyEvent(key)) => { - // Handle keyboard event - input.key_event(key.vk_code as u16, key.down)?; - } - - Some(message::Payload::SpecialKey(special)) => { - use crate::proto::SpecialKey; - match SpecialKey::try_from(special.key).ok() { - Some(SpecialKey::CtrlAltDel) => { - input.send_ctrl_alt_del()?; - } - _ => {} + if let Some(input) = self.input.as_mut() { + input.key_event(key.vk_code as u16, key.down)?; } } - Some(message::Payload::Heartbeat(_)) => { - // Respond to heartbeat - // TODO: Send heartbeat ack + Some(message::Payload::SpecialKey(special)) => { + if let Some(input) = self.input.as_mut() { + use crate::proto::SpecialKey; + match SpecialKey::try_from(special.key).ok() { + Some(SpecialKey::CtrlAltDel) => { + input.send_ctrl_alt_del()?; + } + _ => {} + } + } } Some(message::Payload::Disconnect(disc)) => { tracing::info!("Disconnect requested: {}", disc.reason); - // Check if this is a cancellation (support session) if disc.reason.contains("cancelled") { return Err(anyhow::anyhow!("SESSION_CANCELLED: {}", disc.reason)); } - // Check if this is an admin disconnect (persistent session) if disc.reason.contains("administrator") || disc.reason.contains("Disconnected") { return Err(anyhow::anyhow!("ADMIN_DISCONNECT: {}", disc.reason)); } diff --git a/proto/guruconnect.proto b/proto/guruconnect.proto index cc8c5a5..c399e0d 100644 --- a/proto/guruconnect.proto +++ b/proto/guruconnect.proto @@ -257,6 +257,27 @@ message Disconnect { string reason = 1; } +// Server commands agent to start streaming video +message StartStream { + string viewer_id = 1; // ID of viewer requesting stream + int32 display_id = 2; // Which display to stream (0 = primary) +} + +// Server commands agent to stop streaming +message StopStream { + string viewer_id = 1; // Which viewer disconnected +} + +// Agent reports its status periodically when idle +message AgentStatus { + string hostname = 1; + string os_version = 2; + bool is_elevated = 3; + int64 uptime_secs = 4; + int32 display_count = 5; + bool is_streaming = 6; +} + // ============================================================================ // Top-Level Message Wrapper // ============================================================================ @@ -293,6 +314,9 @@ message Message { Heartbeat heartbeat = 50; HeartbeatAck heartbeat_ack = 51; Disconnect disconnect = 52; + StartStream start_stream = 53; + StopStream stop_stream = 54; + AgentStatus agent_status = 55; // Chat ChatMessage chat_message = 60; diff --git a/server/src/api/mod.rs b/server/src/api/mod.rs index 684ed77..a48e5be 100644 --- a/server/src/api/mod.rs +++ b/server/src/api/mod.rs @@ -17,6 +17,12 @@ pub struct SessionInfo { pub agent_name: String, pub started_at: String, pub viewer_count: usize, + pub is_streaming: bool, + pub last_heartbeat: String, + pub os_version: Option, + pub is_elevated: bool, + pub uptime_secs: i64, + pub display_count: i32, } impl From for SessionInfo { @@ -27,6 +33,12 @@ impl From for SessionInfo { agent_name: s.agent_name, started_at: s.started_at.to_rfc3339(), viewer_count: s.viewer_count, + is_streaming: s.is_streaming, + last_heartbeat: s.last_heartbeat.to_rfc3339(), + os_version: s.os_version, + is_elevated: s.is_elevated, + uptime_secs: s.uptime_secs, + display_count: s.display_count, } } } diff --git a/server/src/relay/mod.rs b/server/src/relay/mod.rs index 251ed77..b947220 100644 --- a/server/src/relay/mod.rs +++ b/server/src/relay/mod.rs @@ -14,6 +14,7 @@ use futures_util::{SinkExt, StreamExt}; use prost::Message as ProstMessage; use serde::Deserialize; use tracing::{error, info, warn}; +use uuid::Uuid; use crate::proto; use crate::session::SessionManager; @@ -98,7 +99,7 @@ async fn handle_agent_connection( // Register the agent and get channels let (session_id, frame_tx, mut input_rx) = sessions.register_agent(agent_id.clone(), agent_name.clone()).await; - info!("Session created: {}", session_id); + info!("Session created: {} (agent in idle mode)", session_id); // If a support code was provided, mark it as connected if let Some(ref code) = support_code { @@ -123,6 +124,7 @@ async fn handle_agent_connection( }); let sessions_cleanup = sessions.clone(); + let sessions_status = sessions.clone(); let support_codes_cleanup = support_codes.clone(); let support_code_cleanup = support_code.clone(); let support_code_check = support_code.clone(); @@ -154,7 +156,7 @@ async fn handle_agent_connection( } }); - // Main loop: receive frames from agent and broadcast to viewers + // Main loop: receive messages from agent while let Some(msg) = ws_receiver.next().await { match msg { Ok(Message::Binary(data)) => { @@ -163,7 +165,7 @@ async fn handle_agent_connection( Ok(proto_msg) => { match &proto_msg.payload { Some(proto::message::Payload::VideoFrame(_)) => { - // Broadcast frame to all viewers + // Broadcast frame to all viewers (only sent when streaming) let _ = frame_tx.send(data.to_vec()); } Some(proto::message::Payload::ChatMessage(chat)) => { @@ -171,6 +173,27 @@ async fn handle_agent_connection( info!("Chat from client: {}", chat.content); let _ = frame_tx.send(data.to_vec()); } + Some(proto::message::Payload::AgentStatus(status)) => { + // Update session with agent status + sessions_status.update_agent_status( + session_id, + Some(status.os_version.clone()), + status.is_elevated, + status.uptime_secs, + status.display_count, + status.is_streaming, + ).await; + info!("Agent status update: {} - streaming={}, uptime={}s", + status.hostname, status.is_streaming, status.uptime_secs); + } + Some(proto::message::Payload::Heartbeat(_)) => { + // Update heartbeat timestamp + sessions_status.update_heartbeat(session_id).await; + } + Some(proto::message::Payload::HeartbeatAck(_)) => { + // Agent acknowledged our heartbeat + sessions_status.update_heartbeat(session_id).await; + } _ => {} } } @@ -226,8 +249,11 @@ async fn handle_viewer_connection( } }; - // Join the session - let (mut frame_rx, input_tx) = match sessions.join_session(session_id).await { + // Generate unique viewer ID + let viewer_id = Uuid::new_v4().to_string(); + + // Join the session (this sends StartStream to agent if first viewer) + let (mut frame_rx, input_tx) = match sessions.join_session(session_id, viewer_id.clone()).await { Some(channels) => channels, None => { warn!("Session not found: {}", session_id); @@ -235,7 +261,7 @@ async fn handle_viewer_connection( } }; - info!("Viewer joined session: {}", session_id); + info!("Viewer {} joined session: {}", viewer_id, session_id); let (mut ws_sender, mut ws_receiver) = socket.split(); @@ -249,6 +275,7 @@ async fn handle_viewer_connection( }); let sessions_cleanup = sessions.clone(); + let viewer_id_cleanup = viewer_id.clone(); // Main loop: receive input from viewer and forward to agent while let Some(msg) = ws_receiver.next().await { @@ -259,7 +286,8 @@ async fn handle_viewer_connection( Ok(proto_msg) => { match &proto_msg.payload { Some(proto::message::Payload::MouseEvent(_)) | - Some(proto::message::Payload::KeyEvent(_)) => { + Some(proto::message::Payload::KeyEvent(_)) | + Some(proto::message::Payload::SpecialKey(_)) => { // Forward input to agent let _ = input_tx.send(data.to_vec()).await; } @@ -277,19 +305,19 @@ async fn handle_viewer_connection( } } Ok(Message::Close(_)) => { - info!("Viewer disconnected from session: {}", session_id); + info!("Viewer {} disconnected from session: {}", viewer_id, session_id); break; } Ok(_) => {} Err(e) => { - error!("WebSocket error from viewer: {}", e); + error!("WebSocket error from viewer {}: {}", viewer_id, e); break; } } } - // Cleanup + // Cleanup (this sends StopStream to agent if last viewer) frame_forward.abort(); - sessions_cleanup.leave_session(session_id).await; - info!("Viewer left session: {}", session_id); + sessions_cleanup.leave_session(session_id, &viewer_id_cleanup).await; + info!("Viewer {} left session: {}", viewer_id_cleanup, session_id); } diff --git a/server/src/session/mod.rs b/server/src/session/mod.rs index 7fe0609..1c27353 100644 --- a/server/src/session/mod.rs +++ b/server/src/session/mod.rs @@ -3,8 +3,9 @@ //! Manages active remote desktop sessions, tracking which agents //! are connected and which viewers are watching them. -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::sync::Arc; +use std::time::Instant; use tokio::sync::{broadcast, RwLock}; use uuid::Uuid; @@ -14,6 +15,12 @@ pub type SessionId = Uuid; /// Unique identifier for an agent pub type AgentId = String; +/// Unique identifier for a viewer +pub type ViewerId = String; + +/// Heartbeat timeout (90 seconds - 3x the agent's 30 second interval) +const HEARTBEAT_TIMEOUT_SECS: u64 = 90; + /// Session state #[derive(Debug, Clone)] pub struct Session { @@ -22,6 +29,13 @@ pub struct Session { pub agent_name: String, pub started_at: chrono::DateTime, pub viewer_count: usize, + pub is_streaming: bool, + pub last_heartbeat: chrono::DateTime, + // Agent status info + pub os_version: Option, + pub is_elevated: bool, + pub uptime_secs: i64, + pub display_count: i32, } /// Channel for sending frames from agent to viewers @@ -40,6 +54,10 @@ struct SessionData { /// Channel for input events (viewer -> agent) input_tx: InputSender, input_rx: Option, + /// Set of connected viewer IDs + viewers: HashSet, + /// Instant for heartbeat tracking + last_heartbeat_instant: Instant, } /// Manages all active sessions @@ -65,19 +83,28 @@ impl SessionManager { let (frame_tx, _) = broadcast::channel(16); // Buffer 16 frames let (input_tx, input_rx) = tokio::sync::mpsc::channel(64); // Buffer 64 input events + let now = chrono::Utc::now(); let session = Session { id: session_id, agent_id: agent_id.clone(), agent_name, - started_at: chrono::Utc::now(), + started_at: now, viewer_count: 0, + is_streaming: false, + last_heartbeat: now, + os_version: None, + is_elevated: false, + uptime_secs: 0, + display_count: 1, }; let session_data = SessionData { info: session, frame_tx: frame_tx.clone(), input_tx, - input_rx: None, // Will be taken by the agent handler + input_rx: None, + viewers: HashSet::new(), + last_heartbeat_instant: Instant::now(), }; let mut sessions = self.sessions.write().await; @@ -89,6 +116,59 @@ impl SessionManager { (session_id, frame_tx, input_rx) } + /// Update agent status from heartbeat or status message + pub async fn update_agent_status( + &self, + session_id: SessionId, + os_version: Option, + is_elevated: bool, + uptime_secs: i64, + display_count: i32, + is_streaming: bool, + ) { + let mut sessions = self.sessions.write().await; + if let Some(session_data) = sessions.get_mut(&session_id) { + session_data.info.last_heartbeat = chrono::Utc::now(); + session_data.last_heartbeat_instant = Instant::now(); + session_data.info.is_streaming = is_streaming; + if let Some(os) = os_version { + session_data.info.os_version = Some(os); + } + session_data.info.is_elevated = is_elevated; + session_data.info.uptime_secs = uptime_secs; + session_data.info.display_count = display_count; + } + } + + /// Update heartbeat timestamp + pub async fn update_heartbeat(&self, session_id: SessionId) { + let mut sessions = self.sessions.write().await; + if let Some(session_data) = sessions.get_mut(&session_id) { + session_data.info.last_heartbeat = chrono::Utc::now(); + session_data.last_heartbeat_instant = Instant::now(); + } + } + + /// Check if a session has timed out (no heartbeat for too long) + pub async fn is_session_timed_out(&self, session_id: SessionId) -> bool { + let sessions = self.sessions.read().await; + if let Some(session_data) = sessions.get(&session_id) { + session_data.last_heartbeat_instant.elapsed().as_secs() > HEARTBEAT_TIMEOUT_SECS + } else { + true // Non-existent sessions are considered timed out + } + } + + /// Get sessions that have timed out + pub async fn get_timed_out_sessions(&self) -> Vec { + let sessions = self.sessions.read().await; + sessions + .iter() + .filter(|(_, data)| data.last_heartbeat_instant.elapsed().as_secs() > HEARTBEAT_TIMEOUT_SECS) + .map(|(id, _)| *id) + .collect() + } + /// Get a session by agent ID pub async fn get_session_by_agent(&self, agent_id: &str) -> Option { let agents = self.agents.read().await; @@ -104,24 +184,74 @@ impl SessionManager { sessions.get(&session_id).map(|s| s.info.clone()) } - /// Join a session as a viewer - pub async fn join_session(&self, session_id: SessionId) -> Option<(FrameReceiver, InputSender)> { + /// Join a session as a viewer, returns channels and sends StartStream to agent + pub async fn join_session(&self, session_id: SessionId, viewer_id: ViewerId) -> Option<(FrameReceiver, InputSender)> { let mut sessions = self.sessions.write().await; let session_data = sessions.get_mut(&session_id)?; - session_data.info.viewer_count += 1; + let was_empty = session_data.viewers.is_empty(); + session_data.viewers.insert(viewer_id.clone()); + session_data.info.viewer_count = session_data.viewers.len(); let frame_rx = session_data.frame_tx.subscribe(); let input_tx = session_data.input_tx.clone(); + // If this is the first viewer, send StartStream to agent + if was_empty { + tracing::info!("First viewer {} joined session {}, sending StartStream", viewer_id, session_id); + self.send_start_stream_internal(session_data, &viewer_id).await; + } + Some((frame_rx, input_tx)) } - /// Leave a session as a viewer - pub async fn leave_session(&self, session_id: SessionId) { + /// Internal helper to send StartStream message + async fn send_start_stream_internal(session_data: &SessionData, viewer_id: &str) { + use crate::proto; + use prost::Message; + + let start_stream = proto::Message { + payload: Some(proto::message::Payload::StartStream(proto::StartStream { + viewer_id: viewer_id.to_string(), + display_id: 0, // Primary display + })), + }; + + let mut buf = Vec::new(); + if start_stream.encode(&mut buf).is_ok() { + let _ = session_data.input_tx.send(buf).await; + } + } + + /// Leave a session as a viewer, sends StopStream if no viewers left + pub async fn leave_session(&self, session_id: SessionId, viewer_id: &ViewerId) { let mut sessions = self.sessions.write().await; if let Some(session_data) = sessions.get_mut(&session_id) { - session_data.info.viewer_count = session_data.info.viewer_count.saturating_sub(1); + session_data.viewers.remove(viewer_id); + session_data.info.viewer_count = session_data.viewers.len(); + + // If no more viewers, send StopStream to agent + if session_data.viewers.is_empty() { + tracing::info!("Last viewer {} left session {}, sending StopStream", viewer_id, session_id); + self.send_stop_stream_internal(session_data, viewer_id).await; + } + } + } + + /// Internal helper to send StopStream message + async fn send_stop_stream_internal(session_data: &SessionData, viewer_id: &str) { + use crate::proto; + use prost::Message; + + let stop_stream = proto::Message { + payload: Some(proto::message::Payload::StopStream(proto::StopStream { + viewer_id: viewer_id.to_string(), + })), + }; + + let mut buf = Vec::new(); + if stop_stream.encode(&mut buf).is_ok() { + let _ = session_data.input_tx.send(buf).await; } }