style: cargo fmt --all — make codebase rustfmt-clean
Some checks failed
Build and Test / Build Server (Linux) (push) Failing after 2m59s
Build and Test / Build Agent (Windows) (push) Has started running
Build and Test / Security Audit (push) Has been cancelled
Build and Test / Build Summary (push) Has been cancelled
Run Tests / Test Server (push) Has been cancelled
Run Tests / Test Agent (push) Has been cancelled
Run Tests / Code Coverage (push) Has been cancelled
Run Tests / Lint and Format Check (push) Has been cancelled
Some checks failed
Build and Test / Build Server (Linux) (push) Failing after 2m59s
Build and Test / Build Agent (Windows) (push) Has started running
Build and Test / Security Audit (push) Has been cancelled
Build and Test / Build Summary (push) Has been cancelled
Run Tests / Test Server (push) Has been cancelled
Run Tests / Test Agent (push) Has been cancelled
Run Tests / Code Coverage (push) Has been cancelled
Run Tests / Lint and Format Check (push) Has been cancelled
First run of the build-and-test CI gate (cargo fmt --all -- --check) surfaced pre-existing formatting drift across the agent and server crates. Apply rustfmt across the workspace so the codebase meets its own CI gate. Pure formatting; no logic changes. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -13,7 +13,9 @@ fn main() -> Result<()> {
|
||||
println!("cargo:rerun-if-changed=../.git/index");
|
||||
|
||||
// Build timestamp (UTC)
|
||||
let build_timestamp = chrono::Utc::now().format("%Y-%m-%d %H:%M:%S UTC").to_string();
|
||||
let build_timestamp = chrono::Utc::now()
|
||||
.format("%Y-%m-%d %H:%M:%S UTC")
|
||||
.to_string();
|
||||
println!("cargo:rustc-env=BUILD_TIMESTAMP={}", build_timestamp);
|
||||
|
||||
// Git commit hash (short)
|
||||
@@ -53,7 +55,10 @@ fn main() -> Result<()> {
|
||||
.ok()
|
||||
.map(|o| !o.stdout.is_empty())
|
||||
.unwrap_or(false);
|
||||
println!("cargo:rustc-env=GIT_DIRTY={}", if git_dirty { "dirty" } else { "clean" });
|
||||
println!(
|
||||
"cargo:rustc-env=GIT_DIRTY={}",
|
||||
if git_dirty { "dirty" } else { "clean" }
|
||||
);
|
||||
|
||||
// Git commit date
|
||||
let git_commit_date = Command::new("git")
|
||||
|
||||
@@ -25,7 +25,8 @@ use windows_service::{
|
||||
// Service configuration
|
||||
const SERVICE_NAME: &str = "GuruConnectSAS";
|
||||
const SERVICE_DISPLAY_NAME: &str = "GuruConnect SAS Service";
|
||||
const SERVICE_DESCRIPTION: &str = "Handles Secure Attention Sequence (Ctrl+Alt+Del) for GuruConnect remote sessions";
|
||||
const SERVICE_DESCRIPTION: &str =
|
||||
"Handles Secure Attention Sequence (Ctrl+Alt+Del) for GuruConnect remote sessions";
|
||||
const PIPE_NAME: &str = r"\\.\pipe\guruconnect-sas";
|
||||
const INSTALL_DIR: &str = r"C:\Program Files\GuruConnect";
|
||||
|
||||
@@ -360,18 +361,16 @@ fn run_pipe_server() -> Result<()> {
|
||||
tracing::info!("Received command: {}", command);
|
||||
|
||||
let response = match command {
|
||||
"sas" => {
|
||||
match send_sas() {
|
||||
Ok(()) => {
|
||||
tracing::info!("SendSAS executed successfully");
|
||||
"ok\n"
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("SendSAS failed: {}", e);
|
||||
"error\n"
|
||||
}
|
||||
"sas" => match send_sas() {
|
||||
Ok(()) => {
|
||||
tracing::info!("SendSAS executed successfully");
|
||||
"ok\n"
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("SendSAS failed: {}", e);
|
||||
"error\n"
|
||||
}
|
||||
},
|
||||
"ping" => {
|
||||
tracing::info!("Ping received");
|
||||
"pong\n"
|
||||
@@ -432,7 +431,8 @@ fn install_service() -> Result<()> {
|
||||
// Get current executable path
|
||||
let current_exe = std::env::current_exe().context("Failed to get current executable")?;
|
||||
|
||||
let binary_dest = std::path::PathBuf::from(format!(r"{}\\guruconnect-sas-service.exe", INSTALL_DIR));
|
||||
let binary_dest =
|
||||
std::path::PathBuf::from(format!(r"{}\\guruconnect-sas-service.exe", INSTALL_DIR));
|
||||
|
||||
// Create install directory
|
||||
std::fs::create_dir_all(INSTALL_DIR).context("Failed to create install directory")?;
|
||||
@@ -462,7 +462,9 @@ fn install_service() -> Result<()> {
|
||||
}
|
||||
}
|
||||
|
||||
service.delete().context("Failed to delete existing service")?;
|
||||
service
|
||||
.delete()
|
||||
.context("Failed to delete existing service")?;
|
||||
drop(service);
|
||||
std::thread::sleep(Duration::from_secs(2));
|
||||
}
|
||||
@@ -482,7 +484,10 @@ fn install_service() -> Result<()> {
|
||||
};
|
||||
|
||||
let service = manager
|
||||
.create_service(&service_info, ServiceAccess::CHANGE_CONFIG | ServiceAccess::START)
|
||||
.create_service(
|
||||
&service_info,
|
||||
ServiceAccess::CHANGE_CONFIG | ServiceAccess::START,
|
||||
)
|
||||
.context("Failed to create service")?;
|
||||
|
||||
// Set description
|
||||
@@ -514,13 +519,11 @@ fn install_service() -> Result<()> {
|
||||
fn uninstall_service() -> Result<()> {
|
||||
println!("Uninstalling GuruConnect SAS Service...");
|
||||
|
||||
let binary_path = std::path::PathBuf::from(format!(r"{}\\guruconnect-sas-service.exe", INSTALL_DIR));
|
||||
let binary_path =
|
||||
std::path::PathBuf::from(format!(r"{}\\guruconnect-sas-service.exe", INSTALL_DIR));
|
||||
|
||||
let manager = ServiceManager::local_computer(
|
||||
None::<&str>,
|
||||
ServiceManagerAccess::CONNECT,
|
||||
)
|
||||
.context("Failed to connect to Service Control Manager. Run as Administrator.")?;
|
||||
let manager = ServiceManager::local_computer(None::<&str>, ServiceManagerAccess::CONNECT)
|
||||
.context("Failed to connect to Service Control Manager. Run as Administrator.")?;
|
||||
|
||||
match manager.open_service(
|
||||
SERVICE_NAME,
|
||||
@@ -558,17 +561,19 @@ fn uninstall_service() -> Result<()> {
|
||||
|
||||
/// Start the service
|
||||
fn start_service() -> Result<()> {
|
||||
let manager = ServiceManager::local_computer(
|
||||
None::<&str>,
|
||||
ServiceManagerAccess::CONNECT,
|
||||
)
|
||||
.context("Failed to connect to Service Control Manager")?;
|
||||
let manager = ServiceManager::local_computer(None::<&str>, ServiceManagerAccess::CONNECT)
|
||||
.context("Failed to connect to Service Control Manager")?;
|
||||
|
||||
let service = manager
|
||||
.open_service(SERVICE_NAME, ServiceAccess::START | ServiceAccess::QUERY_STATUS)
|
||||
.open_service(
|
||||
SERVICE_NAME,
|
||||
ServiceAccess::START | ServiceAccess::QUERY_STATUS,
|
||||
)
|
||||
.context("Failed to open service. Is it installed?")?;
|
||||
|
||||
service.start::<String>(&[]).context("Failed to start service")?;
|
||||
service
|
||||
.start::<String>(&[])
|
||||
.context("Failed to start service")?;
|
||||
|
||||
std::thread::sleep(Duration::from_secs(1));
|
||||
|
||||
@@ -584,14 +589,14 @@ fn start_service() -> Result<()> {
|
||||
|
||||
/// Stop the service
|
||||
fn stop_service() -> Result<()> {
|
||||
let manager = ServiceManager::local_computer(
|
||||
None::<&str>,
|
||||
ServiceManagerAccess::CONNECT,
|
||||
)
|
||||
.context("Failed to connect to Service Control Manager")?;
|
||||
let manager = ServiceManager::local_computer(None::<&str>, ServiceManagerAccess::CONNECT)
|
||||
.context("Failed to connect to Service Control Manager")?;
|
||||
|
||||
let service = manager
|
||||
.open_service(SERVICE_NAME, ServiceAccess::STOP | ServiceAccess::QUERY_STATUS)
|
||||
.open_service(
|
||||
SERVICE_NAME,
|
||||
ServiceAccess::STOP | ServiceAccess::QUERY_STATUS,
|
||||
)
|
||||
.context("Failed to open service")?;
|
||||
|
||||
service.stop().context("Failed to stop service")?;
|
||||
@@ -610,11 +615,8 @@ fn stop_service() -> Result<()> {
|
||||
|
||||
/// Query service status
|
||||
fn query_status() -> Result<()> {
|
||||
let manager = ServiceManager::local_computer(
|
||||
None::<&str>,
|
||||
ServiceManagerAccess::CONNECT,
|
||||
)
|
||||
.context("Failed to connect to Service Control Manager")?;
|
||||
let manager = ServiceManager::local_computer(None::<&str>, ServiceManagerAccess::CONNECT)
|
||||
.context("Failed to connect to Service Control Manager")?;
|
||||
|
||||
match manager.open_service(SERVICE_NAME, ServiceAccess::QUERY_STATUS) {
|
||||
Ok(service) => {
|
||||
|
||||
@@ -53,11 +53,11 @@ impl Display {
|
||||
/// Enumerate all connected displays
|
||||
#[cfg(windows)]
|
||||
pub fn enumerate_displays() -> Result<Vec<Display>> {
|
||||
use std::mem;
|
||||
use windows::Win32::Foundation::{BOOL, LPARAM, RECT};
|
||||
use windows::Win32::Graphics::Gdi::{
|
||||
EnumDisplayMonitors, GetMonitorInfoW, HMONITOR, MONITORINFOEXW,
|
||||
};
|
||||
use windows::Win32::Foundation::{BOOL, LPARAM, RECT};
|
||||
use std::mem;
|
||||
|
||||
let mut displays = Vec::new();
|
||||
let mut display_id = 0u32;
|
||||
@@ -98,7 +98,11 @@ pub fn enumerate_displays() -> Result<Vec<Display>> {
|
||||
if GetMonitorInfoW(hmonitor, &mut info.monitorInfo as *mut _ as *mut _).as_bool() {
|
||||
let rect = info.monitorInfo.rcMonitor;
|
||||
let name = String::from_utf16_lossy(
|
||||
&info.szDevice[..info.szDevice.iter().position(|&c| c == 0).unwrap_or(info.szDevice.len())]
|
||||
&info.szDevice[..info
|
||||
.szDevice
|
||||
.iter()
|
||||
.position(|&c| c == 0)
|
||||
.unwrap_or(info.szDevice.len())],
|
||||
);
|
||||
|
||||
let is_primary = (info.monitorInfo.dwFlags & 1) != 0; // MONITORINFOF_PRIMARY
|
||||
|
||||
@@ -10,19 +10,18 @@ use anyhow::{Context, Result};
|
||||
use std::ptr;
|
||||
use std::time::Instant;
|
||||
|
||||
use windows::core::Interface;
|
||||
use windows::Win32::Graphics::Direct3D::D3D_DRIVER_TYPE_UNKNOWN;
|
||||
use windows::Win32::Graphics::Direct3D11::{
|
||||
D3D11CreateDevice, ID3D11Device, ID3D11DeviceContext, ID3D11Texture2D,
|
||||
D3D11_SDK_VERSION, D3D11_TEXTURE2D_DESC,
|
||||
D3D11_USAGE_STAGING, D3D11_MAPPED_SUBRESOURCE, D3D11_MAP_READ,
|
||||
D3D11_MAPPED_SUBRESOURCE, D3D11_MAP_READ, D3D11_SDK_VERSION, D3D11_TEXTURE2D_DESC,
|
||||
D3D11_USAGE_STAGING,
|
||||
};
|
||||
use windows::Win32::Graphics::Dxgi::{
|
||||
CreateDXGIFactory1, IDXGIAdapter1, IDXGIFactory1, IDXGIOutput, IDXGIOutput1,
|
||||
IDXGIOutputDuplication, IDXGIResource, DXGI_ERROR_ACCESS_LOST,
|
||||
DXGI_ERROR_WAIT_TIMEOUT, DXGI_OUTDUPL_DESC, DXGI_OUTDUPL_FRAME_INFO,
|
||||
DXGI_RESOURCE_PRIORITY_MAXIMUM,
|
||||
IDXGIOutputDuplication, IDXGIResource, DXGI_ERROR_ACCESS_LOST, DXGI_ERROR_WAIT_TIMEOUT,
|
||||
DXGI_OUTDUPL_DESC, DXGI_OUTDUPL_FRAME_INFO, DXGI_RESOURCE_PRIORITY_MAXIMUM,
|
||||
};
|
||||
use windows::core::Interface;
|
||||
|
||||
/// DXGI Desktop Duplication capturer
|
||||
pub struct DxgiCapturer {
|
||||
@@ -56,11 +55,16 @@ impl DxgiCapturer {
|
||||
/// Create D3D device and output duplication
|
||||
fn create_duplication(
|
||||
target_display: &Display,
|
||||
) -> Result<(ID3D11Device, ID3D11DeviceContext, IDXGIOutputDuplication, DXGI_OUTDUPL_DESC)> {
|
||||
) -> Result<(
|
||||
ID3D11Device,
|
||||
ID3D11DeviceContext,
|
||||
IDXGIOutputDuplication,
|
||||
DXGI_OUTDUPL_DESC,
|
||||
)> {
|
||||
unsafe {
|
||||
// Create DXGI factory
|
||||
let factory: IDXGIFactory1 = CreateDXGIFactory1()
|
||||
.context("Failed to create DXGI factory")?;
|
||||
let factory: IDXGIFactory1 =
|
||||
CreateDXGIFactory1().context("Failed to create DXGI factory")?;
|
||||
|
||||
// Find the adapter and output for this display
|
||||
let (adapter, output) = Self::find_adapter_output(&factory, target_display)?;
|
||||
@@ -86,11 +90,13 @@ impl DxgiCapturer {
|
||||
let context = context.context("D3D11 context is None")?;
|
||||
|
||||
// Get IDXGIOutput1 interface
|
||||
let output1: IDXGIOutput1 = output.cast()
|
||||
let output1: IDXGIOutput1 = output
|
||||
.cast()
|
||||
.context("Failed to get IDXGIOutput1 interface")?;
|
||||
|
||||
// Create output duplication
|
||||
let duplication = output1.DuplicateOutput(&device)
|
||||
let duplication = output1
|
||||
.DuplicateOutput(&device)
|
||||
.context("Failed to create output duplication")?;
|
||||
|
||||
// Get duplication description
|
||||
@@ -135,7 +141,11 @@ impl DxgiCapturer {
|
||||
let desc = output.GetDesc()?;
|
||||
|
||||
let name = String::from_utf16_lossy(
|
||||
&desc.DeviceName[..desc.DeviceName.iter().position(|&c| c == 0).unwrap_or(desc.DeviceName.len())]
|
||||
&desc.DeviceName[..desc
|
||||
.DeviceName
|
||||
.iter()
|
||||
.position(|&c| c == 0)
|
||||
.unwrap_or(desc.DeviceName.len())],
|
||||
);
|
||||
|
||||
if name == display.name || desc.Monitor.0 as isize == display.handle {
|
||||
@@ -149,10 +159,8 @@ impl DxgiCapturer {
|
||||
}
|
||||
|
||||
// If we didn't find the specific display, use the first one
|
||||
let adapter: IDXGIAdapter1 = factory.EnumAdapters1(0)
|
||||
.context("No adapters found")?;
|
||||
let output: IDXGIOutput = adapter.EnumOutputs(0)
|
||||
.context("No outputs found")?;
|
||||
let adapter: IDXGIAdapter1 = factory.EnumAdapters1(0).context("No adapters found")?;
|
||||
let output: IDXGIOutput = adapter.EnumOutputs(0).context("No outputs found")?;
|
||||
|
||||
Ok((adapter, output))
|
||||
}
|
||||
@@ -171,7 +179,8 @@ impl DxgiCapturer {
|
||||
desc.MiscFlags = Default::default();
|
||||
|
||||
let mut staging: Option<ID3D11Texture2D> = None;
|
||||
self.device.CreateTexture2D(&desc, None, Some(&mut staging))
|
||||
self.device
|
||||
.CreateTexture2D(&desc, None, Some(&mut staging))
|
||||
.context("Failed to create staging texture")?;
|
||||
|
||||
let staging = staging.context("Staging texture is None")?;
|
||||
@@ -188,7 +197,10 @@ impl DxgiCapturer {
|
||||
}
|
||||
|
||||
/// Acquire the next frame from the desktop
|
||||
fn acquire_frame(&mut self, timeout_ms: u32) -> Result<Option<(ID3D11Texture2D, DXGI_OUTDUPL_FRAME_INFO)>> {
|
||||
fn acquire_frame(
|
||||
&mut self,
|
||||
timeout_ms: u32,
|
||||
) -> Result<Option<(ID3D11Texture2D, DXGI_OUTDUPL_FRAME_INFO)>> {
|
||||
unsafe {
|
||||
let mut frame_info = DXGI_OUTDUPL_FRAME_INFO::default();
|
||||
let mut desktop_resource: Option<IDXGIResource> = None;
|
||||
@@ -209,7 +221,8 @@ impl DxgiCapturer {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let texture: ID3D11Texture2D = resource.cast()
|
||||
let texture: ID3D11Texture2D = resource
|
||||
.cast()
|
||||
.context("Failed to cast to ID3D11Texture2D")?;
|
||||
|
||||
Ok(Some((texture, frame_info)))
|
||||
@@ -223,9 +236,7 @@ impl DxgiCapturer {
|
||||
tracing::warn!("Desktop duplication access lost, will need to recreate");
|
||||
Err(anyhow::anyhow!("Access lost"))
|
||||
}
|
||||
Err(e) => {
|
||||
Err(e).context("Failed to acquire frame")
|
||||
}
|
||||
Err(e) => Err(e).context("Failed to acquire frame"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,12 +7,11 @@ use super::{CapturedFrame, Capturer, Display};
|
||||
use anyhow::Result;
|
||||
use std::time::Instant;
|
||||
|
||||
use windows::Win32::Graphics::Gdi::{
|
||||
BitBlt, CreateCompatibleBitmap, CreateCompatibleDC, DeleteDC, DeleteObject,
|
||||
GetDIBits, SelectObject, BITMAPINFO, BITMAPINFOHEADER, BI_RGB, DIB_RGB_COLORS,
|
||||
SRCCOPY, GetDC, ReleaseDC,
|
||||
};
|
||||
use windows::Win32::Foundation::HWND;
|
||||
use windows::Win32::Graphics::Gdi::{
|
||||
BitBlt, CreateCompatibleBitmap, CreateCompatibleDC, DeleteDC, DeleteObject, GetDC, GetDIBits,
|
||||
ReleaseDC, SelectObject, BITMAPINFO, BITMAPINFOHEADER, BI_RGB, DIB_RGB_COLORS, SRCCOPY,
|
||||
};
|
||||
|
||||
/// GDI-based screen capturer
|
||||
pub struct GdiCapturer {
|
||||
|
||||
@@ -3,11 +3,11 @@
|
||||
//! Provides DXGI Desktop Duplication for high-performance screen capture on Windows 8+,
|
||||
//! with GDI fallback for legacy systems or edge cases.
|
||||
|
||||
mod display;
|
||||
#[cfg(windows)]
|
||||
mod dxgi;
|
||||
#[cfg(windows)]
|
||||
mod gdi;
|
||||
mod display;
|
||||
|
||||
pub use display::{Display, DisplayInfo};
|
||||
|
||||
@@ -61,7 +61,11 @@ pub trait Capturer: Send {
|
||||
|
||||
/// Create a capturer for the specified display
|
||||
#[cfg(windows)]
|
||||
pub fn create_capturer(display: Display, use_dxgi: bool, gdi_fallback: bool) -> Result<Box<dyn Capturer>> {
|
||||
pub fn create_capturer(
|
||||
display: Display,
|
||||
use_dxgi: bool,
|
||||
gdi_fallback: bool,
|
||||
) -> Result<Box<dyn Capturer>> {
|
||||
if use_dxgi {
|
||||
match dxgi::DxgiCapturer::new(display.clone()) {
|
||||
Ok(capturer) => {
|
||||
@@ -83,7 +87,11 @@ pub fn create_capturer(display: Display, use_dxgi: bool, gdi_fallback: bool) ->
|
||||
}
|
||||
|
||||
#[cfg(not(windows))]
|
||||
pub fn create_capturer(_display: Display, _use_dxgi: bool, _gdi_fallback: bool) -> Result<Box<dyn Capturer>> {
|
||||
pub fn create_capturer(
|
||||
_display: Display,
|
||||
_use_dxgi: bool,
|
||||
_gdi_fallback: bool,
|
||||
) -> Result<Box<dyn Capturer>> {
|
||||
anyhow::bail!("Screen capture only supported on Windows")
|
||||
}
|
||||
|
||||
|
||||
@@ -6,10 +6,10 @@
|
||||
use std::sync::mpsc::{self, Receiver, Sender};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::thread;
|
||||
use tracing::{info, warn, error};
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
#[cfg(windows)]
|
||||
use windows::Win32::UI::WindowsAndMessaging::*;
|
||||
use windows::core::PCWSTR;
|
||||
#[cfg(windows)]
|
||||
use windows::Win32::Foundation::*;
|
||||
#[cfg(windows)]
|
||||
@@ -17,7 +17,7 @@ use windows::Win32::Graphics::Gdi::*;
|
||||
#[cfg(windows)]
|
||||
use windows::Win32::System::LibraryLoader::GetModuleHandleW;
|
||||
#[cfg(windows)]
|
||||
use windows::core::PCWSTR;
|
||||
use windows::Win32::UI::WindowsAndMessaging::*;
|
||||
|
||||
/// A chat message
|
||||
#[derive(Debug, Clone)]
|
||||
|
||||
@@ -221,11 +221,10 @@ impl Config {
|
||||
|
||||
/// Read embedded configuration from the executable
|
||||
pub fn read_embedded_config() -> Result<EmbeddedConfig> {
|
||||
let exe_path = std::env::current_exe()
|
||||
.context("Failed to get current executable path")?;
|
||||
let exe_path = std::env::current_exe().context("Failed to get current executable path")?;
|
||||
|
||||
let mut file = std::fs::File::open(&exe_path)
|
||||
.context("Failed to open executable for reading")?;
|
||||
let mut file =
|
||||
std::fs::File::open(&exe_path).context("Failed to open executable for reading")?;
|
||||
|
||||
let file_size = file.metadata()?.len();
|
||||
if file_size < (MAGIC_MARKER.len() + 4) as u64 {
|
||||
@@ -245,7 +244,8 @@ impl Config {
|
||||
file.read_exact(&mut buffer)?;
|
||||
|
||||
// Find magic marker
|
||||
let marker_pos = buffer.windows(MAGIC_MARKER.len())
|
||||
let marker_pos = buffer
|
||||
.windows(MAGIC_MARKER.len())
|
||||
.rposition(|window| window == MAGIC_MARKER)
|
||||
.ok_or_else(|| anyhow!("Magic marker not found"))?;
|
||||
|
||||
@@ -269,11 +269,13 @@ impl Config {
|
||||
}
|
||||
|
||||
let config_bytes = &buffer[config_start..config_start + config_length];
|
||||
let config: EmbeddedConfig = serde_json::from_slice(config_bytes)
|
||||
.context("Failed to parse embedded config JSON")?;
|
||||
let config: EmbeddedConfig =
|
||||
serde_json::from_slice(config_bytes).context("Failed to parse embedded config JSON")?;
|
||||
|
||||
info!("Loaded embedded config: server={}, company={:?}",
|
||||
config.server_url, config.company);
|
||||
info!(
|
||||
"Loaded embedded config: server={}, company={:?}",
|
||||
config.server_url, config.company
|
||||
);
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
@@ -338,8 +340,8 @@ impl Config {
|
||||
let contents = std::fs::read_to_string(&config_path)
|
||||
.with_context(|| format!("Failed to read config from {:?}", config_path))?;
|
||||
|
||||
let mut config: Config = toml::from_str(&contents)
|
||||
.with_context(|| "Failed to parse config file")?;
|
||||
let mut config: Config =
|
||||
toml::from_str(&contents).with_context(|| "Failed to parse config file")?;
|
||||
|
||||
// Ensure agent_id is set and saved
|
||||
if config.agent_id.is_empty() {
|
||||
@@ -357,11 +359,11 @@ impl Config {
|
||||
let server_url = std::env::var("GURUCONNECT_SERVER_URL")
|
||||
.unwrap_or_else(|_| "wss://connect.azcomputerguru.com/ws/agent".to_string());
|
||||
|
||||
let api_key = std::env::var("GURUCONNECT_API_KEY")
|
||||
.unwrap_or_else(|_| "dev-key".to_string());
|
||||
let api_key =
|
||||
std::env::var("GURUCONNECT_API_KEY").unwrap_or_else(|_| "dev-key".to_string());
|
||||
|
||||
let agent_id = std::env::var("GURUCONNECT_AGENT_ID")
|
||||
.unwrap_or_else(|_| generate_agent_id());
|
||||
let agent_id =
|
||||
std::env::var("GURUCONNECT_AGENT_ID").unwrap_or_else(|_| generate_agent_id());
|
||||
|
||||
let config = Config {
|
||||
server_url,
|
||||
@@ -409,13 +411,11 @@ impl Config {
|
||||
|
||||
/// Get the hostname to use
|
||||
pub fn hostname(&self) -> String {
|
||||
self.hostname_override
|
||||
.clone()
|
||||
.unwrap_or_else(|| {
|
||||
hostname::get()
|
||||
.map(|h| h.to_string_lossy().to_string())
|
||||
.unwrap_or_else(|_| "unknown".to_string())
|
||||
})
|
||||
self.hostname_override.clone().unwrap_or_else(|| {
|
||||
hostname::get()
|
||||
.map(|h| h.to_string_lossy().to_string())
|
||||
.unwrap_or_else(|_| "unknown".to_string())
|
||||
})
|
||||
}
|
||||
|
||||
/// Save current configuration to file
|
||||
|
||||
@@ -10,7 +10,7 @@ mod raw;
|
||||
pub use raw::RawEncoder;
|
||||
|
||||
use crate::capture::CapturedFrame;
|
||||
use crate::proto::{VideoFrame, RawFrame, DirtyRect as ProtoDirtyRect};
|
||||
use crate::proto::{DirtyRect as ProtoDirtyRect, RawFrame, VideoFrame};
|
||||
use anyhow::Result;
|
||||
|
||||
/// Encoded frame ready for transmission
|
||||
|
||||
@@ -122,12 +122,7 @@ impl RawEncoder {
|
||||
}
|
||||
|
||||
/// Extract pixels for dirty rectangles only
|
||||
fn extract_dirty_pixels(
|
||||
&self,
|
||||
data: &[u8],
|
||||
width: u32,
|
||||
dirty_rects: &[DirtyRect],
|
||||
) -> Vec<u8> {
|
||||
fn extract_dirty_pixels(&self, data: &[u8], width: u32, dirty_rects: &[DirtyRect]) -> Vec<u8> {
|
||||
let stride = (width * 4) as usize;
|
||||
let mut pixels = Vec::new();
|
||||
|
||||
@@ -174,7 +169,8 @@ impl Encoder for RawEncoder {
|
||||
if dirty_rects.len() > 50 {
|
||||
(frame.data.clone(), Vec::new(), true)
|
||||
} else {
|
||||
let dirty_pixels = self.extract_dirty_pixels(&frame.data, frame.width, &dirty_rects);
|
||||
let dirty_pixels =
|
||||
self.extract_dirty_pixels(&frame.data, frame.width, &dirty_rects);
|
||||
(dirty_pixels, dirty_rects, false)
|
||||
}
|
||||
} else {
|
||||
|
||||
@@ -4,9 +4,9 @@ use anyhow::Result;
|
||||
|
||||
#[cfg(windows)]
|
||||
use windows::Win32::UI::Input::KeyboardAndMouse::{
|
||||
SendInput, INPUT, INPUT_0, INPUT_KEYBOARD, KEYBD_EVENT_FLAGS, KEYEVENTF_EXTENDEDKEY,
|
||||
KEYEVENTF_KEYUP, KEYEVENTF_SCANCODE, KEYEVENTF_UNICODE, KEYBDINPUT,
|
||||
MapVirtualKeyW, MAPVK_VK_TO_VSC_EX,
|
||||
MapVirtualKeyW, SendInput, INPUT, INPUT_0, INPUT_KEYBOARD, KEYBDINPUT, KEYBD_EVENT_FLAGS,
|
||||
KEYEVENTF_EXTENDEDKEY, KEYEVENTF_KEYUP, KEYEVENTF_SCANCODE, KEYEVENTF_UNICODE,
|
||||
MAPVK_VK_TO_VSC_EX,
|
||||
};
|
||||
|
||||
/// Keyboard input controller
|
||||
@@ -144,8 +144,8 @@ impl KeyboardController {
|
||||
tracing::info!("SAS service not available, trying direct sas.dll...");
|
||||
|
||||
// Tier 2: Try using the sas.dll directly (requires SYSTEM privileges)
|
||||
use windows::Win32::System::LibraryLoader::{GetProcAddress, LoadLibraryW};
|
||||
use windows::core::PCWSTR;
|
||||
use windows::Win32::System::LibraryLoader::{GetProcAddress, LoadLibraryW};
|
||||
|
||||
unsafe {
|
||||
let dll_name: Vec<u16> = "sas.dll\0".encode_utf16().collect();
|
||||
@@ -195,7 +195,7 @@ impl KeyboardController {
|
||||
0x5D | // Applications key
|
||||
0x6F | // Numpad Divide
|
||||
0x90 | // Num Lock
|
||||
0x91 // Scroll Lock
|
||||
0x91 // Scroll Lock
|
||||
)
|
||||
}
|
||||
|
||||
@@ -205,11 +205,7 @@ impl KeyboardController {
|
||||
let sent = unsafe { SendInput(inputs, std::mem::size_of::<INPUT>() as i32) };
|
||||
|
||||
if sent as usize != inputs.len() {
|
||||
anyhow::bail!(
|
||||
"SendInput failed: sent {} of {} inputs",
|
||||
sent,
|
||||
inputs.len()
|
||||
);
|
||||
anyhow::bail!("SendInput failed: sent {} of {} inputs", sent, inputs.len());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -250,7 +246,7 @@ pub mod vk {
|
||||
pub const ESCAPE: u16 = 0x1B;
|
||||
pub const SPACE: u16 = 0x20;
|
||||
pub const PRIOR: u16 = 0x21; // Page Up
|
||||
pub const NEXT: u16 = 0x22; // Page Down
|
||||
pub const NEXT: u16 = 0x22; // Page Down
|
||||
pub const END: u16 = 0x23;
|
||||
pub const HOME: u16 = 0x24;
|
||||
pub const LEFT: u16 = 0x25;
|
||||
|
||||
@@ -2,11 +2,11 @@
|
||||
//!
|
||||
//! Handles mouse and keyboard input simulation using Windows SendInput API.
|
||||
|
||||
mod mouse;
|
||||
mod keyboard;
|
||||
mod mouse;
|
||||
|
||||
pub use mouse::MouseController;
|
||||
pub use keyboard::KeyboardController;
|
||||
pub use mouse::MouseController;
|
||||
|
||||
use anyhow::Result;
|
||||
|
||||
|
||||
@@ -19,8 +19,7 @@ const XBUTTON2: u32 = 0x0002;
|
||||
|
||||
#[cfg(windows)]
|
||||
use windows::Win32::UI::WindowsAndMessaging::{
|
||||
GetSystemMetrics, SM_CXVIRTUALSCREEN, SM_CYVIRTUALSCREEN, SM_XVIRTUALSCREEN,
|
||||
SM_YVIRTUALSCREEN,
|
||||
GetSystemMetrics, SM_CXVIRTUALSCREEN, SM_CYVIRTUALSCREEN, SM_XVIRTUALSCREEN, SM_YVIRTUALSCREEN,
|
||||
};
|
||||
|
||||
/// Mouse input controller
|
||||
@@ -190,9 +189,7 @@ impl MouseController {
|
||||
/// Send input events
|
||||
#[cfg(windows)]
|
||||
fn send_input(&self, inputs: &[INPUT]) -> Result<()> {
|
||||
let sent = unsafe {
|
||||
SendInput(inputs, std::mem::size_of::<INPUT>() as i32)
|
||||
};
|
||||
let sent = unsafe { SendInput(inputs, std::mem::size_of::<INPUT>() as i32) };
|
||||
|
||||
if sent as usize != inputs.len() {
|
||||
anyhow::bail!("SendInput failed: sent {} of {} inputs", sent, inputs.len());
|
||||
|
||||
@@ -6,18 +6,18 @@
|
||||
//! - UAC elevation with graceful fallback
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use tracing::{info, warn, error};
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
#[cfg(windows)]
|
||||
use windows::{
|
||||
core::PCWSTR,
|
||||
Win32::Foundation::HANDLE,
|
||||
Win32::Security::{GetTokenInformation, TokenElevation, TOKEN_ELEVATION, TOKEN_QUERY},
|
||||
Win32::System::Threading::{GetCurrentProcess, OpenProcessToken},
|
||||
Win32::System::Registry::{
|
||||
RegCreateKeyExW, RegSetValueExW, RegCloseKey, HKEY, HKEY_CLASSES_ROOT,
|
||||
HKEY_CURRENT_USER, KEY_WRITE, REG_SZ, REG_OPTION_NON_VOLATILE,
|
||||
RegCloseKey, RegCreateKeyExW, RegSetValueExW, HKEY, HKEY_CLASSES_ROOT, HKEY_CURRENT_USER,
|
||||
KEY_WRITE, REG_OPTION_NON_VOLATILE, REG_SZ,
|
||||
},
|
||||
Win32::System::Threading::{GetCurrentProcess, OpenProcessToken},
|
||||
Win32::UI::Shell::ShellExecuteW,
|
||||
Win32::UI::WindowsAndMessaging::SW_SHOWNORMAL,
|
||||
};
|
||||
@@ -67,11 +67,10 @@ pub fn get_install_path(elevated: bool) -> std::path::PathBuf {
|
||||
if elevated {
|
||||
std::path::PathBuf::from(SYSTEM_INSTALL_PATH)
|
||||
} else {
|
||||
let local_app_data = std::env::var("LOCALAPPDATA")
|
||||
.unwrap_or_else(|_| {
|
||||
let home = std::env::var("USERPROFILE").unwrap_or_else(|_| ".".to_string());
|
||||
format!(r"{}\AppData\Local", home)
|
||||
});
|
||||
let local_app_data = std::env::var("LOCALAPPDATA").unwrap_or_else(|_| {
|
||||
let home = std::env::var("USERPROFILE").unwrap_or_else(|_| ".".to_string());
|
||||
format!(r"{}\AppData\Local", home)
|
||||
});
|
||||
std::path::PathBuf::from(local_app_data).join(USER_INSTALL_PATH)
|
||||
}
|
||||
}
|
||||
@@ -305,7 +304,7 @@ pub fn install(force_user_install: bool) -> Result<()> {
|
||||
#[cfg(windows)]
|
||||
pub fn is_protocol_handler_registered() -> bool {
|
||||
use windows::Win32::System::Registry::{
|
||||
RegOpenKeyExW, RegCloseKey, HKEY_CLASSES_ROOT, HKEY_CURRENT_USER, KEY_READ,
|
||||
RegCloseKey, RegOpenKeyExW, HKEY_CLASSES_ROOT, HKEY_CURRENT_USER, KEY_READ,
|
||||
};
|
||||
|
||||
unsafe {
|
||||
@@ -318,7 +317,9 @@ pub fn is_protocol_handler_registered() -> bool {
|
||||
0,
|
||||
KEY_READ,
|
||||
&mut key,
|
||||
).is_ok() {
|
||||
)
|
||||
.is_ok()
|
||||
{
|
||||
let _ = RegCloseKey(key);
|
||||
return true;
|
||||
}
|
||||
@@ -331,7 +332,9 @@ pub fn is_protocol_handler_registered() -> bool {
|
||||
0,
|
||||
KEY_READ,
|
||||
&mut key,
|
||||
).is_ok() {
|
||||
)
|
||||
.is_ok()
|
||||
{
|
||||
let _ = RegCloseKey(key);
|
||||
return true;
|
||||
}
|
||||
@@ -355,22 +358,25 @@ pub fn parse_protocol_url(url_str: &str) -> Result<(String, String, Option<Strin
|
||||
//
|
||||
// Note: In URL parsing, "view" becomes the host, SESSION_ID is the path
|
||||
|
||||
let url = url::Url::parse(url_str)
|
||||
.map_err(|e| anyhow!("Invalid URL: {}", e))?;
|
||||
let url = url::Url::parse(url_str).map_err(|e| anyhow!("Invalid URL: {}", e))?;
|
||||
|
||||
if url.scheme() != "guruconnect" {
|
||||
return Err(anyhow!("Invalid scheme: expected guruconnect://"));
|
||||
}
|
||||
|
||||
// The "action" (view/connect) is parsed as the host
|
||||
let action = url.host_str()
|
||||
let action = url
|
||||
.host_str()
|
||||
.ok_or_else(|| anyhow!("Missing action in URL"))?;
|
||||
|
||||
// The session ID is the first path segment
|
||||
let path = url.path().trim_start_matches('/');
|
||||
info!("URL path: '{}', host: '{:?}'", path, url.host_str());
|
||||
let session_id = if path.is_empty() {
|
||||
return Err(anyhow!("Invalid URL: Missing session ID (path was empty, full URL: {})", url_str));
|
||||
return Err(anyhow!(
|
||||
"Invalid URL: Missing session ID (path was empty, full URL: {})",
|
||||
url_str
|
||||
));
|
||||
} else {
|
||||
path.split('/').next().unwrap_or("").to_string()
|
||||
};
|
||||
@@ -411,7 +417,5 @@ fn to_wide(s: &str) -> Vec<u16> {
|
||||
|
||||
#[cfg(windows)]
|
||||
fn description_to_bytes(wide: &[u16]) -> Vec<u8> {
|
||||
wide.iter()
|
||||
.flat_map(|w| w.to_le_bytes())
|
||||
.collect()
|
||||
wide.iter().flat_map(|w| w.to_le_bytes()).collect()
|
||||
}
|
||||
|
||||
@@ -92,16 +92,18 @@ pub mod build_info {
|
||||
|
||||
use anyhow::Result;
|
||||
use clap::{Parser, Subcommand};
|
||||
use tracing::{info, error, warn, Level};
|
||||
use tracing::{error, info, warn, Level};
|
||||
use tracing_subscriber::FmtSubscriber;
|
||||
|
||||
#[cfg(windows)]
|
||||
use windows::Win32::UI::WindowsAndMessaging::{MessageBoxW, MB_OK, MB_ICONINFORMATION, MB_ICONERROR};
|
||||
#[cfg(windows)]
|
||||
use windows::core::PCWSTR;
|
||||
#[cfg(windows)]
|
||||
use windows::Win32::System::Console::{AllocConsole, GetConsoleWindow};
|
||||
#[cfg(windows)]
|
||||
use windows::Win32::UI::WindowsAndMessaging::{
|
||||
MessageBoxW, MB_ICONERROR, MB_ICONINFORMATION, MB_OK,
|
||||
};
|
||||
#[cfg(windows)]
|
||||
use windows::Win32::UI::WindowsAndMessaging::{ShowWindow, SW_SHOW};
|
||||
|
||||
/// GuruConnect Remote Desktop
|
||||
@@ -140,7 +142,11 @@ enum Commands {
|
||||
session_id: String,
|
||||
|
||||
/// Server URL
|
||||
#[arg(short, long, default_value = "wss://connect.azcomputerguru.com/ws/viewer")]
|
||||
#[arg(
|
||||
short,
|
||||
long,
|
||||
default_value = "wss://connect.azcomputerguru.com/ws/viewer"
|
||||
)]
|
||||
server: String,
|
||||
|
||||
/// API key for authentication
|
||||
@@ -177,15 +183,27 @@ fn main() -> Result<()> {
|
||||
let cli = Cli::parse();
|
||||
|
||||
// Initialize logging
|
||||
let level = if cli.verbose { Level::DEBUG } else { Level::INFO };
|
||||
let level = if cli.verbose {
|
||||
Level::DEBUG
|
||||
} else {
|
||||
Level::INFO
|
||||
};
|
||||
FmtSubscriber::builder()
|
||||
.with_max_level(level)
|
||||
.with_target(true)
|
||||
.with_thread_ids(true)
|
||||
.init();
|
||||
|
||||
info!("GuruConnect {} ({})", build_info::short_version(), build_info::BUILD_TARGET);
|
||||
info!("Built: {} | Commit: {}", build_info::BUILD_TIMESTAMP, build_info::GIT_COMMIT_DATE);
|
||||
info!(
|
||||
"GuruConnect {} ({})",
|
||||
build_info::short_version(),
|
||||
build_info::BUILD_TARGET
|
||||
);
|
||||
info!(
|
||||
"Built: {} | Commit: {}",
|
||||
build_info::BUILD_TIMESTAMP,
|
||||
build_info::GIT_COMMIT_DATE
|
||||
);
|
||||
|
||||
// Handle post-update cleanup
|
||||
if cli.post_update {
|
||||
@@ -194,21 +212,18 @@ fn main() -> Result<()> {
|
||||
}
|
||||
|
||||
match cli.command {
|
||||
Some(Commands::Agent { code }) => {
|
||||
run_agent_mode(code)
|
||||
}
|
||||
Some(Commands::View { session_id, server, api_key }) => {
|
||||
run_viewer_mode(&server, &session_id, &api_key)
|
||||
}
|
||||
Some(Commands::Install { user_only, elevated }) => {
|
||||
run_install(user_only || elevated)
|
||||
}
|
||||
Some(Commands::Uninstall) => {
|
||||
run_uninstall()
|
||||
}
|
||||
Some(Commands::Launch { url }) => {
|
||||
run_launch(&url)
|
||||
}
|
||||
Some(Commands::Agent { code }) => run_agent_mode(code),
|
||||
Some(Commands::View {
|
||||
session_id,
|
||||
server,
|
||||
api_key,
|
||||
}) => run_viewer_mode(&server, &session_id, &api_key),
|
||||
Some(Commands::Install {
|
||||
user_only,
|
||||
elevated,
|
||||
}) => run_install(user_only || elevated),
|
||||
Some(Commands::Uninstall) => run_uninstall(),
|
||||
Some(Commands::Launch { url }) => run_launch(&url),
|
||||
Some(Commands::VersionInfo) => {
|
||||
// Show detailed version info (allocate console on Windows for visibility)
|
||||
#[cfg(windows)]
|
||||
@@ -341,7 +356,10 @@ fn run_install(force_user_install: bool) -> Result<()> {
|
||||
|
||||
match install::install(force_user_install) {
|
||||
Ok(()) => {
|
||||
show_message_box("GuruConnect", "Installation complete!\n\nYou can now use guruconnect:// links.");
|
||||
show_message_box(
|
||||
"GuruConnect",
|
||||
"Installation complete!\n\nYou can now use guruconnect:// links.",
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => {
|
||||
@@ -467,7 +485,11 @@ async fn run_agent(config: config::Config) -> Result<()> {
|
||||
}
|
||||
|
||||
// Create tray icon
|
||||
let tray = match tray::TrayController::new(&hostname, config.support_code.as_deref(), is_support_session) {
|
||||
let tray = match tray::TrayController::new(
|
||||
&hostname,
|
||||
config.support_code.as_deref(),
|
||||
is_support_session,
|
||||
) {
|
||||
Ok(t) => {
|
||||
info!("Tray icon created");
|
||||
Some(t)
|
||||
@@ -503,7 +525,10 @@ async fn run_agent(config: config::Config) -> Result<()> {
|
||||
t.update_status("Status: Connected");
|
||||
}
|
||||
|
||||
if let Err(e) = session.run_with_tray(tray.as_ref(), chat_ctrl.as_ref()).await {
|
||||
if let Err(e) = session
|
||||
.run_with_tray(tray.as_ref(), chat_ctrl.as_ref())
|
||||
.await
|
||||
{
|
||||
let error_msg = e.to_string();
|
||||
|
||||
if error_msg.contains("USER_EXIT") {
|
||||
@@ -515,7 +540,10 @@ async fn run_agent(config: config::Config) -> Result<()> {
|
||||
if error_msg.contains("SESSION_CANCELLED") {
|
||||
info!("Session was cancelled by technician");
|
||||
cleanup_on_exit();
|
||||
show_message_box("Support Session Ended", "The support session was cancelled.");
|
||||
show_message_box(
|
||||
"Support Session Ended",
|
||||
"The support session was cancelled.",
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
@@ -524,7 +552,10 @@ async fn run_agent(config: config::Config) -> Result<()> {
|
||||
if let Err(e) = startup::uninstall() {
|
||||
warn!("Uninstall failed: {}", e);
|
||||
}
|
||||
show_message_box("Remote Session Ended", "The session was ended by the administrator.");
|
||||
show_message_box(
|
||||
"Remote Session Ended",
|
||||
"The session was ended by the administrator.",
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
@@ -533,7 +564,10 @@ async fn run_agent(config: config::Config) -> Result<()> {
|
||||
if let Err(e) = startup::uninstall() {
|
||||
warn!("Uninstall failed: {}", e);
|
||||
}
|
||||
show_message_box("GuruConnect Removed", "This computer has been removed from remote management.");
|
||||
show_message_box(
|
||||
"GuruConnect Removed",
|
||||
"This computer has been removed from remote management.",
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
@@ -551,7 +585,10 @@ async fn run_agent(config: config::Config) -> Result<()> {
|
||||
if error_msg.contains("cancelled") {
|
||||
info!("Support code was cancelled");
|
||||
cleanup_on_exit();
|
||||
show_message_box("Support Session Cancelled", "This support session has been cancelled.");
|
||||
show_message_box(
|
||||
"Support Session Cancelled",
|
||||
"This support session has been cancelled.",
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
|
||||
@@ -18,11 +18,7 @@ pub fn request_sas() -> Result<()> {
|
||||
info!("Requesting SAS via service pipe...");
|
||||
|
||||
// Try to connect to the pipe
|
||||
let mut pipe = match OpenOptions::new()
|
||||
.read(true)
|
||||
.write(true)
|
||||
.open(PIPE_NAME)
|
||||
{
|
||||
let mut pipe = match OpenOptions::new().read(true).write(true).open(PIPE_NAME) {
|
||||
Ok(p) => p,
|
||||
Err(e) => {
|
||||
warn!("Failed to connect to SAS service pipe: {}", e);
|
||||
@@ -40,7 +36,8 @@ pub fn request_sas() -> Result<()> {
|
||||
|
||||
// Read the response
|
||||
let mut response = [0u8; 64];
|
||||
let n = pipe.read(&mut response)
|
||||
let n = pipe
|
||||
.read(&mut response)
|
||||
.context("Failed to read response from SAS service")?;
|
||||
|
||||
let response_str = String::from_utf8_lossy(&response[..n]);
|
||||
@@ -59,7 +56,10 @@ pub fn request_sas() -> Result<()> {
|
||||
}
|
||||
_ => {
|
||||
error!("Unexpected response from SAS service: {}", response_str);
|
||||
Err(anyhow::anyhow!("Unexpected SAS service response: {}", response_str))
|
||||
Err(anyhow::anyhow!(
|
||||
"Unexpected SAS service response: {}",
|
||||
response_str
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -67,11 +67,7 @@ pub fn request_sas() -> Result<()> {
|
||||
/// Check if the SAS service is available
|
||||
pub fn is_service_available() -> bool {
|
||||
// Try to open the pipe
|
||||
if let Ok(mut pipe) = OpenOptions::new()
|
||||
.read(true)
|
||||
.write(true)
|
||||
.open(PIPE_NAME)
|
||||
{
|
||||
if let Ok(mut pipe) = OpenOptions::new().read(true).write(true).open(PIPE_NAME) {
|
||||
// Send a ping command
|
||||
if pipe.write_all(b"ping\n").is_ok() {
|
||||
let mut response = [0u8; 64];
|
||||
|
||||
@@ -37,9 +37,9 @@ fn show_debug_console() {
|
||||
// No-op on non-Windows platforms
|
||||
}
|
||||
|
||||
use crate::proto::{Message, message, ChatMessage, AgentStatus, Heartbeat, HeartbeatAck};
|
||||
use crate::proto::{message, AgentStatus, ChatMessage, Heartbeat, HeartbeatAck, Message};
|
||||
use crate::transport::WebSocketTransport;
|
||||
use crate::tray::{TrayController, TrayAction};
|
||||
use crate::tray::{TrayAction, TrayController};
|
||||
use anyhow::Result;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
@@ -71,8 +71,8 @@ pub struct SessionManager {
|
||||
enum SessionState {
|
||||
Disconnected,
|
||||
Connecting,
|
||||
Idle, // Connected but not streaming - minimal resource usage
|
||||
Streaming, // Actively capturing and sending frames
|
||||
Idle, // Connected but not streaming - minimal resource usage
|
||||
Streaming, // Actively capturing and sending frames
|
||||
}
|
||||
|
||||
impl SessionManager {
|
||||
@@ -103,10 +103,11 @@ impl SessionManager {
|
||||
&self.config.api_key,
|
||||
Some(&self.hostname),
|
||||
self.config.support_code.as_deref(),
|
||||
).await?;
|
||||
)
|
||||
.await?;
|
||||
|
||||
self.transport = Some(transport);
|
||||
self.state = SessionState::Idle; // Start in idle mode
|
||||
self.state = SessionState::Idle; // Start in idle mode
|
||||
|
||||
tracing::info!("Connected to server, entering idle mode");
|
||||
|
||||
@@ -120,8 +121,12 @@ impl SessionManager {
|
||||
}
|
||||
|
||||
tracing::info!("Initializing streaming resources...");
|
||||
tracing::info!("Capture config: use_dxgi={}, gdi_fallback={}, fps={}",
|
||||
self.config.capture.use_dxgi, self.config.capture.gdi_fallback, self.config.capture.fps);
|
||||
tracing::info!(
|
||||
"Capture config: use_dxgi={}, gdi_fallback={}, fps={}",
|
||||
self.config.capture.use_dxgi,
|
||||
self.config.capture.gdi_fallback,
|
||||
self.config.capture.fps
|
||||
);
|
||||
|
||||
// Get primary display with panic protection
|
||||
tracing::debug!("Enumerating displays...");
|
||||
@@ -132,12 +137,19 @@ impl SessionManager {
|
||||
return Err(anyhow::anyhow!("Display enumeration panicked"));
|
||||
}
|
||||
};
|
||||
tracing::info!("Using display: {} ({}x{})",
|
||||
primary_display.name, primary_display.width, primary_display.height);
|
||||
tracing::info!(
|
||||
"Using display: {} ({}x{})",
|
||||
primary_display.name,
|
||||
primary_display.width,
|
||||
primary_display.height
|
||||
);
|
||||
|
||||
// Create capturer with panic protection
|
||||
// Force GDI mode if DXGI fails or panics
|
||||
tracing::debug!("Creating capturer (DXGI={})...", self.config.capture.use_dxgi);
|
||||
tracing::debug!(
|
||||
"Creating capturer (DXGI={})...",
|
||||
self.config.capture.use_dxgi
|
||||
);
|
||||
let capturer = match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
|
||||
capture::create_capturer(
|
||||
primary_display.clone(),
|
||||
@@ -157,13 +169,13 @@ impl SessionManager {
|
||||
tracing::info!("Capturer created successfully");
|
||||
|
||||
// Create encoder with panic protection
|
||||
tracing::debug!("Creating encoder (codec={}, quality={})...",
|
||||
self.config.encoding.codec, self.config.encoding.quality);
|
||||
tracing::debug!(
|
||||
"Creating encoder (codec={}, quality={})...",
|
||||
self.config.encoding.codec,
|
||||
self.config.encoding.quality
|
||||
);
|
||||
let encoder = match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
|
||||
encoder::create_encoder(
|
||||
&self.config.encoding.codec,
|
||||
self.config.encoding.quality,
|
||||
)
|
||||
encoder::create_encoder(&self.config.encoding.codec, self.config.encoding.quality)
|
||||
})) {
|
||||
Ok(result) => result?,
|
||||
Err(e) => {
|
||||
@@ -202,7 +214,9 @@ impl SessionManager {
|
||||
|
||||
/// Get display count for status reports
|
||||
fn get_display_count(&self) -> i32 {
|
||||
capture::enumerate_displays().map(|d| d.len() as i32).unwrap_or(1)
|
||||
capture::enumerate_displays()
|
||||
.map(|d| d.len() as i32)
|
||||
.unwrap_or(1)
|
||||
}
|
||||
|
||||
/// Send agent status to server
|
||||
@@ -249,7 +263,11 @@ impl SessionManager {
|
||||
}
|
||||
|
||||
/// Run the session main loop with tray and chat event processing
|
||||
pub async fn run_with_tray(&mut self, tray: Option<&TrayController>, chat: Option<&ChatController>) -> Result<()> {
|
||||
pub async fn run_with_tray(
|
||||
&mut self,
|
||||
tray: Option<&TrayController>,
|
||||
chat: Option<&ChatController>,
|
||||
) -> Result<()> {
|
||||
if self.transport.is_none() {
|
||||
anyhow::bail!("Not connected");
|
||||
}
|
||||
@@ -395,12 +413,23 @@ impl SessionManager {
|
||||
}
|
||||
|
||||
// Periodic update check (only for persistent agents, not support sessions)
|
||||
if self.config.support_code.is_none() && last_update_check.elapsed() >= UPDATE_CHECK_INTERVAL {
|
||||
if self.config.support_code.is_none()
|
||||
&& last_update_check.elapsed() >= UPDATE_CHECK_INTERVAL
|
||||
{
|
||||
last_update_check = Instant::now();
|
||||
let server_url = self.config.server_url.replace("/ws/agent", "").replace("wss://", "https://").replace("ws://", "http://");
|
||||
let server_url = self
|
||||
.config
|
||||
.server_url
|
||||
.replace("/ws/agent", "")
|
||||
.replace("wss://", "https://")
|
||||
.replace("ws://", "http://");
|
||||
match crate::update::check_for_update(&server_url).await {
|
||||
Ok(Some(version_info)) => {
|
||||
tracing::info!("Update available: {} -> {}", crate::build_info::VERSION, version_info.latest_version);
|
||||
tracing::info!(
|
||||
"Update available: {} -> {}",
|
||||
crate::build_info::VERSION,
|
||||
version_info.latest_version
|
||||
);
|
||||
if let Err(e) = crate::update::perform_update(&version_info).await {
|
||||
tracing::error!("Auto-update failed: {}", e);
|
||||
}
|
||||
@@ -429,7 +458,9 @@ impl SessionManager {
|
||||
if let Ok(encoded) = encoder.encode(&frame) {
|
||||
if encoded.size > 0 {
|
||||
let msg = Message {
|
||||
payload: Some(message::Payload::VideoFrame(encoded.frame)),
|
||||
payload: Some(message::Payload::VideoFrame(
|
||||
encoded.frame,
|
||||
)),
|
||||
};
|
||||
let transport = self.transport.as_mut().unwrap();
|
||||
if let Err(e) = transport.send(msg).await {
|
||||
@@ -472,26 +503,40 @@ impl SessionManager {
|
||||
match msg.payload {
|
||||
Some(message::Payload::MouseEvent(mouse)) => {
|
||||
if let Some(input) = self.input.as_mut() {
|
||||
use crate::proto::MouseEventType;
|
||||
use crate::input::MouseButton;
|
||||
use crate::proto::MouseEventType;
|
||||
|
||||
match MouseEventType::try_from(mouse.event_type).unwrap_or(MouseEventType::MouseMove) {
|
||||
match MouseEventType::try_from(mouse.event_type)
|
||||
.unwrap_or(MouseEventType::MouseMove)
|
||||
{
|
||||
MouseEventType::MouseMove => {
|
||||
input.mouse_move(mouse.x, mouse.y)?;
|
||||
}
|
||||
MouseEventType::MouseDown => {
|
||||
input.mouse_move(mouse.x, mouse.y)?;
|
||||
if let Some(ref buttons) = mouse.buttons {
|
||||
if buttons.left { input.mouse_click(MouseButton::Left, true)?; }
|
||||
if buttons.right { input.mouse_click(MouseButton::Right, true)?; }
|
||||
if buttons.middle { input.mouse_click(MouseButton::Middle, true)?; }
|
||||
if buttons.left {
|
||||
input.mouse_click(MouseButton::Left, true)?;
|
||||
}
|
||||
if buttons.right {
|
||||
input.mouse_click(MouseButton::Right, true)?;
|
||||
}
|
||||
if buttons.middle {
|
||||
input.mouse_click(MouseButton::Middle, true)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
MouseEventType::MouseUp => {
|
||||
if let Some(ref buttons) = mouse.buttons {
|
||||
if buttons.left { input.mouse_click(MouseButton::Left, false)?; }
|
||||
if buttons.right { input.mouse_click(MouseButton::Right, false)?; }
|
||||
if buttons.middle { input.mouse_click(MouseButton::Middle, false)?; }
|
||||
if buttons.left {
|
||||
input.mouse_click(MouseButton::Left, false)?;
|
||||
}
|
||||
if buttons.right {
|
||||
input.mouse_click(MouseButton::Right, false)?;
|
||||
}
|
||||
if buttons.middle {
|
||||
input.mouse_click(MouseButton::Middle, false)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
MouseEventType::MouseWheel => {
|
||||
@@ -538,10 +583,19 @@ impl SessionManager {
|
||||
tracing::info!("Update command received from server: {}", cmd.reason);
|
||||
// Trigger update check and perform update if available
|
||||
// The server URL is derived from the config
|
||||
let server_url = self.config.server_url.replace("/ws/agent", "").replace("wss://", "https://").replace("ws://", "http://");
|
||||
let server_url = self
|
||||
.config
|
||||
.server_url
|
||||
.replace("/ws/agent", "")
|
||||
.replace("wss://", "https://")
|
||||
.replace("ws://", "http://");
|
||||
match crate::update::check_for_update(&server_url).await {
|
||||
Ok(Some(version_info)) => {
|
||||
tracing::info!("Update available: {} -> {}", crate::build_info::VERSION, version_info.latest_version);
|
||||
tracing::info!(
|
||||
"Update available: {} -> {}",
|
||||
crate::build_info::VERSION,
|
||||
version_info.latest_version
|
||||
);
|
||||
if let Err(e) = crate::update::perform_update(&version_info).await {
|
||||
tracing::error!("Update failed: {}", e);
|
||||
}
|
||||
|
||||
@@ -3,15 +3,15 @@
|
||||
//! Handles adding/removing the agent from Windows startup.
|
||||
|
||||
use anyhow::Result;
|
||||
use tracing::{info, warn, error};
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
#[cfg(windows)]
|
||||
use windows::Win32::System::Registry::{
|
||||
RegOpenKeyExW, RegSetValueExW, RegDeleteValueW, RegCloseKey,
|
||||
HKEY_CURRENT_USER, KEY_WRITE, REG_SZ,
|
||||
};
|
||||
#[cfg(windows)]
|
||||
use windows::core::PCWSTR;
|
||||
#[cfg(windows)]
|
||||
use windows::Win32::System::Registry::{
|
||||
RegCloseKey, RegDeleteValueW, RegOpenKeyExW, RegSetValueExW, HKEY_CURRENT_USER, KEY_WRITE,
|
||||
REG_SZ,
|
||||
};
|
||||
|
||||
const STARTUP_KEY: &str = r"Software\Microsoft\Windows\CurrentVersion\Run";
|
||||
const STARTUP_VALUE_NAME: &str = "GuruConnect";
|
||||
@@ -61,10 +61,8 @@ pub fn add_to_startup() -> Result<()> {
|
||||
let hkey_raw = std::mem::transmute::<_, windows::Win32::System::Registry::HKEY>(hkey);
|
||||
|
||||
// Set the value
|
||||
let data_bytes = std::slice::from_raw_parts(
|
||||
value_data.as_ptr() as *const u8,
|
||||
value_data.len() * 2,
|
||||
);
|
||||
let data_bytes =
|
||||
std::slice::from_raw_parts(value_data.as_ptr() as *const u8, value_data.len() * 2);
|
||||
|
||||
let set_result = RegSetValueExW(
|
||||
hkey_raw,
|
||||
@@ -168,7 +166,10 @@ pub fn uninstall() -> Result<()> {
|
||||
);
|
||||
|
||||
if result.is_err() {
|
||||
warn!("Failed to schedule file deletion: {:?}. File may need manual removal.", result);
|
||||
warn!(
|
||||
"Failed to schedule file deletion: {:?}. File may need manual removal.",
|
||||
result
|
||||
);
|
||||
} else {
|
||||
info!("Executable scheduled for deletion on reboot");
|
||||
}
|
||||
@@ -185,12 +186,15 @@ pub fn install_sas_service() -> Result<()> {
|
||||
|
||||
// Check if the SAS service binary exists alongside the agent
|
||||
let exe_path = std::env::current_exe()?;
|
||||
let exe_dir = exe_path.parent().ok_or_else(|| anyhow::anyhow!("No parent directory"))?;
|
||||
let exe_dir = exe_path
|
||||
.parent()
|
||||
.ok_or_else(|| anyhow::anyhow!("No parent directory"))?;
|
||||
let sas_binary = exe_dir.join("guruconnect-sas-service.exe");
|
||||
|
||||
if !sas_binary.exists() {
|
||||
// Also check in Program Files
|
||||
let program_files = std::path::PathBuf::from(r"C:\Program Files\GuruConnect\guruconnect-sas-service.exe");
|
||||
let program_files =
|
||||
std::path::PathBuf::from(r"C:\Program Files\GuruConnect\guruconnect-sas-service.exe");
|
||||
if !program_files.exists() {
|
||||
warn!("SAS service binary not found");
|
||||
return Ok(());
|
||||
@@ -232,16 +236,18 @@ pub fn uninstall_sas_service() -> Result<()> {
|
||||
|
||||
// Try to find and run the uninstall command
|
||||
let paths = [
|
||||
std::env::current_exe().ok().and_then(|p| p.parent().map(|d| d.join("guruconnect-sas-service.exe"))),
|
||||
Some(std::path::PathBuf::from(r"C:\Program Files\GuruConnect\guruconnect-sas-service.exe")),
|
||||
std::env::current_exe()
|
||||
.ok()
|
||||
.and_then(|p| p.parent().map(|d| d.join("guruconnect-sas-service.exe"))),
|
||||
Some(std::path::PathBuf::from(
|
||||
r"C:\Program Files\GuruConnect\guruconnect-sas-service.exe",
|
||||
)),
|
||||
];
|
||||
|
||||
for path_opt in paths.iter() {
|
||||
if let Some(ref path) = path_opt {
|
||||
if path.exists() {
|
||||
let output = std::process::Command::new(path)
|
||||
.arg("uninstall")
|
||||
.output();
|
||||
let output = std::process::Command::new(path).arg("uninstall").output();
|
||||
|
||||
if let Ok(result) = output {
|
||||
if result.status.success() {
|
||||
|
||||
@@ -103,11 +103,7 @@ impl WebSocketTransport {
|
||||
let mut stream = stream.lock().await;
|
||||
|
||||
// Use try_next for non-blocking receive
|
||||
match tokio::time::timeout(
|
||||
std::time::Duration::from_millis(1),
|
||||
stream.next(),
|
||||
)
|
||||
.await
|
||||
match tokio::time::timeout(std::time::Duration::from_millis(1), stream.next()).await
|
||||
{
|
||||
Ok(Some(Ok(ws_msg))) => Ok(Some(ws_msg)),
|
||||
Ok(Some(Err(e))) => Err(anyhow::anyhow!("WebSocket error: {}", e)),
|
||||
|
||||
@@ -9,12 +9,12 @@ use anyhow::Result;
|
||||
use muda::{Menu, MenuEvent, MenuItem, PredefinedMenuItem, Submenu};
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use tray_icon::{Icon, TrayIcon, TrayIconBuilder, TrayIconEvent};
|
||||
use tracing::{info, warn};
|
||||
use tray_icon::{Icon, TrayIcon, TrayIconBuilder, TrayIconEvent};
|
||||
|
||||
#[cfg(windows)]
|
||||
use windows::Win32::UI::WindowsAndMessaging::{
|
||||
PeekMessageW, TranslateMessage, DispatchMessageW, MSG, PM_REMOVE,
|
||||
DispatchMessageW, PeekMessageW, TranslateMessage, MSG, PM_REMOVE,
|
||||
};
|
||||
|
||||
/// Events that can be triggered from the tray menu
|
||||
@@ -38,7 +38,11 @@ pub struct TrayController {
|
||||
impl TrayController {
|
||||
/// Create a new tray controller
|
||||
/// `allow_end_session` - If true, show "End Session" menu item (only for support sessions)
|
||||
pub fn new(machine_name: &str, support_code: Option<&str>, allow_end_session: bool) -> Result<Self> {
|
||||
pub fn new(
|
||||
machine_name: &str,
|
||||
support_code: Option<&str>,
|
||||
allow_end_session: bool,
|
||||
) -> Result<Self> {
|
||||
// Create menu items
|
||||
let status_text = if let Some(code) = support_code {
|
||||
format!("Support Session: {}", code)
|
||||
@@ -166,9 +170,9 @@ fn create_default_icon() -> Result<Icon> {
|
||||
|
||||
if dist <= radius {
|
||||
// Green circle
|
||||
rgba[idx] = 76; // R
|
||||
rgba[idx] = 76; // R
|
||||
rgba[idx + 1] = 175; // G
|
||||
rgba[idx + 2] = 80; // B
|
||||
rgba[idx + 2] = 80; // B
|
||||
rgba[idx + 3] = 255; // A
|
||||
} else if dist <= radius + 1.0 {
|
||||
// Anti-aliased edge
|
||||
|
||||
@@ -4,9 +4,9 @@
|
||||
//! in-place binary replacement with restart.
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use sha2::{Sha256, Digest};
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::path::PathBuf;
|
||||
use tracing::{info, warn, error};
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
use crate::build_info;
|
||||
|
||||
@@ -38,7 +38,7 @@ pub async fn check_for_update(server_base_url: &str) -> Result<Option<VersionInf
|
||||
info!("Checking for updates at {}", url);
|
||||
|
||||
let client = reqwest::Client::builder()
|
||||
.danger_accept_invalid_certs(true) // For self-signed certs in dev
|
||||
.danger_accept_invalid_certs(true) // For self-signed certs in dev
|
||||
.build()?;
|
||||
|
||||
let response = client
|
||||
@@ -79,11 +79,8 @@ fn is_newer_version(available: &str, current: &str) -> bool {
|
||||
let available_clean = available.split('-').next().unwrap_or(available);
|
||||
let current_clean = current.split('-').next().unwrap_or(current);
|
||||
|
||||
let parse_version = |s: &str| -> Vec<u32> {
|
||||
s.split('.')
|
||||
.filter_map(|p| p.parse().ok())
|
||||
.collect()
|
||||
};
|
||||
let parse_version =
|
||||
|s: &str| -> Vec<u32> { s.split('.').filter_map(|p| p.parse().ok()).collect() };
|
||||
|
||||
let av = parse_version(available_clean);
|
||||
let cv = parse_version(current_clean);
|
||||
@@ -112,7 +109,7 @@ pub async fn download_update(version_info: &VersionInfo) -> Result<PathBuf> {
|
||||
|
||||
let response = client
|
||||
.get(&version_info.download_url)
|
||||
.timeout(std::time::Duration::from_secs(300)) // 5 minutes for large files
|
||||
.timeout(std::time::Duration::from_secs(300)) // 5 minutes for large files
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
@@ -147,7 +144,10 @@ pub fn verify_checksum(file_path: &PathBuf, expected_sha256: &str) -> Result<boo
|
||||
if matches {
|
||||
info!("Checksum verified: {}", computed);
|
||||
} else {
|
||||
error!("Checksum mismatch! Expected: {}, Got: {}", expected_sha256, computed);
|
||||
error!(
|
||||
"Checksum mismatch! Expected: {}, Got: {}",
|
||||
expected_sha256, computed
|
||||
);
|
||||
}
|
||||
|
||||
Ok(matches)
|
||||
@@ -160,7 +160,8 @@ pub fn install_update(temp_path: &PathBuf) -> Result<PathBuf> {
|
||||
|
||||
// Get current executable path
|
||||
let current_exe = std::env::current_exe()?;
|
||||
let exe_dir = current_exe.parent()
|
||||
let exe_dir = current_exe
|
||||
.parent()
|
||||
.ok_or_else(|| anyhow!("Cannot get executable directory"))?;
|
||||
|
||||
// Create paths for backup and new executable
|
||||
@@ -257,10 +258,11 @@ pub fn cleanup_post_update() {
|
||||
#[cfg(windows)]
|
||||
fn schedule_delete_on_reboot(path: &PathBuf) {
|
||||
use std::os::windows::ffi::OsStrExt;
|
||||
use windows::Win32::Storage::FileSystem::{MoveFileExW, MOVEFILE_DELAY_UNTIL_REBOOT};
|
||||
use windows::core::PCWSTR;
|
||||
use windows::Win32::Storage::FileSystem::{MoveFileExW, MOVEFILE_DELAY_UNTIL_REBOOT};
|
||||
|
||||
let path_wide: Vec<u16> = path.as_os_str()
|
||||
let path_wide: Vec<u16> = path
|
||||
.as_os_str()
|
||||
.encode_wide()
|
||||
.chain(std::iter::once(0))
|
||||
.collect();
|
||||
|
||||
@@ -37,11 +37,11 @@ mod vk {
|
||||
pub const VK_RSHIFT: u32 = 0xA1;
|
||||
pub const VK_LCONTROL: u32 = 0xA2;
|
||||
pub const VK_RCONTROL: u32 = 0xA3;
|
||||
pub const VK_LMENU: u32 = 0xA4; // Left Alt
|
||||
pub const VK_RMENU: u32 = 0xA5; // Right Alt
|
||||
pub const VK_LMENU: u32 = 0xA4; // Left Alt
|
||||
pub const VK_RMENU: u32 = 0xA5; // Right Alt
|
||||
pub const VK_TAB: u32 = 0x09;
|
||||
pub const VK_ESCAPE: u32 = 0x1B;
|
||||
pub const VK_SNAPSHOT: u32 = 0x2C; // Print Screen
|
||||
pub const VK_SNAPSHOT: u32 = 0x2C; // Print Screen
|
||||
}
|
||||
|
||||
#[cfg(windows)]
|
||||
@@ -53,15 +53,12 @@ pub struct KeyboardHook {
|
||||
impl KeyboardHook {
|
||||
pub fn new(input_tx: mpsc::Sender<InputEvent>) -> Result<Self> {
|
||||
// Store the sender globally for the hook callback
|
||||
INPUT_TX.set(input_tx).map_err(|_| anyhow::anyhow!("Input TX already set"))?;
|
||||
INPUT_TX
|
||||
.set(input_tx)
|
||||
.map_err(|_| anyhow::anyhow!("Input TX already set"))?;
|
||||
|
||||
unsafe {
|
||||
let hook = SetWindowsHookExW(
|
||||
WH_KEYBOARD_LL,
|
||||
Some(keyboard_hook_proc),
|
||||
None,
|
||||
0,
|
||||
)?;
|
||||
let hook = SetWindowsHookExW(WH_KEYBOARD_LL, Some(keyboard_hook_proc), None, 0)?;
|
||||
|
||||
HOOK_HANDLE = hook;
|
||||
Ok(Self { _hook: hook })
|
||||
@@ -82,11 +79,7 @@ impl Drop for KeyboardHook {
|
||||
}
|
||||
|
||||
#[cfg(windows)]
|
||||
unsafe extern "system" fn keyboard_hook_proc(
|
||||
code: i32,
|
||||
wparam: WPARAM,
|
||||
lparam: LPARAM,
|
||||
) -> LRESULT {
|
||||
unsafe extern "system" fn keyboard_hook_proc(code: i32, wparam: WPARAM, lparam: LPARAM) -> LRESULT {
|
||||
if code >= 0 {
|
||||
let kb_struct = &*(lparam.0 as *const KBDLLHOOKSTRUCT);
|
||||
let vk_code = kb_struct.vkCode;
|
||||
@@ -97,10 +90,7 @@ unsafe extern "system" fn keyboard_hook_proc(
|
||||
|
||||
if is_down || is_up {
|
||||
// Check if this is a key we want to intercept (Win key, Alt+Tab, etc.)
|
||||
let should_intercept = matches!(
|
||||
vk_code,
|
||||
vk::VK_LWIN | vk::VK_RWIN | vk::VK_APPS
|
||||
);
|
||||
let should_intercept = matches!(vk_code, vk::VK_LWIN | vk::VK_RWIN | vk::VK_APPS);
|
||||
|
||||
// Send the key event to the remote
|
||||
if let Some(tx) = INPUT_TX.get() {
|
||||
@@ -114,7 +104,12 @@ unsafe extern "system" fn keyboard_hook_proc(
|
||||
};
|
||||
|
||||
let _ = tx.try_send(InputEvent::Key(event));
|
||||
trace!("Key hook: vk={:#x} scan={} down={}", vk_code, scan_code, is_down);
|
||||
trace!(
|
||||
"Key hook: vk={:#x} scan={} down={}",
|
||||
vk_code,
|
||||
scan_code,
|
||||
is_down
|
||||
);
|
||||
}
|
||||
|
||||
// For Win key, consume the event so it doesn't open Start menu locally
|
||||
@@ -133,12 +128,12 @@ fn get_current_modifiers() -> proto::Modifiers {
|
||||
|
||||
unsafe {
|
||||
proto::Modifiers {
|
||||
ctrl: GetAsyncKeyState(0x11) < 0, // VK_CONTROL
|
||||
alt: GetAsyncKeyState(0x12) < 0, // VK_MENU
|
||||
shift: GetAsyncKeyState(0x10) < 0, // VK_SHIFT
|
||||
ctrl: GetAsyncKeyState(0x11) < 0, // VK_CONTROL
|
||||
alt: GetAsyncKeyState(0x12) < 0, // VK_MENU
|
||||
shift: GetAsyncKeyState(0x10) < 0, // VK_SHIFT
|
||||
meta: GetAsyncKeyState(0x5B) < 0 || GetAsyncKeyState(0x5C) < 0, // VK_LWIN/RWIN
|
||||
caps_lock: GetAsyncKeyState(0x14) & 1 != 0, // VK_CAPITAL
|
||||
num_lock: GetAsyncKeyState(0x90) & 1 != 0, // VK_NUMLOCK
|
||||
num_lock: GetAsyncKeyState(0x90) & 1 != 0, // VK_NUMLOCK
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@ use crate::proto;
|
||||
use anyhow::Result;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{mpsc, Mutex};
|
||||
use tracing::{info, error, warn};
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ViewerEvent {
|
||||
@@ -93,16 +93,18 @@ pub async fn run(server_url: &str, session_id: &str, api_key: &str) -> Result<()
|
||||
}
|
||||
}
|
||||
Some(proto::message::Payload::CursorPosition(pos)) => {
|
||||
let _ = viewer_tx_recv.send(ViewerEvent::CursorPosition(
|
||||
pos.x, pos.y, pos.visible
|
||||
)).await;
|
||||
let _ = viewer_tx_recv
|
||||
.send(ViewerEvent::CursorPosition(pos.x, pos.y, pos.visible))
|
||||
.await;
|
||||
}
|
||||
Some(proto::message::Payload::CursorShape(shape)) => {
|
||||
let _ = viewer_tx_recv.send(ViewerEvent::CursorShape(shape)).await;
|
||||
}
|
||||
Some(proto::message::Payload::Disconnect(d)) => {
|
||||
warn!("Server disconnected: {}", d.reason);
|
||||
let _ = viewer_tx_recv.send(ViewerEvent::Disconnected(d.reason)).await;
|
||||
let _ = viewer_tx_recv
|
||||
.send(ViewerEvent::Disconnected(d.reason))
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
_ => {}
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
//! Window rendering and frame display
|
||||
|
||||
use super::{ViewerEvent, InputEvent};
|
||||
use crate::proto;
|
||||
#[cfg(windows)]
|
||||
use super::input;
|
||||
use super::{InputEvent, ViewerEvent};
|
||||
use crate::proto;
|
||||
use anyhow::Result;
|
||||
use std::num::NonZeroU32;
|
||||
use std::sync::Arc;
|
||||
@@ -43,10 +43,7 @@ struct ViewerApp {
|
||||
}
|
||||
|
||||
impl ViewerApp {
|
||||
fn new(
|
||||
viewer_rx: mpsc::Receiver<ViewerEvent>,
|
||||
input_tx: mpsc::Sender<InputEvent>,
|
||||
) -> Self {
|
||||
fn new(viewer_rx: mpsc::Receiver<ViewerEvent>, input_tx: mpsc::Sender<InputEvent>) -> Self {
|
||||
Self {
|
||||
window: None,
|
||||
surface: None,
|
||||
@@ -112,7 +109,9 @@ impl ViewerApp {
|
||||
}
|
||||
|
||||
fn render(&mut self) {
|
||||
let Some(surface) = &mut self.surface else { return };
|
||||
let Some(surface) = &mut self.surface else {
|
||||
return;
|
||||
};
|
||||
let Some(window) = &self.window else { return };
|
||||
|
||||
if self.frame_buffer.is_empty() || self.frame_width == 0 || self.frame_height == 0 {
|
||||
|
||||
@@ -6,23 +6,17 @@ use bytes::Bytes;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use prost::Message as ProstMessage;
|
||||
use std::sync::Arc;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio_tungstenite::{
|
||||
connect_async,
|
||||
tungstenite::protocol::Message as WsMessage,
|
||||
MaybeTlsStream, WebSocketStream,
|
||||
connect_async, tungstenite::protocol::Message as WsMessage, MaybeTlsStream, WebSocketStream,
|
||||
};
|
||||
use tokio::net::TcpStream;
|
||||
use tracing::{debug, error, trace};
|
||||
|
||||
pub type WsSender = futures_util::stream::SplitSink<
|
||||
WebSocketStream<MaybeTlsStream<TcpStream>>,
|
||||
WsMessage,
|
||||
>;
|
||||
pub type WsSender =
|
||||
futures_util::stream::SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, WsMessage>;
|
||||
|
||||
pub type WsReceiver = futures_util::stream::SplitStream<
|
||||
WebSocketStream<MaybeTlsStream<TcpStream>>,
|
||||
>;
|
||||
pub type WsReceiver = futures_util::stream::SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
|
||||
|
||||
/// Receiver wrapper that parses protobuf messages
|
||||
pub struct MessageReceiver {
|
||||
@@ -88,10 +82,7 @@ pub async fn connect(url: &str, token: &str) -> Result<(WsSender, MessageReceive
|
||||
}
|
||||
|
||||
/// Send a protobuf message over the WebSocket
|
||||
pub async fn send_message(
|
||||
sender: &Arc<Mutex<WsSender>>,
|
||||
msg: &proto::Message,
|
||||
) -> Result<()> {
|
||||
pub async fn send_message(sender: &Arc<Mutex<WsSender>>, msg: &proto::Message) -> Result<()> {
|
||||
let mut buf = Vec::with_capacity(msg.encoded_len());
|
||||
msg.encode(&mut buf)?;
|
||||
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
//! Authentication API endpoints
|
||||
|
||||
use axum::{
|
||||
extract::{State, Request},
|
||||
extract::{Request, State},
|
||||
http::StatusCode,
|
||||
Json,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::auth::{
|
||||
verify_password, AuthenticatedUser, JwtConfig,
|
||||
};
|
||||
use crate::auth::{verify_password, AuthenticatedUser, JwtConfig};
|
||||
use crate::db;
|
||||
use crate::AppState;
|
||||
|
||||
@@ -89,16 +87,15 @@ pub async fn login(
|
||||
}
|
||||
|
||||
// Verify password
|
||||
let password_valid = verify_password(&request.password, &user.password_hash)
|
||||
.map_err(|e| {
|
||||
tracing::error!("Password verification error: {}", e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ErrorResponse {
|
||||
error: "Internal server error".to_string(),
|
||||
}),
|
||||
)
|
||||
})?;
|
||||
let password_valid = verify_password(&request.password, &user.password_hash).map_err(|e| {
|
||||
tracing::error!("Password verification error: {}", e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ErrorResponse {
|
||||
error: "Internal server error".to_string(),
|
||||
}),
|
||||
)
|
||||
})?;
|
||||
|
||||
if !password_valid {
|
||||
return Err((
|
||||
@@ -118,21 +115,18 @@ pub async fn login(
|
||||
let _ = db::update_last_login(db.pool(), user.id).await;
|
||||
|
||||
// Create JWT token
|
||||
let token = state.jwt_config.create_token(
|
||||
user.id,
|
||||
&user.username,
|
||||
&user.role,
|
||||
permissions.clone(),
|
||||
)
|
||||
.map_err(|e| {
|
||||
tracing::error!("Token creation error: {}", e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ErrorResponse {
|
||||
error: "Failed to create token".to_string(),
|
||||
}),
|
||||
)
|
||||
})?;
|
||||
let token = state
|
||||
.jwt_config
|
||||
.create_token(user.id, &user.username, &user.role, permissions.clone())
|
||||
.map_err(|e| {
|
||||
tracing::error!("Token creation error: {}", e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ErrorResponse {
|
||||
error: "Failed to create token".to_string(),
|
||||
}),
|
||||
)
|
||||
})?;
|
||||
|
||||
tracing::info!("User {} logged in successfully", user.username);
|
||||
|
||||
@@ -288,16 +282,15 @@ pub async fn change_password(
|
||||
}
|
||||
|
||||
// Hash new password
|
||||
let new_hash = crate::auth::hash_password(&request.new_password)
|
||||
.map_err(|e| {
|
||||
tracing::error!("Password hashing error: {}", e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ErrorResponse {
|
||||
error: "Failed to hash password".to_string(),
|
||||
}),
|
||||
)
|
||||
})?;
|
||||
let new_hash = crate::auth::hash_password(&request.new_password).map_err(|e| {
|
||||
tracing::error!("Password hashing error: {}", e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ErrorResponse {
|
||||
error: "Failed to hash password".to_string(),
|
||||
}),
|
||||
)
|
||||
})?;
|
||||
|
||||
// Update password
|
||||
db::update_user_password(db.pool(), user_id, &new_hash)
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
//! Logout and token revocation endpoints
|
||||
|
||||
use axum::{
|
||||
extract::{Request, State, Path},
|
||||
http::{StatusCode, HeaderMap},
|
||||
extract::{Path, Request, State},
|
||||
http::{HeaderMap, StatusCode},
|
||||
Json,
|
||||
};
|
||||
use uuid::Uuid;
|
||||
use serde::Serialize;
|
||||
use tracing::{info, warn};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::auth::AuthenticatedUser;
|
||||
use crate::AppState;
|
||||
@@ -15,7 +15,9 @@ use crate::AppState;
|
||||
use super::auth::ErrorResponse;
|
||||
|
||||
/// Extract JWT token from Authorization header
|
||||
fn extract_token_from_headers(headers: &HeaderMap) -> Result<String, (StatusCode, Json<ErrorResponse>)> {
|
||||
fn extract_token_from_headers(
|
||||
headers: &HeaderMap,
|
||||
) -> Result<String, (StatusCode, Json<ErrorResponse>)> {
|
||||
let auth_header = headers
|
||||
.get("Authorization")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
@@ -28,16 +30,14 @@ fn extract_token_from_headers(headers: &HeaderMap) -> Result<String, (StatusCode
|
||||
)
|
||||
})?;
|
||||
|
||||
let token = auth_header
|
||||
.strip_prefix("Bearer ")
|
||||
.ok_or_else(|| {
|
||||
(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(ErrorResponse {
|
||||
error: "Invalid Authorization format".to_string(),
|
||||
}),
|
||||
)
|
||||
})?;
|
||||
let token = auth_header.strip_prefix("Bearer ").ok_or_else(|| {
|
||||
(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(ErrorResponse {
|
||||
error: "Invalid Authorization format".to_string(),
|
||||
}),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(token.to_string())
|
||||
}
|
||||
@@ -124,7 +124,8 @@ pub async fn revoke_user_tokens(
|
||||
Err((
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
Json(ErrorResponse {
|
||||
error: "User token revocation not yet implemented - requires session tracking table".to_string(),
|
||||
error: "User token revocation not yet implemented - requires session tracking table"
|
||||
.to_string(),
|
||||
}),
|
||||
))
|
||||
}
|
||||
@@ -179,10 +180,16 @@ pub async fn cleanup_blacklist(
|
||||
));
|
||||
}
|
||||
|
||||
let removed = state.token_blacklist.cleanup_expired(&state.jwt_config).await;
|
||||
let removed = state
|
||||
.token_blacklist
|
||||
.cleanup_expired(&state.jwt_config)
|
||||
.await;
|
||||
let remaining = state.token_blacklist.len().await;
|
||||
|
||||
info!("Admin {} cleaned up blacklist: {} tokens removed, {} remaining", admin.username, removed, remaining);
|
||||
info!(
|
||||
"Admin {} cleaned up blacklist: {} tokens removed, {} remaining",
|
||||
admin.username, removed, remaining
|
||||
);
|
||||
|
||||
Ok(Json(CleanupResponse {
|
||||
removed_count: removed,
|
||||
|
||||
@@ -13,7 +13,7 @@ use axum::{
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::PathBuf;
|
||||
use tracing::{info, warn, error};
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
/// Magic marker for embedded configuration (must match agent)
|
||||
const MAGIC_MARKER: &[u8] = b"GURUCONFIG";
|
||||
@@ -87,7 +87,7 @@ pub async fn download_viewer() -> impl IntoResponse {
|
||||
.header(header::CONTENT_TYPE, "application/octet-stream")
|
||||
.header(
|
||||
header::CONTENT_DISPOSITION,
|
||||
"attachment; filename=\"GuruConnect-Viewer.exe\""
|
||||
"attachment; filename=\"GuruConnect-Viewer.exe\"",
|
||||
)
|
||||
.header(header::CONTENT_LENGTH, binary_data.len())
|
||||
.body(Body::from(binary_data))
|
||||
@@ -104,9 +104,7 @@ pub async fn download_viewer() -> impl IntoResponse {
|
||||
}
|
||||
|
||||
/// Download support session binary (code embedded in filename)
|
||||
pub async fn download_support(
|
||||
Query(params): Query<SupportDownloadParams>,
|
||||
) -> impl IntoResponse {
|
||||
pub async fn download_support(Query(params): Query<SupportDownloadParams>) -> impl IntoResponse {
|
||||
// Validate support code (must be 6 digits)
|
||||
let code = params.code.trim();
|
||||
if code.len() != 6 || !code.chars().all(|c| c.is_ascii_digit()) {
|
||||
@@ -120,7 +118,11 @@ pub async fn download_support(
|
||||
|
||||
match std::fs::read(&binary_path) {
|
||||
Ok(binary_data) => {
|
||||
info!("Serving support session download for code {} ({} bytes)", code, binary_data.len());
|
||||
info!(
|
||||
"Serving support session download for code {} ({} bytes)",
|
||||
code,
|
||||
binary_data.len()
|
||||
);
|
||||
|
||||
// Filename includes the support code
|
||||
let filename = format!("GuruConnect-{}.exe", code);
|
||||
@@ -130,7 +132,7 @@ pub async fn download_support(
|
||||
.header(header::CONTENT_TYPE, "application/octet-stream")
|
||||
.header(
|
||||
header::CONTENT_DISPOSITION,
|
||||
format!("attachment; filename=\"{}\"", filename)
|
||||
format!("attachment; filename=\"{}\"", filename),
|
||||
)
|
||||
.header(header::CONTENT_LENGTH, binary_data.len())
|
||||
.body(Body::from(binary_data))
|
||||
@@ -147,9 +149,7 @@ pub async fn download_support(
|
||||
}
|
||||
|
||||
/// Download permanent agent binary with embedded configuration
|
||||
pub async fn download_agent(
|
||||
Query(params): Query<AgentDownloadParams>,
|
||||
) -> impl IntoResponse {
|
||||
pub async fn download_agent(Query(params): Query<AgentDownloadParams>) -> impl IntoResponse {
|
||||
let binary_path = get_base_binary_path();
|
||||
|
||||
// Read base binary
|
||||
@@ -167,10 +167,13 @@ pub async fn download_agent(
|
||||
// Build embedded config
|
||||
let config = EmbeddedConfig {
|
||||
server_url: "wss://connect.azcomputerguru.com/ws/agent".to_string(),
|
||||
api_key: params.api_key.unwrap_or_else(|| "managed-agent".to_string()),
|
||||
api_key: params
|
||||
.api_key
|
||||
.unwrap_or_else(|| "managed-agent".to_string()),
|
||||
company: params.company.clone(),
|
||||
site: params.site.clone(),
|
||||
tags: params.tags
|
||||
tags: params
|
||||
.tags
|
||||
.as_ref()
|
||||
.map(|t| t.split(',').map(|s| s.trim().to_string()).collect())
|
||||
.unwrap_or_default(),
|
||||
@@ -196,18 +199,25 @@ pub async fn download_agent(
|
||||
|
||||
info!(
|
||||
"Serving permanent agent download: company={:?}, site={:?}, tags={:?} ({} bytes)",
|
||||
config.company, config.site, config.tags, binary_data.len()
|
||||
config.company,
|
||||
config.site,
|
||||
config.tags,
|
||||
binary_data.len()
|
||||
);
|
||||
|
||||
// Generate filename based on company/site
|
||||
let filename = match (¶ms.company, ¶ms.site) {
|
||||
(Some(company), Some(site)) => {
|
||||
format!("GuruConnect-{}-{}-Setup.exe", sanitize_filename(company), sanitize_filename(site))
|
||||
format!(
|
||||
"GuruConnect-{}-{}-Setup.exe",
|
||||
sanitize_filename(company),
|
||||
sanitize_filename(site)
|
||||
)
|
||||
}
|
||||
(Some(company), None) => {
|
||||
format!("GuruConnect-{}-Setup.exe", sanitize_filename(company))
|
||||
}
|
||||
_ => "GuruConnect-Setup.exe".to_string()
|
||||
_ => "GuruConnect-Setup.exe".to_string(),
|
||||
};
|
||||
|
||||
Response::builder()
|
||||
@@ -215,7 +225,7 @@ pub async fn download_agent(
|
||||
.header(header::CONTENT_TYPE, "application/octet-stream")
|
||||
.header(
|
||||
header::CONTENT_DISPOSITION,
|
||||
format!("attachment; filename=\"{}\"", filename)
|
||||
format!("attachment; filename=\"{}\"", filename),
|
||||
)
|
||||
.header(header::CONTENT_LENGTH, binary_data.len())
|
||||
.body(Body::from(binary_data))
|
||||
|
||||
@@ -3,19 +3,19 @@
|
||||
pub mod auth;
|
||||
pub mod auth_logout;
|
||||
pub mod changelog;
|
||||
pub mod users;
|
||||
pub mod releases;
|
||||
pub mod downloads;
|
||||
pub mod releases;
|
||||
pub mod users;
|
||||
|
||||
use axum::{
|
||||
extract::{Path, State, Query},
|
||||
extract::{Path, Query, State},
|
||||
Json,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::session::SessionManager;
|
||||
use crate::db;
|
||||
use crate::session::SessionManager;
|
||||
|
||||
/// Viewer info returned by API
|
||||
#[derive(Debug, Serialize)]
|
||||
@@ -78,9 +78,7 @@ impl From<crate::session::Session> for SessionInfo {
|
||||
}
|
||||
|
||||
/// List all active sessions
|
||||
pub async fn list_sessions(
|
||||
State(sessions): State<SessionManager>,
|
||||
) -> Json<Vec<SessionInfo>> {
|
||||
pub async fn list_sessions(State(sessions): State<SessionManager>) -> Json<Vec<SessionInfo>> {
|
||||
let sessions = sessions.list_sessions().await;
|
||||
Json(sessions.into_iter().map(SessionInfo::from).collect())
|
||||
}
|
||||
@@ -93,7 +91,9 @@ pub async fn get_session(
|
||||
let session_id = Uuid::parse_str(&id)
|
||||
.map_err(|_| (axum::http::StatusCode::BAD_REQUEST, "Invalid session ID"))?;
|
||||
|
||||
let session = sessions.get_session(session_id).await
|
||||
let session = sessions
|
||||
.get_session(session_id)
|
||||
.await
|
||||
.ok_or((axum::http::StatusCode::NOT_FOUND, "Session not found"))?;
|
||||
|
||||
Ok(Json(SessionInfo::from(session)))
|
||||
|
||||
@@ -129,17 +129,15 @@ pub async fn list_releases(
|
||||
)
|
||||
})?;
|
||||
|
||||
let releases = db::get_all_releases(db.pool())
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!("Database error: {}", e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ErrorResponse {
|
||||
error: "Failed to fetch releases".to_string(),
|
||||
}),
|
||||
)
|
||||
})?;
|
||||
let releases = db::get_all_releases(db.pool()).await.map_err(|e| {
|
||||
tracing::error!("Database error: {}", e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ErrorResponse {
|
||||
error: "Failed to fetch releases".to_string(),
|
||||
}),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(Json(releases.into_iter().map(ReleaseInfo::from).collect()))
|
||||
}
|
||||
@@ -171,7 +169,10 @@ pub async fn create_release(
|
||||
|
||||
// Validate checksum format (64 hex chars for SHA-256)
|
||||
if request.checksum_sha256.len() != 64
|
||||
|| !request.checksum_sha256.chars().all(|c| c.is_ascii_hexdigit())
|
||||
|| !request
|
||||
.checksum_sha256
|
||||
.chars()
|
||||
.all(|c| c.is_ascii_hexdigit())
|
||||
{
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
@@ -349,17 +350,15 @@ pub async fn delete_release(
|
||||
)
|
||||
})?;
|
||||
|
||||
let deleted = db::delete_release(db.pool(), &version)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!("Database error: {}", e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ErrorResponse {
|
||||
error: "Failed to delete release".to_string(),
|
||||
}),
|
||||
)
|
||||
})?;
|
||||
let deleted = db::delete_release(db.pool(), &version).await.map_err(|e| {
|
||||
tracing::error!("Database error: {}", e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ErrorResponse {
|
||||
error: "Failed to delete release".to_string(),
|
||||
}),
|
||||
)
|
||||
})?;
|
||||
|
||||
if deleted {
|
||||
tracing::info!("Deleted release: {}", version);
|
||||
|
||||
@@ -72,17 +72,15 @@ pub async fn list_users(
|
||||
)
|
||||
})?;
|
||||
|
||||
let users = db::get_all_users(db.pool())
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!("Database error: {}", e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ErrorResponse {
|
||||
error: "Failed to fetch users".to_string(),
|
||||
}),
|
||||
)
|
||||
})?;
|
||||
let users = db::get_all_users(db.pool()).await.map_err(|e| {
|
||||
tracing::error!("Database error: {}", e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ErrorResponse {
|
||||
error: "Failed to fetch users".to_string(),
|
||||
}),
|
||||
)
|
||||
})?;
|
||||
|
||||
let mut result = Vec::new();
|
||||
for user in users {
|
||||
@@ -210,7 +208,13 @@ pub async fn create_user(
|
||||
} else {
|
||||
// Default permissions based on role
|
||||
let default_perms = match request.role.as_str() {
|
||||
"admin" => vec!["view", "control", "transfer", "manage_users", "manage_clients"],
|
||||
"admin" => vec![
|
||||
"view",
|
||||
"control",
|
||||
"transfer",
|
||||
"manage_users",
|
||||
"manage_clients",
|
||||
],
|
||||
"operator" => vec!["view", "control", "transfer"],
|
||||
"viewer" => vec!["view"],
|
||||
_ => vec!["view"],
|
||||
@@ -455,17 +459,15 @@ pub async fn delete_user(
|
||||
));
|
||||
}
|
||||
|
||||
let deleted = db::delete_user(db.pool(), user_id)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!("Database error: {}", e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ErrorResponse {
|
||||
error: "Failed to delete user".to_string(),
|
||||
}),
|
||||
)
|
||||
})?;
|
||||
let deleted = db::delete_user(db.pool(), user_id).await.map_err(|e| {
|
||||
tracing::error!("Database error: {}", e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ErrorResponse {
|
||||
error: "Failed to delete user".to_string(),
|
||||
}),
|
||||
)
|
||||
})?;
|
||||
|
||||
if deleted {
|
||||
tracing::info!("Deleted user: {}", id);
|
||||
@@ -506,13 +508,22 @@ pub async fn set_permissions(
|
||||
})?;
|
||||
|
||||
// Validate permissions
|
||||
let valid_permissions = ["view", "control", "transfer", "manage_users", "manage_clients"];
|
||||
let valid_permissions = [
|
||||
"view",
|
||||
"control",
|
||||
"transfer",
|
||||
"manage_users",
|
||||
"manage_clients",
|
||||
];
|
||||
for perm in &request.permissions {
|
||||
if !valid_permissions.contains(&perm.as_str()) {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(ErrorResponse {
|
||||
error: format!("Invalid permission: {}. Valid: {:?}", perm, valid_permissions),
|
||||
error: format!(
|
||||
"Invalid permission: {}. Valid: {:?}",
|
||||
perm, valid_permissions
|
||||
),
|
||||
}),
|
||||
));
|
||||
}
|
||||
|
||||
@@ -54,7 +54,10 @@ pub struct JwtConfig {
|
||||
impl JwtConfig {
|
||||
/// Create new JWT config
|
||||
pub fn new(secret: String, expiry_hours: i64) -> Self {
|
||||
Self { secret, expiry_hours }
|
||||
Self {
|
||||
secret,
|
||||
expiry_hours,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a JWT token for a user
|
||||
@@ -97,9 +100,9 @@ impl JwtConfig {
|
||||
pub fn validate_token(&self, token: &str) -> Result<Claims> {
|
||||
// SEC-13: Explicit validation configuration
|
||||
let mut validation = Validation::default();
|
||||
validation.validate_exp = true; // Enforce expiration check
|
||||
validation.validate_exp = true; // Enforce expiration check
|
||||
validation.validate_nbf = false; // Not using "not before" claim
|
||||
validation.leeway = 0; // No clock skew tolerance
|
||||
validation.leeway = 0; // No clock skew tolerance
|
||||
|
||||
let token_data = decode::<Claims>(
|
||||
token,
|
||||
@@ -129,12 +132,14 @@ mod tests {
|
||||
let config = JwtConfig::new("test-secret".to_string(), 24);
|
||||
let user_id = Uuid::new_v4();
|
||||
|
||||
let token = config.create_token(
|
||||
user_id,
|
||||
"testuser",
|
||||
"admin",
|
||||
vec!["view".to_string(), "control".to_string()],
|
||||
).unwrap();
|
||||
let token = config
|
||||
.create_token(
|
||||
user_id,
|
||||
"testuser",
|
||||
"admin",
|
||||
vec!["view".to_string(), "control".to_string()],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let claims = config.validate_token(&token).unwrap();
|
||||
assert_eq!(claims.username, "testuser");
|
||||
|
||||
@@ -8,7 +8,7 @@ pub mod password;
|
||||
pub mod token_blacklist;
|
||||
|
||||
pub use jwt::{Claims, JwtConfig};
|
||||
pub use password::{hash_password, verify_password, generate_random_password};
|
||||
pub use password::{generate_random_password, hash_password, verify_password};
|
||||
pub use token_blacklist::TokenBlacklist;
|
||||
|
||||
use axum::{
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
use anyhow::{anyhow, Result};
|
||||
use argon2::{
|
||||
password_hash::{rand_core::OsRng, PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
|
||||
Argon2, Algorithm, Version, Params,
|
||||
Algorithm, Argon2, Params, Version,
|
||||
};
|
||||
|
||||
/// Hash a password using Argon2id
|
||||
@@ -22,9 +22,9 @@ pub fn hash_password(password: &str) -> Result<String> {
|
||||
|
||||
// Explicitly use Argon2id (Algorithm::Argon2id)
|
||||
let argon2 = Argon2::new(
|
||||
Algorithm::Argon2id, // SEC-9: Explicit Argon2id variant
|
||||
Version::V0x13, // Latest version
|
||||
Params::default(), // Default params (19456 KiB, 2 iterations, 1 parallelism)
|
||||
Algorithm::Argon2id, // SEC-9: Explicit Argon2id variant
|
||||
Version::V0x13, // Latest version
|
||||
Params::default(), // Default params (19456 KiB, 2 iterations, 1 parallelism)
|
||||
);
|
||||
|
||||
let hash = argon2
|
||||
@@ -35,12 +35,14 @@ pub fn hash_password(password: &str) -> Result<String> {
|
||||
|
||||
/// Verify a password against a stored hash
|
||||
pub fn verify_password(password: &str, hash: &str) -> Result<bool> {
|
||||
let parsed_hash = PasswordHash::new(hash)
|
||||
.map_err(|e| anyhow!("Invalid password hash format: {}", e))?;
|
||||
let parsed_hash =
|
||||
PasswordHash::new(hash).map_err(|e| anyhow!("Invalid password hash format: {}", e))?;
|
||||
|
||||
// Argon2::default() uses Argon2id, but we verify against the hash's embedded algorithm
|
||||
let argon2 = Argon2::default();
|
||||
Ok(argon2.verify_password(password.as_bytes(), &parsed_hash).is_ok())
|
||||
Ok(argon2
|
||||
.verify_password(password.as_bytes(), &parsed_hash)
|
||||
.is_ok())
|
||||
}
|
||||
|
||||
/// Generate a random password (for initial admin)
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::{info, debug};
|
||||
use tracing::{debug, info};
|
||||
|
||||
/// Token blacklist for revocation
|
||||
///
|
||||
@@ -41,7 +41,10 @@ impl TokenBlacklist {
|
||||
let was_new = tokens.insert(token.to_string());
|
||||
|
||||
if was_new {
|
||||
debug!("Token revoked and added to blacklist (length: {})", token.len());
|
||||
debug!(
|
||||
"Token revoked and added to blacklist (length: {})",
|
||||
token.len()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -92,7 +95,11 @@ impl TokenBlacklist {
|
||||
let removed = original_len - tokens.len();
|
||||
|
||||
if removed > 0 {
|
||||
info!("Cleaned {} expired tokens from blacklist ({} remaining)", removed, tokens.len());
|
||||
info!(
|
||||
"Cleaned {} expired tokens from blacklist ({} remaining)",
|
||||
removed,
|
||||
tokens.len()
|
||||
);
|
||||
}
|
||||
|
||||
removed
|
||||
|
||||
@@ -36,8 +36,10 @@ impl EventTypes {
|
||||
pub const CONNECTION_REJECTED_NO_AUTH: &'static str = "connection_rejected_no_auth";
|
||||
pub const CONNECTION_REJECTED_INVALID_CODE: &'static str = "connection_rejected_invalid_code";
|
||||
pub const CONNECTION_REJECTED_EXPIRED_CODE: &'static str = "connection_rejected_expired_code";
|
||||
pub const CONNECTION_REJECTED_INVALID_API_KEY: &'static str = "connection_rejected_invalid_api_key";
|
||||
pub const CONNECTION_REJECTED_CANCELLED_CODE: &'static str = "connection_rejected_cancelled_code";
|
||||
pub const CONNECTION_REJECTED_INVALID_API_KEY: &'static str =
|
||||
"connection_rejected_invalid_api_key";
|
||||
pub const CONNECTION_REJECTED_CANCELLED_CODE: &'static str =
|
||||
"connection_rejected_cancelled_code";
|
||||
}
|
||||
|
||||
/// Log a session event
|
||||
|
||||
@@ -80,7 +80,7 @@ pub async fn update_machine_status(
|
||||
/// Get all persistent machines (for restore on startup)
|
||||
pub async fn get_all_machines(pool: &PgPool) -> Result<Vec<Machine>, sqlx::Error> {
|
||||
sqlx::query_as::<_, Machine>(
|
||||
"SELECT * FROM connect_machines WHERE is_persistent = true ORDER BY hostname"
|
||||
"SELECT * FROM connect_machines WHERE is_persistent = true ORDER BY hostname",
|
||||
)
|
||||
.fetch_all(pool)
|
||||
.await
|
||||
@@ -91,20 +91,20 @@ pub async fn get_machine_by_agent_id(
|
||||
pool: &PgPool,
|
||||
agent_id: &str,
|
||||
) -> Result<Option<Machine>, sqlx::Error> {
|
||||
sqlx::query_as::<_, Machine>(
|
||||
"SELECT * FROM connect_machines WHERE agent_id = $1"
|
||||
)
|
||||
.bind(agent_id)
|
||||
.fetch_optional(pool)
|
||||
.await
|
||||
sqlx::query_as::<_, Machine>("SELECT * FROM connect_machines WHERE agent_id = $1")
|
||||
.bind(agent_id)
|
||||
.fetch_optional(pool)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Mark machine as offline
|
||||
pub async fn mark_machine_offline(pool: &PgPool, agent_id: &str) -> Result<(), sqlx::Error> {
|
||||
sqlx::query("UPDATE connect_machines SET status = 'offline', last_seen = NOW() WHERE agent_id = $1")
|
||||
.bind(agent_id)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
sqlx::query(
|
||||
"UPDATE connect_machines SET status = 'offline', last_seen = NOW() WHERE agent_id = $1",
|
||||
)
|
||||
.bind(agent_id)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
@@ -3,24 +3,24 @@
|
||||
//! Handles persistence for machines, sessions, and audit logging.
|
||||
//! Optional - server works without database if DATABASE_URL not set.
|
||||
|
||||
pub mod machines;
|
||||
pub mod sessions;
|
||||
pub mod events;
|
||||
pub mod machines;
|
||||
pub mod releases;
|
||||
pub mod sessions;
|
||||
pub mod support_codes;
|
||||
pub mod users;
|
||||
pub mod releases;
|
||||
|
||||
use anyhow::Result;
|
||||
use sqlx::postgres::PgPoolOptions;
|
||||
use sqlx::PgPool;
|
||||
use tracing::info;
|
||||
|
||||
pub use machines::*;
|
||||
pub use sessions::*;
|
||||
pub use events::*;
|
||||
pub use machines::*;
|
||||
pub use releases::*;
|
||||
pub use sessions::*;
|
||||
pub use support_codes::*;
|
||||
pub use users::*;
|
||||
pub use releases::*;
|
||||
|
||||
/// Database connection pool wrapper
|
||||
#[derive(Clone)]
|
||||
|
||||
@@ -45,7 +45,7 @@ pub async fn create_session(
|
||||
pub async fn end_session(
|
||||
pool: &PgPool,
|
||||
session_id: Uuid,
|
||||
status: &str, // 'ended' or 'disconnected' or 'timeout'
|
||||
status: &str, // 'ended' or 'disconnected' or 'timeout'
|
||||
) -> Result<(), sqlx::Error> {
|
||||
sqlx::query(
|
||||
r#"
|
||||
@@ -64,7 +64,10 @@ pub async fn end_session(
|
||||
}
|
||||
|
||||
/// Get session by ID
|
||||
pub async fn get_session(pool: &PgPool, session_id: Uuid) -> Result<Option<DbSession>, sqlx::Error> {
|
||||
pub async fn get_session(
|
||||
pool: &PgPool,
|
||||
session_id: Uuid,
|
||||
) -> Result<Option<DbSession>, sqlx::Error> {
|
||||
sqlx::query_as::<_, DbSession>("SELECT * FROM connect_sessions WHERE id = $1")
|
||||
.bind(session_id)
|
||||
.fetch_optional(pool)
|
||||
@@ -85,12 +88,9 @@ pub async fn get_active_sessions_for_machine(
|
||||
}
|
||||
|
||||
/// Get recent sessions (for dashboard)
|
||||
pub async fn get_recent_sessions(
|
||||
pool: &PgPool,
|
||||
limit: i64,
|
||||
) -> Result<Vec<DbSession>, sqlx::Error> {
|
||||
pub async fn get_recent_sessions(pool: &PgPool, limit: i64) -> Result<Vec<DbSession>, sqlx::Error> {
|
||||
sqlx::query_as::<_, DbSession>(
|
||||
"SELECT * FROM connect_sessions ORDER BY started_at DESC LIMIT $1"
|
||||
"SELECT * FROM connect_sessions ORDER BY started_at DESC LIMIT $1",
|
||||
)
|
||||
.bind(limit)
|
||||
.fetch_all(pool)
|
||||
@@ -103,7 +103,7 @@ pub async fn get_sessions_for_machine(
|
||||
machine_id: Uuid,
|
||||
) -> Result<Vec<DbSession>, sqlx::Error> {
|
||||
sqlx::query_as::<_, DbSession>(
|
||||
"SELECT * FROM connect_sessions WHERE machine_id = $1 ORDER BY started_at DESC"
|
||||
"SELECT * FROM connect_sessions WHERE machine_id = $1 ORDER BY started_at DESC",
|
||||
)
|
||||
.bind(machine_id)
|
||||
.fetch_all(pool)
|
||||
|
||||
@@ -40,13 +40,14 @@ pub async fn create_support_code(
|
||||
}
|
||||
|
||||
/// Get support code by code string
|
||||
pub async fn get_support_code(pool: &PgPool, code: &str) -> Result<Option<DbSupportCode>, sqlx::Error> {
|
||||
sqlx::query_as::<_, DbSupportCode>(
|
||||
"SELECT * FROM connect_support_codes WHERE code = $1"
|
||||
)
|
||||
.bind(code)
|
||||
.fetch_optional(pool)
|
||||
.await
|
||||
pub async fn get_support_code(
|
||||
pool: &PgPool,
|
||||
code: &str,
|
||||
) -> Result<Option<DbSupportCode>, sqlx::Error> {
|
||||
sqlx::query_as::<_, DbSupportCode>("SELECT * FROM connect_support_codes WHERE code = $1")
|
||||
.bind(code)
|
||||
.fetch_optional(pool)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Update support code when client connects
|
||||
@@ -107,7 +108,7 @@ pub async fn get_active_support_codes(pool: &PgPool) -> Result<Vec<DbSupportCode
|
||||
/// Check if code exists and is valid for connection
|
||||
pub async fn is_code_valid(pool: &PgPool, code: &str) -> Result<bool, sqlx::Error> {
|
||||
let result = sqlx::query_scalar::<_, bool>(
|
||||
"SELECT EXISTS(SELECT 1 FROM connect_support_codes WHERE code = $1 AND status = 'pending')"
|
||||
"SELECT EXISTS(SELECT 1 FROM connect_support_codes WHERE code = $1 AND status = 'pending')",
|
||||
)
|
||||
.bind(code)
|
||||
.fetch_one(pool)
|
||||
|
||||
@@ -49,33 +49,27 @@ impl From<User> for UserInfo {
|
||||
|
||||
/// Get user by username
|
||||
pub async fn get_user_by_username(pool: &PgPool, username: &str) -> Result<Option<User>> {
|
||||
let user = sqlx::query_as::<_, User>(
|
||||
"SELECT * FROM users WHERE username = $1"
|
||||
)
|
||||
.bind(username)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
let user = sqlx::query_as::<_, User>("SELECT * FROM users WHERE username = $1")
|
||||
.bind(username)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
Ok(user)
|
||||
}
|
||||
|
||||
/// Get user by ID
|
||||
pub async fn get_user_by_id(pool: &PgPool, id: Uuid) -> Result<Option<User>> {
|
||||
let user = sqlx::query_as::<_, User>(
|
||||
"SELECT * FROM users WHERE id = $1"
|
||||
)
|
||||
.bind(id)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
let user = sqlx::query_as::<_, User>("SELECT * FROM users WHERE id = $1")
|
||||
.bind(id)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
Ok(user)
|
||||
}
|
||||
|
||||
/// Get all users
|
||||
pub async fn get_all_users(pool: &PgPool) -> Result<Vec<User>> {
|
||||
let users = sqlx::query_as::<_, User>(
|
||||
"SELECT * FROM users ORDER BY username"
|
||||
)
|
||||
.fetch_all(pool)
|
||||
.await?;
|
||||
let users = sqlx::query_as::<_, User>("SELECT * FROM users ORDER BY username")
|
||||
.fetch_all(pool)
|
||||
.await?;
|
||||
Ok(users)
|
||||
}
|
||||
|
||||
@@ -92,7 +86,7 @@ pub async fn create_user(
|
||||
INSERT INTO users (username, password_hash, email, role)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
RETURNING *
|
||||
"#
|
||||
"#,
|
||||
)
|
||||
.bind(username)
|
||||
.bind(password_hash)
|
||||
@@ -117,7 +111,7 @@ pub async fn update_user(
|
||||
SET email = $2, role = $3, enabled = $4, updated_at = NOW()
|
||||
WHERE id = $1
|
||||
RETURNING *
|
||||
"#
|
||||
"#,
|
||||
)
|
||||
.bind(id)
|
||||
.bind(email)
|
||||
@@ -129,18 +123,13 @@ pub async fn update_user(
|
||||
}
|
||||
|
||||
/// Update user password
|
||||
pub async fn update_user_password(
|
||||
pool: &PgPool,
|
||||
id: Uuid,
|
||||
password_hash: &str,
|
||||
) -> Result<bool> {
|
||||
let result = sqlx::query(
|
||||
"UPDATE users SET password_hash = $2, updated_at = NOW() WHERE id = $1"
|
||||
)
|
||||
.bind(id)
|
||||
.bind(password_hash)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
pub async fn update_user_password(pool: &PgPool, id: Uuid, password_hash: &str) -> Result<bool> {
|
||||
let result =
|
||||
sqlx::query("UPDATE users SET password_hash = $2, updated_at = NOW() WHERE id = $1")
|
||||
.bind(id)
|
||||
.bind(password_hash)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
Ok(result.rows_affected() > 0)
|
||||
}
|
||||
|
||||
@@ -172,12 +161,11 @@ pub async fn count_users(pool: &PgPool) -> Result<i64> {
|
||||
|
||||
/// Get user permissions
|
||||
pub async fn get_user_permissions(pool: &PgPool, user_id: Uuid) -> Result<Vec<String>> {
|
||||
let perms: Vec<(String,)> = sqlx::query_as(
|
||||
"SELECT permission FROM user_permissions WHERE user_id = $1"
|
||||
)
|
||||
.bind(user_id)
|
||||
.fetch_all(pool)
|
||||
.await?;
|
||||
let perms: Vec<(String,)> =
|
||||
sqlx::query_as("SELECT permission FROM user_permissions WHERE user_id = $1")
|
||||
.bind(user_id)
|
||||
.fetch_all(pool)
|
||||
.await?;
|
||||
Ok(perms.into_iter().map(|p| p.0).collect())
|
||||
}
|
||||
|
||||
@@ -195,25 +183,22 @@ pub async fn set_user_permissions(
|
||||
|
||||
// Insert new
|
||||
for perm in permissions {
|
||||
sqlx::query(
|
||||
"INSERT INTO user_permissions (user_id, permission) VALUES ($1, $2)"
|
||||
)
|
||||
.bind(user_id)
|
||||
.bind(perm)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
sqlx::query("INSERT INTO user_permissions (user_id, permission) VALUES ($1, $2)")
|
||||
.bind(user_id)
|
||||
.bind(perm)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get user's accessible client IDs (empty = all access)
|
||||
pub async fn get_user_client_access(pool: &PgPool, user_id: Uuid) -> Result<Vec<Uuid>> {
|
||||
let clients: Vec<(Uuid,)> = sqlx::query_as(
|
||||
"SELECT client_id FROM user_client_access WHERE user_id = $1"
|
||||
)
|
||||
.bind(user_id)
|
||||
.fetch_all(pool)
|
||||
.await?;
|
||||
let clients: Vec<(Uuid,)> =
|
||||
sqlx::query_as("SELECT client_id FROM user_client_access WHERE user_id = $1")
|
||||
.bind(user_id)
|
||||
.fetch_all(pool)
|
||||
.await?;
|
||||
Ok(clients.into_iter().map(|c| c.0).collect())
|
||||
}
|
||||
|
||||
@@ -231,23 +216,17 @@ pub async fn set_user_client_access(
|
||||
|
||||
// Insert new
|
||||
for client_id in client_ids {
|
||||
sqlx::query(
|
||||
"INSERT INTO user_client_access (user_id, client_id) VALUES ($1, $2)"
|
||||
)
|
||||
.bind(user_id)
|
||||
.bind(client_id)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
sqlx::query("INSERT INTO user_client_access (user_id, client_id) VALUES ($1, $2)")
|
||||
.bind(user_id)
|
||||
.bind(client_id)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if user has access to a specific client
|
||||
pub async fn user_has_client_access(
|
||||
pool: &PgPool,
|
||||
user_id: Uuid,
|
||||
client_id: Uuid,
|
||||
) -> Result<bool> {
|
||||
pub async fn user_has_client_access(pool: &PgPool, user_id: Uuid, client_id: Uuid) -> Result<bool> {
|
||||
// Admins have access to all
|
||||
let user = get_user_by_id(pool, user_id).await?;
|
||||
if let Some(u) = user {
|
||||
@@ -258,7 +237,7 @@ pub async fn user_has_client_access(
|
||||
|
||||
// Check explicit access
|
||||
let access: Option<(Uuid,)> = sqlx::query_as(
|
||||
"SELECT client_id FROM user_client_access WHERE user_id = $1 AND client_id = $2"
|
||||
"SELECT client_id FROM user_client_access WHERE user_id = $1 AND client_id = $2",
|
||||
)
|
||||
.bind(user_id)
|
||||
.bind(client_id)
|
||||
@@ -271,12 +250,11 @@ pub async fn user_has_client_access(
|
||||
}
|
||||
|
||||
// Check if user has ANY access restrictions
|
||||
let count: (i64,) = sqlx::query_as(
|
||||
"SELECT COUNT(*) FROM user_client_access WHERE user_id = $1"
|
||||
)
|
||||
.bind(user_id)
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
let count: (i64,) =
|
||||
sqlx::query_as("SELECT COUNT(*) FROM user_client_access WHERE user_id = $1")
|
||||
.bind(user_id)
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
|
||||
// No restrictions means access to all
|
||||
Ok(count.0 == 0)
|
||||
|
||||
@@ -3,44 +3,44 @@
|
||||
//! Handles connections from both agents and dashboard viewers,
|
||||
//! relaying video frames and input events between them.
|
||||
|
||||
mod api;
|
||||
mod auth;
|
||||
mod config;
|
||||
mod db;
|
||||
mod metrics;
|
||||
mod middleware;
|
||||
mod relay;
|
||||
mod session;
|
||||
mod auth;
|
||||
mod api;
|
||||
mod db;
|
||||
mod support_codes;
|
||||
mod middleware;
|
||||
mod utils;
|
||||
mod metrics;
|
||||
|
||||
pub mod proto {
|
||||
include!(concat!(env!("OUT_DIR"), "/guruconnect.rs"));
|
||||
}
|
||||
|
||||
use anyhow::Result;
|
||||
use axum::http::{HeaderValue, Method};
|
||||
use axum::{
|
||||
Router,
|
||||
routing::{get, post, put, delete},
|
||||
extract::{Path, State, Json, Query, Request},
|
||||
response::{Html, IntoResponse},
|
||||
extract::{Json, Path, Query, Request, State},
|
||||
http::StatusCode,
|
||||
middleware::{self as axum_middleware, Next},
|
||||
response::{Html, IntoResponse},
|
||||
routing::{delete, get, post, put},
|
||||
Router,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use tower_http::cors::{Any, CorsLayer, AllowOrigin};
|
||||
use axum::http::{Method, HeaderValue};
|
||||
use tower_http::trace::TraceLayer;
|
||||
use tower_http::cors::{AllowOrigin, Any, CorsLayer};
|
||||
use tower_http::services::ServeDir;
|
||||
use tower_http::trace::TraceLayer;
|
||||
use tracing::{info, Level};
|
||||
use tracing_subscriber::FmtSubscriber;
|
||||
use serde::Deserialize;
|
||||
|
||||
use support_codes::{SupportCodeManager, CreateCodeRequest, SupportCode, CodeValidation};
|
||||
use auth::{JwtConfig, TokenBlacklist, hash_password, generate_random_password, AuthenticatedUser};
|
||||
use auth::{generate_random_password, hash_password, AuthenticatedUser, JwtConfig, TokenBlacklist};
|
||||
use metrics::SharedMetrics;
|
||||
use prometheus_client::registry::Registry;
|
||||
use support_codes::{CodeValidation, CreateCodeRequest, SupportCode, SupportCodeManager};
|
||||
|
||||
/// Application state
|
||||
#[derive(Clone)]
|
||||
@@ -67,7 +67,9 @@ async fn auth_layer(
|
||||
next: Next,
|
||||
) -> impl IntoResponse {
|
||||
request.extensions_mut().insert(state.jwt_config.clone());
|
||||
request.extensions_mut().insert(Arc::new(state.token_blacklist.clone()));
|
||||
request
|
||||
.extensions_mut()
|
||||
.insert(Arc::new(state.token_blacklist.clone()));
|
||||
next.run(request).await
|
||||
}
|
||||
|
||||
@@ -89,8 +91,9 @@ async fn main() -> Result<()> {
|
||||
info!("Loaded configuration, listening on {}", listen_addr);
|
||||
|
||||
// JWT configuration - REQUIRED for security
|
||||
let jwt_secret = std::env::var("JWT_SECRET")
|
||||
.expect("JWT_SECRET environment variable must be set! Generate one with: openssl rand -base64 64");
|
||||
let jwt_secret = std::env::var("JWT_SECRET").expect(
|
||||
"JWT_SECRET environment variable must be set! Generate one with: openssl rand -base64 64",
|
||||
);
|
||||
|
||||
if jwt_secret.len() < 32 {
|
||||
panic!("JWT_SECRET must be at least 32 characters long for security!");
|
||||
@@ -114,7 +117,10 @@ async fn main() -> Result<()> {
|
||||
Some(db)
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to connect to database: {}. Running without persistence.", e);
|
||||
tracing::warn!(
|
||||
"Failed to connect to database: {}. Running without persistence.",
|
||||
e
|
||||
);
|
||||
None
|
||||
}
|
||||
}
|
||||
@@ -194,9 +200,14 @@ async fn main() -> Result<()> {
|
||||
if let Some(ref db) = database {
|
||||
match db::machines::get_all_machines(db.pool()).await {
|
||||
Ok(machines) => {
|
||||
info!("Restoring {} persistent machines from database", machines.len());
|
||||
info!(
|
||||
"Restoring {} persistent machines from database",
|
||||
machines.len()
|
||||
);
|
||||
for machine in machines {
|
||||
sessions.restore_offline_machine(&machine.agent_id, &machine.hostname).await;
|
||||
sessions
|
||||
.restore_offline_machine(&machine.agent_id, &machine.hostname)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
@@ -254,92 +265,117 @@ async fn main() -> Result<()> {
|
||||
.route("/health", get(health))
|
||||
// Prometheus metrics (no auth required - for monitoring)
|
||||
.route("/metrics", get(prometheus_metrics))
|
||||
|
||||
// Auth endpoints (TODO: Add rate limiting - see SEC2_RATE_LIMITING_TODO.md)
|
||||
.route("/api/auth/login", post(api::auth::login))
|
||||
.route("/api/auth/change-password", post(api::auth::change_password))
|
||||
.route(
|
||||
"/api/auth/change-password",
|
||||
post(api::auth::change_password),
|
||||
)
|
||||
.route("/api/auth/me", get(api::auth::get_me))
|
||||
.route("/api/auth/logout", post(api::auth_logout::logout))
|
||||
.route("/api/auth/revoke-token", post(api::auth_logout::revoke_own_token))
|
||||
.route("/api/auth/admin/revoke-user", post(api::auth_logout::revoke_user_tokens))
|
||||
.route("/api/auth/blacklist/stats", get(api::auth_logout::get_blacklist_stats))
|
||||
.route("/api/auth/blacklist/cleanup", post(api::auth_logout::cleanup_blacklist))
|
||||
|
||||
.route(
|
||||
"/api/auth/revoke-token",
|
||||
post(api::auth_logout::revoke_own_token),
|
||||
)
|
||||
.route(
|
||||
"/api/auth/admin/revoke-user",
|
||||
post(api::auth_logout::revoke_user_tokens),
|
||||
)
|
||||
.route(
|
||||
"/api/auth/blacklist/stats",
|
||||
get(api::auth_logout::get_blacklist_stats),
|
||||
)
|
||||
.route(
|
||||
"/api/auth/blacklist/cleanup",
|
||||
post(api::auth_logout::cleanup_blacklist),
|
||||
)
|
||||
// User management (admin only)
|
||||
.route("/api/users", get(api::users::list_users))
|
||||
.route("/api/users", post(api::users::create_user))
|
||||
.route("/api/users/:id", get(api::users::get_user))
|
||||
.route("/api/users/:id", put(api::users::update_user))
|
||||
.route("/api/users/:id", delete(api::users::delete_user))
|
||||
.route("/api/users/:id/permissions", put(api::users::set_permissions))
|
||||
.route(
|
||||
"/api/users/:id/permissions",
|
||||
put(api::users::set_permissions),
|
||||
)
|
||||
.route("/api/users/:id/clients", put(api::users::set_client_access))
|
||||
|
||||
// Portal API - Support codes (TODO: Add rate limiting)
|
||||
.route("/api/codes", post(create_code))
|
||||
.route("/api/codes", get(list_codes))
|
||||
.route("/api/codes/:code/validate", get(validate_code))
|
||||
.route("/api/codes/:code/cancel", post(cancel_code))
|
||||
|
||||
// WebSocket endpoints
|
||||
.route("/ws/agent", get(relay::agent_ws_handler))
|
||||
.route("/ws/viewer", get(relay::viewer_ws_handler))
|
||||
|
||||
// REST API - Sessions
|
||||
.route("/api/sessions", get(list_sessions))
|
||||
.route("/api/sessions/:id", get(get_session))
|
||||
.route("/api/sessions/:id", delete(disconnect_session))
|
||||
|
||||
// REST API - Machines
|
||||
.route("/api/machines", get(list_machines))
|
||||
.route("/api/machines/:agent_id", get(get_machine))
|
||||
.route("/api/machines/:agent_id", delete(delete_machine))
|
||||
.route("/api/machines/:agent_id/history", get(get_machine_history))
|
||||
.route("/api/machines/:agent_id/update", post(trigger_machine_update))
|
||||
|
||||
.route(
|
||||
"/api/machines/:agent_id/update",
|
||||
post(trigger_machine_update),
|
||||
)
|
||||
// REST API - Releases and Version
|
||||
.route("/api/version", get(api::releases::get_version)) // No auth - for agent polling
|
||||
.route("/api/version", get(api::releases::get_version)) // No auth - for agent polling
|
||||
.route("/api/releases", get(api::releases::list_releases))
|
||||
.route("/api/releases", post(api::releases::create_release))
|
||||
.route("/api/releases/:version", get(api::releases::get_release))
|
||||
.route("/api/releases/:version", put(api::releases::update_release))
|
||||
.route("/api/releases/:version", delete(api::releases::delete_release))
|
||||
|
||||
.route(
|
||||
"/api/releases/:version",
|
||||
delete(api::releases::delete_release),
|
||||
)
|
||||
// Changelog (no auth - public, like /api/version)
|
||||
// Single route: version == "latest" selects the latest file; axum 0.7 / matchit 0.7
|
||||
// panics if a static segment and a path param share this position, so do not split it.
|
||||
.route("/api/changelog/:component/:version", get(api::changelog::get))
|
||||
|
||||
.route(
|
||||
"/api/changelog/:component/:version",
|
||||
get(api::changelog::get),
|
||||
)
|
||||
// Agent downloads (no auth - public download links)
|
||||
.route("/api/download/viewer", get(api::downloads::download_viewer))
|
||||
.route("/api/download/support", get(api::downloads::download_support))
|
||||
.route(
|
||||
"/api/download/support",
|
||||
get(api::downloads::download_support),
|
||||
)
|
||||
.route("/api/download/agent", get(api::downloads::download_agent))
|
||||
|
||||
// HTML page routes (clean URLs)
|
||||
.route("/login", get(serve_login))
|
||||
.route("/dashboard", get(serve_dashboard))
|
||||
.route("/users", get(serve_users))
|
||||
|
||||
// State and middleware
|
||||
.with_state(state.clone())
|
||||
.layer(axum_middleware::from_fn_with_state(state, auth_layer))
|
||||
|
||||
// Serve static files for portal (fallback)
|
||||
.fallback_service(ServeDir::new("static").append_index_html_on_directories(true))
|
||||
|
||||
// Middleware
|
||||
.layer(axum_middleware::from_fn(middleware::add_security_headers)) // SEC-7 & SEC-12
|
||||
.layer(axum_middleware::from_fn(middleware::add_security_headers)) // SEC-7 & SEC-12
|
||||
.layer(TraceLayer::new_for_http())
|
||||
// SEC-11: Restricted CORS configuration
|
||||
.layer({
|
||||
let cors = CorsLayer::new()
|
||||
// Allow requests from the production domain and localhost (for development)
|
||||
.allow_origin([
|
||||
"https://connect.azcomputerguru.com".parse::<HeaderValue>().unwrap(),
|
||||
"https://connect.azcomputerguru.com"
|
||||
.parse::<HeaderValue>()
|
||||
.unwrap(),
|
||||
"http://localhost:3002".parse::<HeaderValue>().unwrap(),
|
||||
"http://127.0.0.1:3002".parse::<HeaderValue>().unwrap(),
|
||||
])
|
||||
// Allow only necessary HTTP methods
|
||||
.allow_methods([Method::GET, Method::POST, Method::PUT, Method::DELETE, Method::OPTIONS])
|
||||
.allow_methods([
|
||||
Method::GET,
|
||||
Method::POST,
|
||||
Method::PUT,
|
||||
Method::DELETE,
|
||||
Method::OPTIONS,
|
||||
])
|
||||
// Allow common headers needed for API requests
|
||||
.allow_headers([
|
||||
axum::http::header::AUTHORIZATION,
|
||||
@@ -360,8 +396,9 @@ async fn main() -> Result<()> {
|
||||
// Use into_make_service_with_connect_info to enable IP address extraction
|
||||
axum::serve(
|
||||
listener,
|
||||
app.into_make_service_with_connect_info::<SocketAddr>()
|
||||
).await?;
|
||||
app.into_make_service_with_connect_info::<SocketAddr>(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -371,9 +408,7 @@ async fn health() -> &'static str {
|
||||
}
|
||||
|
||||
/// Prometheus metrics endpoint
|
||||
async fn prometheus_metrics(
|
||||
State(state): State<AppState>,
|
||||
) -> String {
|
||||
async fn prometheus_metrics(State(state): State<AppState>) -> String {
|
||||
use prometheus_client::encoding::text::encode;
|
||||
|
||||
let registry = state.registry.lock().unwrap();
|
||||
@@ -385,7 +420,7 @@ async fn prometheus_metrics(
|
||||
// Support code API handlers
|
||||
|
||||
async fn create_code(
|
||||
_user: AuthenticatedUser, // Require authentication
|
||||
_user: AuthenticatedUser, // Require authentication
|
||||
State(state): State<AppState>,
|
||||
Json(request): Json<CreateCodeRequest>,
|
||||
) -> Json<SupportCode> {
|
||||
@@ -395,7 +430,7 @@ async fn create_code(
|
||||
}
|
||||
|
||||
async fn list_codes(
|
||||
_user: AuthenticatedUser, // Require authentication
|
||||
_user: AuthenticatedUser, // Require authentication
|
||||
State(state): State<AppState>,
|
||||
) -> Json<Vec<SupportCode>> {
|
||||
Json(state.support_codes.list_active_codes().await)
|
||||
@@ -414,7 +449,7 @@ async fn validate_code(
|
||||
}
|
||||
|
||||
async fn cancel_code(
|
||||
_user: AuthenticatedUser, // Require authentication
|
||||
_user: AuthenticatedUser, // Require authentication
|
||||
State(state): State<AppState>,
|
||||
Path(code): Path<String>,
|
||||
) -> impl IntoResponse {
|
||||
@@ -428,7 +463,7 @@ async fn cancel_code(
|
||||
// Session API handlers (updated to use AppState)
|
||||
|
||||
async fn list_sessions(
|
||||
_user: AuthenticatedUser, // Require authentication
|
||||
_user: AuthenticatedUser, // Require authentication
|
||||
State(state): State<AppState>,
|
||||
) -> Json<Vec<api::SessionInfo>> {
|
||||
let sessions = state.sessions.list_sessions().await;
|
||||
@@ -436,21 +471,24 @@ async fn list_sessions(
|
||||
}
|
||||
|
||||
async fn get_session(
|
||||
_user: AuthenticatedUser, // Require authentication
|
||||
_user: AuthenticatedUser, // Require authentication
|
||||
State(state): State<AppState>,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<Json<api::SessionInfo>, (StatusCode, &'static str)> {
|
||||
let session_id = uuid::Uuid::parse_str(&id)
|
||||
.map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID"))?;
|
||||
let session_id =
|
||||
uuid::Uuid::parse_str(&id).map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID"))?;
|
||||
|
||||
let session = state.sessions.get_session(session_id).await
|
||||
let session = state
|
||||
.sessions
|
||||
.get_session(session_id)
|
||||
.await
|
||||
.ok_or((StatusCode::NOT_FOUND, "Session not found"))?;
|
||||
|
||||
Ok(Json(api::SessionInfo::from(session)))
|
||||
}
|
||||
|
||||
async fn disconnect_session(
|
||||
_user: AuthenticatedUser, // Require authentication
|
||||
_user: AuthenticatedUser, // Require authentication
|
||||
State(state): State<AppState>,
|
||||
Path(id): Path<String>,
|
||||
) -> impl IntoResponse {
|
||||
@@ -459,7 +497,11 @@ async fn disconnect_session(
|
||||
Err(_) => return (StatusCode::BAD_REQUEST, "Invalid session ID"),
|
||||
};
|
||||
|
||||
if state.sessions.disconnect_session(session_id, "Disconnected by administrator").await {
|
||||
if state
|
||||
.sessions
|
||||
.disconnect_session(session_id, "Disconnected by administrator")
|
||||
.await
|
||||
{
|
||||
info!("Session {} disconnected by admin", session_id);
|
||||
(StatusCode::OK, "Session disconnected")
|
||||
} else {
|
||||
@@ -470,27 +512,35 @@ async fn disconnect_session(
|
||||
// Machine API handlers
|
||||
|
||||
async fn list_machines(
|
||||
_user: AuthenticatedUser, // Require authentication
|
||||
_user: AuthenticatedUser, // Require authentication
|
||||
State(state): State<AppState>,
|
||||
) -> Result<Json<Vec<api::MachineInfo>>, (StatusCode, &'static str)> {
|
||||
let db = state.db.as_ref()
|
||||
let db = state
|
||||
.db
|
||||
.as_ref()
|
||||
.ok_or((StatusCode::SERVICE_UNAVAILABLE, "Database not available"))?;
|
||||
|
||||
let machines = db::machines::get_all_machines(db.pool()).await
|
||||
let machines = db::machines::get_all_machines(db.pool())
|
||||
.await
|
||||
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Database error"))?;
|
||||
|
||||
Ok(Json(machines.into_iter().map(api::MachineInfo::from).collect()))
|
||||
Ok(Json(
|
||||
machines.into_iter().map(api::MachineInfo::from).collect(),
|
||||
))
|
||||
}
|
||||
|
||||
async fn get_machine(
|
||||
_user: AuthenticatedUser, // Require authentication
|
||||
_user: AuthenticatedUser, // Require authentication
|
||||
State(state): State<AppState>,
|
||||
Path(agent_id): Path<String>,
|
||||
) -> Result<Json<api::MachineInfo>, (StatusCode, &'static str)> {
|
||||
let db = state.db.as_ref()
|
||||
let db = state
|
||||
.db
|
||||
.as_ref()
|
||||
.ok_or((StatusCode::SERVICE_UNAVAILABLE, "Database not available"))?;
|
||||
|
||||
let machine = db::machines::get_machine_by_agent_id(db.pool(), &agent_id).await
|
||||
let machine = db::machines::get_machine_by_agent_id(db.pool(), &agent_id)
|
||||
.await
|
||||
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Database error"))?
|
||||
.ok_or((StatusCode::NOT_FOUND, "Machine not found"))?;
|
||||
|
||||
@@ -498,24 +548,29 @@ async fn get_machine(
|
||||
}
|
||||
|
||||
async fn get_machine_history(
|
||||
_user: AuthenticatedUser, // Require authentication
|
||||
_user: AuthenticatedUser, // Require authentication
|
||||
State(state): State<AppState>,
|
||||
Path(agent_id): Path<String>,
|
||||
) -> Result<Json<api::MachineHistory>, (StatusCode, &'static str)> {
|
||||
let db = state.db.as_ref()
|
||||
let db = state
|
||||
.db
|
||||
.as_ref()
|
||||
.ok_or((StatusCode::SERVICE_UNAVAILABLE, "Database not available"))?;
|
||||
|
||||
// Get machine
|
||||
let machine = db::machines::get_machine_by_agent_id(db.pool(), &agent_id).await
|
||||
let machine = db::machines::get_machine_by_agent_id(db.pool(), &agent_id)
|
||||
.await
|
||||
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Database error"))?
|
||||
.ok_or((StatusCode::NOT_FOUND, "Machine not found"))?;
|
||||
|
||||
// Get sessions for this machine
|
||||
let sessions = db::sessions::get_sessions_for_machine(db.pool(), machine.id).await
|
||||
let sessions = db::sessions::get_sessions_for_machine(db.pool(), machine.id)
|
||||
.await
|
||||
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Database error"))?;
|
||||
|
||||
// Get events for this machine
|
||||
let events = db::events::get_events_for_machine(db.pool(), machine.id).await
|
||||
let events = db::events::get_events_for_machine(db.pool(), machine.id)
|
||||
.await
|
||||
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Database error"))?;
|
||||
|
||||
let history = api::MachineHistory {
|
||||
@@ -529,24 +584,29 @@ async fn get_machine_history(
|
||||
}
|
||||
|
||||
async fn delete_machine(
|
||||
_user: AuthenticatedUser, // Require authentication
|
||||
_user: AuthenticatedUser, // Require authentication
|
||||
State(state): State<AppState>,
|
||||
Path(agent_id): Path<String>,
|
||||
Query(params): Query<api::DeleteMachineParams>,
|
||||
) -> Result<Json<api::DeleteMachineResponse>, (StatusCode, &'static str)> {
|
||||
let db = state.db.as_ref()
|
||||
let db = state
|
||||
.db
|
||||
.as_ref()
|
||||
.ok_or((StatusCode::SERVICE_UNAVAILABLE, "Database not available"))?;
|
||||
|
||||
// Get machine first
|
||||
let machine = db::machines::get_machine_by_agent_id(db.pool(), &agent_id).await
|
||||
let machine = db::machines::get_machine_by_agent_id(db.pool(), &agent_id)
|
||||
.await
|
||||
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Database error"))?
|
||||
.ok_or((StatusCode::NOT_FOUND, "Machine not found"))?;
|
||||
|
||||
// Export history if requested
|
||||
let history = if params.export {
|
||||
let sessions = db::sessions::get_sessions_for_machine(db.pool(), machine.id).await
|
||||
let sessions = db::sessions::get_sessions_for_machine(db.pool(), machine.id)
|
||||
.await
|
||||
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Database error"))?;
|
||||
let events = db::events::get_events_for_machine(db.pool(), machine.id).await
|
||||
let events = db::events::get_events_for_machine(db.pool(), machine.id)
|
||||
.await
|
||||
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Database error"))?;
|
||||
|
||||
Some(api::MachineHistory {
|
||||
@@ -565,11 +625,14 @@ async fn delete_machine(
|
||||
// Find session for this agent
|
||||
if let Some(session) = state.sessions.get_session_by_agent(&agent_id).await {
|
||||
if session.is_online {
|
||||
uninstall_sent = state.sessions.send_admin_command(
|
||||
session.id,
|
||||
proto::AdminCommandType::AdminUninstall,
|
||||
"Deleted by administrator",
|
||||
).await;
|
||||
uninstall_sent = state
|
||||
.sessions
|
||||
.send_admin_command(
|
||||
session.id,
|
||||
proto::AdminCommandType::AdminUninstall,
|
||||
"Deleted by administrator",
|
||||
)
|
||||
.await;
|
||||
if uninstall_sent {
|
||||
info!("Sent uninstall command to agent {}", agent_id);
|
||||
}
|
||||
@@ -581,10 +644,19 @@ async fn delete_machine(
|
||||
state.sessions.remove_agent(&agent_id).await;
|
||||
|
||||
// Delete from database (cascades to sessions and events)
|
||||
db::machines::delete_machine(db.pool(), &agent_id).await
|
||||
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Failed to delete machine"))?;
|
||||
db::machines::delete_machine(db.pool(), &agent_id)
|
||||
.await
|
||||
.map_err(|_| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"Failed to delete machine",
|
||||
)
|
||||
})?;
|
||||
|
||||
info!("Deleted machine {} (uninstall_sent: {})", agent_id, uninstall_sent);
|
||||
info!(
|
||||
"Deleted machine {} (uninstall_sent: {})",
|
||||
agent_id, uninstall_sent
|
||||
);
|
||||
|
||||
Ok(Json(api::DeleteMachineResponse {
|
||||
success: true,
|
||||
@@ -603,27 +675,34 @@ struct TriggerUpdateRequest {
|
||||
|
||||
/// Trigger update on a specific machine
|
||||
async fn trigger_machine_update(
|
||||
_user: AuthenticatedUser, // Require authentication
|
||||
_user: AuthenticatedUser, // Require authentication
|
||||
State(state): State<AppState>,
|
||||
Path(agent_id): Path<String>,
|
||||
Json(request): Json<TriggerUpdateRequest>,
|
||||
) -> Result<impl IntoResponse, (StatusCode, &'static str)> {
|
||||
let db = state.db.as_ref()
|
||||
let db = state
|
||||
.db
|
||||
.as_ref()
|
||||
.ok_or((StatusCode::SERVICE_UNAVAILABLE, "Database not available"))?;
|
||||
|
||||
// Get the target release (either specified or latest stable)
|
||||
let release = if let Some(version) = request.version {
|
||||
db::releases::get_release_by_version(db.pool(), &version).await
|
||||
db::releases::get_release_by_version(db.pool(), &version)
|
||||
.await
|
||||
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Database error"))?
|
||||
.ok_or((StatusCode::NOT_FOUND, "Release version not found"))?
|
||||
} else {
|
||||
db::releases::get_latest_stable_release(db.pool()).await
|
||||
db::releases::get_latest_stable_release(db.pool())
|
||||
.await
|
||||
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Database error"))?
|
||||
.ok_or((StatusCode::NOT_FOUND, "No stable release available"))?
|
||||
};
|
||||
|
||||
// Find session for this agent
|
||||
let session = state.sessions.get_session_by_agent(&agent_id).await
|
||||
let session = state
|
||||
.sessions
|
||||
.get_session_by_agent(&agent_id)
|
||||
.await
|
||||
.ok_or((StatusCode::NOT_FOUND, "Agent not found or offline"))?;
|
||||
|
||||
if !session.is_online {
|
||||
@@ -632,21 +711,31 @@ async fn trigger_machine_update(
|
||||
|
||||
// Send update command via WebSocket
|
||||
// For now, we send admin command - later we'll include UpdateInfo in the message
|
||||
let sent = state.sessions.send_admin_command(
|
||||
session.id,
|
||||
proto::AdminCommandType::AdminUpdate,
|
||||
&format!("Update to version {}", release.version),
|
||||
).await;
|
||||
let sent = state
|
||||
.sessions
|
||||
.send_admin_command(
|
||||
session.id,
|
||||
proto::AdminCommandType::AdminUpdate,
|
||||
&format!("Update to version {}", release.version),
|
||||
)
|
||||
.await;
|
||||
|
||||
if sent {
|
||||
info!("Sent update command to agent {} (version {})", agent_id, release.version);
|
||||
info!(
|
||||
"Sent update command to agent {} (version {})",
|
||||
agent_id, release.version
|
||||
);
|
||||
|
||||
// Update machine update status in database
|
||||
let _ = db::releases::update_machine_update_status(db.pool(), &agent_id, "downloading").await;
|
||||
let _ =
|
||||
db::releases::update_machine_update_status(db.pool(), &agent_id, "downloading").await;
|
||||
|
||||
Ok((StatusCode::OK, "Update command sent"))
|
||||
} else {
|
||||
Err((StatusCode::INTERNAL_SERVER_ERROR, "Failed to send update command"))
|
||||
Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"Failed to send update command",
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -22,26 +22,26 @@ pub struct RequestLabels {
|
||||
/// Metrics labels for session events
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelSet)]
|
||||
pub struct SessionLabels {
|
||||
pub status: String, // created, closed, failed, expired
|
||||
pub status: String, // created, closed, failed, expired
|
||||
}
|
||||
|
||||
/// Metrics labels for connection events
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelSet)]
|
||||
pub struct ConnectionLabels {
|
||||
pub conn_type: String, // agent, viewer, dashboard
|
||||
pub conn_type: String, // agent, viewer, dashboard
|
||||
}
|
||||
|
||||
/// Metrics labels for error tracking
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelSet)]
|
||||
pub struct ErrorLabels {
|
||||
pub error_type: String, // auth, database, websocket, protocol, internal
|
||||
pub error_type: String, // auth, database, websocket, protocol, internal
|
||||
}
|
||||
|
||||
/// Metrics labels for database operations
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelSet)]
|
||||
pub struct DatabaseLabels {
|
||||
pub operation: String, // select, insert, update, delete
|
||||
pub status: String, // success, error
|
||||
pub operation: String, // select, insert, update, delete
|
||||
pub status: String, // success, error
|
||||
}
|
||||
|
||||
/// GuruConnect server metrics
|
||||
@@ -82,9 +82,10 @@ impl Metrics {
|
||||
requests_total.clone(),
|
||||
);
|
||||
|
||||
let request_duration_seconds = Family::<RequestLabels, Histogram>::new_with_constructor(|| {
|
||||
Histogram::new(exponential_buckets(0.001, 2.0, 10)) // 1ms to ~1s
|
||||
});
|
||||
let request_duration_seconds =
|
||||
Family::<RequestLabels, Histogram>::new_with_constructor(|| {
|
||||
Histogram::new(exponential_buckets(0.001, 2.0, 10)) // 1ms to ~1s
|
||||
});
|
||||
registry.register(
|
||||
"guruconnect_request_duration_seconds",
|
||||
"HTTP request duration in seconds",
|
||||
@@ -106,7 +107,7 @@ impl Metrics {
|
||||
active_sessions.clone(),
|
||||
);
|
||||
|
||||
let session_duration_seconds = Histogram::new(exponential_buckets(1.0, 2.0, 15)); // 1s to ~9 hours
|
||||
let session_duration_seconds = Histogram::new(exponential_buckets(1.0, 2.0, 15)); // 1s to ~9 hours
|
||||
registry.register(
|
||||
"guruconnect_session_duration_seconds",
|
||||
"Session duration in seconds",
|
||||
@@ -144,9 +145,10 @@ impl Metrics {
|
||||
db_operations_total.clone(),
|
||||
);
|
||||
|
||||
let db_query_duration_seconds = Family::<DatabaseLabels, Histogram>::new_with_constructor(|| {
|
||||
Histogram::new(exponential_buckets(0.0001, 2.0, 12)) // 0.1ms to ~400ms
|
||||
});
|
||||
let db_query_duration_seconds =
|
||||
Family::<DatabaseLabels, Histogram>::new_with_constructor(|| {
|
||||
Histogram::new(exponential_buckets(0.0001, 2.0, 12)) // 0.1ms to ~400ms
|
||||
});
|
||||
registry.register(
|
||||
"guruconnect_db_query_duration_seconds",
|
||||
"Database query duration in seconds",
|
||||
@@ -188,7 +190,13 @@ impl Metrics {
|
||||
}
|
||||
|
||||
/// Record request duration
|
||||
pub fn record_request_duration(&self, method: &str, path: &str, status: u16, duration_secs: f64) {
|
||||
pub fn record_request_duration(
|
||||
&self,
|
||||
method: &str,
|
||||
path: &str,
|
||||
status: u16,
|
||||
duration_secs: f64,
|
||||
) {
|
||||
self.request_duration_seconds
|
||||
.get_or_create(&RequestLabels {
|
||||
method: method.to_string(),
|
||||
|
||||
@@ -3,17 +3,10 @@
|
||||
//! SEC-7: XSS Prevention via Content-Security-Policy
|
||||
//! SEC-12: Additional security headers
|
||||
|
||||
use axum::{
|
||||
extract::Request,
|
||||
middleware::Next,
|
||||
response::Response,
|
||||
};
|
||||
use axum::{extract::Request, middleware::Next, response::Response};
|
||||
|
||||
/// Add security headers to all responses
|
||||
pub async fn add_security_headers(
|
||||
request: Request,
|
||||
next: Next,
|
||||
) -> Response {
|
||||
pub async fn add_security_headers(request: Request, next: Next) -> Response {
|
||||
let mut response = next.run(request).await;
|
||||
let headers = response.headers_mut();
|
||||
|
||||
@@ -35,22 +28,13 @@ pub async fn add_security_headers(
|
||||
);
|
||||
|
||||
// SEC-12: X-Frame-Options (Clickjacking protection)
|
||||
headers.insert(
|
||||
"X-Frame-Options",
|
||||
"DENY".parse().unwrap(),
|
||||
);
|
||||
headers.insert("X-Frame-Options", "DENY".parse().unwrap());
|
||||
|
||||
// SEC-12: X-Content-Type-Options (MIME sniffing protection)
|
||||
headers.insert(
|
||||
"X-Content-Type-Options",
|
||||
"nosniff".parse().unwrap(),
|
||||
);
|
||||
headers.insert("X-Content-Type-Options", "nosniff".parse().unwrap());
|
||||
|
||||
// SEC-12: X-XSS-Protection (Legacy XSS filter - deprecated but still useful)
|
||||
headers.insert(
|
||||
"X-XSS-Protection",
|
||||
"1; mode=block".parse().unwrap(),
|
||||
);
|
||||
headers.insert("X-XSS-Protection", "1; mode=block".parse().unwrap());
|
||||
|
||||
// SEC-12: Referrer-Policy (Control referrer information)
|
||||
headers.insert(
|
||||
|
||||
@@ -6,21 +6,21 @@
|
||||
use axum::{
|
||||
extract::{
|
||||
ws::{Message, WebSocket, WebSocketUpgrade},
|
||||
Query, State, ConnectInfo,
|
||||
ConnectInfo, Query, State,
|
||||
},
|
||||
response::IntoResponse,
|
||||
http::StatusCode,
|
||||
response::IntoResponse,
|
||||
};
|
||||
use std::net::SocketAddr;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use prost::Message as ProstMessage;
|
||||
use serde::Deserialize;
|
||||
use std::net::SocketAddr;
|
||||
use tracing::{error, info, warn};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::db::{self, Database};
|
||||
use crate::proto;
|
||||
use crate::session::SessionManager;
|
||||
use crate::db::{self, Database};
|
||||
use crate::AppState;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
@@ -59,7 +59,11 @@ pub async fn agent_ws_handler(
|
||||
Query(params): Query<AgentParams>,
|
||||
) -> Result<impl IntoResponse, StatusCode> {
|
||||
let agent_id = params.agent_id.clone();
|
||||
let agent_name = params.hostname.clone().or(params.agent_name.clone()).unwrap_or_else(|| agent_id.clone());
|
||||
let agent_name = params
|
||||
.hostname
|
||||
.clone()
|
||||
.or(params.agent_name.clone())
|
||||
.unwrap_or_else(|| agent_id.clone());
|
||||
let support_code = params.support_code.clone();
|
||||
let api_key = params.api_key.clone();
|
||||
let client_ip = addr.ip();
|
||||
@@ -69,7 +73,10 @@ pub async fn agent_ws_handler(
|
||||
// API key = persistent managed agent
|
||||
|
||||
if support_code.is_none() && api_key.is_none() {
|
||||
warn!("Agent connection rejected: {} from {} - no support code or API key", agent_id, client_ip);
|
||||
warn!(
|
||||
"Agent connection rejected: {} from {} - no support code or API key",
|
||||
agent_id, client_ip
|
||||
);
|
||||
|
||||
// Log failed connection attempt to database
|
||||
if let Some(ref db) = state.db {
|
||||
@@ -84,7 +91,8 @@ pub async fn agent_ws_handler(
|
||||
"agent_id": agent_id
|
||||
})),
|
||||
Some(client_ip),
|
||||
).await;
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
return Err(StatusCode::UNAUTHORIZED);
|
||||
@@ -95,7 +103,10 @@ pub async fn agent_ws_handler(
|
||||
// Check if it's a valid, pending support code
|
||||
let code_info = state.support_codes.get_status(code).await;
|
||||
if code_info.is_none() {
|
||||
warn!("Agent connection rejected: {} from {} - invalid support code {}", agent_id, client_ip, code);
|
||||
warn!(
|
||||
"Agent connection rejected: {} from {} - invalid support code {}",
|
||||
agent_id, client_ip, code
|
||||
);
|
||||
|
||||
// Log failed connection attempt
|
||||
if let Some(ref db) = state.db {
|
||||
@@ -111,14 +122,18 @@ pub async fn agent_ws_handler(
|
||||
"agent_id": agent_id
|
||||
})),
|
||||
Some(client_ip),
|
||||
).await;
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
return Err(StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
let status = code_info.unwrap();
|
||||
if status != "pending" && status != "connected" {
|
||||
warn!("Agent connection rejected: {} from {} - support code {} has status {}", agent_id, client_ip, code, status);
|
||||
warn!(
|
||||
"Agent connection rejected: {} from {} - support code {} has status {}",
|
||||
agent_id, client_ip, code, status
|
||||
);
|
||||
|
||||
// Log failed connection attempt (expired/cancelled code)
|
||||
if let Some(ref db) = state.db {
|
||||
@@ -140,12 +155,16 @@ pub async fn agent_ws_handler(
|
||||
"agent_id": agent_id
|
||||
})),
|
||||
Some(client_ip),
|
||||
).await;
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
return Err(StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
info!("Agent {} from {} authenticated via support code {}", agent_id, client_ip, code);
|
||||
info!(
|
||||
"Agent {} from {} authenticated via support code {}",
|
||||
agent_id, client_ip, code
|
||||
);
|
||||
}
|
||||
|
||||
// Validate API key if provided (for persistent agents)
|
||||
@@ -153,7 +172,10 @@ pub async fn agent_ws_handler(
|
||||
// For now, we'll accept API keys that match the JWT secret or a configured agent key
|
||||
// In production, this should validate against a database of registered agents
|
||||
if !validate_agent_api_key(&state, key).await {
|
||||
warn!("Agent connection rejected: {} from {} - invalid API key", agent_id, client_ip);
|
||||
warn!(
|
||||
"Agent connection rejected: {} from {} - invalid API key",
|
||||
agent_id, client_ip
|
||||
);
|
||||
|
||||
// Log failed connection attempt
|
||||
if let Some(ref db) = state.db {
|
||||
@@ -168,19 +190,34 @@ pub async fn agent_ws_handler(
|
||||
"agent_id": agent_id
|
||||
})),
|
||||
Some(client_ip),
|
||||
).await;
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
return Err(StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
info!("Agent {} from {} authenticated via API key", agent_id, client_ip);
|
||||
info!(
|
||||
"Agent {} from {} authenticated via API key",
|
||||
agent_id, client_ip
|
||||
);
|
||||
}
|
||||
|
||||
let sessions = state.sessions.clone();
|
||||
let support_codes = state.support_codes.clone();
|
||||
let db = state.db.clone();
|
||||
|
||||
Ok(ws.on_upgrade(move |socket| handle_agent_connection(socket, sessions, support_codes, db, agent_id, agent_name, support_code, Some(client_ip))))
|
||||
Ok(ws.on_upgrade(move |socket| {
|
||||
handle_agent_connection(
|
||||
socket,
|
||||
sessions,
|
||||
support_codes,
|
||||
db,
|
||||
agent_id,
|
||||
agent_name,
|
||||
support_code,
|
||||
Some(client_ip),
|
||||
)
|
||||
}))
|
||||
}
|
||||
|
||||
/// Validate an agent API key
|
||||
@@ -212,24 +249,42 @@ pub async fn viewer_ws_handler(
|
||||
|
||||
// Require JWT token for viewers
|
||||
let token = params.token.ok_or_else(|| {
|
||||
warn!("Viewer connection rejected from {}: missing token", client_ip);
|
||||
warn!(
|
||||
"Viewer connection rejected from {}: missing token",
|
||||
client_ip
|
||||
);
|
||||
StatusCode::UNAUTHORIZED
|
||||
})?;
|
||||
|
||||
// Validate the token
|
||||
let claims = state.jwt_config.validate_token(&token).map_err(|e| {
|
||||
warn!("Viewer connection rejected from {}: invalid token: {}", client_ip, e);
|
||||
warn!(
|
||||
"Viewer connection rejected from {}: invalid token: {}",
|
||||
client_ip, e
|
||||
);
|
||||
StatusCode::UNAUTHORIZED
|
||||
})?;
|
||||
|
||||
info!("Viewer {} authenticated via JWT from {}", claims.username, client_ip);
|
||||
info!(
|
||||
"Viewer {} authenticated via JWT from {}",
|
||||
claims.username, client_ip
|
||||
);
|
||||
|
||||
let session_id = params.session_id;
|
||||
let viewer_name = params.viewer_name;
|
||||
let sessions = state.sessions.clone();
|
||||
let db = state.db.clone();
|
||||
|
||||
Ok(ws.on_upgrade(move |socket| handle_viewer_connection(socket, sessions, db, session_id, viewer_name, Some(client_ip))))
|
||||
Ok(ws.on_upgrade(move |socket| {
|
||||
handle_viewer_connection(
|
||||
socket,
|
||||
sessions,
|
||||
db,
|
||||
session_id,
|
||||
viewer_name,
|
||||
Some(client_ip),
|
||||
)
|
||||
}))
|
||||
}
|
||||
|
||||
/// Handle an agent WebSocket connection
|
||||
@@ -243,7 +298,10 @@ async fn handle_agent_connection(
|
||||
support_code: Option<String>,
|
||||
client_ip: Option<std::net::IpAddr>,
|
||||
) {
|
||||
info!("Agent connected: {} ({}) from {:?}", agent_name, agent_id, client_ip);
|
||||
info!(
|
||||
"Agent connected: {} ({}) from {:?}",
|
||||
agent_name, agent_id, client_ip
|
||||
);
|
||||
|
||||
let (mut ws_sender, mut ws_receiver) = socket.split();
|
||||
|
||||
@@ -270,7 +328,9 @@ async fn handle_agent_connection(
|
||||
// Register the agent and get channels
|
||||
// Persistent agents (no support code) keep their session when disconnected
|
||||
let is_persistent = support_code.is_none();
|
||||
let (session_id, frame_tx, mut input_rx) = sessions.register_agent(agent_id.clone(), agent_name.clone(), is_persistent).await;
|
||||
let (session_id, frame_tx, mut input_rx) = sessions
|
||||
.register_agent(agent_id.clone(), agent_name.clone(), is_persistent)
|
||||
.await;
|
||||
|
||||
info!("Session created: {} (agent in idle mode)", session_id);
|
||||
|
||||
@@ -285,15 +345,20 @@ async fn handle_agent_connection(
|
||||
machine.id,
|
||||
support_code.is_some(),
|
||||
support_code.as_deref(),
|
||||
).await;
|
||||
)
|
||||
.await;
|
||||
|
||||
// Log session started event
|
||||
let _ = db::events::log_event(
|
||||
db.pool(),
|
||||
session_id,
|
||||
db::events::EventTypes::SESSION_STARTED,
|
||||
None, None, None, client_ip,
|
||||
).await;
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
client_ip,
|
||||
)
|
||||
.await;
|
||||
|
||||
Some(machine.id)
|
||||
}
|
||||
@@ -309,7 +374,9 @@ async fn handle_agent_connection(
|
||||
// If a support code was provided, mark it as connected
|
||||
if let Some(ref code) = support_code {
|
||||
info!("Linking support code {} to session {}", code, session_id);
|
||||
support_codes.mark_connected(code, Some(agent_name.clone()), Some(agent_id.clone())).await;
|
||||
support_codes
|
||||
.mark_connected(code, Some(agent_name.clone()), Some(agent_id.clone()))
|
||||
.await;
|
||||
support_codes.link_session(code, session_id).await;
|
||||
|
||||
// Database: update support code
|
||||
@@ -320,7 +387,8 @@ async fn handle_agent_connection(
|
||||
Some(session_id),
|
||||
Some(&agent_name),
|
||||
Some(&agent_id),
|
||||
).await;
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -333,7 +401,11 @@ async fn handle_agent_connection(
|
||||
let input_forward = tokio::spawn(async move {
|
||||
while let Some(input_data) = input_rx.recv().await {
|
||||
let mut sender = ws_sender_input.lock().await;
|
||||
if sender.send(Message::Binary(input_data.into())).await.is_err() {
|
||||
if sender
|
||||
.send(Message::Binary(input_data.into()))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -406,22 +478,29 @@ async fn handle_agent_connection(
|
||||
} else {
|
||||
Some(status.site.clone())
|
||||
};
|
||||
sessions_status.update_agent_status(
|
||||
session_id,
|
||||
Some(status.os_version.clone()),
|
||||
status.is_elevated,
|
||||
status.uptime_secs,
|
||||
status.display_count,
|
||||
status.is_streaming,
|
||||
agent_version.clone(),
|
||||
organization.clone(),
|
||||
site.clone(),
|
||||
status.tags.clone(),
|
||||
).await;
|
||||
sessions_status
|
||||
.update_agent_status(
|
||||
session_id,
|
||||
Some(status.os_version.clone()),
|
||||
status.is_elevated,
|
||||
status.uptime_secs,
|
||||
status.display_count,
|
||||
status.is_streaming,
|
||||
agent_version.clone(),
|
||||
organization.clone(),
|
||||
site.clone(),
|
||||
status.tags.clone(),
|
||||
)
|
||||
.await;
|
||||
|
||||
// Update version in database if present
|
||||
if let (Some(ref db), Some(ref version)) = (&db, &agent_version) {
|
||||
let _ = crate::db::releases::update_machine_version(db.pool(), &agent_id, version).await;
|
||||
let _ = crate::db::releases::update_machine_version(
|
||||
db.pool(),
|
||||
&agent_id,
|
||||
version,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
// Update organization/site/tags in database if present
|
||||
@@ -432,7 +511,8 @@ async fn handle_agent_connection(
|
||||
organization.as_deref(),
|
||||
site.as_deref(),
|
||||
&status.tags,
|
||||
).await;
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
info!("Agent status update: {} - streaming={}, uptime={}s, version={:?}, org={:?}, site={:?}",
|
||||
@@ -489,8 +569,12 @@ async fn handle_agent_connection(
|
||||
db.pool(),
|
||||
session_id,
|
||||
db::events::EventTypes::SESSION_ENDED,
|
||||
None, None, None, client_ip,
|
||||
).await;
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
client_ip,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
// Mark support code as completed if one was used (unless cancelled)
|
||||
@@ -532,7 +616,10 @@ async fn handle_viewer_connection(
|
||||
let viewer_id = Uuid::new_v4().to_string();
|
||||
|
||||
// Join the session (this sends StartStream to agent if first viewer)
|
||||
let (mut frame_rx, input_tx) = match sessions.join_session(session_id, viewer_id.clone(), viewer_name.clone()).await {
|
||||
let (mut frame_rx, input_tx) = match sessions
|
||||
.join_session(session_id, viewer_id.clone(), viewer_name.clone())
|
||||
.await
|
||||
{
|
||||
Some(channels) => channels,
|
||||
None => {
|
||||
warn!("Session not found: {}", session_id);
|
||||
@@ -540,7 +627,10 @@ async fn handle_viewer_connection(
|
||||
}
|
||||
};
|
||||
|
||||
info!("Viewer {} ({}) joined session: {} from {:?}", viewer_name, viewer_id, session_id, client_ip);
|
||||
info!(
|
||||
"Viewer {} ({}) joined session: {} from {:?}",
|
||||
viewer_name, viewer_id, session_id, client_ip
|
||||
);
|
||||
|
||||
// Database: log viewer joined event
|
||||
if let Some(ref db) = db {
|
||||
@@ -550,8 +640,10 @@ async fn handle_viewer_connection(
|
||||
db::events::EventTypes::VIEWER_JOINED,
|
||||
Some(&viewer_id),
|
||||
Some(&viewer_name),
|
||||
None, client_ip,
|
||||
).await;
|
||||
None,
|
||||
client_ip,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
let (mut ws_sender, mut ws_receiver) = socket.split();
|
||||
@@ -559,7 +651,11 @@ async fn handle_viewer_connection(
|
||||
// Task to forward frames from agent to this viewer
|
||||
let frame_forward = tokio::spawn(async move {
|
||||
while let Ok(frame_data) = frame_rx.recv().await {
|
||||
if ws_sender.send(Message::Binary(frame_data.into())).await.is_err() {
|
||||
if ws_sender
|
||||
.send(Message::Binary(frame_data.into()))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -577,9 +673,9 @@ async fn handle_viewer_connection(
|
||||
match proto::Message::decode(data.as_ref()) {
|
||||
Ok(proto_msg) => {
|
||||
match &proto_msg.payload {
|
||||
Some(proto::message::Payload::MouseEvent(_)) |
|
||||
Some(proto::message::Payload::KeyEvent(_)) |
|
||||
Some(proto::message::Payload::SpecialKey(_)) => {
|
||||
Some(proto::message::Payload::MouseEvent(_))
|
||||
| Some(proto::message::Payload::KeyEvent(_))
|
||||
| Some(proto::message::Payload::SpecialKey(_)) => {
|
||||
// Forward input to agent
|
||||
let _ = input_tx.send(data.to_vec()).await;
|
||||
}
|
||||
@@ -597,7 +693,10 @@ async fn handle_viewer_connection(
|
||||
}
|
||||
}
|
||||
Ok(Message::Close(_)) => {
|
||||
info!("Viewer {} disconnected from session: {}", viewer_id, session_id);
|
||||
info!(
|
||||
"Viewer {} disconnected from session: {}",
|
||||
viewer_id, session_id
|
||||
);
|
||||
break;
|
||||
}
|
||||
Ok(_) => {}
|
||||
@@ -610,7 +709,9 @@ async fn handle_viewer_connection(
|
||||
|
||||
// Cleanup (this sends StopStream to agent if last viewer)
|
||||
frame_forward.abort();
|
||||
sessions_cleanup.leave_session(session_id, &viewer_id_cleanup).await;
|
||||
sessions_cleanup
|
||||
.leave_session(session_id, &viewer_id_cleanup)
|
||||
.await;
|
||||
|
||||
// Database: log viewer left event
|
||||
if let Some(ref db) = db {
|
||||
@@ -620,8 +721,10 @@ async fn handle_viewer_connection(
|
||||
db::events::EventTypes::VIEWER_LEFT,
|
||||
Some(&viewer_id_cleanup),
|
||||
Some(&viewer_name_cleanup),
|
||||
None, client_ip,
|
||||
).await;
|
||||
None,
|
||||
client_ip,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
info!("Viewer {} left session: {}", viewer_id_cleanup, session_id);
|
||||
|
||||
@@ -37,20 +37,20 @@ pub struct Session {
|
||||
pub agent_name: String,
|
||||
pub started_at: chrono::DateTime<chrono::Utc>,
|
||||
pub viewer_count: usize,
|
||||
pub viewers: Vec<ViewerInfo>, // List of connected technicians
|
||||
pub viewers: Vec<ViewerInfo>, // List of connected technicians
|
||||
pub is_streaming: bool,
|
||||
pub is_online: bool, // Whether agent is currently connected
|
||||
pub is_persistent: bool, // Persistent agent (no support code) vs support session
|
||||
pub is_online: bool, // Whether agent is currently connected
|
||||
pub is_persistent: bool, // Persistent agent (no support code) vs support session
|
||||
pub last_heartbeat: chrono::DateTime<chrono::Utc>,
|
||||
// Agent status info
|
||||
pub os_version: Option<String>,
|
||||
pub is_elevated: bool,
|
||||
pub uptime_secs: i64,
|
||||
pub display_count: i32,
|
||||
pub agent_version: Option<String>, // Agent software version
|
||||
pub organization: Option<String>, // Company/organization name
|
||||
pub site: Option<String>, // Site/location name
|
||||
pub tags: Vec<String>, // Tags for categorization
|
||||
pub agent_version: Option<String>, // Agent software version
|
||||
pub organization: Option<String>, // Company/organization name
|
||||
pub site: Option<String>, // Site/location name
|
||||
pub tags: Vec<String>, // Tags for categorization
|
||||
}
|
||||
|
||||
/// Channel for sending frames from agent to viewers
|
||||
@@ -92,7 +92,12 @@ impl SessionManager {
|
||||
|
||||
/// Register a new agent and create a session
|
||||
/// If agent was previously connected (offline session exists), reuse that session
|
||||
pub async fn register_agent(&self, agent_id: AgentId, agent_name: String, is_persistent: bool) -> (SessionId, FrameSender, InputReceiver) {
|
||||
pub async fn register_agent(
|
||||
&self,
|
||||
agent_id: AgentId,
|
||||
agent_name: String,
|
||||
is_persistent: bool,
|
||||
) -> (SessionId, FrameSender, InputReceiver) {
|
||||
// Check if this agent already has an offline session (reconnecting)
|
||||
{
|
||||
let agents = self.agents.read().await;
|
||||
@@ -101,7 +106,11 @@ impl SessionManager {
|
||||
if let Some(session_data) = sessions.get_mut(&existing_session_id) {
|
||||
if !session_data.info.is_online {
|
||||
// Reuse existing session - mark as online and create new channels
|
||||
tracing::info!("Agent {} reconnecting to existing session {}", agent_id, existing_session_id);
|
||||
tracing::info!(
|
||||
"Agent {} reconnecting to existing session {}",
|
||||
agent_id,
|
||||
existing_session_id
|
||||
);
|
||||
|
||||
let (frame_tx, _) = broadcast::channel(16);
|
||||
let (input_tx, input_rx) = tokio::sync::mpsc::channel(64);
|
||||
@@ -230,7 +239,9 @@ impl SessionManager {
|
||||
let sessions = self.sessions.read().await;
|
||||
sessions
|
||||
.iter()
|
||||
.filter(|(_, data)| data.last_heartbeat_instant.elapsed().as_secs() > HEARTBEAT_TIMEOUT_SECS)
|
||||
.filter(|(_, data)| {
|
||||
data.last_heartbeat_instant.elapsed().as_secs() > HEARTBEAT_TIMEOUT_SECS
|
||||
})
|
||||
.map(|(id, _)| *id)
|
||||
.collect()
|
||||
}
|
||||
@@ -251,7 +262,12 @@ impl SessionManager {
|
||||
}
|
||||
|
||||
/// Join a session as a viewer, returns channels and sends StartStream to agent
|
||||
pub async fn join_session(&self, session_id: SessionId, viewer_id: ViewerId, viewer_name: String) -> Option<(FrameReceiver, InputSender)> {
|
||||
pub async fn join_session(
|
||||
&self,
|
||||
session_id: SessionId,
|
||||
viewer_id: ViewerId,
|
||||
viewer_name: String,
|
||||
) -> Option<(FrameReceiver, InputSender)> {
|
||||
let mut sessions = self.sessions.write().await;
|
||||
let session_data = sessions.get_mut(&session_id)?;
|
||||
|
||||
@@ -274,10 +290,20 @@ impl SessionManager {
|
||||
|
||||
// If this is the first viewer, send StartStream to agent
|
||||
if was_empty {
|
||||
tracing::info!("Viewer {} ({}) joined session {}, sending StartStream", viewer_name, viewer_id, session_id);
|
||||
tracing::info!(
|
||||
"Viewer {} ({}) joined session {}, sending StartStream",
|
||||
viewer_name,
|
||||
viewer_id,
|
||||
session_id
|
||||
);
|
||||
Self::send_start_stream_internal(session_data, &viewer_id).await;
|
||||
} else {
|
||||
tracing::info!("Viewer {} ({}) joined session {}", viewer_name, viewer_id, session_id);
|
||||
tracing::info!(
|
||||
"Viewer {} ({}) joined session {}",
|
||||
viewer_name,
|
||||
viewer_id,
|
||||
session_id
|
||||
);
|
||||
}
|
||||
|
||||
Some((frame_rx, input_tx))
|
||||
@@ -312,12 +338,20 @@ impl SessionManager {
|
||||
|
||||
// If no more viewers, send StopStream to agent
|
||||
if session_data.viewers.is_empty() {
|
||||
tracing::info!("Last viewer {} ({}) left session {}, sending StopStream",
|
||||
viewer_name.as_deref().unwrap_or("unknown"), viewer_id, session_id);
|
||||
tracing::info!(
|
||||
"Last viewer {} ({}) left session {}, sending StopStream",
|
||||
viewer_name.as_deref().unwrap_or("unknown"),
|
||||
viewer_id,
|
||||
session_id
|
||||
);
|
||||
Self::send_stop_stream_internal(session_data, viewer_id).await;
|
||||
} else {
|
||||
tracing::info!("Viewer {} ({}) left session {}",
|
||||
viewer_name.as_deref().unwrap_or("unknown"), viewer_id, session_id);
|
||||
tracing::info!(
|
||||
"Viewer {} ({}) left session {}",
|
||||
viewer_name.as_deref().unwrap_or("unknown"),
|
||||
viewer_id,
|
||||
session_id
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -347,8 +381,11 @@ impl SessionManager {
|
||||
if let Some(session_data) = sessions.get_mut(&session_id) {
|
||||
if session_data.info.is_persistent {
|
||||
// Persistent agent - keep session but mark as offline
|
||||
tracing::info!("Persistent agent {} marked offline (session {} preserved)",
|
||||
session_data.info.agent_id, session_id);
|
||||
tracing::info!(
|
||||
"Persistent agent {} marked offline (session {} preserved)",
|
||||
session_data.info.agent_id,
|
||||
session_id
|
||||
);
|
||||
session_data.info.is_online = false;
|
||||
session_data.info.is_streaming = false;
|
||||
session_data.info.viewer_count = 0;
|
||||
@@ -410,7 +447,12 @@ impl SessionManager {
|
||||
|
||||
/// Send an admin command to an agent (uninstall, restart, etc.)
|
||||
/// Returns true if the message was sent successfully
|
||||
pub async fn send_admin_command(&self, session_id: SessionId, command: crate::proto::AdminCommandType, reason: &str) -> bool {
|
||||
pub async fn send_admin_command(
|
||||
&self,
|
||||
session_id: SessionId,
|
||||
command: crate::proto::AdminCommandType,
|
||||
reason: &str,
|
||||
) -> bool {
|
||||
let sessions = self.sessions.read().await;
|
||||
if let Some(session_data) = sessions.get(&session_id) {
|
||||
if !session_data.info.is_online {
|
||||
@@ -471,7 +513,7 @@ impl SessionManager {
|
||||
viewer_count: 0,
|
||||
viewers: Vec::new(),
|
||||
is_streaming: false,
|
||||
is_online: false, // Offline until agent reconnects
|
||||
is_online: false, // Offline until agent reconnects
|
||||
is_persistent: true,
|
||||
last_heartbeat: now,
|
||||
os_version: None,
|
||||
|
||||
@@ -3,12 +3,12 @@
|
||||
//! Handles generation and validation of 6-digit support codes
|
||||
//! for one-time remote support sessions.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use chrono::{DateTime, Utc};
|
||||
use rand::Rng;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// A support session code
|
||||
@@ -27,10 +27,10 @@ pub struct SupportCode {
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum CodeStatus {
|
||||
Pending, // Waiting for client to connect
|
||||
Connected, // Client connected, session active
|
||||
Completed, // Session ended normally
|
||||
Cancelled, // Code cancelled by tech
|
||||
Pending, // Waiting for client to connect
|
||||
Connected, // Client connected, session active
|
||||
Completed, // Session ended normally
|
||||
Cancelled, // Code cancelled by tech
|
||||
}
|
||||
|
||||
/// Request to create a new support code
|
||||
@@ -69,11 +69,11 @@ impl SupportCodeManager {
|
||||
async fn generate_unique_code(&self) -> String {
|
||||
let codes = self.codes.read().await;
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
|
||||
loop {
|
||||
let code: u32 = rng.gen_range(100000..999999);
|
||||
let code_str = code.to_string();
|
||||
|
||||
|
||||
if !codes.contains_key(&code_str) {
|
||||
return code_str;
|
||||
}
|
||||
@@ -84,11 +84,13 @@ impl SupportCodeManager {
|
||||
pub async fn create_code(&self, request: CreateCodeRequest) -> SupportCode {
|
||||
let code = self.generate_unique_code().await;
|
||||
let session_id = Uuid::new_v4();
|
||||
|
||||
|
||||
let support_code = SupportCode {
|
||||
code: code.clone(),
|
||||
session_id,
|
||||
created_by: request.technician_name.unwrap_or_else(|| "Unknown".to_string()),
|
||||
created_by: request
|
||||
.technician_name
|
||||
.unwrap_or_else(|| "Unknown".to_string()),
|
||||
created_at: Utc::now(),
|
||||
status: CodeStatus::Pending,
|
||||
client_name: None,
|
||||
@@ -108,10 +110,12 @@ impl SupportCodeManager {
|
||||
/// Validate a code and return session info
|
||||
pub async fn validate_code(&self, code: &str) -> CodeValidation {
|
||||
let codes = self.codes.read().await;
|
||||
|
||||
|
||||
match codes.get(code) {
|
||||
Some(support_code) => {
|
||||
if support_code.status == CodeStatus::Pending || support_code.status == CodeStatus::Connected {
|
||||
if support_code.status == CodeStatus::Pending
|
||||
|| support_code.status == CodeStatus::Connected
|
||||
{
|
||||
CodeValidation {
|
||||
valid: true,
|
||||
session_id: Some(support_code.session_id.to_string()),
|
||||
@@ -137,7 +141,12 @@ impl SupportCodeManager {
|
||||
}
|
||||
|
||||
/// Mark a code as connected
|
||||
pub async fn mark_connected(&self, code: &str, client_name: Option<String>, client_machine: Option<String>) {
|
||||
pub async fn mark_connected(
|
||||
&self,
|
||||
code: &str,
|
||||
client_name: Option<String>,
|
||||
client_machine: Option<String>,
|
||||
) {
|
||||
let mut codes = self.codes.write().await;
|
||||
if let Some(support_code) = codes.get_mut(code) {
|
||||
support_code.status = CodeStatus::Connected;
|
||||
@@ -180,7 +189,9 @@ impl SupportCodeManager {
|
||||
pub async fn cancel_code(&self, code: &str) -> bool {
|
||||
let mut codes = self.codes.write().await;
|
||||
if let Some(support_code) = codes.get_mut(code) {
|
||||
if support_code.status == CodeStatus::Pending || support_code.status == CodeStatus::Connected {
|
||||
if support_code.status == CodeStatus::Pending
|
||||
|| support_code.status == CodeStatus::Connected
|
||||
{
|
||||
support_code.status = CodeStatus::Cancelled;
|
||||
return true;
|
||||
}
|
||||
@@ -191,13 +202,19 @@ impl SupportCodeManager {
|
||||
/// Check if a code is cancelled
|
||||
pub async fn is_cancelled(&self, code: &str) -> bool {
|
||||
let codes = self.codes.read().await;
|
||||
codes.get(code).map(|c| c.status == CodeStatus::Cancelled).unwrap_or(false)
|
||||
codes
|
||||
.get(code)
|
||||
.map(|c| c.status == CodeStatus::Cancelled)
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
/// Check if a code is valid for connection (exists and is pending)
|
||||
pub async fn is_valid_for_connection(&self, code: &str) -> bool {
|
||||
let codes = self.codes.read().await;
|
||||
codes.get(code).map(|c| c.status == CodeStatus::Pending).unwrap_or(false)
|
||||
codes
|
||||
.get(code)
|
||||
.map(|c| c.status == CodeStatus::Pending)
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
/// List all codes (for dashboard)
|
||||
@@ -209,7 +226,8 @@ impl SupportCodeManager {
|
||||
/// List active codes only
|
||||
pub async fn list_active_codes(&self) -> Vec<SupportCode> {
|
||||
let codes = self.codes.read().await;
|
||||
codes.values()
|
||||
codes
|
||||
.values()
|
||||
.filter(|c| c.status == CodeStatus::Pending || c.status == CodeStatus::Connected)
|
||||
.cloned()
|
||||
.collect()
|
||||
|
||||
@@ -11,18 +11,29 @@ use anyhow::{anyhow, Result};
|
||||
pub fn validate_api_key_strength(api_key: &str) -> Result<()> {
|
||||
// Minimum length check
|
||||
if api_key.len() < 32 {
|
||||
return Err(anyhow!("API key must be at least 32 characters long for security"));
|
||||
return Err(anyhow!(
|
||||
"API key must be at least 32 characters long for security"
|
||||
));
|
||||
}
|
||||
|
||||
// Check for common weak keys
|
||||
let weak_keys = [
|
||||
"password", "12345", "admin", "test", "api_key",
|
||||
"secret", "changeme", "default", "guruconnect"
|
||||
"password",
|
||||
"12345",
|
||||
"admin",
|
||||
"test",
|
||||
"api_key",
|
||||
"secret",
|
||||
"changeme",
|
||||
"default",
|
||||
"guruconnect",
|
||||
];
|
||||
let lowercase_key = api_key.to_lowercase();
|
||||
for weak in &weak_keys {
|
||||
if lowercase_key.contains(weak) {
|
||||
return Err(anyhow!("API key contains weak/common patterns and is not secure"));
|
||||
return Err(anyhow!(
|
||||
"API key contains weak/common patterns and is not secure"
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,6 +64,9 @@ mod tests {
|
||||
assert!(validate_api_key_strength("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa").is_err());
|
||||
|
||||
// Good key
|
||||
assert!(validate_api_key_strength("KfPrjjC3J6YMx9q1yjPxZAYkHLM2JdFy1XRxHJ9oPnw0NU3xH074ufHk7fj").is_ok());
|
||||
assert!(validate_api_key_strength(
|
||||
"KfPrjjC3J6YMx9q1yjPxZAYkHLM2JdFy1XRxHJ9oPnw0NU3xH074ufHk7fj"
|
||||
)
|
||||
.is_ok());
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user