diff --git a/agent/src/main.rs b/agent/src/main.rs index de3d9d1..3e2132b 100644 --- a/agent/src/main.rs +++ b/agent/src/main.rs @@ -23,6 +23,11 @@ use anyhow::Result; use tracing::{info, error, Level}; use tracing_subscriber::FmtSubscriber; +#[cfg(windows)] +use windows::Win32::UI::WindowsAndMessaging::{MessageBoxW, MB_OK, MB_ICONINFORMATION}; +#[cfg(windows)] +use windows::core::PCWSTR; + /// Extract a 6-digit support code from the executable's filename. /// Looks for patterns like "GuruConnect-123456.exe" or "123456.exe" fn extract_code_from_filename() -> Option { @@ -50,6 +55,37 @@ fn extract_code_from_filename() -> Option { None } +/// Show a message box to the user (Windows only) +#[cfg(windows)] +fn show_message_box(title: &str, message: &str) { + use std::ffi::OsStr; + use std::os::windows::ffi::OsStrExt; + + // Convert strings to wide strings for Windows API + let title_wide: Vec = OsStr::new(title) + .encode_wide() + .chain(std::iter::once(0)) + .collect(); + let message_wide: Vec = OsStr::new(message) + .encode_wide() + .chain(std::iter::once(0)) + .collect(); + + unsafe { + MessageBoxW( + None, + PCWSTR(message_wide.as_ptr()), + PCWSTR(title_wide.as_ptr()), + MB_OK | MB_ICONINFORMATION, + ); + } +} + +#[cfg(not(windows))] +fn show_message_box(_title: &str, _message: &str) { + // No-op on non-Windows platforms +} + #[tokio::main] async fn main() -> Result<()> { // Initialize logging @@ -85,6 +121,7 @@ async fn main() -> Result<()> { async fn run_agent(config: config::Config) -> Result<()> { // Create session manager let mut session = session::SessionManager::new(config.clone()); + let is_support_session = config.support_code.is_some(); // Connect to server and run main loop loop { @@ -96,15 +133,47 @@ async fn run_agent(config: config::Config) -> Result<()> { // Run session until disconnect if let Err(e) = session.run().await { + let error_msg = e.to_string(); + + // Check if this is a cancellation + if error_msg.contains("SESSION_CANCELLED") { + info!("Session was cancelled by technician"); + show_message_box( + "Support Session Ended", + "The support session was cancelled by the technician.\n\nThis window will close automatically.", + ); + // Exit cleanly without reconnecting + return Ok(()); + } + error!("Session error: {}", e); } } Err(e) => { + let error_msg = e.to_string(); + + // Check if connection was rejected due to cancelled code + if error_msg.contains("cancelled") { + info!("Support code was cancelled before connection"); + show_message_box( + "Support Session Cancelled", + "This support session has been cancelled.\n\nPlease contact your technician for a new support code.", + ); + // Exit cleanly without reconnecting + return Ok(()); + } + error!("Connection failed: {}", e); } } - // Wait before reconnecting + // For support sessions, don't reconnect if something goes wrong + if is_support_session { + info!("Support session ended, not reconnecting"); + return Ok(()); + } + + // Wait before reconnecting (only for persistent agent connections) info!("Reconnecting in 5 seconds..."); tokio::time::sleep(tokio::time::Duration::from_secs(5)).await; } diff --git a/agent/src/session/mod.rs b/agent/src/session/mod.rs index 1c3daa1..7837b5e 100644 --- a/agent/src/session/mod.rs +++ b/agent/src/session/mod.rs @@ -202,6 +202,10 @@ impl SessionManager { Some(message::Payload::Disconnect(disc)) => { tracing::info!("Disconnect requested: {}", disc.reason); + // Check if this is a cancellation + if disc.reason.contains("cancelled") { + return Err(anyhow::anyhow!("SESSION_CANCELLED: {}", disc.reason)); + } return Err(anyhow::anyhow!("Disconnect: {}", disc.reason)); } diff --git a/server/src/relay/mod.rs b/server/src/relay/mod.rs index da5e94f..8882fab 100644 --- a/server/src/relay/mod.rs +++ b/server/src/relay/mod.rs @@ -73,6 +73,28 @@ async fn handle_agent_connection( ) { info!("Agent connected: {} ({})", agent_name, agent_id); + let (mut ws_sender, mut ws_receiver) = socket.split(); + + // If a support code was provided, check if it's valid + if let Some(ref code) = support_code { + // Check if the code is cancelled or invalid + if support_codes.is_cancelled(code).await { + warn!("Agent tried to connect with cancelled code: {}", code); + // Send disconnect message to agent + let disconnect_msg = proto::Message { + payload: Some(proto::message::Payload::Disconnect(proto::Disconnect { + reason: "Support session was cancelled by technician".to_string(), + })), + }; + let mut buf = Vec::new(); + if prost::Message::encode(&disconnect_msg, &mut buf).is_ok() { + let _ = ws_sender.send(Message::Binary(buf.into())).await; + } + let _ = ws_sender.close().await; + return; + } + } + // Register the agent and get channels let (session_id, frame_tx, mut input_rx) = sessions.register_agent(agent_id.clone(), agent_name.clone()).await; @@ -85,12 +107,16 @@ async fn handle_agent_connection( support_codes.link_session(code, session_id).await; } - let (mut ws_sender, mut ws_receiver) = socket.split(); + // Use Arc for sender so we can use it from multiple places + let ws_sender = std::sync::Arc::new(tokio::sync::Mutex::new(ws_sender)); + let ws_sender_input = ws_sender.clone(); + let ws_sender_cancel = ws_sender.clone(); // Task to forward input events from viewers to agent let input_forward = tokio::spawn(async move { while let Some(input_data) = input_rx.recv().await { - if ws_sender.send(Message::Binary(input_data.into())).await.is_err() { + let mut sender = ws_sender_input.lock().await; + if sender.send(Message::Binary(input_data.into())).await.is_err() { break; } } @@ -99,6 +125,34 @@ async fn handle_agent_connection( let sessions_cleanup = sessions.clone(); let support_codes_cleanup = support_codes.clone(); let support_code_cleanup = support_code.clone(); + let support_code_check = support_code.clone(); + let support_codes_check = support_codes.clone(); + + // Task to check for cancellation every 2 seconds + let cancel_check = tokio::spawn(async move { + let mut interval = tokio::time::interval(std::time::Duration::from_secs(2)); + loop { + interval.tick().await; + if let Some(ref code) = support_code_check { + if support_codes_check.is_cancelled(code).await { + info!("Support code {} was cancelled, disconnecting agent", code); + // Send disconnect message + let disconnect_msg = proto::Message { + payload: Some(proto::message::Payload::Disconnect(proto::Disconnect { + reason: "Support session was cancelled by technician".to_string(), + })), + }; + let mut buf = Vec::new(); + if prost::Message::encode(&disconnect_msg, &mut buf).is_ok() { + let mut sender = ws_sender_cancel.lock().await; + let _ = sender.send(Message::Binary(buf.into())).await; + let _ = sender.close().await; + } + break; + } + } + } + }); // Main loop: receive frames from agent and broadcast to viewers while let Some(msg) = ws_receiver.next().await { @@ -135,12 +189,15 @@ async fn handle_agent_connection( // Cleanup input_forward.abort(); + cancel_check.abort(); sessions_cleanup.remove_session(session_id).await; - // Mark support code as completed if one was used + // Mark support code as completed if one was used (unless cancelled) if let Some(ref code) = support_code_cleanup { - support_codes_cleanup.mark_completed(code).await; - info!("Support code {} marked as completed", code); + if !support_codes_cleanup.is_cancelled(code).await { + support_codes_cleanup.mark_completed(code).await; + info!("Support code {} marked as completed", code); + } } info!("Session {} ended", session_id); diff --git a/server/src/support_codes.rs b/server/src/support_codes.rs index ae119ea..5207a5e 100644 --- a/server/src/support_codes.rs +++ b/server/src/support_codes.rs @@ -176,11 +176,11 @@ impl SupportCodeManager { } } - /// Cancel a code + /// Cancel a code (works for both pending and connected) pub async fn cancel_code(&self, code: &str) -> bool { let mut codes = self.codes.write().await; if let Some(support_code) = codes.get_mut(code) { - if support_code.status == CodeStatus::Pending { + if support_code.status == CodeStatus::Pending || support_code.status == CodeStatus::Connected { support_code.status = CodeStatus::Cancelled; return true; } @@ -188,6 +188,18 @@ impl SupportCodeManager { false } + /// Check if a code is cancelled + pub async fn is_cancelled(&self, code: &str) -> bool { + let codes = self.codes.read().await; + codes.get(code).map(|c| c.status == CodeStatus::Cancelled).unwrap_or(false) + } + + /// Check if a code is valid for connection (exists and is pending) + pub async fn is_valid_for_connection(&self, code: &str) -> bool { + let codes = self.codes.read().await; + codes.get(code).map(|c| c.status == CodeStatus::Pending).unwrap_or(false) + } + /// List all codes (for dashboard) pub async fn list_codes(&self) -> Vec { let codes = self.codes.read().await;