fix(security): Implement Phase 1 critical security fixes

CORS:
- Restrict CORS to DASHBOARD_URL environment variable
- Default to production dashboard domain

Authentication:
- Add AuthUser requirement to all agent management endpoints
- Add AuthUser requirement to all command endpoints
- Add AuthUser requirement to all metrics endpoints
- Add audit logging for command execution (user_id tracked)

Agent Security:
- Replace Unicode characters with ASCII markers [OK]/[ERROR]/[WARNING]
- Add certificate pinning for update downloads (allowlist domains)
- Fix insecure temp file creation (use /var/run/gururmm with 0700 perms)
- Fix rollback script backgrounding (use setsid instead of literal &)

Dashboard Security:
- Move token storage from localStorage to sessionStorage
- Add proper TypeScript types (remove 'any' from error handlers)
- Centralize token management functions

Legacy Agent:
- Add -AllowInsecureTLS parameter (opt-in required)
- Add Windows Event Log audit trail when insecure mode used
- Update documentation with security warnings

Closes: Phase 1 items in issue #1

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-20 21:16:24 -07:00
parent 6d3271c144
commit 65086f4407
15 changed files with 1708 additions and 99 deletions

View File

@@ -1,19 +1,56 @@
#Requires -Version 2.0 #Requires -Version 2.0
<# <#
.SYNOPSIS .SYNOPSIS
GuruRMM Legacy Agent - PowerShell-based agent for Windows Server 2008 R2 and older systems GuruRMM Legacy Agent for Windows Server 2008 R2 and older systems.
.DESCRIPTION .DESCRIPTION
Lightweight RMM agent that: This PowerShell-based agent is designed for legacy Windows systems that cannot
- Registers with GuruRMM server using site code run the modern Rust-based GuruRMM agent. It provides basic RMM functionality
- Reports system information including registration, heartbeat, system info collection, and remote command
- Executes remote scripts/commands execution.
- Monitors system health
IMPORTANT: This agent is intended for legacy systems only. For Windows 10/
Server 2016 and newer, use the native Rust agent instead.
.PARAMETER ConfigPath
Path to the agent configuration file. Default: $env:ProgramData\GuruRMM\agent.json
.PARAMETER ServerUrl
The URL of the GuruRMM server (e.g., https://rmm.example.com)
.PARAMETER SiteCode
The site code for agent registration (e.g., ACME-CORP-1234)
.PARAMETER AllowInsecureTLS
[SECURITY RISK] Disables SSL/TLS certificate validation. Required ONLY for
systems with self-signed certificates or broken certificate chains.
WARNING: This flag makes the connection vulnerable to man-in-the-middle
attacks. Only use on isolated networks or when absolutely necessary.
This flag must be explicitly provided - certificate validation is enabled
by default.
.PARAMETER Register
Register this agent with the server.
.EXAMPLE
# Secure installation (recommended)
.\GuruRMM-Agent.ps1 -Register -ServerUrl "https://rmm.example.com" -SiteCode "ACME-CORP-1234"
.EXAMPLE
# Insecure installation (legacy systems with self-signed certs ONLY)
.\GuruRMM-Agent.ps1 -Register -ServerUrl "https://rmm.example.com" -SiteCode "ACME-CORP-1234" -AllowInsecureTLS
.EXAMPLE
# Run the agent
.\GuruRMM-Agent.ps1
.NOTES .NOTES
Compatible with PowerShell 2.0+ (Windows Server 2008 R2) Version: 1.1.0
Requires: PowerShell 2.0+
Platforms: Windows Server 2008 R2, Windows 7, and newer
Author: GuruRMM Author: GuruRMM
Version: 1.0.0
#> #>
param( param(
@@ -27,18 +64,23 @@ param(
[string]$SiteCode, [string]$SiteCode,
[Parameter()] [Parameter()]
[string]$ServerUrl = "https://rmm-api.azcomputerguru.com" [string]$ServerUrl = "https://rmm-api.azcomputerguru.com",
[Parameter()]
[switch]$AllowInsecureTLS
) )
# ============================================================================ # ============================================================================
# Configuration # Configuration
# ============================================================================ # ============================================================================
$script:Version = "1.0.0" $script:Version = "1.1.0"
$script:AgentType = "powershell-legacy" $script:AgentType = "powershell-legacy"
$script:ConfigDir = "$env:ProgramData\GuruRMM" $script:ConfigDir = "$env:ProgramData\GuruRMM"
$script:LogFile = "$script:ConfigDir\agent.log" $script:LogFile = "$script:ConfigDir\agent.log"
$script:PollInterval = 60 # seconds $script:PollInterval = 60 # seconds
$script:AllowInsecureTLS = $AllowInsecureTLS
$script:TLSInitialized = $false
# ============================================================================ # ============================================================================
# Logging # Logging
@@ -67,6 +109,63 @@ function Write-Log {
} catch {} } catch {}
} }
# ============================================================================
# TLS Initialization
# ============================================================================
function Initialize-TLS {
if ($script:TLSInitialized) {
return
}
# Configure TLS - prefer TLS 1.2
try {
[System.Net.ServicePointManager]::SecurityProtocol = [System.Net.SecurityProtocolType]::Tls12
Write-Log "TLS 1.2 configured successfully" "INFO"
} catch {
Write-Log "TLS 1.2 not available, trying TLS 1.1" "WARN"
try {
[System.Net.ServicePointManager]::SecurityProtocol = [System.Net.SecurityProtocolType]::Tls11
} catch {
Write-Log "TLS 1.1 not available - using system default TLS" "WARN"
try {
[System.Net.ServicePointManager]::SecurityProtocol = [System.Net.SecurityProtocolType]::Tls
} catch {
Write-Log "TLS configuration failed - connection security may be limited" "WARN"
}
}
}
# Certificate validation - ONLY disable if explicitly requested
if ($script:AllowInsecureTLS) {
Write-Log "============================================" "WARN"
Write-Log "[SECURITY WARNING] Certificate validation DISABLED" "WARN"
Write-Log "This makes the connection vulnerable to MITM attacks" "WARN"
Write-Log "Only use on legacy systems with self-signed certificates" "WARN"
Write-Log "============================================" "WARN"
# Log to Windows Event Log for audit trail
try {
$source = "GuruRMM"
if (-not [System.Diagnostics.EventLog]::SourceExists($source)) {
New-EventLog -LogName Application -Source $source -ErrorAction SilentlyContinue
}
Write-EventLog -LogName Application -Source $source -EventId 1001 -EntryType Warning `
-Message "GuruRMM agent started with certificate validation disabled (-AllowInsecureTLS). This is a security risk."
} catch {
Write-Log "Could not write to Windows Event Log: $_" "WARN"
}
[System.Net.ServicePointManager]::ServerCertificateValidationCallback = { $true }
} else {
Write-Log "Certificate validation ENABLED (secure mode)" "INFO"
# Ensure callback is reset to default (validate certificates)
[System.Net.ServicePointManager]::ServerCertificateValidationCallback = $null
}
$script:TLSInitialized = $true
}
# ============================================================================ # ============================================================================
# HTTP Functions (PS 2.0 compatible) # HTTP Functions (PS 2.0 compatible)
# ============================================================================ # ============================================================================
@@ -82,6 +181,9 @@ function Invoke-ApiRequest {
$url = "$($script:Config.ServerUrl)$Endpoint" $url = "$($script:Config.ServerUrl)$Endpoint"
try { try {
# Initialize TLS settings (only runs once)
Initialize-TLS
# Use .NET WebClient for PS 2.0 compatibility # Use .NET WebClient for PS 2.0 compatibility
$webClient = New-Object System.Net.WebClient $webClient = New-Object System.Net.WebClient
$webClient.Headers.Add("Content-Type", "application/json") $webClient.Headers.Add("Content-Type", "application/json")
@@ -91,17 +193,6 @@ function Invoke-ApiRequest {
$webClient.Headers.Add("Authorization", "Bearer $ApiKey") $webClient.Headers.Add("Authorization", "Bearer $ApiKey")
} }
# Handle TLS (important for older systems)
try {
[System.Net.ServicePointManager]::SecurityProtocol = [System.Net.SecurityProtocolType]::Tls12
} catch {
# Fallback for systems without TLS 1.2
[System.Net.ServicePointManager]::SecurityProtocol = [System.Net.SecurityProtocolType]::Tls
}
# Ignore certificate errors for self-signed certs (optional)
[System.Net.ServicePointManager]::ServerCertificateValidationCallback = { $true }
if ($Method -eq "GET") { if ($Method -eq "GET") {
$response = $webClient.DownloadString($url) $response = $webClient.DownloadString($url)
} else { } else {

View File

@@ -15,8 +15,20 @@
.PARAMETER ServerUrl .PARAMETER ServerUrl
The GuruRMM server URL (default: https://rmm-api.azcomputerguru.com) The GuruRMM server URL (default: https://rmm-api.azcomputerguru.com)
.PARAMETER AllowInsecureTLS
[SECURITY RISK] Disables SSL/TLS certificate validation. Required ONLY for
systems with self-signed certificates or broken certificate chains.
WARNING: This flag makes the connection vulnerable to man-in-the-middle
attacks. Only use on isolated networks or when absolutely necessary.
.EXAMPLE .EXAMPLE
# Secure installation (recommended)
.\Install-GuruRMM.ps1 -SiteCode DARK-GROVE-7839 .\Install-GuruRMM.ps1 -SiteCode DARK-GROVE-7839
.EXAMPLE
# Insecure installation (legacy systems with self-signed certs ONLY)
.\Install-GuruRMM.ps1 -SiteCode DARK-GROVE-7839 -AllowInsecureTLS
#> #>
param( param(
@@ -24,7 +36,10 @@ param(
[string]$SiteCode, [string]$SiteCode,
[Parameter()] [Parameter()]
[string]$ServerUrl = "https://rmm-api.azcomputerguru.com" [string]$ServerUrl = "https://rmm-api.azcomputerguru.com",
[Parameter()]
[switch]$AllowInsecureTLS
) )
$ErrorActionPreference = "Stop" $ErrorActionPreference = "Stop"
@@ -112,8 +127,15 @@ try {
# Step 3: Register agent # Step 3: Register agent
Write-Status "Registering with GuruRMM server..." Write-Status "Registering with GuruRMM server..."
if ($AllowInsecureTLS) {
Write-Status "[SECURITY WARNING] Installing with certificate validation DISABLED" "WARN"
Write-Status "This makes the connection vulnerable to MITM attacks" "WARN"
}
try { try {
$registerArgs = "-ExecutionPolicy Bypass -File `"$destScript`" -SiteCode `"$SiteCode`" -ServerUrl `"$ServerUrl`"" $registerArgs = "-ExecutionPolicy Bypass -File `"$destScript`" -SiteCode `"$SiteCode`" -ServerUrl `"$ServerUrl`""
if ($AllowInsecureTLS) {
$registerArgs += " -AllowInsecureTLS"
}
$process = Start-Process powershell.exe -ArgumentList $registerArgs -Wait -PassThru -NoNewWindow $process = Start-Process powershell.exe -ArgumentList $registerArgs -Wait -PassThru -NoNewWindow
if ($process.ExitCode -ne 0) { if ($process.ExitCode -ne 0) {
@@ -137,13 +159,19 @@ try {
# Step 5: Create scheduled task # Step 5: Create scheduled task
try { try {
# Create the task to run at startup and every 5 minutes # Create the task to run at startup
$taskCommand = "powershell.exe -ExecutionPolicy Bypass -WindowStyle Hidden -File `"$destScript`"" $taskCommand = "powershell.exe -ExecutionPolicy Bypass -WindowStyle Hidden -File `"$destScript`""
if ($AllowInsecureTLS) {
$taskCommand += " -AllowInsecureTLS"
}
# Create task that runs at system startup # Create task that runs at system startup
schtasks /create /tn $TaskName /tr $taskCommand /sc onstart /ru SYSTEM /rl HIGHEST /f | Out-Null schtasks /create /tn $TaskName /tr $taskCommand /sc onstart /ru SYSTEM /rl HIGHEST /f | Out-Null
Write-Status "Scheduled task created: $TaskName" "OK" Write-Status "Scheduled task created: $TaskName" "OK"
if ($AllowInsecureTLS) {
Write-Status "Task configured with -AllowInsecureTLS flag" "WARN"
}
} catch { } catch {
Write-Status "Failed to create scheduled task: $($_.Exception.Message)" "ERROR" Write-Status "Failed to create scheduled task: $($_.Exception.Message)" "ERROR"
Write-Status "You may need to manually create the task" "WARN" Write-Status "You may need to manually create the task" "WARN"

View File

@@ -45,6 +45,9 @@ thiserror = "1"
# UUID for identifiers # UUID for identifiers
uuid = { version = "1", features = ["v4", "serde"] } uuid = { version = "1", features = ["v4", "serde"] }
# URL parsing for download validation
url = "2"
# SHA256 checksums for update verification # SHA256 checksums for update verification
sha2 = "0.10" sha2 = "0.10"

View File

@@ -457,14 +457,14 @@ WantedBy=multi-user.target
anyhow::bail!("systemctl enable failed"); anyhow::bail!("systemctl enable failed");
} }
println!("\n GuruRMM Agent installed successfully!"); println!("\n[OK] GuruRMM Agent installed successfully!");
println!("\nInstalled files:"); println!("\nInstalled files:");
println!(" Binary: {}", binary_dest); println!(" Binary: {}", binary_dest);
println!(" Config: {}", config_dest); println!(" Config: {}", config_dest);
println!(" Service: {}", unit_file); println!(" Service: {}", unit_file);
if config_needs_manual_edit { if config_needs_manual_edit {
println!("\n⚠️ IMPORTANT: Edit {} with your server URL and API key!", config_dest); println!("\n[WARNING] IMPORTANT: Edit {} with your server URL and API key!", config_dest);
println!("\nNext steps:"); println!("\nNext steps:");
println!(" 1. Edit {} with your server URL and API key", config_dest); println!(" 1. Edit {} with your server URL and API key", config_dest);
println!(" 2. Start the service: sudo systemctl start {}", SERVICE_NAME); println!(" 2. Start the service: sudo systemctl start {}", SERVICE_NAME);
@@ -475,9 +475,9 @@ WantedBy=multi-user.target
.status(); .status();
if status.is_ok() && status.unwrap().success() { if status.is_ok() && status.unwrap().success() {
println!(" Service started successfully!"); println!("[OK] Service started successfully!");
} else { } else {
println!("⚠️ Failed to start service. Check logs: sudo journalctl -u {} -f", SERVICE_NAME); println!("[WARNING] Failed to start service. Check logs: sudo journalctl -u {} -f", SERVICE_NAME);
} }
} }
@@ -556,7 +556,7 @@ async fn uninstall_systemd_service() -> Result<()> {
.args(["daemon-reload"]) .args(["daemon-reload"])
.status(); .status();
println!("\n GuruRMM Agent uninstalled successfully!"); println!("\n[OK] GuruRMM Agent uninstalled successfully!");
println!("\nNote: Config directory {} was preserved.", CONFIG_DIR); println!("\nNote: Config directory {} was preserved.", CONFIG_DIR);
println!("Remove it manually if no longer needed: sudo rm -rf {}", CONFIG_DIR); println!("Remove it manually if no longer needed: sudo rm -rf {}", CONFIG_DIR);
@@ -582,7 +582,7 @@ async fn start_service() -> Result<()> {
.context("Failed to start service")?; .context("Failed to start service")?;
if status.success() { if status.success() {
println!("** Service started successfully"); println!("[OK] Service started successfully");
println!("Check status: sudo systemctl status gururmm-agent"); println!("Check status: sudo systemctl status gururmm-agent");
} else { } else {
anyhow::bail!("Failed to start service. Check: sudo journalctl -u gururmm-agent -n 50"); anyhow::bail!("Failed to start service. Check: sudo journalctl -u gururmm-agent -n 50");
@@ -616,7 +616,7 @@ async fn stop_service() -> Result<()> {
.context("Failed to stop service")?; .context("Failed to stop service")?;
if status.success() { if status.success() {
println!("** Service stopped successfully"); println!("[OK] Service stopped successfully");
} else { } else {
anyhow::bail!("Failed to stop service"); anyhow::bail!("Failed to stop service");
} }

View File

@@ -177,7 +177,36 @@ impl AgentUpdater {
} }
/// Download the new binary to a temp file /// Download the new binary to a temp file
///
/// Security: Validates URL against allowed domains and requires HTTPS for external hosts
async fn download_binary(&self, url: &str) -> Result<PathBuf> { async fn download_binary(&self, url: &str) -> Result<PathBuf> {
// Validate URL is from trusted domain
let allowed_domains = [
"rmm-api.azcomputerguru.com",
"downloads.azcomputerguru.com",
"172.16.3.30", // Internal server
];
let parsed_url = url::Url::parse(url)
.context("Invalid download URL")?;
let host = parsed_url.host_str()
.ok_or_else(|| anyhow::anyhow!("No host in download URL"))?;
if !allowed_domains.iter().any(|d| host == *d || host.ends_with(&format!(".{}", d))) {
return Err(anyhow::anyhow!(
"Download URL host '{}' not in allowed domains",
host
));
}
// Require HTTPS (except for local/internal IPs)
if parsed_url.scheme() != "https" && !host.starts_with("172.16.") && !host.starts_with("192.168.") {
return Err(anyhow::anyhow!("Download URL must use HTTPS"));
}
info!("[OK] URL validation passed: {}", url);
let response = self.http_client.get(url) let response = self.http_client.get(url)
.send() .send()
.await .await
@@ -273,10 +302,26 @@ impl AgentUpdater {
#[cfg(unix)] #[cfg(unix)]
async fn create_unix_rollback_watchdog(&self) -> Result<()> { async fn create_unix_rollback_watchdog(&self) -> Result<()> {
use std::os::unix::fs::PermissionsExt;
let backup_path = self.config.backup_path(); let backup_path = self.config.backup_path();
let binary_path = &self.config.binary_path; let binary_path = &self.config.binary_path;
let timeout = self.config.rollback_timeout_secs; let timeout = self.config.rollback_timeout_secs;
// Use secure directory instead of /tmp/ (world-writable)
let script_dir = PathBuf::from("/var/run/gururmm");
// Create directory if needed with restricted permissions (owner only)
if !script_dir.exists() {
tokio::fs::create_dir_all(&script_dir).await
.context("Failed to create secure script directory")?;
std::fs::set_permissions(&script_dir, std::fs::Permissions::from_mode(0o700))
.context("Failed to set script directory permissions")?;
}
// Use UUID in filename to prevent predictable paths
let script_path = script_dir.join(format!("rollback-{}.sh", Uuid::new_v4()));
let script = format!(r#"#!/bin/bash let script = format!(r#"#!/bin/bash
# GuruRMM Rollback Watchdog # GuruRMM Rollback Watchdog
# Auto-generated - will be deleted after successful update # Auto-generated - will be deleted after successful update
@@ -284,49 +329,50 @@ impl AgentUpdater {
BACKUP="{backup}" BACKUP="{backup}"
BINARY="{binary}" BINARY="{binary}"
TIMEOUT={timeout} TIMEOUT={timeout}
SCRIPT_PATH="{script}"
sleep $TIMEOUT sleep $TIMEOUT
# Check if agent service is running # Check if agent service is running
if ! systemctl is-active --quiet gururmm-agent 2>/dev/null; then if ! systemctl is-active --quiet gururmm-agent 2>/dev/null; then
echo "Agent not running after update, rolling back..." echo "[WARNING] Agent not running after update, rolling back..."
if [ -f "$BACKUP" ]; then if [ -f "$BACKUP" ]; then
cp "$BACKUP" "$BINARY" cp "$BACKUP" "$BINARY"
chmod +x "$BINARY" chmod +x "$BINARY"
systemctl start gururmm-agent systemctl start gururmm-agent
echo "Rollback completed" echo "[OK] Rollback completed"
else else
echo "No backup file found!" echo "[ERROR] No backup file found!"
fi fi
fi fi
# Clean up this script # Clean up this script
rm -f /tmp/gururmm-rollback.sh rm -f "$SCRIPT_PATH"
"#, "#,
backup = backup_path.display(), backup = backup_path.display(),
binary = binary_path.display(), binary = binary_path.display(),
timeout = timeout timeout = timeout,
script = script_path.display()
); );
let script_path = PathBuf::from("/tmp/gururmm-rollback.sh"); fs::write(&script_path, script).await
fs::write(&script_path, script).await?; .context("Failed to write rollback script")?;
// Make executable and run in background // Set restrictive permissions (700 - owner only)
tokio::process::Command::new("chmod") std::fs::set_permissions(&script_path, std::fs::Permissions::from_mode(0o700))
.arg("+x") .context("Failed to set rollback script permissions")?;
.arg(&script_path)
.status()
.await?;
// Spawn as detached background process // Spawn as detached background process using setsid (not nohup with "&" literal arg)
tokio::process::Command::new("nohup") tokio::process::Command::new("setsid")
.arg("bash") .arg("bash")
.arg(&script_path) .arg(&script_path)
.arg("&") .stdin(std::process::Stdio::null())
.stdout(std::process::Stdio::null())
.stderr(std::process::Stdio::null())
.spawn() .spawn()
.context("Failed to spawn rollback watchdog")?; .context("Failed to spawn rollback watchdog")?;
info!("Rollback watchdog started (timeout: {}s)", timeout); info!("[OK] Rollback watchdog started (timeout: {}s)", timeout);
Ok(()) Ok(())
} }
@@ -524,12 +570,29 @@ Remove-Item -Path $MyInvocation.MyCommand.Path -Force
pub async fn cancel_rollback_watchdog(&self) { pub async fn cancel_rollback_watchdog(&self) {
#[cfg(unix)] #[cfg(unix)]
{ {
// Kill the watchdog script // Kill any running rollback watchdog scripts
let _ = tokio::process::Command::new("pkill") let _ = tokio::process::Command::new("pkill")
.args(["-f", "gururmm-rollback.sh"]) .args(["-f", "rollback-.*\\.sh"])
.status() .status()
.await; .await;
let _ = fs::remove_file("/tmp/gururmm-rollback.sh").await;
// Clean up the secure script directory
let script_dir = PathBuf::from("/var/run/gururmm");
if script_dir.exists() {
// Remove all rollback scripts in the directory
if let Ok(mut entries) = tokio::fs::read_dir(&script_dir).await {
while let Ok(Some(entry)) = entries.next_entry().await {
let path = entry.path();
if path.file_name()
.and_then(|n| n.to_str())
.map(|n| n.starts_with("rollback-"))
.unwrap_or(false)
{
let _ = fs::remove_file(&path).await;
}
}
}
}
} }
#[cfg(windows)] #[cfg(windows)]

View File

@@ -1,4 +1,4 @@
import axios from "axios"; import axios, { AxiosError } from "axios";
// Default to production URL, override with VITE_API_URL for local dev // Default to production URL, override with VITE_API_URL for local dev
const API_URL = import.meta.env.VITE_API_URL || "https://rmm-api.azcomputerguru.com"; const API_URL = import.meta.env.VITE_API_URL || "https://rmm-api.azcomputerguru.com";
@@ -10,22 +10,41 @@ export const api = axios.create({
}, },
}); });
// Add auth token to requests // Token management - use sessionStorage (cleared on tab close) instead of localStorage
// This provides better security against XSS attacks as tokens are not persisted
const TOKEN_KEY = "gururmm_auth_token";
export const getToken = (): string | null => {
return sessionStorage.getItem(TOKEN_KEY);
};
export const setToken = (token: string): void => {
sessionStorage.setItem(TOKEN_KEY, token);
};
export const clearToken = (): void => {
sessionStorage.removeItem(TOKEN_KEY);
};
// Request interceptor - add auth header
api.interceptors.request.use((config) => { api.interceptors.request.use((config) => {
const token = localStorage.getItem("token"); const token = getToken();
if (token) { if (token) {
config.headers.Authorization = `Bearer ${token}`; config.headers.Authorization = `Bearer ${token}`;
} }
return config; return config;
}); });
// Handle auth errors // Response interceptor - handle 401 unauthorized
api.interceptors.response.use( api.interceptors.response.use(
(response) => response, (response) => response,
(error) => { (error: AxiosError) => {
if (error.response?.status === 401) { if (error.response?.status === 401) {
localStorage.removeItem("token"); clearToken();
window.location.href = "/login"; // Use a more graceful redirect that preserves SPA state
if (window.location.pathname !== "/login") {
window.location.href = "/login";
}
} }
return Promise.reject(error); return Promise.reject(error);
} }
@@ -156,9 +175,31 @@ export interface RegisterRequest {
// API functions // API functions
export const authApi = { export const authApi = {
login: (data: LoginRequest) => api.post<LoginResponse>("/api/auth/login", data), login: async (data: LoginRequest): Promise<LoginResponse> => {
register: (data: RegisterRequest) => api.post<LoginResponse>("/api/auth/register", data), const response = await api.post<LoginResponse>("/api/auth/login", data);
if (response.data.token) {
setToken(response.data.token);
}
return response.data;
},
register: async (data: RegisterRequest): Promise<LoginResponse> => {
const response = await api.post<LoginResponse>("/api/auth/register", data);
if (response.data.token) {
setToken(response.data.token);
}
return response.data;
},
me: () => api.get<User>("/api/auth/me"), me: () => api.get<User>("/api/auth/me"),
logout: (): void => {
clearToken();
},
isAuthenticated: (): boolean => {
return !!getToken();
},
}; };
export const agentsApi = { export const agentsApi = {

View File

@@ -1,9 +1,9 @@
import { createContext, useContext, useState, useEffect, ReactNode } from "react"; import { createContext, useContext, useState, useEffect, ReactNode } from "react";
import { User, authApi } from "../api/client"; import { User, authApi, getToken, clearToken } from "../api/client";
interface AuthContextType { interface AuthContextType {
user: User | null; user: User | null;
token: string | null; isAuthenticated: boolean;
isLoading: boolean; isLoading: boolean;
login: (email: string, password: string) => Promise<void>; login: (email: string, password: string) => Promise<void>;
register: (email: string, password: string, name?: string) => Promise<void>; register: (email: string, password: string, name?: string) => Promise<void>;
@@ -14,46 +14,49 @@ const AuthContext = createContext<AuthContextType | null>(null);
export function AuthProvider({ children }: { children: ReactNode }) { export function AuthProvider({ children }: { children: ReactNode }) {
const [user, setUser] = useState<User | null>(null); const [user, setUser] = useState<User | null>(null);
const [token, setToken] = useState<string | null>(() => localStorage.getItem("token"));
const [isLoading, setIsLoading] = useState(true); const [isLoading, setIsLoading] = useState(true);
// Check authentication status on mount
useEffect(() => { useEffect(() => {
if (token) { const checkAuth = async () => {
authApi const token = getToken();
.me() if (token) {
.then((res) => setUser(res.data)) try {
.catch(() => { const res = await authApi.me();
localStorage.removeItem("token"); setUser(res.data);
setToken(null); } catch {
}) // Token is invalid or expired, clear it
.finally(() => setIsLoading(false)); clearToken();
} else { setUser(null);
}
}
setIsLoading(false); setIsLoading(false);
} };
}, [token]);
checkAuth();
}, []);
const login = async (email: string, password: string) => { const login = async (email: string, password: string) => {
const res = await authApi.login({ email, password }); const response = await authApi.login({ email, password });
localStorage.setItem("token", res.data.token); // Token is automatically stored by authApi.login
setToken(res.data.token); setUser(response.user);
setUser(res.data.user);
}; };
const register = async (email: string, password: string, name?: string) => { const register = async (email: string, password: string, name?: string) => {
const res = await authApi.register({ email, password, name }); const response = await authApi.register({ email, password, name });
localStorage.setItem("token", res.data.token); // Token is automatically stored by authApi.register
setToken(res.data.token); setUser(response.user);
setUser(res.data.user);
}; };
const logout = () => { const logout = () => {
localStorage.removeItem("token"); authApi.logout();
setToken(null);
setUser(null); setUser(null);
}; };
const isAuthenticated = authApi.isAuthenticated();
return ( return (
<AuthContext.Provider value={{ user, token, isLoading, login, register, logout }}> <AuthContext.Provider value={{ user, isAuthenticated, isLoading, login, register, logout }}>
{children} {children}
</AuthContext.Provider> </AuthContext.Provider>
); );

View File

@@ -1,10 +1,16 @@
import { useState, FormEvent } from "react"; import { useState, FormEvent } from "react";
import { Link, useNavigate } from "react-router-dom"; import { Link, useNavigate } from "react-router-dom";
import { AxiosError } from "axios";
import { useAuth } from "../hooks/useAuth"; import { useAuth } from "../hooks/useAuth";
import { Card, CardHeader, CardTitle, CardDescription, CardContent } from "../components/Card"; import { Card, CardHeader, CardTitle, CardDescription, CardContent } from "../components/Card";
import { Input } from "../components/Input"; import { Input } from "../components/Input";
import { Button } from "../components/Button"; import { Button } from "../components/Button";
interface ApiErrorResponse {
error?: string;
message?: string;
}
export function Login() { export function Login() {
const [email, setEmail] = useState(""); const [email, setEmail] = useState("");
const [password, setPassword] = useState(""); const [password, setPassword] = useState("");
@@ -21,8 +27,15 @@ export function Login() {
try { try {
await login(email, password); await login(email, password);
navigate("/"); navigate("/");
} catch (err: any) { } catch (err) {
setError(err.response?.data?.error || "Login failed. Please try again."); if (err instanceof AxiosError) {
const errorData = err.response?.data as ApiErrorResponse | undefined;
setError(errorData?.error || errorData?.message || err.message || "Login failed. Please try again.");
} else if (err instanceof Error) {
setError(err.message);
} else {
setError("An unexpected error occurred");
}
} finally { } finally {
setIsLoading(false); setIsLoading(false);
} }

View File

@@ -1,10 +1,16 @@
import { useState, FormEvent } from "react"; import { useState, FormEvent } from "react";
import { Link, useNavigate } from "react-router-dom"; import { Link, useNavigate } from "react-router-dom";
import { AxiosError } from "axios";
import { useAuth } from "../hooks/useAuth"; import { useAuth } from "../hooks/useAuth";
import { Card, CardHeader, CardTitle, CardDescription, CardContent } from "../components/Card"; import { Card, CardHeader, CardTitle, CardDescription, CardContent } from "../components/Card";
import { Input } from "../components/Input"; import { Input } from "../components/Input";
import { Button } from "../components/Button"; import { Button } from "../components/Button";
interface ApiErrorResponse {
error?: string;
message?: string;
}
export function Register() { export function Register() {
const [email, setEmail] = useState(""); const [email, setEmail] = useState("");
const [password, setPassword] = useState(""); const [password, setPassword] = useState("");
@@ -34,8 +40,15 @@ export function Register() {
try { try {
await register(email, password, name || undefined); await register(email, password, name || undefined);
navigate("/"); navigate("/");
} catch (err: any) { } catch (err) {
setError(err.response?.data?.error || "Registration failed. Please try again."); if (err instanceof AxiosError) {
const errorData = err.response?.data as ApiErrorResponse | undefined;
setError(errorData?.error || errorData?.message || err.message || "Registration failed. Please try again.");
} else if (err instanceof Error) {
setError(err.message);
} else {
setError("An unexpected error occurred");
}
} finally { } finally {
setIsLoading(false); setIsLoading(false);
} }

File diff suppressed because it is too large Load Diff

View File

@@ -11,6 +11,7 @@ axum = { version = "0.7", features = ["ws", "macros"] }
axum-extra = { version = "0.9", features = ["typed-header"] } axum-extra = { version = "0.9", features = ["typed-header"] }
tower = { version = "0.5", features = ["util", "timeout"] } tower = { version = "0.5", features = ["util", "timeout"] }
tower-http = { version = "0.6", features = ["cors", "trace", "compression-gzip"] } tower-http = { version = "0.6", features = ["cors", "trace", "compression-gzip"] }
http = "1"
# Async runtime # Async runtime
tokio = { version = "1", features = ["full"] } tokio = { version = "1", features = ["full"] }

View File

@@ -8,6 +8,7 @@ use axum::{
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use uuid::Uuid; use uuid::Uuid;
use crate::auth::AuthUser;
use crate::db::{self, AgentResponse, AgentStats}; use crate::db::{self, AgentResponse, AgentStats};
use crate::ws::{generate_api_key, hash_api_key}; use crate::ws::{generate_api_key, hash_api_key};
use crate::AppState; use crate::AppState;
@@ -29,10 +30,20 @@ pub struct RegisterAgentRequest {
} }
/// Register a new agent (generates API key) /// Register a new agent (generates API key)
/// Requires authentication to prevent unauthorized agent registration.
pub async fn register_agent( pub async fn register_agent(
State(state): State<AppState>, State(state): State<AppState>,
user: AuthUser,
Json(req): Json<RegisterAgentRequest>, Json(req): Json<RegisterAgentRequest>,
) -> Result<Json<RegisterAgentResponse>, (StatusCode, String)> { ) -> Result<Json<RegisterAgentResponse>, (StatusCode, String)> {
// Log who is registering the agent
tracing::info!(
user_id = %user.user_id,
hostname = %req.hostname,
os_type = %req.os_type,
"Agent registration initiated by user"
);
// Generate a new API key // Generate a new API key
let api_key = generate_api_key(&state.config.auth.api_key_prefix); let api_key = generate_api_key(&state.config.auth.api_key_prefix);
let api_key_hash = hash_api_key(&api_key); let api_key_hash = hash_api_key(&api_key);
@@ -50,6 +61,12 @@ pub async fn register_agent(
.await .await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
tracing::info!(
user_id = %user.user_id,
agent_id = %agent.id,
"Agent registered successfully"
);
Ok(Json(RegisterAgentResponse { Ok(Json(RegisterAgentResponse {
agent_id: agent.id, agent_id: agent.id,
api_key, // Return the plain API key (only shown once!) api_key, // Return the plain API key (only shown once!)
@@ -59,8 +76,10 @@ pub async fn register_agent(
} }
/// List all agents /// List all agents
/// Requires authentication.
pub async fn list_agents( pub async fn list_agents(
State(state): State<AppState>, State(state): State<AppState>,
_user: AuthUser,
) -> Result<Json<Vec<AgentResponse>>, (StatusCode, String)> { ) -> Result<Json<Vec<AgentResponse>>, (StatusCode, String)> {
let agents = db::get_all_agents(&state.db) let agents = db::get_all_agents(&state.db)
.await .await
@@ -71,8 +90,10 @@ pub async fn list_agents(
} }
/// Get a specific agent /// Get a specific agent
/// Requires authentication.
pub async fn get_agent( pub async fn get_agent(
State(state): State<AppState>, State(state): State<AppState>,
_user: AuthUser,
Path(id): Path<Uuid>, Path(id): Path<Uuid>,
) -> Result<Json<AgentResponse>, (StatusCode, String)> { ) -> Result<Json<AgentResponse>, (StatusCode, String)> {
let agent = db::get_agent_by_id(&state.db, id) let agent = db::get_agent_by_id(&state.db, id)
@@ -84,8 +105,10 @@ pub async fn get_agent(
} }
/// Delete an agent /// Delete an agent
/// Requires authentication.
pub async fn delete_agent( pub async fn delete_agent(
State(state): State<AppState>, State(state): State<AppState>,
_user: AuthUser,
Path(id): Path<Uuid>, Path(id): Path<Uuid>,
) -> Result<StatusCode, (StatusCode, String)> { ) -> Result<StatusCode, (StatusCode, String)> {
// Check if agent is connected and disconnect it // Check if agent is connected and disconnect it
@@ -106,8 +129,10 @@ pub async fn delete_agent(
} }
/// Get agent statistics /// Get agent statistics
/// Requires authentication.
pub async fn get_stats( pub async fn get_stats(
State(state): State<AppState>, State(state): State<AppState>,
_user: AuthUser,
) -> Result<Json<AgentStats>, (StatusCode, String)> { ) -> Result<Json<AgentStats>, (StatusCode, String)> {
let stats = db::get_agent_stats(&state.db) let stats = db::get_agent_stats(&state.db)
.await .await
@@ -123,8 +148,10 @@ pub struct MoveAgentRequest {
} }
/// Move an agent to a different site /// Move an agent to a different site
/// Requires authentication.
pub async fn move_agent( pub async fn move_agent(
State(state): State<AppState>, State(state): State<AppState>,
_user: AuthUser,
Path(id): Path<Uuid>, Path(id): Path<Uuid>,
Json(req): Json<MoveAgentRequest>, Json(req): Json<MoveAgentRequest>,
) -> Result<Json<AgentResponse>, (StatusCode, String)> { ) -> Result<Json<AgentResponse>, (StatusCode, String)> {
@@ -149,8 +176,10 @@ pub async fn move_agent(
} }
/// List all agents with full details (site/client info) /// List all agents with full details (site/client info)
/// Requires authentication.
pub async fn list_agents_with_details( pub async fn list_agents_with_details(
State(state): State<AppState>, State(state): State<AppState>,
_user: AuthUser,
) -> Result<Json<Vec<db::AgentWithDetails>>, (StatusCode, String)> { ) -> Result<Json<Vec<db::AgentWithDetails>>, (StatusCode, String)> {
let agents = db::get_all_agents_with_details(&state.db) let agents = db::get_all_agents_with_details(&state.db)
.await .await
@@ -160,8 +189,10 @@ pub async fn list_agents_with_details(
} }
/// List unassigned agents (not belonging to any site) /// List unassigned agents (not belonging to any site)
/// Requires authentication.
pub async fn list_unassigned_agents( pub async fn list_unassigned_agents(
State(state): State<AppState>, State(state): State<AppState>,
_user: AuthUser,
) -> Result<Json<Vec<AgentResponse>>, (StatusCode, String)> { ) -> Result<Json<Vec<AgentResponse>>, (StatusCode, String)> {
let agents = db::get_unassigned_agents(&state.db) let agents = db::get_unassigned_agents(&state.db)
.await .await
@@ -172,8 +203,10 @@ pub async fn list_unassigned_agents(
} }
/// Get extended state for an agent (network interfaces, uptime, etc.) /// Get extended state for an agent (network interfaces, uptime, etc.)
/// Requires authentication.
pub async fn get_agent_state( pub async fn get_agent_state(
State(state): State<AppState>, State(state): State<AppState>,
_user: AuthUser,
Path(id): Path<Uuid>, Path(id): Path<Uuid>,
) -> Result<Json<db::AgentState>, (StatusCode, String)> { ) -> Result<Json<db::AgentState>, (StatusCode, String)> {
let agent_state = db::get_agent_state(&state.db, id) let agent_state = db::get_agent_state(&state.db, id)

View File

@@ -8,6 +8,7 @@ use axum::{
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use uuid::Uuid; use uuid::Uuid;
use crate::auth::AuthUser;
use crate::db::{self, Command}; use crate::db::{self, Command};
use crate::ws::{CommandPayload, ServerMessage}; use crate::ws::{CommandPayload, ServerMessage};
use crate::AppState; use crate::AppState;
@@ -43,23 +44,33 @@ pub struct CommandsQuery {
} }
/// Send a command to an agent /// Send a command to an agent
/// Requires authentication. Logs the user who sent the command for audit trail.
pub async fn send_command( pub async fn send_command(
State(state): State<AppState>, State(state): State<AppState>,
user: AuthUser,
Path(agent_id): Path<Uuid>, Path(agent_id): Path<Uuid>,
Json(req): Json<SendCommandRequest>, Json(req): Json<SendCommandRequest>,
) -> Result<Json<SendCommandResponse>, (StatusCode, String)> { ) -> Result<Json<SendCommandResponse>, (StatusCode, String)> {
// Log the command being sent for audit trail
tracing::info!(
user_id = %user.user_id,
agent_id = %agent_id,
command_type = %req.command_type,
"Command sent by user"
);
// Verify agent exists // Verify agent exists
let agent = db::get_agent_by_id(&state.db, agent_id) let _agent = db::get_agent_by_id(&state.db, agent_id)
.await .await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))? .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
.ok_or((StatusCode::NOT_FOUND, "Agent not found".to_string()))?; .ok_or((StatusCode::NOT_FOUND, "Agent not found".to_string()))?;
// Create command record // Create command record with user ID for audit trail
let create = db::CreateCommand { let create = db::CreateCommand {
agent_id, agent_id,
command_type: req.command_type.clone(), command_type: req.command_type.clone(),
command_text: req.command.clone(), command_text: req.command.clone(),
created_by: None, // TODO: Get from JWT created_by: Some(user.user_id),
}; };
let command = db::create_command(&state.db, create) let command = db::create_command(&state.db, create)
@@ -100,8 +111,10 @@ pub async fn send_command(
} }
/// List recent commands /// List recent commands
/// Requires authentication.
pub async fn list_commands( pub async fn list_commands(
State(state): State<AppState>, State(state): State<AppState>,
_user: AuthUser,
Query(query): Query<CommandsQuery>, Query(query): Query<CommandsQuery>,
) -> Result<Json<Vec<Command>>, (StatusCode, String)> { ) -> Result<Json<Vec<Command>>, (StatusCode, String)> {
let limit = query.limit.unwrap_or(50).min(500); let limit = query.limit.unwrap_or(50).min(500);
@@ -114,8 +127,10 @@ pub async fn list_commands(
} }
/// Get a specific command by ID /// Get a specific command by ID
/// Requires authentication.
pub async fn get_command( pub async fn get_command(
State(state): State<AppState>, State(state): State<AppState>,
_user: AuthUser,
Path(id): Path<Uuid>, Path(id): Path<Uuid>,
) -> Result<Json<Command>, (StatusCode, String)> { ) -> Result<Json<Command>, (StatusCode, String)> {
let command = db::get_command_by_id(&state.db, id) let command = db::get_command_by_id(&state.db, id)

View File

@@ -5,10 +5,11 @@ use axum::{
http::StatusCode, http::StatusCode,
Json, Json,
}; };
use chrono::{DateTime, Duration, Utc}; use chrono::{DateTime, Utc};
use serde::Deserialize; use serde::Deserialize;
use uuid::Uuid; use uuid::Uuid;
use crate::auth::AuthUser;
use crate::db::{self, Metrics, MetricsSummary}; use crate::db::{self, Metrics, MetricsSummary};
use crate::AppState; use crate::AppState;
@@ -26,13 +27,15 @@ pub struct MetricsQuery {
} }
/// Get metrics for a specific agent /// Get metrics for a specific agent
/// Requires authentication.
pub async fn get_agent_metrics( pub async fn get_agent_metrics(
State(state): State<AppState>, State(state): State<AppState>,
_user: AuthUser,
Path(id): Path<Uuid>, Path(id): Path<Uuid>,
Query(query): Query<MetricsQuery>, Query(query): Query<MetricsQuery>,
) -> Result<Json<Vec<Metrics>>, (StatusCode, String)> { ) -> Result<Json<Vec<Metrics>>, (StatusCode, String)> {
// First verify the agent exists // First verify the agent exists
let agent = db::get_agent_by_id(&state.db, id) let _agent = db::get_agent_by_id(&state.db, id)
.await .await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))? .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
.ok_or((StatusCode::NOT_FOUND, "Agent not found".to_string()))?; .ok_or((StatusCode::NOT_FOUND, "Agent not found".to_string()))?;
@@ -54,8 +57,10 @@ pub async fn get_agent_metrics(
} }
/// Get summary metrics across all agents /// Get summary metrics across all agents
/// Requires authentication.
pub async fn get_summary( pub async fn get_summary(
State(state): State<AppState>, State(state): State<AppState>,
_user: AuthUser,
) -> Result<Json<MetricsSummary>, (StatusCode, String)> { ) -> Result<Json<MetricsSummary>, (StatusCode, String)> {
let summary = db::get_metrics_summary(&state.db) let summary = db::get_metrics_summary(&state.db)
.await .await

View File

@@ -24,7 +24,8 @@ use axum::{
}; };
use sqlx::postgres::PgPoolOptions; use sqlx::postgres::PgPoolOptions;
use tokio::sync::RwLock; use tokio::sync::RwLock;
use tower_http::cors::{Any, CorsLayer}; use http::HeaderValue;
use tower_http::cors::{AllowOrigin, CorsLayer};
use tower_http::trace::TraceLayer; use tower_http::trace::TraceLayer;
use tracing::info; use tracing::info;
@@ -129,11 +130,34 @@ async fn main() -> Result<()> {
/// Build the application router /// Build the application router
fn build_router(state: AppState) -> Router { fn build_router(state: AppState) -> Router {
// CORS configuration (allow dashboard access) // TODO: Add rate limiting for registration endpoints using tower-governor
// Currently, registration is protected by AuthUser authentication.
// For additional protection against brute-force attacks, consider adding:
// - tower-governor crate for per-IP rate limiting on /api/agents/register
// - Configurable limits via environment variables
// Reference: https://docs.rs/tower-governor/latest/tower_governor/
// CORS configuration - restrict to specific dashboard origin
let dashboard_origin = std::env::var("DASHBOARD_URL")
.unwrap_or_else(|_| "https://rmm.azcomputerguru.com".to_string());
let cors = CorsLayer::new() let cors = CorsLayer::new()
.allow_origin(Any) .allow_origin(AllowOrigin::exact(
.allow_methods(Any) HeaderValue::from_str(&dashboard_origin).expect("Invalid DASHBOARD_URL"),
.allow_headers(Any); ))
.allow_methods([
http::Method::GET,
http::Method::POST,
http::Method::PUT,
http::Method::DELETE,
http::Method::OPTIONS,
])
.allow_headers([
http::header::AUTHORIZATION,
http::header::CONTENT_TYPE,
http::header::ACCEPT,
])
.allow_credentials(true);
Router::new() Router::new()
// Health check // Health check