Implement idle/active mode for scalable agent connections

- Add StartStream/StopStream/AgentStatus messages to protobuf
- Agent now starts in idle mode (heartbeat only, no capture)
- Agent enters streaming mode when viewer connects (StartStream)
- Agent returns to idle when all viewers disconnect (StopStream)
- Server tracks viewer IDs and sends start/stop commands
- Heartbeat mechanism with 90 second timeout detection
- Session API now includes streaming status and agent info

This allows 2000+ agents to connect with minimal bandwidth.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2025-12-28 17:24:51 -07:00
parent 5bb5116b92
commit 4417fdfb6e
6 changed files with 455 additions and 176 deletions

View File

@@ -204,8 +204,9 @@ fn cleanup_on_exit() {
} }
async fn run_agent(config: config::Config) -> Result<()> { async fn run_agent(config: config::Config) -> Result<()> {
// Create session manager // Create session manager with elevation status
let mut session = session::SessionManager::new(config.clone()); let elevated = is_elevated();
let mut session = session::SessionManager::new(config.clone(), elevated);
let is_support_session = config.support_code.is_some(); let is_support_session = config.support_code.is_some();
let hostname = config.hostname(); let hostname = config.hostname();

View File

@@ -2,8 +2,8 @@
//! //!
//! Handles the lifecycle of a remote session including: //! Handles the lifecycle of a remote session including:
//! - Connection to server //! - Connection to server
//! - Authentication //! - Idle mode (heartbeat only, minimal resources)
//! - Frame capture and encoding loop //! - Active/streaming mode (capture and send frames)
//! - Input event handling //! - Input event handling
#[cfg(windows)] #[cfg(windows)]
@@ -36,36 +36,58 @@ fn show_debug_console() {
fn show_debug_console() { fn show_debug_console() {
// No-op on non-Windows platforms // No-op on non-Windows platforms
} }
use crate::proto::{Message, message, ChatMessage};
use crate::proto::{Message, message, ChatMessage, AgentStatus, Heartbeat, HeartbeatAck};
use crate::transport::WebSocketTransport; use crate::transport::WebSocketTransport;
use crate::tray::{TrayController, TrayAction}; use crate::tray::{TrayController, TrayAction};
use anyhow::Result; use anyhow::Result;
use std::sync::Arc;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use tokio::sync::mpsc;
// Heartbeat interval (30 seconds)
const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(30);
// Status report interval (60 seconds)
const STATUS_INTERVAL: Duration = Duration::from_secs(60);
/// Session manager handles the remote control session /// Session manager handles the remote control session
pub struct SessionManager { pub struct SessionManager {
config: Config, config: Config,
transport: Option<WebSocketTransport>, transport: Option<WebSocketTransport>,
state: SessionState, state: SessionState,
// Lazy-initialized streaming resources
capturer: Option<Box<dyn Capturer>>,
encoder: Option<Box<dyn Encoder>>,
input: Option<InputController>,
// Streaming state
current_viewer_id: Option<String>,
// System info for status reports
hostname: String,
is_elevated: bool,
start_time: Instant,
} }
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
enum SessionState { enum SessionState {
Disconnected, Disconnected,
Connecting, Connecting,
Connected, Idle, // Connected but not streaming - minimal resource usage
Active, Streaming, // Actively capturing and sending frames
} }
impl SessionManager { impl SessionManager {
/// Create a new session manager /// Create a new session manager
pub fn new(config: Config) -> Self { pub fn new(config: Config, is_elevated: bool) -> Self {
let hostname = config.hostname();
Self { Self {
config, config,
transport: None, transport: None,
state: SessionState::Disconnected, state: SessionState::Disconnected,
capturer: None,
encoder: None,
input: None,
current_viewer_id: None,
hostname,
is_elevated,
start_time: Instant::now(),
} }
} }
@@ -73,104 +95,110 @@ impl SessionManager {
pub async fn connect(&mut self) -> Result<()> { pub async fn connect(&mut self) -> Result<()> {
self.state = SessionState::Connecting; self.state = SessionState::Connecting;
let hostname = self.config.hostname();
let transport = WebSocketTransport::connect( let transport = WebSocketTransport::connect(
&self.config.server_url, &self.config.server_url,
&self.config.agent_id, &self.config.agent_id,
&self.config.api_key, &self.config.api_key,
Some(&hostname), Some(&self.hostname),
self.config.support_code.as_deref(), self.config.support_code.as_deref(),
).await?; ).await?;
self.transport = Some(transport); self.transport = Some(transport);
self.state = SessionState::Connected; self.state = SessionState::Idle; // Start in idle mode
tracing::info!("Connected to server, entering idle mode");
Ok(()) Ok(())
} }
/// Run the session main loop /// Initialize streaming resources (capturer, encoder, input)
pub async fn run(&mut self) -> Result<()> { fn init_streaming(&mut self) -> Result<()> {
if self.transport.is_none() { if self.capturer.is_some() {
anyhow::bail!("Not connected"); return Ok(()); // Already initialized
} }
self.state = SessionState::Active; tracing::info!("Initializing streaming resources...");
// Get primary display // Get primary display
let primary_display = capture::primary_display()?; let primary_display = capture::primary_display()?;
tracing::info!("Using display: {} ({}x{})", primary_display.name, primary_display.width, primary_display.height); tracing::info!("Using display: {} ({}x{})",
primary_display.name, primary_display.width, primary_display.height);
// Create capturer // Create capturer
let mut capturer = capture::create_capturer( let capturer = capture::create_capturer(
primary_display.clone(), primary_display.clone(),
self.config.capture.use_dxgi, self.config.capture.use_dxgi,
self.config.capture.gdi_fallback, self.config.capture.gdi_fallback,
)?; )?;
self.capturer = Some(capturer);
// Create encoder // Create encoder
let mut encoder = encoder::create_encoder( let encoder = encoder::create_encoder(
&self.config.encoding.codec, &self.config.encoding.codec,
self.config.encoding.quality, self.config.encoding.quality,
)?; )?;
self.encoder = Some(encoder);
// Create input controller // Create input controller
let mut input = InputController::new()?; let input = InputController::new()?;
self.input = Some(input);
// Calculate frame interval tracing::info!("Streaming resources initialized");
let frame_interval = Duration::from_millis(1000 / self.config.capture.fps as u64); Ok(())
let mut last_frame_time = Instant::now();
// Main loop
loop {
// Check for incoming messages (non-blocking)
// Collect messages first, then release borrow before handling
let messages: Vec<Message> = {
let transport = self.transport.as_mut().unwrap();
let mut msgs = Vec::new();
while let Some(msg) = transport.try_recv()? {
msgs.push(msg);
} }
msgs
/// Release streaming resources to save CPU/memory when idle
fn release_streaming(&mut self) {
if self.capturer.is_some() {
tracing::info!("Releasing streaming resources");
self.capturer = None;
self.encoder = None;
self.input = None;
self.current_viewer_id = None;
}
}
/// Get display count for status reports
fn get_display_count(&self) -> i32 {
capture::enumerate_displays().map(|d| d.len() as i32).unwrap_or(1)
}
/// Send agent status to server
async fn send_status(&mut self) -> Result<()> {
let status = AgentStatus {
hostname: self.hostname.clone(),
os_version: std::env::consts::OS.to_string(),
is_elevated: self.is_elevated,
uptime_secs: self.start_time.elapsed().as_secs() as i64,
display_count: self.get_display_count(),
is_streaming: self.state == SessionState::Streaming,
}; };
for msg in messages {
self.handle_message(&mut input, msg)?;
}
// Capture and send frame if interval elapsed
if last_frame_time.elapsed() >= frame_interval {
last_frame_time = Instant::now();
if let Some(frame) = capturer.capture()? {
let encoded = encoder.encode(&frame)?;
// Skip empty frames (no changes)
if encoded.size > 0 {
let msg = Message { let msg = Message {
payload: Some(message::Payload::VideoFrame(encoded.frame)), payload: Some(message::Payload::AgentStatus(status)),
}; };
let transport = self.transport.as_mut().unwrap();
if let Some(transport) = self.transport.as_mut() {
transport.send(msg).await?; transport.send(msg).await?;
} }
}
Ok(())
} }
// Small sleep to prevent busy loop /// Send heartbeat to server
tokio::time::sleep(Duration::from_millis(1)).await; async fn send_heartbeat(&mut self) -> Result<()> {
let heartbeat = Heartbeat {
timestamp: chrono::Utc::now().timestamp_millis(),
};
// Check if still connected let msg = Message {
if let Some(transport) = self.transport.as_ref() { payload: Some(message::Payload::Heartbeat(heartbeat)),
if !transport.is_connected() { };
tracing::warn!("Connection lost");
break; if let Some(transport) = self.transport.as_mut() {
} transport.send(msg).await?;
} else {
tracing::warn!("Transport is None");
break;
}
} }
self.state = SessionState::Disconnected;
Ok(()) Ok(())
} }
@@ -180,31 +208,14 @@ impl SessionManager {
anyhow::bail!("Not connected"); anyhow::bail!("Not connected");
} }
self.state = SessionState::Active; // Send initial status
self.send_status().await?;
// Get primary display // Timing for heartbeat and status
let primary_display = capture::primary_display()?; let mut last_heartbeat = Instant::now();
tracing::info!("Using display: {} ({}x{})", primary_display.name, primary_display.width, primary_display.height); let mut last_status = Instant::now();
// Create capturer
let mut capturer = capture::create_capturer(
primary_display.clone(),
self.config.capture.use_dxgi,
self.config.capture.gdi_fallback,
)?;
// Create encoder
let mut encoder = encoder::create_encoder(
&self.config.encoding.codec,
self.config.encoding.quality,
)?;
// Create input controller
let mut input = InputController::new()?;
// Calculate frame interval
let frame_interval = Duration::from_millis(1000 / self.config.capture.fps as u64);
let mut last_frame_time = Instant::now(); let mut last_frame_time = Instant::now();
let frame_interval = Duration::from_millis(1000 / self.config.capture.fps as u64);
// Main loop // Main loop
loop { loop {
@@ -217,7 +228,6 @@ impl SessionManager {
return Err(anyhow::anyhow!("USER_EXIT: Session ended by user")); return Err(anyhow::anyhow!("USER_EXIT: Session ended by user"));
} }
TrayAction::ShowDetails => { TrayAction::ShowDetails => {
// TODO: Show a details dialog
tracing::info!("User requested details (not yet implemented)"); tracing::info!("User requested details (not yet implemented)");
} }
TrayAction::ShowDebugWindow => { TrayAction::ShowDebugWindow => {
@@ -226,14 +236,13 @@ impl SessionManager {
} }
} }
// Check if exit was requested
if t.exit_requested() { if t.exit_requested() {
tracing::info!("Exit requested via tray"); tracing::info!("Exit requested via tray");
return Err(anyhow::anyhow!("USER_EXIT: Exit requested by user")); return Err(anyhow::anyhow!("USER_EXIT: Exit requested by user"));
} }
} }
// Check for incoming messages (non-blocking) // Process incoming messages
let messages: Vec<Message> = { let messages: Vec<Message> = {
let transport = self.transport.as_mut().unwrap(); let transport = self.transport.as_mut().unwrap();
let mut msgs = Vec::new(); let mut msgs = Vec::new();
@@ -254,12 +263,56 @@ impl SessionManager {
timestamp: chat_msg.timestamp, timestamp: chat_msg.timestamp,
}); });
} }
continue; // Don't pass to handle_message continue;
}
self.handle_message(&mut input, msg)?;
} }
// Check for outgoing chat messages from user // Handle control messages that affect state
if let Some(ref payload) = msg.payload {
match payload {
message::Payload::StartStream(start) => {
tracing::info!("StartStream received from viewer: {}", start.viewer_id);
if let Err(e) = self.init_streaming() {
tracing::error!("Failed to init streaming: {}", e);
} else {
self.state = SessionState::Streaming;
self.current_viewer_id = Some(start.viewer_id.clone());
tracing::info!("Now streaming to viewer {}", start.viewer_id);
}
continue;
}
message::Payload::StopStream(stop) => {
tracing::info!("StopStream received for viewer: {}", stop.viewer_id);
// Only stop if it matches current viewer
if self.current_viewer_id.as_ref() == Some(&stop.viewer_id) {
self.release_streaming();
self.state = SessionState::Idle;
tracing::info!("Stopped streaming, returning to idle mode");
}
continue;
}
message::Payload::Heartbeat(hb) => {
// Respond to server heartbeat with ack
let ack = HeartbeatAck {
client_timestamp: hb.timestamp,
server_timestamp: chrono::Utc::now().timestamp_millis(),
};
let ack_msg = Message {
payload: Some(message::Payload::HeartbeatAck(ack)),
};
if let Some(transport) = self.transport.as_mut() {
let _ = transport.send(ack_msg).await;
}
continue;
}
_ => {}
}
}
// Handle other messages (input events, disconnect, etc.)
self.handle_message(msg)?;
}
// Check for outgoing chat messages
if let Some(c) = chat { if let Some(c) = chat {
if let Some(outgoing) = c.poll_outgoing() { if let Some(outgoing) = c.poll_outgoing() {
let chat_proto = ChatMessage { let chat_proto = ChatMessage {
@@ -276,26 +329,59 @@ impl SessionManager {
} }
} }
// Capture and send frame if interval elapsed // State-specific behavior
match self.state {
SessionState::Idle => {
// In idle mode, just send heartbeats and status periodically
if last_heartbeat.elapsed() >= HEARTBEAT_INTERVAL {
last_heartbeat = Instant::now();
if let Err(e) = self.send_heartbeat().await {
tracing::warn!("Failed to send heartbeat: {}", e);
}
}
if last_status.elapsed() >= STATUS_INTERVAL {
last_status = Instant::now();
if let Err(e) = self.send_status().await {
tracing::warn!("Failed to send status: {}", e);
}
}
// Longer sleep in idle mode to reduce CPU usage
tokio::time::sleep(Duration::from_millis(100)).await;
}
SessionState::Streaming => {
// In streaming mode, capture and send frames
if last_frame_time.elapsed() >= frame_interval { if last_frame_time.elapsed() >= frame_interval {
last_frame_time = Instant::now(); last_frame_time = Instant::now();
if let Some(frame) = capturer.capture()? { if let (Some(capturer), Some(encoder)) =
let encoded = encoder.encode(&frame)?; (self.capturer.as_mut(), self.encoder.as_mut())
{
// Skip empty frames (no changes) if let Ok(Some(frame)) = capturer.capture() {
if let Ok(encoded) = encoder.encode(&frame) {
if encoded.size > 0 { if encoded.size > 0 {
let msg = Message { let msg = Message {
payload: Some(message::Payload::VideoFrame(encoded.frame)), payload: Some(message::Payload::VideoFrame(encoded.frame)),
}; };
let transport = self.transport.as_mut().unwrap(); let transport = self.transport.as_mut().unwrap();
transport.send(msg).await?; if let Err(e) = transport.send(msg).await {
tracing::warn!("Failed to send frame: {}", e);
}
}
}
} }
} }
} }
// Small sleep to prevent busy loop // Short sleep in streaming mode
tokio::time::sleep(Duration::from_millis(1)).await; tokio::time::sleep(Duration::from_millis(1)).await;
}
_ => {
// Disconnected or connecting - shouldn't be in main loop
tokio::time::sleep(Duration::from_millis(100)).await;
}
}
// Check if still connected // Check if still connected
if let Some(transport) = self.transport.as_ref() { if let Some(transport) = self.transport.as_ref() {
@@ -309,15 +395,16 @@ impl SessionManager {
} }
} }
self.release_streaming();
self.state = SessionState::Disconnected; self.state = SessionState::Disconnected;
Ok(()) Ok(())
} }
/// Handle incoming message from server /// Handle incoming message from server
fn handle_message(&mut self, input: &mut InputController, msg: Message) -> Result<()> { fn handle_message(&mut self, msg: Message) -> Result<()> {
match msg.payload { match msg.payload {
Some(message::Payload::MouseEvent(mouse)) => { Some(message::Payload::MouseEvent(mouse)) => {
// Handle mouse event if let Some(input) = self.input.as_mut() {
use crate::proto::MouseEventType; use crate::proto::MouseEventType;
use crate::input::MouseButton; use crate::input::MouseButton;
@@ -345,13 +432,16 @@ impl SessionManager {
} }
} }
} }
}
Some(message::Payload::KeyEvent(key)) => { Some(message::Payload::KeyEvent(key)) => {
// Handle keyboard event if let Some(input) = self.input.as_mut() {
input.key_event(key.vk_code as u16, key.down)?; input.key_event(key.vk_code as u16, key.down)?;
} }
}
Some(message::Payload::SpecialKey(special)) => { Some(message::Payload::SpecialKey(special)) => {
if let Some(input) = self.input.as_mut() {
use crate::proto::SpecialKey; use crate::proto::SpecialKey;
match SpecialKey::try_from(special.key).ok() { match SpecialKey::try_from(special.key).ok() {
Some(SpecialKey::CtrlAltDel) => { Some(SpecialKey::CtrlAltDel) => {
@@ -360,19 +450,13 @@ impl SessionManager {
_ => {} _ => {}
} }
} }
Some(message::Payload::Heartbeat(_)) => {
// Respond to heartbeat
// TODO: Send heartbeat ack
} }
Some(message::Payload::Disconnect(disc)) => { Some(message::Payload::Disconnect(disc)) => {
tracing::info!("Disconnect requested: {}", disc.reason); tracing::info!("Disconnect requested: {}", disc.reason);
// Check if this is a cancellation (support session)
if disc.reason.contains("cancelled") { if disc.reason.contains("cancelled") {
return Err(anyhow::anyhow!("SESSION_CANCELLED: {}", disc.reason)); return Err(anyhow::anyhow!("SESSION_CANCELLED: {}", disc.reason));
} }
// Check if this is an admin disconnect (persistent session)
if disc.reason.contains("administrator") || disc.reason.contains("Disconnected") { if disc.reason.contains("administrator") || disc.reason.contains("Disconnected") {
return Err(anyhow::anyhow!("ADMIN_DISCONNECT: {}", disc.reason)); return Err(anyhow::anyhow!("ADMIN_DISCONNECT: {}", disc.reason));
} }

View File

@@ -257,6 +257,27 @@ message Disconnect {
string reason = 1; string reason = 1;
} }
// Server commands agent to start streaming video
message StartStream {
string viewer_id = 1; // ID of viewer requesting stream
int32 display_id = 2; // Which display to stream (0 = primary)
}
// Server commands agent to stop streaming
message StopStream {
string viewer_id = 1; // Which viewer disconnected
}
// Agent reports its status periodically when idle
message AgentStatus {
string hostname = 1;
string os_version = 2;
bool is_elevated = 3;
int64 uptime_secs = 4;
int32 display_count = 5;
bool is_streaming = 6;
}
// ============================================================================ // ============================================================================
// Top-Level Message Wrapper // Top-Level Message Wrapper
// ============================================================================ // ============================================================================
@@ -293,6 +314,9 @@ message Message {
Heartbeat heartbeat = 50; Heartbeat heartbeat = 50;
HeartbeatAck heartbeat_ack = 51; HeartbeatAck heartbeat_ack = 51;
Disconnect disconnect = 52; Disconnect disconnect = 52;
StartStream start_stream = 53;
StopStream stop_stream = 54;
AgentStatus agent_status = 55;
// Chat // Chat
ChatMessage chat_message = 60; ChatMessage chat_message = 60;

View File

@@ -17,6 +17,12 @@ pub struct SessionInfo {
pub agent_name: String, pub agent_name: String,
pub started_at: String, pub started_at: String,
pub viewer_count: usize, pub viewer_count: usize,
pub is_streaming: bool,
pub last_heartbeat: String,
pub os_version: Option<String>,
pub is_elevated: bool,
pub uptime_secs: i64,
pub display_count: i32,
} }
impl From<crate::session::Session> for SessionInfo { impl From<crate::session::Session> for SessionInfo {
@@ -27,6 +33,12 @@ impl From<crate::session::Session> for SessionInfo {
agent_name: s.agent_name, agent_name: s.agent_name,
started_at: s.started_at.to_rfc3339(), started_at: s.started_at.to_rfc3339(),
viewer_count: s.viewer_count, viewer_count: s.viewer_count,
is_streaming: s.is_streaming,
last_heartbeat: s.last_heartbeat.to_rfc3339(),
os_version: s.os_version,
is_elevated: s.is_elevated,
uptime_secs: s.uptime_secs,
display_count: s.display_count,
} }
} }
} }

View File

@@ -14,6 +14,7 @@ use futures_util::{SinkExt, StreamExt};
use prost::Message as ProstMessage; use prost::Message as ProstMessage;
use serde::Deserialize; use serde::Deserialize;
use tracing::{error, info, warn}; use tracing::{error, info, warn};
use uuid::Uuid;
use crate::proto; use crate::proto;
use crate::session::SessionManager; use crate::session::SessionManager;
@@ -98,7 +99,7 @@ async fn handle_agent_connection(
// Register the agent and get channels // Register the agent and get channels
let (session_id, frame_tx, mut input_rx) = sessions.register_agent(agent_id.clone(), agent_name.clone()).await; let (session_id, frame_tx, mut input_rx) = sessions.register_agent(agent_id.clone(), agent_name.clone()).await;
info!("Session created: {}", session_id); info!("Session created: {} (agent in idle mode)", session_id);
// If a support code was provided, mark it as connected // If a support code was provided, mark it as connected
if let Some(ref code) = support_code { if let Some(ref code) = support_code {
@@ -123,6 +124,7 @@ async fn handle_agent_connection(
}); });
let sessions_cleanup = sessions.clone(); let sessions_cleanup = sessions.clone();
let sessions_status = sessions.clone();
let support_codes_cleanup = support_codes.clone(); let support_codes_cleanup = support_codes.clone();
let support_code_cleanup = support_code.clone(); let support_code_cleanup = support_code.clone();
let support_code_check = support_code.clone(); let support_code_check = support_code.clone();
@@ -154,7 +156,7 @@ async fn handle_agent_connection(
} }
}); });
// Main loop: receive frames from agent and broadcast to viewers // Main loop: receive messages from agent
while let Some(msg) = ws_receiver.next().await { while let Some(msg) = ws_receiver.next().await {
match msg { match msg {
Ok(Message::Binary(data)) => { Ok(Message::Binary(data)) => {
@@ -163,7 +165,7 @@ async fn handle_agent_connection(
Ok(proto_msg) => { Ok(proto_msg) => {
match &proto_msg.payload { match &proto_msg.payload {
Some(proto::message::Payload::VideoFrame(_)) => { Some(proto::message::Payload::VideoFrame(_)) => {
// Broadcast frame to all viewers // Broadcast frame to all viewers (only sent when streaming)
let _ = frame_tx.send(data.to_vec()); let _ = frame_tx.send(data.to_vec());
} }
Some(proto::message::Payload::ChatMessage(chat)) => { Some(proto::message::Payload::ChatMessage(chat)) => {
@@ -171,6 +173,27 @@ async fn handle_agent_connection(
info!("Chat from client: {}", chat.content); info!("Chat from client: {}", chat.content);
let _ = frame_tx.send(data.to_vec()); let _ = frame_tx.send(data.to_vec());
} }
Some(proto::message::Payload::AgentStatus(status)) => {
// Update session with agent status
sessions_status.update_agent_status(
session_id,
Some(status.os_version.clone()),
status.is_elevated,
status.uptime_secs,
status.display_count,
status.is_streaming,
).await;
info!("Agent status update: {} - streaming={}, uptime={}s",
status.hostname, status.is_streaming, status.uptime_secs);
}
Some(proto::message::Payload::Heartbeat(_)) => {
// Update heartbeat timestamp
sessions_status.update_heartbeat(session_id).await;
}
Some(proto::message::Payload::HeartbeatAck(_)) => {
// Agent acknowledged our heartbeat
sessions_status.update_heartbeat(session_id).await;
}
_ => {} _ => {}
} }
} }
@@ -226,8 +249,11 @@ async fn handle_viewer_connection(
} }
}; };
// Join the session // Generate unique viewer ID
let (mut frame_rx, input_tx) = match sessions.join_session(session_id).await { let viewer_id = Uuid::new_v4().to_string();
// Join the session (this sends StartStream to agent if first viewer)
let (mut frame_rx, input_tx) = match sessions.join_session(session_id, viewer_id.clone()).await {
Some(channels) => channels, Some(channels) => channels,
None => { None => {
warn!("Session not found: {}", session_id); warn!("Session not found: {}", session_id);
@@ -235,7 +261,7 @@ async fn handle_viewer_connection(
} }
}; };
info!("Viewer joined session: {}", session_id); info!("Viewer {} joined session: {}", viewer_id, session_id);
let (mut ws_sender, mut ws_receiver) = socket.split(); let (mut ws_sender, mut ws_receiver) = socket.split();
@@ -249,6 +275,7 @@ async fn handle_viewer_connection(
}); });
let sessions_cleanup = sessions.clone(); let sessions_cleanup = sessions.clone();
let viewer_id_cleanup = viewer_id.clone();
// Main loop: receive input from viewer and forward to agent // Main loop: receive input from viewer and forward to agent
while let Some(msg) = ws_receiver.next().await { while let Some(msg) = ws_receiver.next().await {
@@ -259,7 +286,8 @@ async fn handle_viewer_connection(
Ok(proto_msg) => { Ok(proto_msg) => {
match &proto_msg.payload { match &proto_msg.payload {
Some(proto::message::Payload::MouseEvent(_)) | Some(proto::message::Payload::MouseEvent(_)) |
Some(proto::message::Payload::KeyEvent(_)) => { Some(proto::message::Payload::KeyEvent(_)) |
Some(proto::message::Payload::SpecialKey(_)) => {
// Forward input to agent // Forward input to agent
let _ = input_tx.send(data.to_vec()).await; let _ = input_tx.send(data.to_vec()).await;
} }
@@ -277,19 +305,19 @@ async fn handle_viewer_connection(
} }
} }
Ok(Message::Close(_)) => { Ok(Message::Close(_)) => {
info!("Viewer disconnected from session: {}", session_id); info!("Viewer {} disconnected from session: {}", viewer_id, session_id);
break; break;
} }
Ok(_) => {} Ok(_) => {}
Err(e) => { Err(e) => {
error!("WebSocket error from viewer: {}", e); error!("WebSocket error from viewer {}: {}", viewer_id, e);
break; break;
} }
} }
} }
// Cleanup // Cleanup (this sends StopStream to agent if last viewer)
frame_forward.abort(); frame_forward.abort();
sessions_cleanup.leave_session(session_id).await; sessions_cleanup.leave_session(session_id, &viewer_id_cleanup).await;
info!("Viewer left session: {}", session_id); info!("Viewer {} left session: {}", viewer_id_cleanup, session_id);
} }

View File

@@ -3,8 +3,9 @@
//! Manages active remote desktop sessions, tracking which agents //! Manages active remote desktop sessions, tracking which agents
//! are connected and which viewers are watching them. //! are connected and which viewers are watching them.
use std::collections::HashMap; use std::collections::{HashMap, HashSet};
use std::sync::Arc; use std::sync::Arc;
use std::time::Instant;
use tokio::sync::{broadcast, RwLock}; use tokio::sync::{broadcast, RwLock};
use uuid::Uuid; use uuid::Uuid;
@@ -14,6 +15,12 @@ pub type SessionId = Uuid;
/// Unique identifier for an agent /// Unique identifier for an agent
pub type AgentId = String; pub type AgentId = String;
/// Unique identifier for a viewer
pub type ViewerId = String;
/// Heartbeat timeout (90 seconds - 3x the agent's 30 second interval)
const HEARTBEAT_TIMEOUT_SECS: u64 = 90;
/// Session state /// Session state
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Session { pub struct Session {
@@ -22,6 +29,13 @@ pub struct Session {
pub agent_name: String, pub agent_name: String,
pub started_at: chrono::DateTime<chrono::Utc>, pub started_at: chrono::DateTime<chrono::Utc>,
pub viewer_count: usize, pub viewer_count: usize,
pub is_streaming: bool,
pub last_heartbeat: chrono::DateTime<chrono::Utc>,
// Agent status info
pub os_version: Option<String>,
pub is_elevated: bool,
pub uptime_secs: i64,
pub display_count: i32,
} }
/// Channel for sending frames from agent to viewers /// Channel for sending frames from agent to viewers
@@ -40,6 +54,10 @@ struct SessionData {
/// Channel for input events (viewer -> agent) /// Channel for input events (viewer -> agent)
input_tx: InputSender, input_tx: InputSender,
input_rx: Option<InputReceiver>, input_rx: Option<InputReceiver>,
/// Set of connected viewer IDs
viewers: HashSet<ViewerId>,
/// Instant for heartbeat tracking
last_heartbeat_instant: Instant,
} }
/// Manages all active sessions /// Manages all active sessions
@@ -65,19 +83,28 @@ impl SessionManager {
let (frame_tx, _) = broadcast::channel(16); // Buffer 16 frames let (frame_tx, _) = broadcast::channel(16); // Buffer 16 frames
let (input_tx, input_rx) = tokio::sync::mpsc::channel(64); // Buffer 64 input events let (input_tx, input_rx) = tokio::sync::mpsc::channel(64); // Buffer 64 input events
let now = chrono::Utc::now();
let session = Session { let session = Session {
id: session_id, id: session_id,
agent_id: agent_id.clone(), agent_id: agent_id.clone(),
agent_name, agent_name,
started_at: chrono::Utc::now(), started_at: now,
viewer_count: 0, viewer_count: 0,
is_streaming: false,
last_heartbeat: now,
os_version: None,
is_elevated: false,
uptime_secs: 0,
display_count: 1,
}; };
let session_data = SessionData { let session_data = SessionData {
info: session, info: session,
frame_tx: frame_tx.clone(), frame_tx: frame_tx.clone(),
input_tx, input_tx,
input_rx: None, // Will be taken by the agent handler input_rx: None,
viewers: HashSet::new(),
last_heartbeat_instant: Instant::now(),
}; };
let mut sessions = self.sessions.write().await; let mut sessions = self.sessions.write().await;
@@ -89,6 +116,59 @@ impl SessionManager {
(session_id, frame_tx, input_rx) (session_id, frame_tx, input_rx)
} }
/// Update agent status from heartbeat or status message
pub async fn update_agent_status(
&self,
session_id: SessionId,
os_version: Option<String>,
is_elevated: bool,
uptime_secs: i64,
display_count: i32,
is_streaming: bool,
) {
let mut sessions = self.sessions.write().await;
if let Some(session_data) = sessions.get_mut(&session_id) {
session_data.info.last_heartbeat = chrono::Utc::now();
session_data.last_heartbeat_instant = Instant::now();
session_data.info.is_streaming = is_streaming;
if let Some(os) = os_version {
session_data.info.os_version = Some(os);
}
session_data.info.is_elevated = is_elevated;
session_data.info.uptime_secs = uptime_secs;
session_data.info.display_count = display_count;
}
}
/// Update heartbeat timestamp
pub async fn update_heartbeat(&self, session_id: SessionId) {
let mut sessions = self.sessions.write().await;
if let Some(session_data) = sessions.get_mut(&session_id) {
session_data.info.last_heartbeat = chrono::Utc::now();
session_data.last_heartbeat_instant = Instant::now();
}
}
/// Check if a session has timed out (no heartbeat for too long)
pub async fn is_session_timed_out(&self, session_id: SessionId) -> bool {
let sessions = self.sessions.read().await;
if let Some(session_data) = sessions.get(&session_id) {
session_data.last_heartbeat_instant.elapsed().as_secs() > HEARTBEAT_TIMEOUT_SECS
} else {
true // Non-existent sessions are considered timed out
}
}
/// Get sessions that have timed out
pub async fn get_timed_out_sessions(&self) -> Vec<SessionId> {
let sessions = self.sessions.read().await;
sessions
.iter()
.filter(|(_, data)| data.last_heartbeat_instant.elapsed().as_secs() > HEARTBEAT_TIMEOUT_SECS)
.map(|(id, _)| *id)
.collect()
}
/// Get a session by agent ID /// Get a session by agent ID
pub async fn get_session_by_agent(&self, agent_id: &str) -> Option<Session> { pub async fn get_session_by_agent(&self, agent_id: &str) -> Option<Session> {
let agents = self.agents.read().await; let agents = self.agents.read().await;
@@ -104,24 +184,74 @@ impl SessionManager {
sessions.get(&session_id).map(|s| s.info.clone()) sessions.get(&session_id).map(|s| s.info.clone())
} }
/// Join a session as a viewer /// Join a session as a viewer, returns channels and sends StartStream to agent
pub async fn join_session(&self, session_id: SessionId) -> Option<(FrameReceiver, InputSender)> { pub async fn join_session(&self, session_id: SessionId, viewer_id: ViewerId) -> Option<(FrameReceiver, InputSender)> {
let mut sessions = self.sessions.write().await; let mut sessions = self.sessions.write().await;
let session_data = sessions.get_mut(&session_id)?; let session_data = sessions.get_mut(&session_id)?;
session_data.info.viewer_count += 1; let was_empty = session_data.viewers.is_empty();
session_data.viewers.insert(viewer_id.clone());
session_data.info.viewer_count = session_data.viewers.len();
let frame_rx = session_data.frame_tx.subscribe(); let frame_rx = session_data.frame_tx.subscribe();
let input_tx = session_data.input_tx.clone(); let input_tx = session_data.input_tx.clone();
// If this is the first viewer, send StartStream to agent
if was_empty {
tracing::info!("First viewer {} joined session {}, sending StartStream", viewer_id, session_id);
self.send_start_stream_internal(session_data, &viewer_id).await;
}
Some((frame_rx, input_tx)) Some((frame_rx, input_tx))
} }
/// Leave a session as a viewer /// Internal helper to send StartStream message
pub async fn leave_session(&self, session_id: SessionId) { async fn send_start_stream_internal(session_data: &SessionData, viewer_id: &str) {
use crate::proto;
use prost::Message;
let start_stream = proto::Message {
payload: Some(proto::message::Payload::StartStream(proto::StartStream {
viewer_id: viewer_id.to_string(),
display_id: 0, // Primary display
})),
};
let mut buf = Vec::new();
if start_stream.encode(&mut buf).is_ok() {
let _ = session_data.input_tx.send(buf).await;
}
}
/// Leave a session as a viewer, sends StopStream if no viewers left
pub async fn leave_session(&self, session_id: SessionId, viewer_id: &ViewerId) {
let mut sessions = self.sessions.write().await; let mut sessions = self.sessions.write().await;
if let Some(session_data) = sessions.get_mut(&session_id) { if let Some(session_data) = sessions.get_mut(&session_id) {
session_data.info.viewer_count = session_data.info.viewer_count.saturating_sub(1); session_data.viewers.remove(viewer_id);
session_data.info.viewer_count = session_data.viewers.len();
// If no more viewers, send StopStream to agent
if session_data.viewers.is_empty() {
tracing::info!("Last viewer {} left session {}, sending StopStream", viewer_id, session_id);
self.send_stop_stream_internal(session_data, viewer_id).await;
}
}
}
/// Internal helper to send StopStream message
async fn send_stop_stream_internal(session_data: &SessionData, viewer_id: &str) {
use crate::proto;
use prost::Message;
let stop_stream = proto::Message {
payload: Some(proto::message::Payload::StopStream(proto::StopStream {
viewer_id: viewer_id.to_string(),
})),
};
let mut buf = Vec::new();
if stop_stream.encode(&mut buf).is_ok() {
let _ = session_data.input_tx.send(buf).await;
} }
} }