diff --git a/agent/build.rs b/agent/build.rs index bc1f387..7656145 100644 --- a/agent/build.rs +++ b/agent/build.rs @@ -13,7 +13,9 @@ fn main() -> Result<()> { println!("cargo:rerun-if-changed=../.git/index"); // Build timestamp (UTC) - let build_timestamp = chrono::Utc::now().format("%Y-%m-%d %H:%M:%S UTC").to_string(); + let build_timestamp = chrono::Utc::now() + .format("%Y-%m-%d %H:%M:%S UTC") + .to_string(); println!("cargo:rustc-env=BUILD_TIMESTAMP={}", build_timestamp); // Git commit hash (short) @@ -53,7 +55,10 @@ fn main() -> Result<()> { .ok() .map(|o| !o.stdout.is_empty()) .unwrap_or(false); - println!("cargo:rustc-env=GIT_DIRTY={}", if git_dirty { "dirty" } else { "clean" }); + println!( + "cargo:rustc-env=GIT_DIRTY={}", + if git_dirty { "dirty" } else { "clean" } + ); // Git commit date let git_commit_date = Command::new("git") diff --git a/agent/src/bin/sas_service.rs b/agent/src/bin/sas_service.rs index 23e2c8f..7e0df96 100644 --- a/agent/src/bin/sas_service.rs +++ b/agent/src/bin/sas_service.rs @@ -25,7 +25,8 @@ use windows_service::{ // Service configuration const SERVICE_NAME: &str = "GuruConnectSAS"; const SERVICE_DISPLAY_NAME: &str = "GuruConnect SAS Service"; -const SERVICE_DESCRIPTION: &str = "Handles Secure Attention Sequence (Ctrl+Alt+Del) for GuruConnect remote sessions"; +const SERVICE_DESCRIPTION: &str = + "Handles Secure Attention Sequence (Ctrl+Alt+Del) for GuruConnect remote sessions"; const PIPE_NAME: &str = r"\\.\pipe\guruconnect-sas"; const INSTALL_DIR: &str = r"C:\Program Files\GuruConnect"; @@ -360,18 +361,16 @@ fn run_pipe_server() -> Result<()> { tracing::info!("Received command: {}", command); let response = match command { - "sas" => { - match send_sas() { - Ok(()) => { - tracing::info!("SendSAS executed successfully"); - "ok\n" - } - Err(e) => { - tracing::error!("SendSAS failed: {}", e); - "error\n" - } + "sas" => match send_sas() { + Ok(()) => { + tracing::info!("SendSAS executed successfully"); + "ok\n" } - } + Err(e) => { + tracing::error!("SendSAS failed: {}", e); + "error\n" + } + }, "ping" => { tracing::info!("Ping received"); "pong\n" @@ -432,7 +431,8 @@ fn install_service() -> Result<()> { // Get current executable path let current_exe = std::env::current_exe().context("Failed to get current executable")?; - let binary_dest = std::path::PathBuf::from(format!(r"{}\\guruconnect-sas-service.exe", INSTALL_DIR)); + let binary_dest = + std::path::PathBuf::from(format!(r"{}\\guruconnect-sas-service.exe", INSTALL_DIR)); // Create install directory std::fs::create_dir_all(INSTALL_DIR).context("Failed to create install directory")?; @@ -462,7 +462,9 @@ fn install_service() -> Result<()> { } } - service.delete().context("Failed to delete existing service")?; + service + .delete() + .context("Failed to delete existing service")?; drop(service); std::thread::sleep(Duration::from_secs(2)); } @@ -482,7 +484,10 @@ fn install_service() -> Result<()> { }; let service = manager - .create_service(&service_info, ServiceAccess::CHANGE_CONFIG | ServiceAccess::START) + .create_service( + &service_info, + ServiceAccess::CHANGE_CONFIG | ServiceAccess::START, + ) .context("Failed to create service")?; // Set description @@ -514,13 +519,11 @@ fn install_service() -> Result<()> { fn uninstall_service() -> Result<()> { println!("Uninstalling GuruConnect SAS Service..."); - let binary_path = std::path::PathBuf::from(format!(r"{}\\guruconnect-sas-service.exe", INSTALL_DIR)); + let binary_path = + std::path::PathBuf::from(format!(r"{}\\guruconnect-sas-service.exe", INSTALL_DIR)); - let manager = ServiceManager::local_computer( - None::<&str>, - ServiceManagerAccess::CONNECT, - ) - .context("Failed to connect to Service Control Manager. Run as Administrator.")?; + let manager = ServiceManager::local_computer(None::<&str>, ServiceManagerAccess::CONNECT) + .context("Failed to connect to Service Control Manager. Run as Administrator.")?; match manager.open_service( SERVICE_NAME, @@ -558,17 +561,19 @@ fn uninstall_service() -> Result<()> { /// Start the service fn start_service() -> Result<()> { - let manager = ServiceManager::local_computer( - None::<&str>, - ServiceManagerAccess::CONNECT, - ) - .context("Failed to connect to Service Control Manager")?; + let manager = ServiceManager::local_computer(None::<&str>, ServiceManagerAccess::CONNECT) + .context("Failed to connect to Service Control Manager")?; let service = manager - .open_service(SERVICE_NAME, ServiceAccess::START | ServiceAccess::QUERY_STATUS) + .open_service( + SERVICE_NAME, + ServiceAccess::START | ServiceAccess::QUERY_STATUS, + ) .context("Failed to open service. Is it installed?")?; - service.start::(&[]).context("Failed to start service")?; + service + .start::(&[]) + .context("Failed to start service")?; std::thread::sleep(Duration::from_secs(1)); @@ -584,14 +589,14 @@ fn start_service() -> Result<()> { /// Stop the service fn stop_service() -> Result<()> { - let manager = ServiceManager::local_computer( - None::<&str>, - ServiceManagerAccess::CONNECT, - ) - .context("Failed to connect to Service Control Manager")?; + let manager = ServiceManager::local_computer(None::<&str>, ServiceManagerAccess::CONNECT) + .context("Failed to connect to Service Control Manager")?; let service = manager - .open_service(SERVICE_NAME, ServiceAccess::STOP | ServiceAccess::QUERY_STATUS) + .open_service( + SERVICE_NAME, + ServiceAccess::STOP | ServiceAccess::QUERY_STATUS, + ) .context("Failed to open service")?; service.stop().context("Failed to stop service")?; @@ -610,11 +615,8 @@ fn stop_service() -> Result<()> { /// Query service status fn query_status() -> Result<()> { - let manager = ServiceManager::local_computer( - None::<&str>, - ServiceManagerAccess::CONNECT, - ) - .context("Failed to connect to Service Control Manager")?; + let manager = ServiceManager::local_computer(None::<&str>, ServiceManagerAccess::CONNECT) + .context("Failed to connect to Service Control Manager")?; match manager.open_service(SERVICE_NAME, ServiceAccess::QUERY_STATUS) { Ok(service) => { diff --git a/agent/src/capture/display.rs b/agent/src/capture/display.rs index 52c4ae2..097f81f 100644 --- a/agent/src/capture/display.rs +++ b/agent/src/capture/display.rs @@ -53,11 +53,11 @@ impl Display { /// Enumerate all connected displays #[cfg(windows)] pub fn enumerate_displays() -> Result> { + use std::mem; + use windows::Win32::Foundation::{BOOL, LPARAM, RECT}; use windows::Win32::Graphics::Gdi::{ EnumDisplayMonitors, GetMonitorInfoW, HMONITOR, MONITORINFOEXW, }; - use windows::Win32::Foundation::{BOOL, LPARAM, RECT}; - use std::mem; let mut displays = Vec::new(); let mut display_id = 0u32; @@ -98,7 +98,11 @@ pub fn enumerate_displays() -> Result> { if GetMonitorInfoW(hmonitor, &mut info.monitorInfo as *mut _ as *mut _).as_bool() { let rect = info.monitorInfo.rcMonitor; let name = String::from_utf16_lossy( - &info.szDevice[..info.szDevice.iter().position(|&c| c == 0).unwrap_or(info.szDevice.len())] + &info.szDevice[..info + .szDevice + .iter() + .position(|&c| c == 0) + .unwrap_or(info.szDevice.len())], ); let is_primary = (info.monitorInfo.dwFlags & 1) != 0; // MONITORINFOF_PRIMARY diff --git a/agent/src/capture/dxgi.rs b/agent/src/capture/dxgi.rs index c455a78..b0ebb4c 100644 --- a/agent/src/capture/dxgi.rs +++ b/agent/src/capture/dxgi.rs @@ -10,19 +10,18 @@ use anyhow::{Context, Result}; use std::ptr; use std::time::Instant; +use windows::core::Interface; use windows::Win32::Graphics::Direct3D::D3D_DRIVER_TYPE_UNKNOWN; use windows::Win32::Graphics::Direct3D11::{ D3D11CreateDevice, ID3D11Device, ID3D11DeviceContext, ID3D11Texture2D, - D3D11_SDK_VERSION, D3D11_TEXTURE2D_DESC, - D3D11_USAGE_STAGING, D3D11_MAPPED_SUBRESOURCE, D3D11_MAP_READ, + D3D11_MAPPED_SUBRESOURCE, D3D11_MAP_READ, D3D11_SDK_VERSION, D3D11_TEXTURE2D_DESC, + D3D11_USAGE_STAGING, }; use windows::Win32::Graphics::Dxgi::{ CreateDXGIFactory1, IDXGIAdapter1, IDXGIFactory1, IDXGIOutput, IDXGIOutput1, - IDXGIOutputDuplication, IDXGIResource, DXGI_ERROR_ACCESS_LOST, - DXGI_ERROR_WAIT_TIMEOUT, DXGI_OUTDUPL_DESC, DXGI_OUTDUPL_FRAME_INFO, - DXGI_RESOURCE_PRIORITY_MAXIMUM, + IDXGIOutputDuplication, IDXGIResource, DXGI_ERROR_ACCESS_LOST, DXGI_ERROR_WAIT_TIMEOUT, + DXGI_OUTDUPL_DESC, DXGI_OUTDUPL_FRAME_INFO, DXGI_RESOURCE_PRIORITY_MAXIMUM, }; -use windows::core::Interface; /// DXGI Desktop Duplication capturer pub struct DxgiCapturer { @@ -56,11 +55,16 @@ impl DxgiCapturer { /// Create D3D device and output duplication fn create_duplication( target_display: &Display, - ) -> Result<(ID3D11Device, ID3D11DeviceContext, IDXGIOutputDuplication, DXGI_OUTDUPL_DESC)> { + ) -> Result<( + ID3D11Device, + ID3D11DeviceContext, + IDXGIOutputDuplication, + DXGI_OUTDUPL_DESC, + )> { unsafe { // Create DXGI factory - let factory: IDXGIFactory1 = CreateDXGIFactory1() - .context("Failed to create DXGI factory")?; + let factory: IDXGIFactory1 = + CreateDXGIFactory1().context("Failed to create DXGI factory")?; // Find the adapter and output for this display let (adapter, output) = Self::find_adapter_output(&factory, target_display)?; @@ -86,11 +90,13 @@ impl DxgiCapturer { let context = context.context("D3D11 context is None")?; // Get IDXGIOutput1 interface - let output1: IDXGIOutput1 = output.cast() + let output1: IDXGIOutput1 = output + .cast() .context("Failed to get IDXGIOutput1 interface")?; // Create output duplication - let duplication = output1.DuplicateOutput(&device) + let duplication = output1 + .DuplicateOutput(&device) .context("Failed to create output duplication")?; // Get duplication description @@ -135,7 +141,11 @@ impl DxgiCapturer { let desc = output.GetDesc()?; let name = String::from_utf16_lossy( - &desc.DeviceName[..desc.DeviceName.iter().position(|&c| c == 0).unwrap_or(desc.DeviceName.len())] + &desc.DeviceName[..desc + .DeviceName + .iter() + .position(|&c| c == 0) + .unwrap_or(desc.DeviceName.len())], ); if name == display.name || desc.Monitor.0 as isize == display.handle { @@ -149,10 +159,8 @@ impl DxgiCapturer { } // If we didn't find the specific display, use the first one - let adapter: IDXGIAdapter1 = factory.EnumAdapters1(0) - .context("No adapters found")?; - let output: IDXGIOutput = adapter.EnumOutputs(0) - .context("No outputs found")?; + let adapter: IDXGIAdapter1 = factory.EnumAdapters1(0).context("No adapters found")?; + let output: IDXGIOutput = adapter.EnumOutputs(0).context("No outputs found")?; Ok((adapter, output)) } @@ -171,7 +179,8 @@ impl DxgiCapturer { desc.MiscFlags = Default::default(); let mut staging: Option = None; - self.device.CreateTexture2D(&desc, None, Some(&mut staging)) + self.device + .CreateTexture2D(&desc, None, Some(&mut staging)) .context("Failed to create staging texture")?; let staging = staging.context("Staging texture is None")?; @@ -188,7 +197,10 @@ impl DxgiCapturer { } /// Acquire the next frame from the desktop - fn acquire_frame(&mut self, timeout_ms: u32) -> Result> { + fn acquire_frame( + &mut self, + timeout_ms: u32, + ) -> Result> { unsafe { let mut frame_info = DXGI_OUTDUPL_FRAME_INFO::default(); let mut desktop_resource: Option = None; @@ -209,7 +221,8 @@ impl DxgiCapturer { return Ok(None); } - let texture: ID3D11Texture2D = resource.cast() + let texture: ID3D11Texture2D = resource + .cast() .context("Failed to cast to ID3D11Texture2D")?; Ok(Some((texture, frame_info))) @@ -223,9 +236,7 @@ impl DxgiCapturer { tracing::warn!("Desktop duplication access lost, will need to recreate"); Err(anyhow::anyhow!("Access lost")) } - Err(e) => { - Err(e).context("Failed to acquire frame") - } + Err(e) => Err(e).context("Failed to acquire frame"), } } } diff --git a/agent/src/capture/gdi.rs b/agent/src/capture/gdi.rs index 6fc67d0..10f43ed 100644 --- a/agent/src/capture/gdi.rs +++ b/agent/src/capture/gdi.rs @@ -7,12 +7,11 @@ use super::{CapturedFrame, Capturer, Display}; use anyhow::Result; use std::time::Instant; -use windows::Win32::Graphics::Gdi::{ - BitBlt, CreateCompatibleBitmap, CreateCompatibleDC, DeleteDC, DeleteObject, - GetDIBits, SelectObject, BITMAPINFO, BITMAPINFOHEADER, BI_RGB, DIB_RGB_COLORS, - SRCCOPY, GetDC, ReleaseDC, -}; use windows::Win32::Foundation::HWND; +use windows::Win32::Graphics::Gdi::{ + BitBlt, CreateCompatibleBitmap, CreateCompatibleDC, DeleteDC, DeleteObject, GetDC, GetDIBits, + ReleaseDC, SelectObject, BITMAPINFO, BITMAPINFOHEADER, BI_RGB, DIB_RGB_COLORS, SRCCOPY, +}; /// GDI-based screen capturer pub struct GdiCapturer { diff --git a/agent/src/capture/mod.rs b/agent/src/capture/mod.rs index 407bc19..f6b9126 100644 --- a/agent/src/capture/mod.rs +++ b/agent/src/capture/mod.rs @@ -3,11 +3,11 @@ //! Provides DXGI Desktop Duplication for high-performance screen capture on Windows 8+, //! with GDI fallback for legacy systems or edge cases. +mod display; #[cfg(windows)] mod dxgi; #[cfg(windows)] mod gdi; -mod display; pub use display::{Display, DisplayInfo}; @@ -61,7 +61,11 @@ pub trait Capturer: Send { /// Create a capturer for the specified display #[cfg(windows)] -pub fn create_capturer(display: Display, use_dxgi: bool, gdi_fallback: bool) -> Result> { +pub fn create_capturer( + display: Display, + use_dxgi: bool, + gdi_fallback: bool, +) -> Result> { if use_dxgi { match dxgi::DxgiCapturer::new(display.clone()) { Ok(capturer) => { @@ -83,7 +87,11 @@ pub fn create_capturer(display: Display, use_dxgi: bool, gdi_fallback: bool) -> } #[cfg(not(windows))] -pub fn create_capturer(_display: Display, _use_dxgi: bool, _gdi_fallback: bool) -> Result> { +pub fn create_capturer( + _display: Display, + _use_dxgi: bool, + _gdi_fallback: bool, +) -> Result> { anyhow::bail!("Screen capture only supported on Windows") } diff --git a/agent/src/chat/mod.rs b/agent/src/chat/mod.rs index 57fd9f7..910ed28 100644 --- a/agent/src/chat/mod.rs +++ b/agent/src/chat/mod.rs @@ -6,10 +6,10 @@ use std::sync::mpsc::{self, Receiver, Sender}; use std::sync::{Arc, Mutex}; use std::thread; -use tracing::{info, warn, error}; +use tracing::{error, info, warn}; #[cfg(windows)] -use windows::Win32::UI::WindowsAndMessaging::*; +use windows::core::PCWSTR; #[cfg(windows)] use windows::Win32::Foundation::*; #[cfg(windows)] @@ -17,7 +17,7 @@ use windows::Win32::Graphics::Gdi::*; #[cfg(windows)] use windows::Win32::System::LibraryLoader::GetModuleHandleW; #[cfg(windows)] -use windows::core::PCWSTR; +use windows::Win32::UI::WindowsAndMessaging::*; /// A chat message #[derive(Debug, Clone)] diff --git a/agent/src/config.rs b/agent/src/config.rs index a2a9872..f007457 100644 --- a/agent/src/config.rs +++ b/agent/src/config.rs @@ -221,11 +221,10 @@ impl Config { /// Read embedded configuration from the executable pub fn read_embedded_config() -> Result { - let exe_path = std::env::current_exe() - .context("Failed to get current executable path")?; + let exe_path = std::env::current_exe().context("Failed to get current executable path")?; - let mut file = std::fs::File::open(&exe_path) - .context("Failed to open executable for reading")?; + let mut file = + std::fs::File::open(&exe_path).context("Failed to open executable for reading")?; let file_size = file.metadata()?.len(); if file_size < (MAGIC_MARKER.len() + 4) as u64 { @@ -245,7 +244,8 @@ impl Config { file.read_exact(&mut buffer)?; // Find magic marker - let marker_pos = buffer.windows(MAGIC_MARKER.len()) + let marker_pos = buffer + .windows(MAGIC_MARKER.len()) .rposition(|window| window == MAGIC_MARKER) .ok_or_else(|| anyhow!("Magic marker not found"))?; @@ -269,11 +269,13 @@ impl Config { } let config_bytes = &buffer[config_start..config_start + config_length]; - let config: EmbeddedConfig = serde_json::from_slice(config_bytes) - .context("Failed to parse embedded config JSON")?; + let config: EmbeddedConfig = + serde_json::from_slice(config_bytes).context("Failed to parse embedded config JSON")?; - info!("Loaded embedded config: server={}, company={:?}", - config.server_url, config.company); + info!( + "Loaded embedded config: server={}, company={:?}", + config.server_url, config.company + ); Ok(config) } @@ -338,8 +340,8 @@ impl Config { let contents = std::fs::read_to_string(&config_path) .with_context(|| format!("Failed to read config from {:?}", config_path))?; - let mut config: Config = toml::from_str(&contents) - .with_context(|| "Failed to parse config file")?; + let mut config: Config = + toml::from_str(&contents).with_context(|| "Failed to parse config file")?; // Ensure agent_id is set and saved if config.agent_id.is_empty() { @@ -357,11 +359,11 @@ impl Config { let server_url = std::env::var("GURUCONNECT_SERVER_URL") .unwrap_or_else(|_| "wss://connect.azcomputerguru.com/ws/agent".to_string()); - let api_key = std::env::var("GURUCONNECT_API_KEY") - .unwrap_or_else(|_| "dev-key".to_string()); + let api_key = + std::env::var("GURUCONNECT_API_KEY").unwrap_or_else(|_| "dev-key".to_string()); - let agent_id = std::env::var("GURUCONNECT_AGENT_ID") - .unwrap_or_else(|_| generate_agent_id()); + let agent_id = + std::env::var("GURUCONNECT_AGENT_ID").unwrap_or_else(|_| generate_agent_id()); let config = Config { server_url, @@ -409,13 +411,11 @@ impl Config { /// Get the hostname to use pub fn hostname(&self) -> String { - self.hostname_override - .clone() - .unwrap_or_else(|| { - hostname::get() - .map(|h| h.to_string_lossy().to_string()) - .unwrap_or_else(|_| "unknown".to_string()) - }) + self.hostname_override.clone().unwrap_or_else(|| { + hostname::get() + .map(|h| h.to_string_lossy().to_string()) + .unwrap_or_else(|_| "unknown".to_string()) + }) } /// Save current configuration to file diff --git a/agent/src/encoder/mod.rs b/agent/src/encoder/mod.rs index 74a174c..c9e5be7 100644 --- a/agent/src/encoder/mod.rs +++ b/agent/src/encoder/mod.rs @@ -10,7 +10,7 @@ mod raw; pub use raw::RawEncoder; use crate::capture::CapturedFrame; -use crate::proto::{VideoFrame, RawFrame, DirtyRect as ProtoDirtyRect}; +use crate::proto::{DirtyRect as ProtoDirtyRect, RawFrame, VideoFrame}; use anyhow::Result; /// Encoded frame ready for transmission diff --git a/agent/src/encoder/raw.rs b/agent/src/encoder/raw.rs index 3282438..4b63661 100644 --- a/agent/src/encoder/raw.rs +++ b/agent/src/encoder/raw.rs @@ -122,12 +122,7 @@ impl RawEncoder { } /// Extract pixels for dirty rectangles only - fn extract_dirty_pixels( - &self, - data: &[u8], - width: u32, - dirty_rects: &[DirtyRect], - ) -> Vec { + fn extract_dirty_pixels(&self, data: &[u8], width: u32, dirty_rects: &[DirtyRect]) -> Vec { let stride = (width * 4) as usize; let mut pixels = Vec::new(); @@ -174,7 +169,8 @@ impl Encoder for RawEncoder { if dirty_rects.len() > 50 { (frame.data.clone(), Vec::new(), true) } else { - let dirty_pixels = self.extract_dirty_pixels(&frame.data, frame.width, &dirty_rects); + let dirty_pixels = + self.extract_dirty_pixels(&frame.data, frame.width, &dirty_rects); (dirty_pixels, dirty_rects, false) } } else { diff --git a/agent/src/input/keyboard.rs b/agent/src/input/keyboard.rs index f99a90e..3f54fd5 100644 --- a/agent/src/input/keyboard.rs +++ b/agent/src/input/keyboard.rs @@ -4,9 +4,9 @@ use anyhow::Result; #[cfg(windows)] use windows::Win32::UI::Input::KeyboardAndMouse::{ - SendInput, INPUT, INPUT_0, INPUT_KEYBOARD, KEYBD_EVENT_FLAGS, KEYEVENTF_EXTENDEDKEY, - KEYEVENTF_KEYUP, KEYEVENTF_SCANCODE, KEYEVENTF_UNICODE, KEYBDINPUT, - MapVirtualKeyW, MAPVK_VK_TO_VSC_EX, + MapVirtualKeyW, SendInput, INPUT, INPUT_0, INPUT_KEYBOARD, KEYBDINPUT, KEYBD_EVENT_FLAGS, + KEYEVENTF_EXTENDEDKEY, KEYEVENTF_KEYUP, KEYEVENTF_SCANCODE, KEYEVENTF_UNICODE, + MAPVK_VK_TO_VSC_EX, }; /// Keyboard input controller @@ -144,8 +144,8 @@ impl KeyboardController { tracing::info!("SAS service not available, trying direct sas.dll..."); // Tier 2: Try using the sas.dll directly (requires SYSTEM privileges) - use windows::Win32::System::LibraryLoader::{GetProcAddress, LoadLibraryW}; use windows::core::PCWSTR; + use windows::Win32::System::LibraryLoader::{GetProcAddress, LoadLibraryW}; unsafe { let dll_name: Vec = "sas.dll\0".encode_utf16().collect(); @@ -195,7 +195,7 @@ impl KeyboardController { 0x5D | // Applications key 0x6F | // Numpad Divide 0x90 | // Num Lock - 0x91 // Scroll Lock + 0x91 // Scroll Lock ) } @@ -205,11 +205,7 @@ impl KeyboardController { let sent = unsafe { SendInput(inputs, std::mem::size_of::() as i32) }; if sent as usize != inputs.len() { - anyhow::bail!( - "SendInput failed: sent {} of {} inputs", - sent, - inputs.len() - ); + anyhow::bail!("SendInput failed: sent {} of {} inputs", sent, inputs.len()); } Ok(()) @@ -250,7 +246,7 @@ pub mod vk { pub const ESCAPE: u16 = 0x1B; pub const SPACE: u16 = 0x20; pub const PRIOR: u16 = 0x21; // Page Up - pub const NEXT: u16 = 0x22; // Page Down + pub const NEXT: u16 = 0x22; // Page Down pub const END: u16 = 0x23; pub const HOME: u16 = 0x24; pub const LEFT: u16 = 0x25; diff --git a/agent/src/input/mod.rs b/agent/src/input/mod.rs index d6ac3b5..989cb85 100644 --- a/agent/src/input/mod.rs +++ b/agent/src/input/mod.rs @@ -2,11 +2,11 @@ //! //! Handles mouse and keyboard input simulation using Windows SendInput API. -mod mouse; mod keyboard; +mod mouse; -pub use mouse::MouseController; pub use keyboard::KeyboardController; +pub use mouse::MouseController; use anyhow::Result; diff --git a/agent/src/input/mouse.rs b/agent/src/input/mouse.rs index 29c3945..04d2965 100644 --- a/agent/src/input/mouse.rs +++ b/agent/src/input/mouse.rs @@ -19,8 +19,7 @@ const XBUTTON2: u32 = 0x0002; #[cfg(windows)] use windows::Win32::UI::WindowsAndMessaging::{ - GetSystemMetrics, SM_CXVIRTUALSCREEN, SM_CYVIRTUALSCREEN, SM_XVIRTUALSCREEN, - SM_YVIRTUALSCREEN, + GetSystemMetrics, SM_CXVIRTUALSCREEN, SM_CYVIRTUALSCREEN, SM_XVIRTUALSCREEN, SM_YVIRTUALSCREEN, }; /// Mouse input controller @@ -190,9 +189,7 @@ impl MouseController { /// Send input events #[cfg(windows)] fn send_input(&self, inputs: &[INPUT]) -> Result<()> { - let sent = unsafe { - SendInput(inputs, std::mem::size_of::() as i32) - }; + let sent = unsafe { SendInput(inputs, std::mem::size_of::() as i32) }; if sent as usize != inputs.len() { anyhow::bail!("SendInput failed: sent {} of {} inputs", sent, inputs.len()); diff --git a/agent/src/install.rs b/agent/src/install.rs index 2dcfca9..4d81c41 100644 --- a/agent/src/install.rs +++ b/agent/src/install.rs @@ -6,18 +6,18 @@ //! - UAC elevation with graceful fallback use anyhow::{anyhow, Result}; -use tracing::{info, warn, error}; +use tracing::{error, info, warn}; #[cfg(windows)] use windows::{ core::PCWSTR, Win32::Foundation::HANDLE, Win32::Security::{GetTokenInformation, TokenElevation, TOKEN_ELEVATION, TOKEN_QUERY}, - Win32::System::Threading::{GetCurrentProcess, OpenProcessToken}, Win32::System::Registry::{ - RegCreateKeyExW, RegSetValueExW, RegCloseKey, HKEY, HKEY_CLASSES_ROOT, - HKEY_CURRENT_USER, KEY_WRITE, REG_SZ, REG_OPTION_NON_VOLATILE, + RegCloseKey, RegCreateKeyExW, RegSetValueExW, HKEY, HKEY_CLASSES_ROOT, HKEY_CURRENT_USER, + KEY_WRITE, REG_OPTION_NON_VOLATILE, REG_SZ, }, + Win32::System::Threading::{GetCurrentProcess, OpenProcessToken}, Win32::UI::Shell::ShellExecuteW, Win32::UI::WindowsAndMessaging::SW_SHOWNORMAL, }; @@ -67,11 +67,10 @@ pub fn get_install_path(elevated: bool) -> std::path::PathBuf { if elevated { std::path::PathBuf::from(SYSTEM_INSTALL_PATH) } else { - let local_app_data = std::env::var("LOCALAPPDATA") - .unwrap_or_else(|_| { - let home = std::env::var("USERPROFILE").unwrap_or_else(|_| ".".to_string()); - format!(r"{}\AppData\Local", home) - }); + let local_app_data = std::env::var("LOCALAPPDATA").unwrap_or_else(|_| { + let home = std::env::var("USERPROFILE").unwrap_or_else(|_| ".".to_string()); + format!(r"{}\AppData\Local", home) + }); std::path::PathBuf::from(local_app_data).join(USER_INSTALL_PATH) } } @@ -305,7 +304,7 @@ pub fn install(force_user_install: bool) -> Result<()> { #[cfg(windows)] pub fn is_protocol_handler_registered() -> bool { use windows::Win32::System::Registry::{ - RegOpenKeyExW, RegCloseKey, HKEY_CLASSES_ROOT, HKEY_CURRENT_USER, KEY_READ, + RegCloseKey, RegOpenKeyExW, HKEY_CLASSES_ROOT, HKEY_CURRENT_USER, KEY_READ, }; unsafe { @@ -318,7 +317,9 @@ pub fn is_protocol_handler_registered() -> bool { 0, KEY_READ, &mut key, - ).is_ok() { + ) + .is_ok() + { let _ = RegCloseKey(key); return true; } @@ -331,7 +332,9 @@ pub fn is_protocol_handler_registered() -> bool { 0, KEY_READ, &mut key, - ).is_ok() { + ) + .is_ok() + { let _ = RegCloseKey(key); return true; } @@ -355,22 +358,25 @@ pub fn parse_protocol_url(url_str: &str) -> Result<(String, String, Option Vec { #[cfg(windows)] fn description_to_bytes(wide: &[u16]) -> Vec { - wide.iter() - .flat_map(|w| w.to_le_bytes()) - .collect() + wide.iter().flat_map(|w| w.to_le_bytes()).collect() } diff --git a/agent/src/main.rs b/agent/src/main.rs index 8d980e4..5fdf9c6 100644 --- a/agent/src/main.rs +++ b/agent/src/main.rs @@ -92,16 +92,18 @@ pub mod build_info { use anyhow::Result; use clap::{Parser, Subcommand}; -use tracing::{info, error, warn, Level}; +use tracing::{error, info, warn, Level}; use tracing_subscriber::FmtSubscriber; -#[cfg(windows)] -use windows::Win32::UI::WindowsAndMessaging::{MessageBoxW, MB_OK, MB_ICONINFORMATION, MB_ICONERROR}; #[cfg(windows)] use windows::core::PCWSTR; #[cfg(windows)] use windows::Win32::System::Console::{AllocConsole, GetConsoleWindow}; #[cfg(windows)] +use windows::Win32::UI::WindowsAndMessaging::{ + MessageBoxW, MB_ICONERROR, MB_ICONINFORMATION, MB_OK, +}; +#[cfg(windows)] use windows::Win32::UI::WindowsAndMessaging::{ShowWindow, SW_SHOW}; /// GuruConnect Remote Desktop @@ -140,7 +142,11 @@ enum Commands { session_id: String, /// Server URL - #[arg(short, long, default_value = "wss://connect.azcomputerguru.com/ws/viewer")] + #[arg( + short, + long, + default_value = "wss://connect.azcomputerguru.com/ws/viewer" + )] server: String, /// API key for authentication @@ -177,15 +183,27 @@ fn main() -> Result<()> { let cli = Cli::parse(); // Initialize logging - let level = if cli.verbose { Level::DEBUG } else { Level::INFO }; + let level = if cli.verbose { + Level::DEBUG + } else { + Level::INFO + }; FmtSubscriber::builder() .with_max_level(level) .with_target(true) .with_thread_ids(true) .init(); - info!("GuruConnect {} ({})", build_info::short_version(), build_info::BUILD_TARGET); - info!("Built: {} | Commit: {}", build_info::BUILD_TIMESTAMP, build_info::GIT_COMMIT_DATE); + info!( + "GuruConnect {} ({})", + build_info::short_version(), + build_info::BUILD_TARGET + ); + info!( + "Built: {} | Commit: {}", + build_info::BUILD_TIMESTAMP, + build_info::GIT_COMMIT_DATE + ); // Handle post-update cleanup if cli.post_update { @@ -194,21 +212,18 @@ fn main() -> Result<()> { } match cli.command { - Some(Commands::Agent { code }) => { - run_agent_mode(code) - } - Some(Commands::View { session_id, server, api_key }) => { - run_viewer_mode(&server, &session_id, &api_key) - } - Some(Commands::Install { user_only, elevated }) => { - run_install(user_only || elevated) - } - Some(Commands::Uninstall) => { - run_uninstall() - } - Some(Commands::Launch { url }) => { - run_launch(&url) - } + Some(Commands::Agent { code }) => run_agent_mode(code), + Some(Commands::View { + session_id, + server, + api_key, + }) => run_viewer_mode(&server, &session_id, &api_key), + Some(Commands::Install { + user_only, + elevated, + }) => run_install(user_only || elevated), + Some(Commands::Uninstall) => run_uninstall(), + Some(Commands::Launch { url }) => run_launch(&url), Some(Commands::VersionInfo) => { // Show detailed version info (allocate console on Windows for visibility) #[cfg(windows)] @@ -341,7 +356,10 @@ fn run_install(force_user_install: bool) -> Result<()> { match install::install(force_user_install) { Ok(()) => { - show_message_box("GuruConnect", "Installation complete!\n\nYou can now use guruconnect:// links."); + show_message_box( + "GuruConnect", + "Installation complete!\n\nYou can now use guruconnect:// links.", + ); Ok(()) } Err(e) => { @@ -467,7 +485,11 @@ async fn run_agent(config: config::Config) -> Result<()> { } // Create tray icon - let tray = match tray::TrayController::new(&hostname, config.support_code.as_deref(), is_support_session) { + let tray = match tray::TrayController::new( + &hostname, + config.support_code.as_deref(), + is_support_session, + ) { Ok(t) => { info!("Tray icon created"); Some(t) @@ -503,7 +525,10 @@ async fn run_agent(config: config::Config) -> Result<()> { t.update_status("Status: Connected"); } - if let Err(e) = session.run_with_tray(tray.as_ref(), chat_ctrl.as_ref()).await { + if let Err(e) = session + .run_with_tray(tray.as_ref(), chat_ctrl.as_ref()) + .await + { let error_msg = e.to_string(); if error_msg.contains("USER_EXIT") { @@ -515,7 +540,10 @@ async fn run_agent(config: config::Config) -> Result<()> { if error_msg.contains("SESSION_CANCELLED") { info!("Session was cancelled by technician"); cleanup_on_exit(); - show_message_box("Support Session Ended", "The support session was cancelled."); + show_message_box( + "Support Session Ended", + "The support session was cancelled.", + ); return Ok(()); } @@ -524,7 +552,10 @@ async fn run_agent(config: config::Config) -> Result<()> { if let Err(e) = startup::uninstall() { warn!("Uninstall failed: {}", e); } - show_message_box("Remote Session Ended", "The session was ended by the administrator."); + show_message_box( + "Remote Session Ended", + "The session was ended by the administrator.", + ); return Ok(()); } @@ -533,7 +564,10 @@ async fn run_agent(config: config::Config) -> Result<()> { if let Err(e) = startup::uninstall() { warn!("Uninstall failed: {}", e); } - show_message_box("GuruConnect Removed", "This computer has been removed from remote management."); + show_message_box( + "GuruConnect Removed", + "This computer has been removed from remote management.", + ); return Ok(()); } @@ -551,7 +585,10 @@ async fn run_agent(config: config::Config) -> Result<()> { if error_msg.contains("cancelled") { info!("Support code was cancelled"); cleanup_on_exit(); - show_message_box("Support Session Cancelled", "This support session has been cancelled."); + show_message_box( + "Support Session Cancelled", + "This support session has been cancelled.", + ); return Ok(()); } diff --git a/agent/src/sas_client.rs b/agent/src/sas_client.rs index 2251757..4a2afad 100644 --- a/agent/src/sas_client.rs +++ b/agent/src/sas_client.rs @@ -18,11 +18,7 @@ pub fn request_sas() -> Result<()> { info!("Requesting SAS via service pipe..."); // Try to connect to the pipe - let mut pipe = match OpenOptions::new() - .read(true) - .write(true) - .open(PIPE_NAME) - { + let mut pipe = match OpenOptions::new().read(true).write(true).open(PIPE_NAME) { Ok(p) => p, Err(e) => { warn!("Failed to connect to SAS service pipe: {}", e); @@ -40,7 +36,8 @@ pub fn request_sas() -> Result<()> { // Read the response let mut response = [0u8; 64]; - let n = pipe.read(&mut response) + let n = pipe + .read(&mut response) .context("Failed to read response from SAS service")?; let response_str = String::from_utf8_lossy(&response[..n]); @@ -59,7 +56,10 @@ pub fn request_sas() -> Result<()> { } _ => { error!("Unexpected response from SAS service: {}", response_str); - Err(anyhow::anyhow!("Unexpected SAS service response: {}", response_str)) + Err(anyhow::anyhow!( + "Unexpected SAS service response: {}", + response_str + )) } } } @@ -67,11 +67,7 @@ pub fn request_sas() -> Result<()> { /// Check if the SAS service is available pub fn is_service_available() -> bool { // Try to open the pipe - if let Ok(mut pipe) = OpenOptions::new() - .read(true) - .write(true) - .open(PIPE_NAME) - { + if let Ok(mut pipe) = OpenOptions::new().read(true).write(true).open(PIPE_NAME) { // Send a ping command if pipe.write_all(b"ping\n").is_ok() { let mut response = [0u8; 64]; diff --git a/agent/src/session/mod.rs b/agent/src/session/mod.rs index 2790968..1a9b3c9 100644 --- a/agent/src/session/mod.rs +++ b/agent/src/session/mod.rs @@ -37,9 +37,9 @@ fn show_debug_console() { // No-op on non-Windows platforms } -use crate::proto::{Message, message, ChatMessage, AgentStatus, Heartbeat, HeartbeatAck}; +use crate::proto::{message, AgentStatus, ChatMessage, Heartbeat, HeartbeatAck, Message}; use crate::transport::WebSocketTransport; -use crate::tray::{TrayController, TrayAction}; +use crate::tray::{TrayAction, TrayController}; use anyhow::Result; use std::time::{Duration, Instant}; @@ -71,8 +71,8 @@ pub struct SessionManager { enum SessionState { Disconnected, Connecting, - Idle, // Connected but not streaming - minimal resource usage - Streaming, // Actively capturing and sending frames + Idle, // Connected but not streaming - minimal resource usage + Streaming, // Actively capturing and sending frames } impl SessionManager { @@ -103,10 +103,11 @@ impl SessionManager { &self.config.api_key, Some(&self.hostname), self.config.support_code.as_deref(), - ).await?; + ) + .await?; self.transport = Some(transport); - self.state = SessionState::Idle; // Start in idle mode + self.state = SessionState::Idle; // Start in idle mode tracing::info!("Connected to server, entering idle mode"); @@ -120,8 +121,12 @@ impl SessionManager { } tracing::info!("Initializing streaming resources..."); - tracing::info!("Capture config: use_dxgi={}, gdi_fallback={}, fps={}", - self.config.capture.use_dxgi, self.config.capture.gdi_fallback, self.config.capture.fps); + tracing::info!( + "Capture config: use_dxgi={}, gdi_fallback={}, fps={}", + self.config.capture.use_dxgi, + self.config.capture.gdi_fallback, + self.config.capture.fps + ); // Get primary display with panic protection tracing::debug!("Enumerating displays..."); @@ -132,12 +137,19 @@ impl SessionManager { return Err(anyhow::anyhow!("Display enumeration panicked")); } }; - tracing::info!("Using display: {} ({}x{})", - primary_display.name, primary_display.width, primary_display.height); + tracing::info!( + "Using display: {} ({}x{})", + primary_display.name, + primary_display.width, + primary_display.height + ); // Create capturer with panic protection // Force GDI mode if DXGI fails or panics - tracing::debug!("Creating capturer (DXGI={})...", self.config.capture.use_dxgi); + tracing::debug!( + "Creating capturer (DXGI={})...", + self.config.capture.use_dxgi + ); let capturer = match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { capture::create_capturer( primary_display.clone(), @@ -157,13 +169,13 @@ impl SessionManager { tracing::info!("Capturer created successfully"); // Create encoder with panic protection - tracing::debug!("Creating encoder (codec={}, quality={})...", - self.config.encoding.codec, self.config.encoding.quality); + tracing::debug!( + "Creating encoder (codec={}, quality={})...", + self.config.encoding.codec, + self.config.encoding.quality + ); let encoder = match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { - encoder::create_encoder( - &self.config.encoding.codec, - self.config.encoding.quality, - ) + encoder::create_encoder(&self.config.encoding.codec, self.config.encoding.quality) })) { Ok(result) => result?, Err(e) => { @@ -202,7 +214,9 @@ impl SessionManager { /// Get display count for status reports fn get_display_count(&self) -> i32 { - capture::enumerate_displays().map(|d| d.len() as i32).unwrap_or(1) + capture::enumerate_displays() + .map(|d| d.len() as i32) + .unwrap_or(1) } /// Send agent status to server @@ -249,7 +263,11 @@ impl SessionManager { } /// Run the session main loop with tray and chat event processing - pub async fn run_with_tray(&mut self, tray: Option<&TrayController>, chat: Option<&ChatController>) -> Result<()> { + pub async fn run_with_tray( + &mut self, + tray: Option<&TrayController>, + chat: Option<&ChatController>, + ) -> Result<()> { if self.transport.is_none() { anyhow::bail!("Not connected"); } @@ -395,12 +413,23 @@ impl SessionManager { } // Periodic update check (only for persistent agents, not support sessions) - if self.config.support_code.is_none() && last_update_check.elapsed() >= UPDATE_CHECK_INTERVAL { + if self.config.support_code.is_none() + && last_update_check.elapsed() >= UPDATE_CHECK_INTERVAL + { last_update_check = Instant::now(); - let server_url = self.config.server_url.replace("/ws/agent", "").replace("wss://", "https://").replace("ws://", "http://"); + let server_url = self + .config + .server_url + .replace("/ws/agent", "") + .replace("wss://", "https://") + .replace("ws://", "http://"); match crate::update::check_for_update(&server_url).await { Ok(Some(version_info)) => { - tracing::info!("Update available: {} -> {}", crate::build_info::VERSION, version_info.latest_version); + tracing::info!( + "Update available: {} -> {}", + crate::build_info::VERSION, + version_info.latest_version + ); if let Err(e) = crate::update::perform_update(&version_info).await { tracing::error!("Auto-update failed: {}", e); } @@ -429,7 +458,9 @@ impl SessionManager { if let Ok(encoded) = encoder.encode(&frame) { if encoded.size > 0 { let msg = Message { - payload: Some(message::Payload::VideoFrame(encoded.frame)), + payload: Some(message::Payload::VideoFrame( + encoded.frame, + )), }; let transport = self.transport.as_mut().unwrap(); if let Err(e) = transport.send(msg).await { @@ -472,26 +503,40 @@ impl SessionManager { match msg.payload { Some(message::Payload::MouseEvent(mouse)) => { if let Some(input) = self.input.as_mut() { - use crate::proto::MouseEventType; use crate::input::MouseButton; + use crate::proto::MouseEventType; - match MouseEventType::try_from(mouse.event_type).unwrap_or(MouseEventType::MouseMove) { + match MouseEventType::try_from(mouse.event_type) + .unwrap_or(MouseEventType::MouseMove) + { MouseEventType::MouseMove => { input.mouse_move(mouse.x, mouse.y)?; } MouseEventType::MouseDown => { input.mouse_move(mouse.x, mouse.y)?; if let Some(ref buttons) = mouse.buttons { - if buttons.left { input.mouse_click(MouseButton::Left, true)?; } - if buttons.right { input.mouse_click(MouseButton::Right, true)?; } - if buttons.middle { input.mouse_click(MouseButton::Middle, true)?; } + if buttons.left { + input.mouse_click(MouseButton::Left, true)?; + } + if buttons.right { + input.mouse_click(MouseButton::Right, true)?; + } + if buttons.middle { + input.mouse_click(MouseButton::Middle, true)?; + } } } MouseEventType::MouseUp => { if let Some(ref buttons) = mouse.buttons { - if buttons.left { input.mouse_click(MouseButton::Left, false)?; } - if buttons.right { input.mouse_click(MouseButton::Right, false)?; } - if buttons.middle { input.mouse_click(MouseButton::Middle, false)?; } + if buttons.left { + input.mouse_click(MouseButton::Left, false)?; + } + if buttons.right { + input.mouse_click(MouseButton::Right, false)?; + } + if buttons.middle { + input.mouse_click(MouseButton::Middle, false)?; + } } } MouseEventType::MouseWheel => { @@ -538,10 +583,19 @@ impl SessionManager { tracing::info!("Update command received from server: {}", cmd.reason); // Trigger update check and perform update if available // The server URL is derived from the config - let server_url = self.config.server_url.replace("/ws/agent", "").replace("wss://", "https://").replace("ws://", "http://"); + let server_url = self + .config + .server_url + .replace("/ws/agent", "") + .replace("wss://", "https://") + .replace("ws://", "http://"); match crate::update::check_for_update(&server_url).await { Ok(Some(version_info)) => { - tracing::info!("Update available: {} -> {}", crate::build_info::VERSION, version_info.latest_version); + tracing::info!( + "Update available: {} -> {}", + crate::build_info::VERSION, + version_info.latest_version + ); if let Err(e) = crate::update::perform_update(&version_info).await { tracing::error!("Update failed: {}", e); } diff --git a/agent/src/startup.rs b/agent/src/startup.rs index 728596a..ce8327d 100644 --- a/agent/src/startup.rs +++ b/agent/src/startup.rs @@ -3,15 +3,15 @@ //! Handles adding/removing the agent from Windows startup. use anyhow::Result; -use tracing::{info, warn, error}; +use tracing::{error, info, warn}; -#[cfg(windows)] -use windows::Win32::System::Registry::{ - RegOpenKeyExW, RegSetValueExW, RegDeleteValueW, RegCloseKey, - HKEY_CURRENT_USER, KEY_WRITE, REG_SZ, -}; #[cfg(windows)] use windows::core::PCWSTR; +#[cfg(windows)] +use windows::Win32::System::Registry::{ + RegCloseKey, RegDeleteValueW, RegOpenKeyExW, RegSetValueExW, HKEY_CURRENT_USER, KEY_WRITE, + REG_SZ, +}; const STARTUP_KEY: &str = r"Software\Microsoft\Windows\CurrentVersion\Run"; const STARTUP_VALUE_NAME: &str = "GuruConnect"; @@ -61,10 +61,8 @@ pub fn add_to_startup() -> Result<()> { let hkey_raw = std::mem::transmute::<_, windows::Win32::System::Registry::HKEY>(hkey); // Set the value - let data_bytes = std::slice::from_raw_parts( - value_data.as_ptr() as *const u8, - value_data.len() * 2, - ); + let data_bytes = + std::slice::from_raw_parts(value_data.as_ptr() as *const u8, value_data.len() * 2); let set_result = RegSetValueExW( hkey_raw, @@ -168,7 +166,10 @@ pub fn uninstall() -> Result<()> { ); if result.is_err() { - warn!("Failed to schedule file deletion: {:?}. File may need manual removal.", result); + warn!( + "Failed to schedule file deletion: {:?}. File may need manual removal.", + result + ); } else { info!("Executable scheduled for deletion on reboot"); } @@ -185,12 +186,15 @@ pub fn install_sas_service() -> Result<()> { // Check if the SAS service binary exists alongside the agent let exe_path = std::env::current_exe()?; - let exe_dir = exe_path.parent().ok_or_else(|| anyhow::anyhow!("No parent directory"))?; + let exe_dir = exe_path + .parent() + .ok_or_else(|| anyhow::anyhow!("No parent directory"))?; let sas_binary = exe_dir.join("guruconnect-sas-service.exe"); if !sas_binary.exists() { // Also check in Program Files - let program_files = std::path::PathBuf::from(r"C:\Program Files\GuruConnect\guruconnect-sas-service.exe"); + let program_files = + std::path::PathBuf::from(r"C:\Program Files\GuruConnect\guruconnect-sas-service.exe"); if !program_files.exists() { warn!("SAS service binary not found"); return Ok(()); @@ -232,16 +236,18 @@ pub fn uninstall_sas_service() -> Result<()> { // Try to find and run the uninstall command let paths = [ - std::env::current_exe().ok().and_then(|p| p.parent().map(|d| d.join("guruconnect-sas-service.exe"))), - Some(std::path::PathBuf::from(r"C:\Program Files\GuruConnect\guruconnect-sas-service.exe")), + std::env::current_exe() + .ok() + .and_then(|p| p.parent().map(|d| d.join("guruconnect-sas-service.exe"))), + Some(std::path::PathBuf::from( + r"C:\Program Files\GuruConnect\guruconnect-sas-service.exe", + )), ]; for path_opt in paths.iter() { if let Some(ref path) = path_opt { if path.exists() { - let output = std::process::Command::new(path) - .arg("uninstall") - .output(); + let output = std::process::Command::new(path).arg("uninstall").output(); if let Ok(result) = output { if result.status.success() { diff --git a/agent/src/transport/websocket.rs b/agent/src/transport/websocket.rs index c4d5bbe..33ad74f 100644 --- a/agent/src/transport/websocket.rs +++ b/agent/src/transport/websocket.rs @@ -103,11 +103,7 @@ impl WebSocketTransport { let mut stream = stream.lock().await; // Use try_next for non-blocking receive - match tokio::time::timeout( - std::time::Duration::from_millis(1), - stream.next(), - ) - .await + match tokio::time::timeout(std::time::Duration::from_millis(1), stream.next()).await { Ok(Some(Ok(ws_msg))) => Ok(Some(ws_msg)), Ok(Some(Err(e))) => Err(anyhow::anyhow!("WebSocket error: {}", e)), diff --git a/agent/src/tray/mod.rs b/agent/src/tray/mod.rs index 03d8049..2e3bf3c 100644 --- a/agent/src/tray/mod.rs +++ b/agent/src/tray/mod.rs @@ -9,12 +9,12 @@ use anyhow::Result; use muda::{Menu, MenuEvent, MenuItem, PredefinedMenuItem, Submenu}; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; -use tray_icon::{Icon, TrayIcon, TrayIconBuilder, TrayIconEvent}; use tracing::{info, warn}; +use tray_icon::{Icon, TrayIcon, TrayIconBuilder, TrayIconEvent}; #[cfg(windows)] use windows::Win32::UI::WindowsAndMessaging::{ - PeekMessageW, TranslateMessage, DispatchMessageW, MSG, PM_REMOVE, + DispatchMessageW, PeekMessageW, TranslateMessage, MSG, PM_REMOVE, }; /// Events that can be triggered from the tray menu @@ -38,7 +38,11 @@ pub struct TrayController { impl TrayController { /// Create a new tray controller /// `allow_end_session` - If true, show "End Session" menu item (only for support sessions) - pub fn new(machine_name: &str, support_code: Option<&str>, allow_end_session: bool) -> Result { + pub fn new( + machine_name: &str, + support_code: Option<&str>, + allow_end_session: bool, + ) -> Result { // Create menu items let status_text = if let Some(code) = support_code { format!("Support Session: {}", code) @@ -166,9 +170,9 @@ fn create_default_icon() -> Result { if dist <= radius { // Green circle - rgba[idx] = 76; // R + rgba[idx] = 76; // R rgba[idx + 1] = 175; // G - rgba[idx + 2] = 80; // B + rgba[idx + 2] = 80; // B rgba[idx + 3] = 255; // A } else if dist <= radius + 1.0 { // Anti-aliased edge diff --git a/agent/src/update.rs b/agent/src/update.rs index ea9785c..a1bf768 100644 --- a/agent/src/update.rs +++ b/agent/src/update.rs @@ -4,9 +4,9 @@ //! in-place binary replacement with restart. use anyhow::{anyhow, Result}; -use sha2::{Sha256, Digest}; +use sha2::{Digest, Sha256}; use std::path::PathBuf; -use tracing::{info, warn, error}; +use tracing::{error, info, warn}; use crate::build_info; @@ -38,7 +38,7 @@ pub async fn check_for_update(server_base_url: &str) -> Result bool { let available_clean = available.split('-').next().unwrap_or(available); let current_clean = current.split('-').next().unwrap_or(current); - let parse_version = |s: &str| -> Vec { - s.split('.') - .filter_map(|p| p.parse().ok()) - .collect() - }; + let parse_version = + |s: &str| -> Vec { s.split('.').filter_map(|p| p.parse().ok()).collect() }; let av = parse_version(available_clean); let cv = parse_version(current_clean); @@ -112,7 +109,7 @@ pub async fn download_update(version_info: &VersionInfo) -> Result { let response = client .get(&version_info.download_url) - .timeout(std::time::Duration::from_secs(300)) // 5 minutes for large files + .timeout(std::time::Duration::from_secs(300)) // 5 minutes for large files .send() .await?; @@ -147,7 +144,10 @@ pub fn verify_checksum(file_path: &PathBuf, expected_sha256: &str) -> Result Result { // Get current executable path let current_exe = std::env::current_exe()?; - let exe_dir = current_exe.parent() + let exe_dir = current_exe + .parent() .ok_or_else(|| anyhow!("Cannot get executable directory"))?; // Create paths for backup and new executable @@ -257,10 +258,11 @@ pub fn cleanup_post_update() { #[cfg(windows)] fn schedule_delete_on_reboot(path: &PathBuf) { use std::os::windows::ffi::OsStrExt; - use windows::Win32::Storage::FileSystem::{MoveFileExW, MOVEFILE_DELAY_UNTIL_REBOOT}; use windows::core::PCWSTR; + use windows::Win32::Storage::FileSystem::{MoveFileExW, MOVEFILE_DELAY_UNTIL_REBOOT}; - let path_wide: Vec = path.as_os_str() + let path_wide: Vec = path + .as_os_str() .encode_wide() .chain(std::iter::once(0)) .collect(); diff --git a/agent/src/viewer/input.rs b/agent/src/viewer/input.rs index 553c951..1c9e50c 100644 --- a/agent/src/viewer/input.rs +++ b/agent/src/viewer/input.rs @@ -37,11 +37,11 @@ mod vk { pub const VK_RSHIFT: u32 = 0xA1; pub const VK_LCONTROL: u32 = 0xA2; pub const VK_RCONTROL: u32 = 0xA3; - pub const VK_LMENU: u32 = 0xA4; // Left Alt - pub const VK_RMENU: u32 = 0xA5; // Right Alt + pub const VK_LMENU: u32 = 0xA4; // Left Alt + pub const VK_RMENU: u32 = 0xA5; // Right Alt pub const VK_TAB: u32 = 0x09; pub const VK_ESCAPE: u32 = 0x1B; - pub const VK_SNAPSHOT: u32 = 0x2C; // Print Screen + pub const VK_SNAPSHOT: u32 = 0x2C; // Print Screen } #[cfg(windows)] @@ -53,15 +53,12 @@ pub struct KeyboardHook { impl KeyboardHook { pub fn new(input_tx: mpsc::Sender) -> Result { // Store the sender globally for the hook callback - INPUT_TX.set(input_tx).map_err(|_| anyhow::anyhow!("Input TX already set"))?; + INPUT_TX + .set(input_tx) + .map_err(|_| anyhow::anyhow!("Input TX already set"))?; unsafe { - let hook = SetWindowsHookExW( - WH_KEYBOARD_LL, - Some(keyboard_hook_proc), - None, - 0, - )?; + let hook = SetWindowsHookExW(WH_KEYBOARD_LL, Some(keyboard_hook_proc), None, 0)?; HOOK_HANDLE = hook; Ok(Self { _hook: hook }) @@ -82,11 +79,7 @@ impl Drop for KeyboardHook { } #[cfg(windows)] -unsafe extern "system" fn keyboard_hook_proc( - code: i32, - wparam: WPARAM, - lparam: LPARAM, -) -> LRESULT { +unsafe extern "system" fn keyboard_hook_proc(code: i32, wparam: WPARAM, lparam: LPARAM) -> LRESULT { if code >= 0 { let kb_struct = &*(lparam.0 as *const KBDLLHOOKSTRUCT); let vk_code = kb_struct.vkCode; @@ -97,10 +90,7 @@ unsafe extern "system" fn keyboard_hook_proc( if is_down || is_up { // Check if this is a key we want to intercept (Win key, Alt+Tab, etc.) - let should_intercept = matches!( - vk_code, - vk::VK_LWIN | vk::VK_RWIN | vk::VK_APPS - ); + let should_intercept = matches!(vk_code, vk::VK_LWIN | vk::VK_RWIN | vk::VK_APPS); // Send the key event to the remote if let Some(tx) = INPUT_TX.get() { @@ -114,7 +104,12 @@ unsafe extern "system" fn keyboard_hook_proc( }; let _ = tx.try_send(InputEvent::Key(event)); - trace!("Key hook: vk={:#x} scan={} down={}", vk_code, scan_code, is_down); + trace!( + "Key hook: vk={:#x} scan={} down={}", + vk_code, + scan_code, + is_down + ); } // For Win key, consume the event so it doesn't open Start menu locally @@ -133,12 +128,12 @@ fn get_current_modifiers() -> proto::Modifiers { unsafe { proto::Modifiers { - ctrl: GetAsyncKeyState(0x11) < 0, // VK_CONTROL - alt: GetAsyncKeyState(0x12) < 0, // VK_MENU - shift: GetAsyncKeyState(0x10) < 0, // VK_SHIFT + ctrl: GetAsyncKeyState(0x11) < 0, // VK_CONTROL + alt: GetAsyncKeyState(0x12) < 0, // VK_MENU + shift: GetAsyncKeyState(0x10) < 0, // VK_SHIFT meta: GetAsyncKeyState(0x5B) < 0 || GetAsyncKeyState(0x5C) < 0, // VK_LWIN/RWIN caps_lock: GetAsyncKeyState(0x14) & 1 != 0, // VK_CAPITAL - num_lock: GetAsyncKeyState(0x90) & 1 != 0, // VK_NUMLOCK + num_lock: GetAsyncKeyState(0x90) & 1 != 0, // VK_NUMLOCK } } } diff --git a/agent/src/viewer/mod.rs b/agent/src/viewer/mod.rs index 44315e1..ff6fe10 100644 --- a/agent/src/viewer/mod.rs +++ b/agent/src/viewer/mod.rs @@ -11,7 +11,7 @@ use crate::proto; use anyhow::Result; use std::sync::Arc; use tokio::sync::{mpsc, Mutex}; -use tracing::{info, error, warn}; +use tracing::{error, info, warn}; #[derive(Debug, Clone)] pub enum ViewerEvent { @@ -93,16 +93,18 @@ pub async fn run(server_url: &str, session_id: &str, api_key: &str) -> Result<() } } Some(proto::message::Payload::CursorPosition(pos)) => { - let _ = viewer_tx_recv.send(ViewerEvent::CursorPosition( - pos.x, pos.y, pos.visible - )).await; + let _ = viewer_tx_recv + .send(ViewerEvent::CursorPosition(pos.x, pos.y, pos.visible)) + .await; } Some(proto::message::Payload::CursorShape(shape)) => { let _ = viewer_tx_recv.send(ViewerEvent::CursorShape(shape)).await; } Some(proto::message::Payload::Disconnect(d)) => { warn!("Server disconnected: {}", d.reason); - let _ = viewer_tx_recv.send(ViewerEvent::Disconnected(d.reason)).await; + let _ = viewer_tx_recv + .send(ViewerEvent::Disconnected(d.reason)) + .await; break; } _ => {} diff --git a/agent/src/viewer/render.rs b/agent/src/viewer/render.rs index eee9a63..ea76cc0 100644 --- a/agent/src/viewer/render.rs +++ b/agent/src/viewer/render.rs @@ -1,9 +1,9 @@ //! Window rendering and frame display -use super::{ViewerEvent, InputEvent}; -use crate::proto; #[cfg(windows)] use super::input; +use super::{InputEvent, ViewerEvent}; +use crate::proto; use anyhow::Result; use std::num::NonZeroU32; use std::sync::Arc; @@ -43,10 +43,7 @@ struct ViewerApp { } impl ViewerApp { - fn new( - viewer_rx: mpsc::Receiver, - input_tx: mpsc::Sender, - ) -> Self { + fn new(viewer_rx: mpsc::Receiver, input_tx: mpsc::Sender) -> Self { Self { window: None, surface: None, @@ -112,7 +109,9 @@ impl ViewerApp { } fn render(&mut self) { - let Some(surface) = &mut self.surface else { return }; + let Some(surface) = &mut self.surface else { + return; + }; let Some(window) = &self.window else { return }; if self.frame_buffer.is_empty() || self.frame_width == 0 || self.frame_height == 0 { diff --git a/agent/src/viewer/transport.rs b/agent/src/viewer/transport.rs index 8826a8e..57eb6aa 100644 --- a/agent/src/viewer/transport.rs +++ b/agent/src/viewer/transport.rs @@ -6,23 +6,17 @@ use bytes::Bytes; use futures_util::{SinkExt, StreamExt}; use prost::Message as ProstMessage; use std::sync::Arc; +use tokio::net::TcpStream; use tokio::sync::Mutex; use tokio_tungstenite::{ - connect_async, - tungstenite::protocol::Message as WsMessage, - MaybeTlsStream, WebSocketStream, + connect_async, tungstenite::protocol::Message as WsMessage, MaybeTlsStream, WebSocketStream, }; -use tokio::net::TcpStream; use tracing::{debug, error, trace}; -pub type WsSender = futures_util::stream::SplitSink< - WebSocketStream>, - WsMessage, ->; +pub type WsSender = + futures_util::stream::SplitSink>, WsMessage>; -pub type WsReceiver = futures_util::stream::SplitStream< - WebSocketStream>, ->; +pub type WsReceiver = futures_util::stream::SplitStream>>; /// Receiver wrapper that parses protobuf messages pub struct MessageReceiver { @@ -88,10 +82,7 @@ pub async fn connect(url: &str, token: &str) -> Result<(WsSender, MessageReceive } /// Send a protobuf message over the WebSocket -pub async fn send_message( - sender: &Arc>, - msg: &proto::Message, -) -> Result<()> { +pub async fn send_message(sender: &Arc>, msg: &proto::Message) -> Result<()> { let mut buf = Vec::with_capacity(msg.encoded_len()); msg.encode(&mut buf)?; diff --git a/server/src/api/auth.rs b/server/src/api/auth.rs index de1a933..9e759e3 100644 --- a/server/src/api/auth.rs +++ b/server/src/api/auth.rs @@ -1,15 +1,13 @@ //! Authentication API endpoints use axum::{ - extract::{State, Request}, + extract::{Request, State}, http::StatusCode, Json, }; use serde::{Deserialize, Serialize}; -use crate::auth::{ - verify_password, AuthenticatedUser, JwtConfig, -}; +use crate::auth::{verify_password, AuthenticatedUser, JwtConfig}; use crate::db; use crate::AppState; @@ -89,16 +87,15 @@ pub async fn login( } // Verify password - let password_valid = verify_password(&request.password, &user.password_hash) - .map_err(|e| { - tracing::error!("Password verification error: {}", e); - ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(ErrorResponse { - error: "Internal server error".to_string(), - }), - ) - })?; + let password_valid = verify_password(&request.password, &user.password_hash).map_err(|e| { + tracing::error!("Password verification error: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: "Internal server error".to_string(), + }), + ) + })?; if !password_valid { return Err(( @@ -118,21 +115,18 @@ pub async fn login( let _ = db::update_last_login(db.pool(), user.id).await; // Create JWT token - let token = state.jwt_config.create_token( - user.id, - &user.username, - &user.role, - permissions.clone(), - ) - .map_err(|e| { - tracing::error!("Token creation error: {}", e); - ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(ErrorResponse { - error: "Failed to create token".to_string(), - }), - ) - })?; + let token = state + .jwt_config + .create_token(user.id, &user.username, &user.role, permissions.clone()) + .map_err(|e| { + tracing::error!("Token creation error: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: "Failed to create token".to_string(), + }), + ) + })?; tracing::info!("User {} logged in successfully", user.username); @@ -288,16 +282,15 @@ pub async fn change_password( } // Hash new password - let new_hash = crate::auth::hash_password(&request.new_password) - .map_err(|e| { - tracing::error!("Password hashing error: {}", e); - ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(ErrorResponse { - error: "Failed to hash password".to_string(), - }), - ) - })?; + let new_hash = crate::auth::hash_password(&request.new_password).map_err(|e| { + tracing::error!("Password hashing error: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: "Failed to hash password".to_string(), + }), + ) + })?; // Update password db::update_user_password(db.pool(), user_id, &new_hash) diff --git a/server/src/api/auth_logout.rs b/server/src/api/auth_logout.rs index 2270fa4..f8b095a 100644 --- a/server/src/api/auth_logout.rs +++ b/server/src/api/auth_logout.rs @@ -1,13 +1,13 @@ //! Logout and token revocation endpoints use axum::{ - extract::{Request, State, Path}, - http::{StatusCode, HeaderMap}, + extract::{Path, Request, State}, + http::{HeaderMap, StatusCode}, Json, }; -use uuid::Uuid; use serde::Serialize; use tracing::{info, warn}; +use uuid::Uuid; use crate::auth::AuthenticatedUser; use crate::AppState; @@ -15,7 +15,9 @@ use crate::AppState; use super::auth::ErrorResponse; /// Extract JWT token from Authorization header -fn extract_token_from_headers(headers: &HeaderMap) -> Result)> { +fn extract_token_from_headers( + headers: &HeaderMap, +) -> Result)> { let auth_header = headers .get("Authorization") .and_then(|v| v.to_str().ok()) @@ -28,16 +30,14 @@ fn extract_token_from_headers(headers: &HeaderMap) -> Result impl IntoResponse { .header(header::CONTENT_TYPE, "application/octet-stream") .header( header::CONTENT_DISPOSITION, - "attachment; filename=\"GuruConnect-Viewer.exe\"" + "attachment; filename=\"GuruConnect-Viewer.exe\"", ) .header(header::CONTENT_LENGTH, binary_data.len()) .body(Body::from(binary_data)) @@ -104,9 +104,7 @@ pub async fn download_viewer() -> impl IntoResponse { } /// Download support session binary (code embedded in filename) -pub async fn download_support( - Query(params): Query, -) -> impl IntoResponse { +pub async fn download_support(Query(params): Query) -> impl IntoResponse { // Validate support code (must be 6 digits) let code = params.code.trim(); if code.len() != 6 || !code.chars().all(|c| c.is_ascii_digit()) { @@ -120,7 +118,11 @@ pub async fn download_support( match std::fs::read(&binary_path) { Ok(binary_data) => { - info!("Serving support session download for code {} ({} bytes)", code, binary_data.len()); + info!( + "Serving support session download for code {} ({} bytes)", + code, + binary_data.len() + ); // Filename includes the support code let filename = format!("GuruConnect-{}.exe", code); @@ -130,7 +132,7 @@ pub async fn download_support( .header(header::CONTENT_TYPE, "application/octet-stream") .header( header::CONTENT_DISPOSITION, - format!("attachment; filename=\"{}\"", filename) + format!("attachment; filename=\"{}\"", filename), ) .header(header::CONTENT_LENGTH, binary_data.len()) .body(Body::from(binary_data)) @@ -147,9 +149,7 @@ pub async fn download_support( } /// Download permanent agent binary with embedded configuration -pub async fn download_agent( - Query(params): Query, -) -> impl IntoResponse { +pub async fn download_agent(Query(params): Query) -> impl IntoResponse { let binary_path = get_base_binary_path(); // Read base binary @@ -167,10 +167,13 @@ pub async fn download_agent( // Build embedded config let config = EmbeddedConfig { server_url: "wss://connect.azcomputerguru.com/ws/agent".to_string(), - api_key: params.api_key.unwrap_or_else(|| "managed-agent".to_string()), + api_key: params + .api_key + .unwrap_or_else(|| "managed-agent".to_string()), company: params.company.clone(), site: params.site.clone(), - tags: params.tags + tags: params + .tags .as_ref() .map(|t| t.split(',').map(|s| s.trim().to_string()).collect()) .unwrap_or_default(), @@ -196,18 +199,25 @@ pub async fn download_agent( info!( "Serving permanent agent download: company={:?}, site={:?}, tags={:?} ({} bytes)", - config.company, config.site, config.tags, binary_data.len() + config.company, + config.site, + config.tags, + binary_data.len() ); // Generate filename based on company/site let filename = match (¶ms.company, ¶ms.site) { (Some(company), Some(site)) => { - format!("GuruConnect-{}-{}-Setup.exe", sanitize_filename(company), sanitize_filename(site)) + format!( + "GuruConnect-{}-{}-Setup.exe", + sanitize_filename(company), + sanitize_filename(site) + ) } (Some(company), None) => { format!("GuruConnect-{}-Setup.exe", sanitize_filename(company)) } - _ => "GuruConnect-Setup.exe".to_string() + _ => "GuruConnect-Setup.exe".to_string(), }; Response::builder() @@ -215,7 +225,7 @@ pub async fn download_agent( .header(header::CONTENT_TYPE, "application/octet-stream") .header( header::CONTENT_DISPOSITION, - format!("attachment; filename=\"{}\"", filename) + format!("attachment; filename=\"{}\"", filename), ) .header(header::CONTENT_LENGTH, binary_data.len()) .body(Body::from(binary_data)) diff --git a/server/src/api/mod.rs b/server/src/api/mod.rs index 2deb5e3..7cd0738 100644 --- a/server/src/api/mod.rs +++ b/server/src/api/mod.rs @@ -3,19 +3,19 @@ pub mod auth; pub mod auth_logout; pub mod changelog; -pub mod users; -pub mod releases; pub mod downloads; +pub mod releases; +pub mod users; use axum::{ - extract::{Path, State, Query}, + extract::{Path, Query, State}, Json, }; use serde::{Deserialize, Serialize}; use uuid::Uuid; -use crate::session::SessionManager; use crate::db; +use crate::session::SessionManager; /// Viewer info returned by API #[derive(Debug, Serialize)] @@ -78,9 +78,7 @@ impl From for SessionInfo { } /// List all active sessions -pub async fn list_sessions( - State(sessions): State, -) -> Json> { +pub async fn list_sessions(State(sessions): State) -> Json> { let sessions = sessions.list_sessions().await; Json(sessions.into_iter().map(SessionInfo::from).collect()) } @@ -93,7 +91,9 @@ pub async fn get_session( let session_id = Uuid::parse_str(&id) .map_err(|_| (axum::http::StatusCode::BAD_REQUEST, "Invalid session ID"))?; - let session = sessions.get_session(session_id).await + let session = sessions + .get_session(session_id) + .await .ok_or((axum::http::StatusCode::NOT_FOUND, "Session not found"))?; Ok(Json(SessionInfo::from(session))) diff --git a/server/src/api/releases.rs b/server/src/api/releases.rs index 2e2b9ba..9d86445 100644 --- a/server/src/api/releases.rs +++ b/server/src/api/releases.rs @@ -129,17 +129,15 @@ pub async fn list_releases( ) })?; - let releases = db::get_all_releases(db.pool()) - .await - .map_err(|e| { - tracing::error!("Database error: {}", e); - ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(ErrorResponse { - error: "Failed to fetch releases".to_string(), - }), - ) - })?; + let releases = db::get_all_releases(db.pool()).await.map_err(|e| { + tracing::error!("Database error: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: "Failed to fetch releases".to_string(), + }), + ) + })?; Ok(Json(releases.into_iter().map(ReleaseInfo::from).collect())) } @@ -171,7 +169,10 @@ pub async fn create_release( // Validate checksum format (64 hex chars for SHA-256) if request.checksum_sha256.len() != 64 - || !request.checksum_sha256.chars().all(|c| c.is_ascii_hexdigit()) + || !request + .checksum_sha256 + .chars() + .all(|c| c.is_ascii_hexdigit()) { return Err(( StatusCode::BAD_REQUEST, @@ -349,17 +350,15 @@ pub async fn delete_release( ) })?; - let deleted = db::delete_release(db.pool(), &version) - .await - .map_err(|e| { - tracing::error!("Database error: {}", e); - ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(ErrorResponse { - error: "Failed to delete release".to_string(), - }), - ) - })?; + let deleted = db::delete_release(db.pool(), &version).await.map_err(|e| { + tracing::error!("Database error: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: "Failed to delete release".to_string(), + }), + ) + })?; if deleted { tracing::info!("Deleted release: {}", version); diff --git a/server/src/api/users.rs b/server/src/api/users.rs index 895717c..cfa5f51 100644 --- a/server/src/api/users.rs +++ b/server/src/api/users.rs @@ -72,17 +72,15 @@ pub async fn list_users( ) })?; - let users = db::get_all_users(db.pool()) - .await - .map_err(|e| { - tracing::error!("Database error: {}", e); - ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(ErrorResponse { - error: "Failed to fetch users".to_string(), - }), - ) - })?; + let users = db::get_all_users(db.pool()).await.map_err(|e| { + tracing::error!("Database error: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: "Failed to fetch users".to_string(), + }), + ) + })?; let mut result = Vec::new(); for user in users { @@ -210,7 +208,13 @@ pub async fn create_user( } else { // Default permissions based on role let default_perms = match request.role.as_str() { - "admin" => vec!["view", "control", "transfer", "manage_users", "manage_clients"], + "admin" => vec![ + "view", + "control", + "transfer", + "manage_users", + "manage_clients", + ], "operator" => vec!["view", "control", "transfer"], "viewer" => vec!["view"], _ => vec!["view"], @@ -455,17 +459,15 @@ pub async fn delete_user( )); } - let deleted = db::delete_user(db.pool(), user_id) - .await - .map_err(|e| { - tracing::error!("Database error: {}", e); - ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(ErrorResponse { - error: "Failed to delete user".to_string(), - }), - ) - })?; + let deleted = db::delete_user(db.pool(), user_id).await.map_err(|e| { + tracing::error!("Database error: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: "Failed to delete user".to_string(), + }), + ) + })?; if deleted { tracing::info!("Deleted user: {}", id); @@ -506,13 +508,22 @@ pub async fn set_permissions( })?; // Validate permissions - let valid_permissions = ["view", "control", "transfer", "manage_users", "manage_clients"]; + let valid_permissions = [ + "view", + "control", + "transfer", + "manage_users", + "manage_clients", + ]; for perm in &request.permissions { if !valid_permissions.contains(&perm.as_str()) { return Err(( StatusCode::BAD_REQUEST, Json(ErrorResponse { - error: format!("Invalid permission: {}. Valid: {:?}", perm, valid_permissions), + error: format!( + "Invalid permission: {}. Valid: {:?}", + perm, valid_permissions + ), }), )); } diff --git a/server/src/auth/jwt.rs b/server/src/auth/jwt.rs index 62202f4..ce8562f 100644 --- a/server/src/auth/jwt.rs +++ b/server/src/auth/jwt.rs @@ -54,7 +54,10 @@ pub struct JwtConfig { impl JwtConfig { /// Create new JWT config pub fn new(secret: String, expiry_hours: i64) -> Self { - Self { secret, expiry_hours } + Self { + secret, + expiry_hours, + } } /// Create a JWT token for a user @@ -97,9 +100,9 @@ impl JwtConfig { pub fn validate_token(&self, token: &str) -> Result { // SEC-13: Explicit validation configuration let mut validation = Validation::default(); - validation.validate_exp = true; // Enforce expiration check + validation.validate_exp = true; // Enforce expiration check validation.validate_nbf = false; // Not using "not before" claim - validation.leeway = 0; // No clock skew tolerance + validation.leeway = 0; // No clock skew tolerance let token_data = decode::( token, @@ -129,12 +132,14 @@ mod tests { let config = JwtConfig::new("test-secret".to_string(), 24); let user_id = Uuid::new_v4(); - let token = config.create_token( - user_id, - "testuser", - "admin", - vec!["view".to_string(), "control".to_string()], - ).unwrap(); + let token = config + .create_token( + user_id, + "testuser", + "admin", + vec!["view".to_string(), "control".to_string()], + ) + .unwrap(); let claims = config.validate_token(&token).unwrap(); assert_eq!(claims.username, "testuser"); diff --git a/server/src/auth/mod.rs b/server/src/auth/mod.rs index 5682ca8..df924d4 100644 --- a/server/src/auth/mod.rs +++ b/server/src/auth/mod.rs @@ -8,7 +8,7 @@ pub mod password; pub mod token_blacklist; pub use jwt::{Claims, JwtConfig}; -pub use password::{hash_password, verify_password, generate_random_password}; +pub use password::{generate_random_password, hash_password, verify_password}; pub use token_blacklist::TokenBlacklist; use axum::{ diff --git a/server/src/auth/password.rs b/server/src/auth/password.rs index 7bb1c2f..34adb0c 100644 --- a/server/src/auth/password.rs +++ b/server/src/auth/password.rs @@ -6,7 +6,7 @@ use anyhow::{anyhow, Result}; use argon2::{ password_hash::{rand_core::OsRng, PasswordHash, PasswordHasher, PasswordVerifier, SaltString}, - Argon2, Algorithm, Version, Params, + Algorithm, Argon2, Params, Version, }; /// Hash a password using Argon2id @@ -22,9 +22,9 @@ pub fn hash_password(password: &str) -> Result { // Explicitly use Argon2id (Algorithm::Argon2id) let argon2 = Argon2::new( - Algorithm::Argon2id, // SEC-9: Explicit Argon2id variant - Version::V0x13, // Latest version - Params::default(), // Default params (19456 KiB, 2 iterations, 1 parallelism) + Algorithm::Argon2id, // SEC-9: Explicit Argon2id variant + Version::V0x13, // Latest version + Params::default(), // Default params (19456 KiB, 2 iterations, 1 parallelism) ); let hash = argon2 @@ -35,12 +35,14 @@ pub fn hash_password(password: &str) -> Result { /// Verify a password against a stored hash pub fn verify_password(password: &str, hash: &str) -> Result { - let parsed_hash = PasswordHash::new(hash) - .map_err(|e| anyhow!("Invalid password hash format: {}", e))?; + let parsed_hash = + PasswordHash::new(hash).map_err(|e| anyhow!("Invalid password hash format: {}", e))?; // Argon2::default() uses Argon2id, but we verify against the hash's embedded algorithm let argon2 = Argon2::default(); - Ok(argon2.verify_password(password.as_bytes(), &parsed_hash).is_ok()) + Ok(argon2 + .verify_password(password.as_bytes(), &parsed_hash) + .is_ok()) } /// Generate a random password (for initial admin) diff --git a/server/src/auth/token_blacklist.rs b/server/src/auth/token_blacklist.rs index c392be7..8b6a8e2 100644 --- a/server/src/auth/token_blacklist.rs +++ b/server/src/auth/token_blacklist.rs @@ -6,7 +6,7 @@ use std::collections::HashSet; use std::sync::Arc; use tokio::sync::RwLock; -use tracing::{info, debug}; +use tracing::{debug, info}; /// Token blacklist for revocation /// @@ -41,7 +41,10 @@ impl TokenBlacklist { let was_new = tokens.insert(token.to_string()); if was_new { - debug!("Token revoked and added to blacklist (length: {})", token.len()); + debug!( + "Token revoked and added to blacklist (length: {})", + token.len() + ); } } @@ -92,7 +95,11 @@ impl TokenBlacklist { let removed = original_len - tokens.len(); if removed > 0 { - info!("Cleaned {} expired tokens from blacklist ({} remaining)", removed, tokens.len()); + info!( + "Cleaned {} expired tokens from blacklist ({} remaining)", + removed, + tokens.len() + ); } removed diff --git a/server/src/db/events.rs b/server/src/db/events.rs index 87c120f..f5bd971 100644 --- a/server/src/db/events.rs +++ b/server/src/db/events.rs @@ -36,8 +36,10 @@ impl EventTypes { pub const CONNECTION_REJECTED_NO_AUTH: &'static str = "connection_rejected_no_auth"; pub const CONNECTION_REJECTED_INVALID_CODE: &'static str = "connection_rejected_invalid_code"; pub const CONNECTION_REJECTED_EXPIRED_CODE: &'static str = "connection_rejected_expired_code"; - pub const CONNECTION_REJECTED_INVALID_API_KEY: &'static str = "connection_rejected_invalid_api_key"; - pub const CONNECTION_REJECTED_CANCELLED_CODE: &'static str = "connection_rejected_cancelled_code"; + pub const CONNECTION_REJECTED_INVALID_API_KEY: &'static str = + "connection_rejected_invalid_api_key"; + pub const CONNECTION_REJECTED_CANCELLED_CODE: &'static str = + "connection_rejected_cancelled_code"; } /// Log a session event diff --git a/server/src/db/machines.rs b/server/src/db/machines.rs index fd44a9b..b5cd52f 100644 --- a/server/src/db/machines.rs +++ b/server/src/db/machines.rs @@ -80,7 +80,7 @@ pub async fn update_machine_status( /// Get all persistent machines (for restore on startup) pub async fn get_all_machines(pool: &PgPool) -> Result, sqlx::Error> { sqlx::query_as::<_, Machine>( - "SELECT * FROM connect_machines WHERE is_persistent = true ORDER BY hostname" + "SELECT * FROM connect_machines WHERE is_persistent = true ORDER BY hostname", ) .fetch_all(pool) .await @@ -91,20 +91,20 @@ pub async fn get_machine_by_agent_id( pool: &PgPool, agent_id: &str, ) -> Result, sqlx::Error> { - sqlx::query_as::<_, Machine>( - "SELECT * FROM connect_machines WHERE agent_id = $1" - ) - .bind(agent_id) - .fetch_optional(pool) - .await + sqlx::query_as::<_, Machine>("SELECT * FROM connect_machines WHERE agent_id = $1") + .bind(agent_id) + .fetch_optional(pool) + .await } /// Mark machine as offline pub async fn mark_machine_offline(pool: &PgPool, agent_id: &str) -> Result<(), sqlx::Error> { - sqlx::query("UPDATE connect_machines SET status = 'offline', last_seen = NOW() WHERE agent_id = $1") - .bind(agent_id) - .execute(pool) - .await?; + sqlx::query( + "UPDATE connect_machines SET status = 'offline', last_seen = NOW() WHERE agent_id = $1", + ) + .bind(agent_id) + .execute(pool) + .await?; Ok(()) } diff --git a/server/src/db/mod.rs b/server/src/db/mod.rs index 3fa9559..69ecf61 100644 --- a/server/src/db/mod.rs +++ b/server/src/db/mod.rs @@ -3,24 +3,24 @@ //! Handles persistence for machines, sessions, and audit logging. //! Optional - server works without database if DATABASE_URL not set. -pub mod machines; -pub mod sessions; pub mod events; +pub mod machines; +pub mod releases; +pub mod sessions; pub mod support_codes; pub mod users; -pub mod releases; use anyhow::Result; use sqlx::postgres::PgPoolOptions; use sqlx::PgPool; use tracing::info; -pub use machines::*; -pub use sessions::*; pub use events::*; +pub use machines::*; +pub use releases::*; +pub use sessions::*; pub use support_codes::*; pub use users::*; -pub use releases::*; /// Database connection pool wrapper #[derive(Clone)] diff --git a/server/src/db/sessions.rs b/server/src/db/sessions.rs index 488af14..2dbf0e1 100644 --- a/server/src/db/sessions.rs +++ b/server/src/db/sessions.rs @@ -45,7 +45,7 @@ pub async fn create_session( pub async fn end_session( pool: &PgPool, session_id: Uuid, - status: &str, // 'ended' or 'disconnected' or 'timeout' + status: &str, // 'ended' or 'disconnected' or 'timeout' ) -> Result<(), sqlx::Error> { sqlx::query( r#" @@ -64,7 +64,10 @@ pub async fn end_session( } /// Get session by ID -pub async fn get_session(pool: &PgPool, session_id: Uuid) -> Result, sqlx::Error> { +pub async fn get_session( + pool: &PgPool, + session_id: Uuid, +) -> Result, sqlx::Error> { sqlx::query_as::<_, DbSession>("SELECT * FROM connect_sessions WHERE id = $1") .bind(session_id) .fetch_optional(pool) @@ -85,12 +88,9 @@ pub async fn get_active_sessions_for_machine( } /// Get recent sessions (for dashboard) -pub async fn get_recent_sessions( - pool: &PgPool, - limit: i64, -) -> Result, sqlx::Error> { +pub async fn get_recent_sessions(pool: &PgPool, limit: i64) -> Result, sqlx::Error> { sqlx::query_as::<_, DbSession>( - "SELECT * FROM connect_sessions ORDER BY started_at DESC LIMIT $1" + "SELECT * FROM connect_sessions ORDER BY started_at DESC LIMIT $1", ) .bind(limit) .fetch_all(pool) @@ -103,7 +103,7 @@ pub async fn get_sessions_for_machine( machine_id: Uuid, ) -> Result, sqlx::Error> { sqlx::query_as::<_, DbSession>( - "SELECT * FROM connect_sessions WHERE machine_id = $1 ORDER BY started_at DESC" + "SELECT * FROM connect_sessions WHERE machine_id = $1 ORDER BY started_at DESC", ) .bind(machine_id) .fetch_all(pool) diff --git a/server/src/db/support_codes.rs b/server/src/db/support_codes.rs index 41b38eb..b6f4d15 100644 --- a/server/src/db/support_codes.rs +++ b/server/src/db/support_codes.rs @@ -40,13 +40,14 @@ pub async fn create_support_code( } /// Get support code by code string -pub async fn get_support_code(pool: &PgPool, code: &str) -> Result, sqlx::Error> { - sqlx::query_as::<_, DbSupportCode>( - "SELECT * FROM connect_support_codes WHERE code = $1" - ) - .bind(code) - .fetch_optional(pool) - .await +pub async fn get_support_code( + pool: &PgPool, + code: &str, +) -> Result, sqlx::Error> { + sqlx::query_as::<_, DbSupportCode>("SELECT * FROM connect_support_codes WHERE code = $1") + .bind(code) + .fetch_optional(pool) + .await } /// Update support code when client connects @@ -107,7 +108,7 @@ pub async fn get_active_support_codes(pool: &PgPool) -> Result Result { let result = sqlx::query_scalar::<_, bool>( - "SELECT EXISTS(SELECT 1 FROM connect_support_codes WHERE code = $1 AND status = 'pending')" + "SELECT EXISTS(SELECT 1 FROM connect_support_codes WHERE code = $1 AND status = 'pending')", ) .bind(code) .fetch_one(pool) diff --git a/server/src/db/users.rs b/server/src/db/users.rs index f75d219..8e099e2 100644 --- a/server/src/db/users.rs +++ b/server/src/db/users.rs @@ -49,33 +49,27 @@ impl From for UserInfo { /// Get user by username pub async fn get_user_by_username(pool: &PgPool, username: &str) -> Result> { - let user = sqlx::query_as::<_, User>( - "SELECT * FROM users WHERE username = $1" - ) - .bind(username) - .fetch_optional(pool) - .await?; + let user = sqlx::query_as::<_, User>("SELECT * FROM users WHERE username = $1") + .bind(username) + .fetch_optional(pool) + .await?; Ok(user) } /// Get user by ID pub async fn get_user_by_id(pool: &PgPool, id: Uuid) -> Result> { - let user = sqlx::query_as::<_, User>( - "SELECT * FROM users WHERE id = $1" - ) - .bind(id) - .fetch_optional(pool) - .await?; + let user = sqlx::query_as::<_, User>("SELECT * FROM users WHERE id = $1") + .bind(id) + .fetch_optional(pool) + .await?; Ok(user) } /// Get all users pub async fn get_all_users(pool: &PgPool) -> Result> { - let users = sqlx::query_as::<_, User>( - "SELECT * FROM users ORDER BY username" - ) - .fetch_all(pool) - .await?; + let users = sqlx::query_as::<_, User>("SELECT * FROM users ORDER BY username") + .fetch_all(pool) + .await?; Ok(users) } @@ -92,7 +86,7 @@ pub async fn create_user( INSERT INTO users (username, password_hash, email, role) VALUES ($1, $2, $3, $4) RETURNING * - "# + "#, ) .bind(username) .bind(password_hash) @@ -117,7 +111,7 @@ pub async fn update_user( SET email = $2, role = $3, enabled = $4, updated_at = NOW() WHERE id = $1 RETURNING * - "# + "#, ) .bind(id) .bind(email) @@ -129,18 +123,13 @@ pub async fn update_user( } /// Update user password -pub async fn update_user_password( - pool: &PgPool, - id: Uuid, - password_hash: &str, -) -> Result { - let result = sqlx::query( - "UPDATE users SET password_hash = $2, updated_at = NOW() WHERE id = $1" - ) - .bind(id) - .bind(password_hash) - .execute(pool) - .await?; +pub async fn update_user_password(pool: &PgPool, id: Uuid, password_hash: &str) -> Result { + let result = + sqlx::query("UPDATE users SET password_hash = $2, updated_at = NOW() WHERE id = $1") + .bind(id) + .bind(password_hash) + .execute(pool) + .await?; Ok(result.rows_affected() > 0) } @@ -172,12 +161,11 @@ pub async fn count_users(pool: &PgPool) -> Result { /// Get user permissions pub async fn get_user_permissions(pool: &PgPool, user_id: Uuid) -> Result> { - let perms: Vec<(String,)> = sqlx::query_as( - "SELECT permission FROM user_permissions WHERE user_id = $1" - ) - .bind(user_id) - .fetch_all(pool) - .await?; + let perms: Vec<(String,)> = + sqlx::query_as("SELECT permission FROM user_permissions WHERE user_id = $1") + .bind(user_id) + .fetch_all(pool) + .await?; Ok(perms.into_iter().map(|p| p.0).collect()) } @@ -195,25 +183,22 @@ pub async fn set_user_permissions( // Insert new for perm in permissions { - sqlx::query( - "INSERT INTO user_permissions (user_id, permission) VALUES ($1, $2)" - ) - .bind(user_id) - .bind(perm) - .execute(pool) - .await?; + sqlx::query("INSERT INTO user_permissions (user_id, permission) VALUES ($1, $2)") + .bind(user_id) + .bind(perm) + .execute(pool) + .await?; } Ok(()) } /// Get user's accessible client IDs (empty = all access) pub async fn get_user_client_access(pool: &PgPool, user_id: Uuid) -> Result> { - let clients: Vec<(Uuid,)> = sqlx::query_as( - "SELECT client_id FROM user_client_access WHERE user_id = $1" - ) - .bind(user_id) - .fetch_all(pool) - .await?; + let clients: Vec<(Uuid,)> = + sqlx::query_as("SELECT client_id FROM user_client_access WHERE user_id = $1") + .bind(user_id) + .fetch_all(pool) + .await?; Ok(clients.into_iter().map(|c| c.0).collect()) } @@ -231,23 +216,17 @@ pub async fn set_user_client_access( // Insert new for client_id in client_ids { - sqlx::query( - "INSERT INTO user_client_access (user_id, client_id) VALUES ($1, $2)" - ) - .bind(user_id) - .bind(client_id) - .execute(pool) - .await?; + sqlx::query("INSERT INTO user_client_access (user_id, client_id) VALUES ($1, $2)") + .bind(user_id) + .bind(client_id) + .execute(pool) + .await?; } Ok(()) } /// Check if user has access to a specific client -pub async fn user_has_client_access( - pool: &PgPool, - user_id: Uuid, - client_id: Uuid, -) -> Result { +pub async fn user_has_client_access(pool: &PgPool, user_id: Uuid, client_id: Uuid) -> Result { // Admins have access to all let user = get_user_by_id(pool, user_id).await?; if let Some(u) = user { @@ -258,7 +237,7 @@ pub async fn user_has_client_access( // Check explicit access let access: Option<(Uuid,)> = sqlx::query_as( - "SELECT client_id FROM user_client_access WHERE user_id = $1 AND client_id = $2" + "SELECT client_id FROM user_client_access WHERE user_id = $1 AND client_id = $2", ) .bind(user_id) .bind(client_id) @@ -271,12 +250,11 @@ pub async fn user_has_client_access( } // Check if user has ANY access restrictions - let count: (i64,) = sqlx::query_as( - "SELECT COUNT(*) FROM user_client_access WHERE user_id = $1" - ) - .bind(user_id) - .fetch_one(pool) - .await?; + let count: (i64,) = + sqlx::query_as("SELECT COUNT(*) FROM user_client_access WHERE user_id = $1") + .bind(user_id) + .fetch_one(pool) + .await?; // No restrictions means access to all Ok(count.0 == 0) diff --git a/server/src/main.rs b/server/src/main.rs index 7316cbf..5cb1185 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -3,44 +3,44 @@ //! Handles connections from both agents and dashboard viewers, //! relaying video frames and input events between them. +mod api; +mod auth; mod config; +mod db; +mod metrics; +mod middleware; mod relay; mod session; -mod auth; -mod api; -mod db; mod support_codes; -mod middleware; mod utils; -mod metrics; pub mod proto { include!(concat!(env!("OUT_DIR"), "/guruconnect.rs")); } use anyhow::Result; +use axum::http::{HeaderValue, Method}; use axum::{ - Router, - routing::{get, post, put, delete}, - extract::{Path, State, Json, Query, Request}, - response::{Html, IntoResponse}, + extract::{Json, Path, Query, Request, State}, http::StatusCode, middleware::{self as axum_middleware, Next}, + response::{Html, IntoResponse}, + routing::{delete, get, post, put}, + Router, }; +use serde::Deserialize; use std::net::SocketAddr; use std::sync::Arc; -use tower_http::cors::{Any, CorsLayer, AllowOrigin}; -use axum::http::{Method, HeaderValue}; -use tower_http::trace::TraceLayer; +use tower_http::cors::{AllowOrigin, Any, CorsLayer}; use tower_http::services::ServeDir; +use tower_http::trace::TraceLayer; use tracing::{info, Level}; use tracing_subscriber::FmtSubscriber; -use serde::Deserialize; -use support_codes::{SupportCodeManager, CreateCodeRequest, SupportCode, CodeValidation}; -use auth::{JwtConfig, TokenBlacklist, hash_password, generate_random_password, AuthenticatedUser}; +use auth::{generate_random_password, hash_password, AuthenticatedUser, JwtConfig, TokenBlacklist}; use metrics::SharedMetrics; use prometheus_client::registry::Registry; +use support_codes::{CodeValidation, CreateCodeRequest, SupportCode, SupportCodeManager}; /// Application state #[derive(Clone)] @@ -67,7 +67,9 @@ async fn auth_layer( next: Next, ) -> impl IntoResponse { request.extensions_mut().insert(state.jwt_config.clone()); - request.extensions_mut().insert(Arc::new(state.token_blacklist.clone())); + request + .extensions_mut() + .insert(Arc::new(state.token_blacklist.clone())); next.run(request).await } @@ -89,8 +91,9 @@ async fn main() -> Result<()> { info!("Loaded configuration, listening on {}", listen_addr); // JWT configuration - REQUIRED for security - let jwt_secret = std::env::var("JWT_SECRET") - .expect("JWT_SECRET environment variable must be set! Generate one with: openssl rand -base64 64"); + let jwt_secret = std::env::var("JWT_SECRET").expect( + "JWT_SECRET environment variable must be set! Generate one with: openssl rand -base64 64", + ); if jwt_secret.len() < 32 { panic!("JWT_SECRET must be at least 32 characters long for security!"); @@ -114,7 +117,10 @@ async fn main() -> Result<()> { Some(db) } Err(e) => { - tracing::warn!("Failed to connect to database: {}. Running without persistence.", e); + tracing::warn!( + "Failed to connect to database: {}. Running without persistence.", + e + ); None } } @@ -194,9 +200,14 @@ async fn main() -> Result<()> { if let Some(ref db) = database { match db::machines::get_all_machines(db.pool()).await { Ok(machines) => { - info!("Restoring {} persistent machines from database", machines.len()); + info!( + "Restoring {} persistent machines from database", + machines.len() + ); for machine in machines { - sessions.restore_offline_machine(&machine.agent_id, &machine.hostname).await; + sessions + .restore_offline_machine(&machine.agent_id, &machine.hostname) + .await; } } Err(e) => { @@ -254,92 +265,117 @@ async fn main() -> Result<()> { .route("/health", get(health)) // Prometheus metrics (no auth required - for monitoring) .route("/metrics", get(prometheus_metrics)) - // Auth endpoints (TODO: Add rate limiting - see SEC2_RATE_LIMITING_TODO.md) .route("/api/auth/login", post(api::auth::login)) - .route("/api/auth/change-password", post(api::auth::change_password)) + .route( + "/api/auth/change-password", + post(api::auth::change_password), + ) .route("/api/auth/me", get(api::auth::get_me)) .route("/api/auth/logout", post(api::auth_logout::logout)) - .route("/api/auth/revoke-token", post(api::auth_logout::revoke_own_token)) - .route("/api/auth/admin/revoke-user", post(api::auth_logout::revoke_user_tokens)) - .route("/api/auth/blacklist/stats", get(api::auth_logout::get_blacklist_stats)) - .route("/api/auth/blacklist/cleanup", post(api::auth_logout::cleanup_blacklist)) - + .route( + "/api/auth/revoke-token", + post(api::auth_logout::revoke_own_token), + ) + .route( + "/api/auth/admin/revoke-user", + post(api::auth_logout::revoke_user_tokens), + ) + .route( + "/api/auth/blacklist/stats", + get(api::auth_logout::get_blacklist_stats), + ) + .route( + "/api/auth/blacklist/cleanup", + post(api::auth_logout::cleanup_blacklist), + ) // User management (admin only) .route("/api/users", get(api::users::list_users)) .route("/api/users", post(api::users::create_user)) .route("/api/users/:id", get(api::users::get_user)) .route("/api/users/:id", put(api::users::update_user)) .route("/api/users/:id", delete(api::users::delete_user)) - .route("/api/users/:id/permissions", put(api::users::set_permissions)) + .route( + "/api/users/:id/permissions", + put(api::users::set_permissions), + ) .route("/api/users/:id/clients", put(api::users::set_client_access)) - // Portal API - Support codes (TODO: Add rate limiting) .route("/api/codes", post(create_code)) .route("/api/codes", get(list_codes)) .route("/api/codes/:code/validate", get(validate_code)) .route("/api/codes/:code/cancel", post(cancel_code)) - // WebSocket endpoints .route("/ws/agent", get(relay::agent_ws_handler)) .route("/ws/viewer", get(relay::viewer_ws_handler)) - // REST API - Sessions .route("/api/sessions", get(list_sessions)) .route("/api/sessions/:id", get(get_session)) .route("/api/sessions/:id", delete(disconnect_session)) - // REST API - Machines .route("/api/machines", get(list_machines)) .route("/api/machines/:agent_id", get(get_machine)) .route("/api/machines/:agent_id", delete(delete_machine)) .route("/api/machines/:agent_id/history", get(get_machine_history)) - .route("/api/machines/:agent_id/update", post(trigger_machine_update)) - + .route( + "/api/machines/:agent_id/update", + post(trigger_machine_update), + ) // REST API - Releases and Version - .route("/api/version", get(api::releases::get_version)) // No auth - for agent polling + .route("/api/version", get(api::releases::get_version)) // No auth - for agent polling .route("/api/releases", get(api::releases::list_releases)) .route("/api/releases", post(api::releases::create_release)) .route("/api/releases/:version", get(api::releases::get_release)) .route("/api/releases/:version", put(api::releases::update_release)) - .route("/api/releases/:version", delete(api::releases::delete_release)) - + .route( + "/api/releases/:version", + delete(api::releases::delete_release), + ) // Changelog (no auth - public, like /api/version) // Single route: version == "latest" selects the latest file; axum 0.7 / matchit 0.7 // panics if a static segment and a path param share this position, so do not split it. - .route("/api/changelog/:component/:version", get(api::changelog::get)) - + .route( + "/api/changelog/:component/:version", + get(api::changelog::get), + ) // Agent downloads (no auth - public download links) .route("/api/download/viewer", get(api::downloads::download_viewer)) - .route("/api/download/support", get(api::downloads::download_support)) + .route( + "/api/download/support", + get(api::downloads::download_support), + ) .route("/api/download/agent", get(api::downloads::download_agent)) - // HTML page routes (clean URLs) .route("/login", get(serve_login)) .route("/dashboard", get(serve_dashboard)) .route("/users", get(serve_users)) - // State and middleware .with_state(state.clone()) .layer(axum_middleware::from_fn_with_state(state, auth_layer)) - // Serve static files for portal (fallback) .fallback_service(ServeDir::new("static").append_index_html_on_directories(true)) - // Middleware - .layer(axum_middleware::from_fn(middleware::add_security_headers)) // SEC-7 & SEC-12 + .layer(axum_middleware::from_fn(middleware::add_security_headers)) // SEC-7 & SEC-12 .layer(TraceLayer::new_for_http()) // SEC-11: Restricted CORS configuration .layer({ let cors = CorsLayer::new() // Allow requests from the production domain and localhost (for development) .allow_origin([ - "https://connect.azcomputerguru.com".parse::().unwrap(), + "https://connect.azcomputerguru.com" + .parse::() + .unwrap(), "http://localhost:3002".parse::().unwrap(), "http://127.0.0.1:3002".parse::().unwrap(), ]) // Allow only necessary HTTP methods - .allow_methods([Method::GET, Method::POST, Method::PUT, Method::DELETE, Method::OPTIONS]) + .allow_methods([ + Method::GET, + Method::POST, + Method::PUT, + Method::DELETE, + Method::OPTIONS, + ]) // Allow common headers needed for API requests .allow_headers([ axum::http::header::AUTHORIZATION, @@ -360,8 +396,9 @@ async fn main() -> Result<()> { // Use into_make_service_with_connect_info to enable IP address extraction axum::serve( listener, - app.into_make_service_with_connect_info::() - ).await?; + app.into_make_service_with_connect_info::(), + ) + .await?; Ok(()) } @@ -371,9 +408,7 @@ async fn health() -> &'static str { } /// Prometheus metrics endpoint -async fn prometheus_metrics( - State(state): State, -) -> String { +async fn prometheus_metrics(State(state): State) -> String { use prometheus_client::encoding::text::encode; let registry = state.registry.lock().unwrap(); @@ -385,7 +420,7 @@ async fn prometheus_metrics( // Support code API handlers async fn create_code( - _user: AuthenticatedUser, // Require authentication + _user: AuthenticatedUser, // Require authentication State(state): State, Json(request): Json, ) -> Json { @@ -395,7 +430,7 @@ async fn create_code( } async fn list_codes( - _user: AuthenticatedUser, // Require authentication + _user: AuthenticatedUser, // Require authentication State(state): State, ) -> Json> { Json(state.support_codes.list_active_codes().await) @@ -414,7 +449,7 @@ async fn validate_code( } async fn cancel_code( - _user: AuthenticatedUser, // Require authentication + _user: AuthenticatedUser, // Require authentication State(state): State, Path(code): Path, ) -> impl IntoResponse { @@ -428,7 +463,7 @@ async fn cancel_code( // Session API handlers (updated to use AppState) async fn list_sessions( - _user: AuthenticatedUser, // Require authentication + _user: AuthenticatedUser, // Require authentication State(state): State, ) -> Json> { let sessions = state.sessions.list_sessions().await; @@ -436,21 +471,24 @@ async fn list_sessions( } async fn get_session( - _user: AuthenticatedUser, // Require authentication + _user: AuthenticatedUser, // Require authentication State(state): State, Path(id): Path, ) -> Result, (StatusCode, &'static str)> { - let session_id = uuid::Uuid::parse_str(&id) - .map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID"))?; + let session_id = + uuid::Uuid::parse_str(&id).map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID"))?; - let session = state.sessions.get_session(session_id).await + let session = state + .sessions + .get_session(session_id) + .await .ok_or((StatusCode::NOT_FOUND, "Session not found"))?; Ok(Json(api::SessionInfo::from(session))) } async fn disconnect_session( - _user: AuthenticatedUser, // Require authentication + _user: AuthenticatedUser, // Require authentication State(state): State, Path(id): Path, ) -> impl IntoResponse { @@ -459,7 +497,11 @@ async fn disconnect_session( Err(_) => return (StatusCode::BAD_REQUEST, "Invalid session ID"), }; - if state.sessions.disconnect_session(session_id, "Disconnected by administrator").await { + if state + .sessions + .disconnect_session(session_id, "Disconnected by administrator") + .await + { info!("Session {} disconnected by admin", session_id); (StatusCode::OK, "Session disconnected") } else { @@ -470,27 +512,35 @@ async fn disconnect_session( // Machine API handlers async fn list_machines( - _user: AuthenticatedUser, // Require authentication + _user: AuthenticatedUser, // Require authentication State(state): State, ) -> Result>, (StatusCode, &'static str)> { - let db = state.db.as_ref() + let db = state + .db + .as_ref() .ok_or((StatusCode::SERVICE_UNAVAILABLE, "Database not available"))?; - let machines = db::machines::get_all_machines(db.pool()).await + let machines = db::machines::get_all_machines(db.pool()) + .await .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Database error"))?; - Ok(Json(machines.into_iter().map(api::MachineInfo::from).collect())) + Ok(Json( + machines.into_iter().map(api::MachineInfo::from).collect(), + )) } async fn get_machine( - _user: AuthenticatedUser, // Require authentication + _user: AuthenticatedUser, // Require authentication State(state): State, Path(agent_id): Path, ) -> Result, (StatusCode, &'static str)> { - let db = state.db.as_ref() + let db = state + .db + .as_ref() .ok_or((StatusCode::SERVICE_UNAVAILABLE, "Database not available"))?; - let machine = db::machines::get_machine_by_agent_id(db.pool(), &agent_id).await + let machine = db::machines::get_machine_by_agent_id(db.pool(), &agent_id) + .await .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Database error"))? .ok_or((StatusCode::NOT_FOUND, "Machine not found"))?; @@ -498,24 +548,29 @@ async fn get_machine( } async fn get_machine_history( - _user: AuthenticatedUser, // Require authentication + _user: AuthenticatedUser, // Require authentication State(state): State, Path(agent_id): Path, ) -> Result, (StatusCode, &'static str)> { - let db = state.db.as_ref() + let db = state + .db + .as_ref() .ok_or((StatusCode::SERVICE_UNAVAILABLE, "Database not available"))?; // Get machine - let machine = db::machines::get_machine_by_agent_id(db.pool(), &agent_id).await + let machine = db::machines::get_machine_by_agent_id(db.pool(), &agent_id) + .await .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Database error"))? .ok_or((StatusCode::NOT_FOUND, "Machine not found"))?; // Get sessions for this machine - let sessions = db::sessions::get_sessions_for_machine(db.pool(), machine.id).await + let sessions = db::sessions::get_sessions_for_machine(db.pool(), machine.id) + .await .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Database error"))?; // Get events for this machine - let events = db::events::get_events_for_machine(db.pool(), machine.id).await + let events = db::events::get_events_for_machine(db.pool(), machine.id) + .await .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Database error"))?; let history = api::MachineHistory { @@ -529,24 +584,29 @@ async fn get_machine_history( } async fn delete_machine( - _user: AuthenticatedUser, // Require authentication + _user: AuthenticatedUser, // Require authentication State(state): State, Path(agent_id): Path, Query(params): Query, ) -> Result, (StatusCode, &'static str)> { - let db = state.db.as_ref() + let db = state + .db + .as_ref() .ok_or((StatusCode::SERVICE_UNAVAILABLE, "Database not available"))?; // Get machine first - let machine = db::machines::get_machine_by_agent_id(db.pool(), &agent_id).await + let machine = db::machines::get_machine_by_agent_id(db.pool(), &agent_id) + .await .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Database error"))? .ok_or((StatusCode::NOT_FOUND, "Machine not found"))?; // Export history if requested let history = if params.export { - let sessions = db::sessions::get_sessions_for_machine(db.pool(), machine.id).await + let sessions = db::sessions::get_sessions_for_machine(db.pool(), machine.id) + .await .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Database error"))?; - let events = db::events::get_events_for_machine(db.pool(), machine.id).await + let events = db::events::get_events_for_machine(db.pool(), machine.id) + .await .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Database error"))?; Some(api::MachineHistory { @@ -565,11 +625,14 @@ async fn delete_machine( // Find session for this agent if let Some(session) = state.sessions.get_session_by_agent(&agent_id).await { if session.is_online { - uninstall_sent = state.sessions.send_admin_command( - session.id, - proto::AdminCommandType::AdminUninstall, - "Deleted by administrator", - ).await; + uninstall_sent = state + .sessions + .send_admin_command( + session.id, + proto::AdminCommandType::AdminUninstall, + "Deleted by administrator", + ) + .await; if uninstall_sent { info!("Sent uninstall command to agent {}", agent_id); } @@ -581,10 +644,19 @@ async fn delete_machine( state.sessions.remove_agent(&agent_id).await; // Delete from database (cascades to sessions and events) - db::machines::delete_machine(db.pool(), &agent_id).await - .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Failed to delete machine"))?; + db::machines::delete_machine(db.pool(), &agent_id) + .await + .map_err(|_| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + "Failed to delete machine", + ) + })?; - info!("Deleted machine {} (uninstall_sent: {})", agent_id, uninstall_sent); + info!( + "Deleted machine {} (uninstall_sent: {})", + agent_id, uninstall_sent + ); Ok(Json(api::DeleteMachineResponse { success: true, @@ -603,27 +675,34 @@ struct TriggerUpdateRequest { /// Trigger update on a specific machine async fn trigger_machine_update( - _user: AuthenticatedUser, // Require authentication + _user: AuthenticatedUser, // Require authentication State(state): State, Path(agent_id): Path, Json(request): Json, ) -> Result { - let db = state.db.as_ref() + let db = state + .db + .as_ref() .ok_or((StatusCode::SERVICE_UNAVAILABLE, "Database not available"))?; // Get the target release (either specified or latest stable) let release = if let Some(version) = request.version { - db::releases::get_release_by_version(db.pool(), &version).await + db::releases::get_release_by_version(db.pool(), &version) + .await .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Database error"))? .ok_or((StatusCode::NOT_FOUND, "Release version not found"))? } else { - db::releases::get_latest_stable_release(db.pool()).await + db::releases::get_latest_stable_release(db.pool()) + .await .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Database error"))? .ok_or((StatusCode::NOT_FOUND, "No stable release available"))? }; // Find session for this agent - let session = state.sessions.get_session_by_agent(&agent_id).await + let session = state + .sessions + .get_session_by_agent(&agent_id) + .await .ok_or((StatusCode::NOT_FOUND, "Agent not found or offline"))?; if !session.is_online { @@ -632,21 +711,31 @@ async fn trigger_machine_update( // Send update command via WebSocket // For now, we send admin command - later we'll include UpdateInfo in the message - let sent = state.sessions.send_admin_command( - session.id, - proto::AdminCommandType::AdminUpdate, - &format!("Update to version {}", release.version), - ).await; + let sent = state + .sessions + .send_admin_command( + session.id, + proto::AdminCommandType::AdminUpdate, + &format!("Update to version {}", release.version), + ) + .await; if sent { - info!("Sent update command to agent {} (version {})", agent_id, release.version); + info!( + "Sent update command to agent {} (version {})", + agent_id, release.version + ); // Update machine update status in database - let _ = db::releases::update_machine_update_status(db.pool(), &agent_id, "downloading").await; + let _ = + db::releases::update_machine_update_status(db.pool(), &agent_id, "downloading").await; Ok((StatusCode::OK, "Update command sent")) } else { - Err((StatusCode::INTERNAL_SERVER_ERROR, "Failed to send update command")) + Err(( + StatusCode::INTERNAL_SERVER_ERROR, + "Failed to send update command", + )) } } diff --git a/server/src/metrics/mod.rs b/server/src/metrics/mod.rs index b78ed76..fa01b4f 100644 --- a/server/src/metrics/mod.rs +++ b/server/src/metrics/mod.rs @@ -22,26 +22,26 @@ pub struct RequestLabels { /// Metrics labels for session events #[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelSet)] pub struct SessionLabels { - pub status: String, // created, closed, failed, expired + pub status: String, // created, closed, failed, expired } /// Metrics labels for connection events #[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelSet)] pub struct ConnectionLabels { - pub conn_type: String, // agent, viewer, dashboard + pub conn_type: String, // agent, viewer, dashboard } /// Metrics labels for error tracking #[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelSet)] pub struct ErrorLabels { - pub error_type: String, // auth, database, websocket, protocol, internal + pub error_type: String, // auth, database, websocket, protocol, internal } /// Metrics labels for database operations #[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelSet)] pub struct DatabaseLabels { - pub operation: String, // select, insert, update, delete - pub status: String, // success, error + pub operation: String, // select, insert, update, delete + pub status: String, // success, error } /// GuruConnect server metrics @@ -82,9 +82,10 @@ impl Metrics { requests_total.clone(), ); - let request_duration_seconds = Family::::new_with_constructor(|| { - Histogram::new(exponential_buckets(0.001, 2.0, 10)) // 1ms to ~1s - }); + let request_duration_seconds = + Family::::new_with_constructor(|| { + Histogram::new(exponential_buckets(0.001, 2.0, 10)) // 1ms to ~1s + }); registry.register( "guruconnect_request_duration_seconds", "HTTP request duration in seconds", @@ -106,7 +107,7 @@ impl Metrics { active_sessions.clone(), ); - let session_duration_seconds = Histogram::new(exponential_buckets(1.0, 2.0, 15)); // 1s to ~9 hours + let session_duration_seconds = Histogram::new(exponential_buckets(1.0, 2.0, 15)); // 1s to ~9 hours registry.register( "guruconnect_session_duration_seconds", "Session duration in seconds", @@ -144,9 +145,10 @@ impl Metrics { db_operations_total.clone(), ); - let db_query_duration_seconds = Family::::new_with_constructor(|| { - Histogram::new(exponential_buckets(0.0001, 2.0, 12)) // 0.1ms to ~400ms - }); + let db_query_duration_seconds = + Family::::new_with_constructor(|| { + Histogram::new(exponential_buckets(0.0001, 2.0, 12)) // 0.1ms to ~400ms + }); registry.register( "guruconnect_db_query_duration_seconds", "Database query duration in seconds", @@ -188,7 +190,13 @@ impl Metrics { } /// Record request duration - pub fn record_request_duration(&self, method: &str, path: &str, status: u16, duration_secs: f64) { + pub fn record_request_duration( + &self, + method: &str, + path: &str, + status: u16, + duration_secs: f64, + ) { self.request_duration_seconds .get_or_create(&RequestLabels { method: method.to_string(), diff --git a/server/src/middleware/security_headers.rs b/server/src/middleware/security_headers.rs index bfde76a..13afbcd 100644 --- a/server/src/middleware/security_headers.rs +++ b/server/src/middleware/security_headers.rs @@ -3,17 +3,10 @@ //! SEC-7: XSS Prevention via Content-Security-Policy //! SEC-12: Additional security headers -use axum::{ - extract::Request, - middleware::Next, - response::Response, -}; +use axum::{extract::Request, middleware::Next, response::Response}; /// Add security headers to all responses -pub async fn add_security_headers( - request: Request, - next: Next, -) -> Response { +pub async fn add_security_headers(request: Request, next: Next) -> Response { let mut response = next.run(request).await; let headers = response.headers_mut(); @@ -35,22 +28,13 @@ pub async fn add_security_headers( ); // SEC-12: X-Frame-Options (Clickjacking protection) - headers.insert( - "X-Frame-Options", - "DENY".parse().unwrap(), - ); + headers.insert("X-Frame-Options", "DENY".parse().unwrap()); // SEC-12: X-Content-Type-Options (MIME sniffing protection) - headers.insert( - "X-Content-Type-Options", - "nosniff".parse().unwrap(), - ); + headers.insert("X-Content-Type-Options", "nosniff".parse().unwrap()); // SEC-12: X-XSS-Protection (Legacy XSS filter - deprecated but still useful) - headers.insert( - "X-XSS-Protection", - "1; mode=block".parse().unwrap(), - ); + headers.insert("X-XSS-Protection", "1; mode=block".parse().unwrap()); // SEC-12: Referrer-Policy (Control referrer information) headers.insert( diff --git a/server/src/relay/mod.rs b/server/src/relay/mod.rs index 531368a..60ee330 100644 --- a/server/src/relay/mod.rs +++ b/server/src/relay/mod.rs @@ -6,21 +6,21 @@ use axum::{ extract::{ ws::{Message, WebSocket, WebSocketUpgrade}, - Query, State, ConnectInfo, + ConnectInfo, Query, State, }, - response::IntoResponse, http::StatusCode, + response::IntoResponse, }; -use std::net::SocketAddr; use futures_util::{SinkExt, StreamExt}; use prost::Message as ProstMessage; use serde::Deserialize; +use std::net::SocketAddr; use tracing::{error, info, warn}; use uuid::Uuid; +use crate::db::{self, Database}; use crate::proto; use crate::session::SessionManager; -use crate::db::{self, Database}; use crate::AppState; #[derive(Debug, Deserialize)] @@ -59,7 +59,11 @@ pub async fn agent_ws_handler( Query(params): Query, ) -> Result { let agent_id = params.agent_id.clone(); - let agent_name = params.hostname.clone().or(params.agent_name.clone()).unwrap_or_else(|| agent_id.clone()); + let agent_name = params + .hostname + .clone() + .or(params.agent_name.clone()) + .unwrap_or_else(|| agent_id.clone()); let support_code = params.support_code.clone(); let api_key = params.api_key.clone(); let client_ip = addr.ip(); @@ -69,7 +73,10 @@ pub async fn agent_ws_handler( // API key = persistent managed agent if support_code.is_none() && api_key.is_none() { - warn!("Agent connection rejected: {} from {} - no support code or API key", agent_id, client_ip); + warn!( + "Agent connection rejected: {} from {} - no support code or API key", + agent_id, client_ip + ); // Log failed connection attempt to database if let Some(ref db) = state.db { @@ -84,7 +91,8 @@ pub async fn agent_ws_handler( "agent_id": agent_id })), Some(client_ip), - ).await; + ) + .await; } return Err(StatusCode::UNAUTHORIZED); @@ -95,7 +103,10 @@ pub async fn agent_ws_handler( // Check if it's a valid, pending support code let code_info = state.support_codes.get_status(code).await; if code_info.is_none() { - warn!("Agent connection rejected: {} from {} - invalid support code {}", agent_id, client_ip, code); + warn!( + "Agent connection rejected: {} from {} - invalid support code {}", + agent_id, client_ip, code + ); // Log failed connection attempt if let Some(ref db) = state.db { @@ -111,14 +122,18 @@ pub async fn agent_ws_handler( "agent_id": agent_id })), Some(client_ip), - ).await; + ) + .await; } return Err(StatusCode::UNAUTHORIZED); } let status = code_info.unwrap(); if status != "pending" && status != "connected" { - warn!("Agent connection rejected: {} from {} - support code {} has status {}", agent_id, client_ip, code, status); + warn!( + "Agent connection rejected: {} from {} - support code {} has status {}", + agent_id, client_ip, code, status + ); // Log failed connection attempt (expired/cancelled code) if let Some(ref db) = state.db { @@ -140,12 +155,16 @@ pub async fn agent_ws_handler( "agent_id": agent_id })), Some(client_ip), - ).await; + ) + .await; } return Err(StatusCode::UNAUTHORIZED); } - info!("Agent {} from {} authenticated via support code {}", agent_id, client_ip, code); + info!( + "Agent {} from {} authenticated via support code {}", + agent_id, client_ip, code + ); } // Validate API key if provided (for persistent agents) @@ -153,7 +172,10 @@ pub async fn agent_ws_handler( // For now, we'll accept API keys that match the JWT secret or a configured agent key // In production, this should validate against a database of registered agents if !validate_agent_api_key(&state, key).await { - warn!("Agent connection rejected: {} from {} - invalid API key", agent_id, client_ip); + warn!( + "Agent connection rejected: {} from {} - invalid API key", + agent_id, client_ip + ); // Log failed connection attempt if let Some(ref db) = state.db { @@ -168,19 +190,34 @@ pub async fn agent_ws_handler( "agent_id": agent_id })), Some(client_ip), - ).await; + ) + .await; } return Err(StatusCode::UNAUTHORIZED); } - info!("Agent {} from {} authenticated via API key", agent_id, client_ip); + info!( + "Agent {} from {} authenticated via API key", + agent_id, client_ip + ); } let sessions = state.sessions.clone(); let support_codes = state.support_codes.clone(); let db = state.db.clone(); - Ok(ws.on_upgrade(move |socket| handle_agent_connection(socket, sessions, support_codes, db, agent_id, agent_name, support_code, Some(client_ip)))) + Ok(ws.on_upgrade(move |socket| { + handle_agent_connection( + socket, + sessions, + support_codes, + db, + agent_id, + agent_name, + support_code, + Some(client_ip), + ) + })) } /// Validate an agent API key @@ -212,24 +249,42 @@ pub async fn viewer_ws_handler( // Require JWT token for viewers let token = params.token.ok_or_else(|| { - warn!("Viewer connection rejected from {}: missing token", client_ip); + warn!( + "Viewer connection rejected from {}: missing token", + client_ip + ); StatusCode::UNAUTHORIZED })?; // Validate the token let claims = state.jwt_config.validate_token(&token).map_err(|e| { - warn!("Viewer connection rejected from {}: invalid token: {}", client_ip, e); + warn!( + "Viewer connection rejected from {}: invalid token: {}", + client_ip, e + ); StatusCode::UNAUTHORIZED })?; - info!("Viewer {} authenticated via JWT from {}", claims.username, client_ip); + info!( + "Viewer {} authenticated via JWT from {}", + claims.username, client_ip + ); let session_id = params.session_id; let viewer_name = params.viewer_name; let sessions = state.sessions.clone(); let db = state.db.clone(); - Ok(ws.on_upgrade(move |socket| handle_viewer_connection(socket, sessions, db, session_id, viewer_name, Some(client_ip)))) + Ok(ws.on_upgrade(move |socket| { + handle_viewer_connection( + socket, + sessions, + db, + session_id, + viewer_name, + Some(client_ip), + ) + })) } /// Handle an agent WebSocket connection @@ -243,7 +298,10 @@ async fn handle_agent_connection( support_code: Option, client_ip: Option, ) { - info!("Agent connected: {} ({}) from {:?}", agent_name, agent_id, client_ip); + info!( + "Agent connected: {} ({}) from {:?}", + agent_name, agent_id, client_ip + ); let (mut ws_sender, mut ws_receiver) = socket.split(); @@ -270,7 +328,9 @@ async fn handle_agent_connection( // Register the agent and get channels // Persistent agents (no support code) keep their session when disconnected let is_persistent = support_code.is_none(); - let (session_id, frame_tx, mut input_rx) = sessions.register_agent(agent_id.clone(), agent_name.clone(), is_persistent).await; + let (session_id, frame_tx, mut input_rx) = sessions + .register_agent(agent_id.clone(), agent_name.clone(), is_persistent) + .await; info!("Session created: {} (agent in idle mode)", session_id); @@ -285,15 +345,20 @@ async fn handle_agent_connection( machine.id, support_code.is_some(), support_code.as_deref(), - ).await; + ) + .await; // Log session started event let _ = db::events::log_event( db.pool(), session_id, db::events::EventTypes::SESSION_STARTED, - None, None, None, client_ip, - ).await; + None, + None, + None, + client_ip, + ) + .await; Some(machine.id) } @@ -309,7 +374,9 @@ async fn handle_agent_connection( // If a support code was provided, mark it as connected if let Some(ref code) = support_code { info!("Linking support code {} to session {}", code, session_id); - support_codes.mark_connected(code, Some(agent_name.clone()), Some(agent_id.clone())).await; + support_codes + .mark_connected(code, Some(agent_name.clone()), Some(agent_id.clone())) + .await; support_codes.link_session(code, session_id).await; // Database: update support code @@ -320,7 +387,8 @@ async fn handle_agent_connection( Some(session_id), Some(&agent_name), Some(&agent_id), - ).await; + ) + .await; } } @@ -333,7 +401,11 @@ async fn handle_agent_connection( let input_forward = tokio::spawn(async move { while let Some(input_data) = input_rx.recv().await { let mut sender = ws_sender_input.lock().await; - if sender.send(Message::Binary(input_data.into())).await.is_err() { + if sender + .send(Message::Binary(input_data.into())) + .await + .is_err() + { break; } } @@ -406,22 +478,29 @@ async fn handle_agent_connection( } else { Some(status.site.clone()) }; - sessions_status.update_agent_status( - session_id, - Some(status.os_version.clone()), - status.is_elevated, - status.uptime_secs, - status.display_count, - status.is_streaming, - agent_version.clone(), - organization.clone(), - site.clone(), - status.tags.clone(), - ).await; + sessions_status + .update_agent_status( + session_id, + Some(status.os_version.clone()), + status.is_elevated, + status.uptime_secs, + status.display_count, + status.is_streaming, + agent_version.clone(), + organization.clone(), + site.clone(), + status.tags.clone(), + ) + .await; // Update version in database if present if let (Some(ref db), Some(ref version)) = (&db, &agent_version) { - let _ = crate::db::releases::update_machine_version(db.pool(), &agent_id, version).await; + let _ = crate::db::releases::update_machine_version( + db.pool(), + &agent_id, + version, + ) + .await; } // Update organization/site/tags in database if present @@ -432,7 +511,8 @@ async fn handle_agent_connection( organization.as_deref(), site.as_deref(), &status.tags, - ).await; + ) + .await; } info!("Agent status update: {} - streaming={}, uptime={}s, version={:?}, org={:?}, site={:?}", @@ -489,8 +569,12 @@ async fn handle_agent_connection( db.pool(), session_id, db::events::EventTypes::SESSION_ENDED, - None, None, None, client_ip, - ).await; + None, + None, + None, + client_ip, + ) + .await; } // Mark support code as completed if one was used (unless cancelled) @@ -532,7 +616,10 @@ async fn handle_viewer_connection( let viewer_id = Uuid::new_v4().to_string(); // Join the session (this sends StartStream to agent if first viewer) - let (mut frame_rx, input_tx) = match sessions.join_session(session_id, viewer_id.clone(), viewer_name.clone()).await { + let (mut frame_rx, input_tx) = match sessions + .join_session(session_id, viewer_id.clone(), viewer_name.clone()) + .await + { Some(channels) => channels, None => { warn!("Session not found: {}", session_id); @@ -540,7 +627,10 @@ async fn handle_viewer_connection( } }; - info!("Viewer {} ({}) joined session: {} from {:?}", viewer_name, viewer_id, session_id, client_ip); + info!( + "Viewer {} ({}) joined session: {} from {:?}", + viewer_name, viewer_id, session_id, client_ip + ); // Database: log viewer joined event if let Some(ref db) = db { @@ -550,8 +640,10 @@ async fn handle_viewer_connection( db::events::EventTypes::VIEWER_JOINED, Some(&viewer_id), Some(&viewer_name), - None, client_ip, - ).await; + None, + client_ip, + ) + .await; } let (mut ws_sender, mut ws_receiver) = socket.split(); @@ -559,7 +651,11 @@ async fn handle_viewer_connection( // Task to forward frames from agent to this viewer let frame_forward = tokio::spawn(async move { while let Ok(frame_data) = frame_rx.recv().await { - if ws_sender.send(Message::Binary(frame_data.into())).await.is_err() { + if ws_sender + .send(Message::Binary(frame_data.into())) + .await + .is_err() + { break; } } @@ -577,9 +673,9 @@ async fn handle_viewer_connection( match proto::Message::decode(data.as_ref()) { Ok(proto_msg) => { match &proto_msg.payload { - Some(proto::message::Payload::MouseEvent(_)) | - Some(proto::message::Payload::KeyEvent(_)) | - Some(proto::message::Payload::SpecialKey(_)) => { + Some(proto::message::Payload::MouseEvent(_)) + | Some(proto::message::Payload::KeyEvent(_)) + | Some(proto::message::Payload::SpecialKey(_)) => { // Forward input to agent let _ = input_tx.send(data.to_vec()).await; } @@ -597,7 +693,10 @@ async fn handle_viewer_connection( } } Ok(Message::Close(_)) => { - info!("Viewer {} disconnected from session: {}", viewer_id, session_id); + info!( + "Viewer {} disconnected from session: {}", + viewer_id, session_id + ); break; } Ok(_) => {} @@ -610,7 +709,9 @@ async fn handle_viewer_connection( // Cleanup (this sends StopStream to agent if last viewer) frame_forward.abort(); - sessions_cleanup.leave_session(session_id, &viewer_id_cleanup).await; + sessions_cleanup + .leave_session(session_id, &viewer_id_cleanup) + .await; // Database: log viewer left event if let Some(ref db) = db { @@ -620,8 +721,10 @@ async fn handle_viewer_connection( db::events::EventTypes::VIEWER_LEFT, Some(&viewer_id_cleanup), Some(&viewer_name_cleanup), - None, client_ip, - ).await; + None, + client_ip, + ) + .await; } info!("Viewer {} left session: {}", viewer_id_cleanup, session_id); diff --git a/server/src/session/mod.rs b/server/src/session/mod.rs index 64cc804..0d05cb0 100644 --- a/server/src/session/mod.rs +++ b/server/src/session/mod.rs @@ -37,20 +37,20 @@ pub struct Session { pub agent_name: String, pub started_at: chrono::DateTime, pub viewer_count: usize, - pub viewers: Vec, // List of connected technicians + pub viewers: Vec, // List of connected technicians pub is_streaming: bool, - pub is_online: bool, // Whether agent is currently connected - pub is_persistent: bool, // Persistent agent (no support code) vs support session + pub is_online: bool, // Whether agent is currently connected + pub is_persistent: bool, // Persistent agent (no support code) vs support session pub last_heartbeat: chrono::DateTime, // Agent status info pub os_version: Option, pub is_elevated: bool, pub uptime_secs: i64, pub display_count: i32, - pub agent_version: Option, // Agent software version - pub organization: Option, // Company/organization name - pub site: Option, // Site/location name - pub tags: Vec, // Tags for categorization + pub agent_version: Option, // Agent software version + pub organization: Option, // Company/organization name + pub site: Option, // Site/location name + pub tags: Vec, // Tags for categorization } /// Channel for sending frames from agent to viewers @@ -92,7 +92,12 @@ impl SessionManager { /// Register a new agent and create a session /// If agent was previously connected (offline session exists), reuse that session - pub async fn register_agent(&self, agent_id: AgentId, agent_name: String, is_persistent: bool) -> (SessionId, FrameSender, InputReceiver) { + pub async fn register_agent( + &self, + agent_id: AgentId, + agent_name: String, + is_persistent: bool, + ) -> (SessionId, FrameSender, InputReceiver) { // Check if this agent already has an offline session (reconnecting) { let agents = self.agents.read().await; @@ -101,7 +106,11 @@ impl SessionManager { if let Some(session_data) = sessions.get_mut(&existing_session_id) { if !session_data.info.is_online { // Reuse existing session - mark as online and create new channels - tracing::info!("Agent {} reconnecting to existing session {}", agent_id, existing_session_id); + tracing::info!( + "Agent {} reconnecting to existing session {}", + agent_id, + existing_session_id + ); let (frame_tx, _) = broadcast::channel(16); let (input_tx, input_rx) = tokio::sync::mpsc::channel(64); @@ -230,7 +239,9 @@ impl SessionManager { let sessions = self.sessions.read().await; sessions .iter() - .filter(|(_, data)| data.last_heartbeat_instant.elapsed().as_secs() > HEARTBEAT_TIMEOUT_SECS) + .filter(|(_, data)| { + data.last_heartbeat_instant.elapsed().as_secs() > HEARTBEAT_TIMEOUT_SECS + }) .map(|(id, _)| *id) .collect() } @@ -251,7 +262,12 @@ impl SessionManager { } /// Join a session as a viewer, returns channels and sends StartStream to agent - pub async fn join_session(&self, session_id: SessionId, viewer_id: ViewerId, viewer_name: String) -> Option<(FrameReceiver, InputSender)> { + pub async fn join_session( + &self, + session_id: SessionId, + viewer_id: ViewerId, + viewer_name: String, + ) -> Option<(FrameReceiver, InputSender)> { let mut sessions = self.sessions.write().await; let session_data = sessions.get_mut(&session_id)?; @@ -274,10 +290,20 @@ impl SessionManager { // If this is the first viewer, send StartStream to agent if was_empty { - tracing::info!("Viewer {} ({}) joined session {}, sending StartStream", viewer_name, viewer_id, session_id); + tracing::info!( + "Viewer {} ({}) joined session {}, sending StartStream", + viewer_name, + viewer_id, + session_id + ); Self::send_start_stream_internal(session_data, &viewer_id).await; } else { - tracing::info!("Viewer {} ({}) joined session {}", viewer_name, viewer_id, session_id); + tracing::info!( + "Viewer {} ({}) joined session {}", + viewer_name, + viewer_id, + session_id + ); } Some((frame_rx, input_tx)) @@ -312,12 +338,20 @@ impl SessionManager { // If no more viewers, send StopStream to agent if session_data.viewers.is_empty() { - tracing::info!("Last viewer {} ({}) left session {}, sending StopStream", - viewer_name.as_deref().unwrap_or("unknown"), viewer_id, session_id); + tracing::info!( + "Last viewer {} ({}) left session {}, sending StopStream", + viewer_name.as_deref().unwrap_or("unknown"), + viewer_id, + session_id + ); Self::send_stop_stream_internal(session_data, viewer_id).await; } else { - tracing::info!("Viewer {} ({}) left session {}", - viewer_name.as_deref().unwrap_or("unknown"), viewer_id, session_id); + tracing::info!( + "Viewer {} ({}) left session {}", + viewer_name.as_deref().unwrap_or("unknown"), + viewer_id, + session_id + ); } } } @@ -347,8 +381,11 @@ impl SessionManager { if let Some(session_data) = sessions.get_mut(&session_id) { if session_data.info.is_persistent { // Persistent agent - keep session but mark as offline - tracing::info!("Persistent agent {} marked offline (session {} preserved)", - session_data.info.agent_id, session_id); + tracing::info!( + "Persistent agent {} marked offline (session {} preserved)", + session_data.info.agent_id, + session_id + ); session_data.info.is_online = false; session_data.info.is_streaming = false; session_data.info.viewer_count = 0; @@ -410,7 +447,12 @@ impl SessionManager { /// Send an admin command to an agent (uninstall, restart, etc.) /// Returns true if the message was sent successfully - pub async fn send_admin_command(&self, session_id: SessionId, command: crate::proto::AdminCommandType, reason: &str) -> bool { + pub async fn send_admin_command( + &self, + session_id: SessionId, + command: crate::proto::AdminCommandType, + reason: &str, + ) -> bool { let sessions = self.sessions.read().await; if let Some(session_data) = sessions.get(&session_id) { if !session_data.info.is_online { @@ -471,7 +513,7 @@ impl SessionManager { viewer_count: 0, viewers: Vec::new(), is_streaming: false, - is_online: false, // Offline until agent reconnects + is_online: false, // Offline until agent reconnects is_persistent: true, last_heartbeat: now, os_version: None, diff --git a/server/src/support_codes.rs b/server/src/support_codes.rs index 8cf98ac..c10e82a 100644 --- a/server/src/support_codes.rs +++ b/server/src/support_codes.rs @@ -3,12 +3,12 @@ //! Handles generation and validation of 6-digit support codes //! for one-time remote support sessions. -use std::collections::HashMap; -use std::sync::Arc; -use tokio::sync::RwLock; use chrono::{DateTime, Utc}; use rand::Rng; use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::RwLock; use uuid::Uuid; /// A support session code @@ -27,10 +27,10 @@ pub struct SupportCode { #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[serde(rename_all = "lowercase")] pub enum CodeStatus { - Pending, // Waiting for client to connect - Connected, // Client connected, session active - Completed, // Session ended normally - Cancelled, // Code cancelled by tech + Pending, // Waiting for client to connect + Connected, // Client connected, session active + Completed, // Session ended normally + Cancelled, // Code cancelled by tech } /// Request to create a new support code @@ -69,11 +69,11 @@ impl SupportCodeManager { async fn generate_unique_code(&self) -> String { let codes = self.codes.read().await; let mut rng = rand::thread_rng(); - + loop { let code: u32 = rng.gen_range(100000..999999); let code_str = code.to_string(); - + if !codes.contains_key(&code_str) { return code_str; } @@ -84,11 +84,13 @@ impl SupportCodeManager { pub async fn create_code(&self, request: CreateCodeRequest) -> SupportCode { let code = self.generate_unique_code().await; let session_id = Uuid::new_v4(); - + let support_code = SupportCode { code: code.clone(), session_id, - created_by: request.technician_name.unwrap_or_else(|| "Unknown".to_string()), + created_by: request + .technician_name + .unwrap_or_else(|| "Unknown".to_string()), created_at: Utc::now(), status: CodeStatus::Pending, client_name: None, @@ -108,10 +110,12 @@ impl SupportCodeManager { /// Validate a code and return session info pub async fn validate_code(&self, code: &str) -> CodeValidation { let codes = self.codes.read().await; - + match codes.get(code) { Some(support_code) => { - if support_code.status == CodeStatus::Pending || support_code.status == CodeStatus::Connected { + if support_code.status == CodeStatus::Pending + || support_code.status == CodeStatus::Connected + { CodeValidation { valid: true, session_id: Some(support_code.session_id.to_string()), @@ -137,7 +141,12 @@ impl SupportCodeManager { } /// Mark a code as connected - pub async fn mark_connected(&self, code: &str, client_name: Option, client_machine: Option) { + pub async fn mark_connected( + &self, + code: &str, + client_name: Option, + client_machine: Option, + ) { let mut codes = self.codes.write().await; if let Some(support_code) = codes.get_mut(code) { support_code.status = CodeStatus::Connected; @@ -180,7 +189,9 @@ impl SupportCodeManager { pub async fn cancel_code(&self, code: &str) -> bool { let mut codes = self.codes.write().await; if let Some(support_code) = codes.get_mut(code) { - if support_code.status == CodeStatus::Pending || support_code.status == CodeStatus::Connected { + if support_code.status == CodeStatus::Pending + || support_code.status == CodeStatus::Connected + { support_code.status = CodeStatus::Cancelled; return true; } @@ -191,13 +202,19 @@ impl SupportCodeManager { /// Check if a code is cancelled pub async fn is_cancelled(&self, code: &str) -> bool { let codes = self.codes.read().await; - codes.get(code).map(|c| c.status == CodeStatus::Cancelled).unwrap_or(false) + codes + .get(code) + .map(|c| c.status == CodeStatus::Cancelled) + .unwrap_or(false) } /// Check if a code is valid for connection (exists and is pending) pub async fn is_valid_for_connection(&self, code: &str) -> bool { let codes = self.codes.read().await; - codes.get(code).map(|c| c.status == CodeStatus::Pending).unwrap_or(false) + codes + .get(code) + .map(|c| c.status == CodeStatus::Pending) + .unwrap_or(false) } /// List all codes (for dashboard) @@ -209,7 +226,8 @@ impl SupportCodeManager { /// List active codes only pub async fn list_active_codes(&self) -> Vec { let codes = self.codes.read().await; - codes.values() + codes + .values() .filter(|c| c.status == CodeStatus::Pending || c.status == CodeStatus::Connected) .cloned() .collect() diff --git a/server/src/utils/validation.rs b/server/src/utils/validation.rs index a5edcd9..84689a3 100644 --- a/server/src/utils/validation.rs +++ b/server/src/utils/validation.rs @@ -11,18 +11,29 @@ use anyhow::{anyhow, Result}; pub fn validate_api_key_strength(api_key: &str) -> Result<()> { // Minimum length check if api_key.len() < 32 { - return Err(anyhow!("API key must be at least 32 characters long for security")); + return Err(anyhow!( + "API key must be at least 32 characters long for security" + )); } // Check for common weak keys let weak_keys = [ - "password", "12345", "admin", "test", "api_key", - "secret", "changeme", "default", "guruconnect" + "password", + "12345", + "admin", + "test", + "api_key", + "secret", + "changeme", + "default", + "guruconnect", ]; let lowercase_key = api_key.to_lowercase(); for weak in &weak_keys { if lowercase_key.contains(weak) { - return Err(anyhow!("API key contains weak/common patterns and is not secure")); + return Err(anyhow!( + "API key contains weak/common patterns and is not secure" + )); } } @@ -53,6 +64,9 @@ mod tests { assert!(validate_api_key_strength("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa").is_err()); // Good key - assert!(validate_api_key_strength("KfPrjjC3J6YMx9q1yjPxZAYkHLM2JdFy1XRxHJ9oPnw0NU3xH074ufHk7fj").is_ok()); + assert!(validate_api_key_strength( + "KfPrjjC3J6YMx9q1yjPxZAYkHLM2JdFy1XRxHJ9oPnw0NU3xH074ufHk7fj" + ) + .is_ok()); } }