fix(security): Implement Phase 1 critical security fixes

CORS:
- Restrict CORS to DASHBOARD_URL environment variable
- Default to production dashboard domain

Authentication:
- Add AuthUser requirement to all agent management endpoints
- Add AuthUser requirement to all command endpoints
- Add AuthUser requirement to all metrics endpoints
- Add audit logging for command execution (user_id tracked)

Agent Security:
- Replace Unicode characters with ASCII markers [OK]/[ERROR]/[WARNING]
- Add certificate pinning for update downloads (allowlist domains)
- Fix insecure temp file creation (use /var/run/gururmm with 0700 perms)
- Fix rollback script backgrounding (use setsid instead of literal &)

Dashboard Security:
- Move token storage from localStorage to sessionStorage
- Add proper TypeScript types (remove 'any' from error handlers)
- Centralize token management functions

Legacy Agent:
- Add -AllowInsecureTLS parameter (opt-in required)
- Add Windows Event Log audit trail when insecure mode used
- Update documentation with security warnings

Closes: Phase 1 items in issue #1

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-20 21:16:24 -07:00
parent 6d3271c144
commit 65086f4407
15 changed files with 1708 additions and 99 deletions

View File

@@ -11,6 +11,7 @@ axum = { version = "0.7", features = ["ws", "macros"] }
axum-extra = { version = "0.9", features = ["typed-header"] }
tower = { version = "0.5", features = ["util", "timeout"] }
tower-http = { version = "0.6", features = ["cors", "trace", "compression-gzip"] }
http = "1"
# Async runtime
tokio = { version = "1", features = ["full"] }

View File

@@ -8,6 +8,7 @@ use axum::{
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::auth::AuthUser;
use crate::db::{self, AgentResponse, AgentStats};
use crate::ws::{generate_api_key, hash_api_key};
use crate::AppState;
@@ -29,10 +30,20 @@ pub struct RegisterAgentRequest {
}
/// Register a new agent (generates API key)
/// Requires authentication to prevent unauthorized agent registration.
pub async fn register_agent(
State(state): State<AppState>,
user: AuthUser,
Json(req): Json<RegisterAgentRequest>,
) -> Result<Json<RegisterAgentResponse>, (StatusCode, String)> {
// Log who is registering the agent
tracing::info!(
user_id = %user.user_id,
hostname = %req.hostname,
os_type = %req.os_type,
"Agent registration initiated by user"
);
// Generate a new API key
let api_key = generate_api_key(&state.config.auth.api_key_prefix);
let api_key_hash = hash_api_key(&api_key);
@@ -50,6 +61,12 @@ pub async fn register_agent(
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
tracing::info!(
user_id = %user.user_id,
agent_id = %agent.id,
"Agent registered successfully"
);
Ok(Json(RegisterAgentResponse {
agent_id: agent.id,
api_key, // Return the plain API key (only shown once!)
@@ -59,8 +76,10 @@ pub async fn register_agent(
}
/// List all agents
/// Requires authentication.
pub async fn list_agents(
State(state): State<AppState>,
_user: AuthUser,
) -> Result<Json<Vec<AgentResponse>>, (StatusCode, String)> {
let agents = db::get_all_agents(&state.db)
.await
@@ -71,8 +90,10 @@ pub async fn list_agents(
}
/// Get a specific agent
/// Requires authentication.
pub async fn get_agent(
State(state): State<AppState>,
_user: AuthUser,
Path(id): Path<Uuid>,
) -> Result<Json<AgentResponse>, (StatusCode, String)> {
let agent = db::get_agent_by_id(&state.db, id)
@@ -84,8 +105,10 @@ pub async fn get_agent(
}
/// Delete an agent
/// Requires authentication.
pub async fn delete_agent(
State(state): State<AppState>,
_user: AuthUser,
Path(id): Path<Uuid>,
) -> Result<StatusCode, (StatusCode, String)> {
// Check if agent is connected and disconnect it
@@ -106,8 +129,10 @@ pub async fn delete_agent(
}
/// Get agent statistics
/// Requires authentication.
pub async fn get_stats(
State(state): State<AppState>,
_user: AuthUser,
) -> Result<Json<AgentStats>, (StatusCode, String)> {
let stats = db::get_agent_stats(&state.db)
.await
@@ -123,8 +148,10 @@ pub struct MoveAgentRequest {
}
/// Move an agent to a different site
/// Requires authentication.
pub async fn move_agent(
State(state): State<AppState>,
_user: AuthUser,
Path(id): Path<Uuid>,
Json(req): Json<MoveAgentRequest>,
) -> Result<Json<AgentResponse>, (StatusCode, String)> {
@@ -149,8 +176,10 @@ pub async fn move_agent(
}
/// List all agents with full details (site/client info)
/// Requires authentication.
pub async fn list_agents_with_details(
State(state): State<AppState>,
_user: AuthUser,
) -> Result<Json<Vec<db::AgentWithDetails>>, (StatusCode, String)> {
let agents = db::get_all_agents_with_details(&state.db)
.await
@@ -160,8 +189,10 @@ pub async fn list_agents_with_details(
}
/// List unassigned agents (not belonging to any site)
/// Requires authentication.
pub async fn list_unassigned_agents(
State(state): State<AppState>,
_user: AuthUser,
) -> Result<Json<Vec<AgentResponse>>, (StatusCode, String)> {
let agents = db::get_unassigned_agents(&state.db)
.await
@@ -172,8 +203,10 @@ pub async fn list_unassigned_agents(
}
/// Get extended state for an agent (network interfaces, uptime, etc.)
/// Requires authentication.
pub async fn get_agent_state(
State(state): State<AppState>,
_user: AuthUser,
Path(id): Path<Uuid>,
) -> Result<Json<db::AgentState>, (StatusCode, String)> {
let agent_state = db::get_agent_state(&state.db, id)

View File

@@ -8,6 +8,7 @@ use axum::{
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::auth::AuthUser;
use crate::db::{self, Command};
use crate::ws::{CommandPayload, ServerMessage};
use crate::AppState;
@@ -43,23 +44,33 @@ pub struct CommandsQuery {
}
/// Send a command to an agent
/// Requires authentication. Logs the user who sent the command for audit trail.
pub async fn send_command(
State(state): State<AppState>,
user: AuthUser,
Path(agent_id): Path<Uuid>,
Json(req): Json<SendCommandRequest>,
) -> Result<Json<SendCommandResponse>, (StatusCode, String)> {
// Log the command being sent for audit trail
tracing::info!(
user_id = %user.user_id,
agent_id = %agent_id,
command_type = %req.command_type,
"Command sent by user"
);
// Verify agent exists
let agent = db::get_agent_by_id(&state.db, agent_id)
let _agent = db::get_agent_by_id(&state.db, agent_id)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
.ok_or((StatusCode::NOT_FOUND, "Agent not found".to_string()))?;
// Create command record
// Create command record with user ID for audit trail
let create = db::CreateCommand {
agent_id,
command_type: req.command_type.clone(),
command_text: req.command.clone(),
created_by: None, // TODO: Get from JWT
created_by: Some(user.user_id),
};
let command = db::create_command(&state.db, create)
@@ -100,8 +111,10 @@ pub async fn send_command(
}
/// List recent commands
/// Requires authentication.
pub async fn list_commands(
State(state): State<AppState>,
_user: AuthUser,
Query(query): Query<CommandsQuery>,
) -> Result<Json<Vec<Command>>, (StatusCode, String)> {
let limit = query.limit.unwrap_or(50).min(500);
@@ -114,8 +127,10 @@ pub async fn list_commands(
}
/// Get a specific command by ID
/// Requires authentication.
pub async fn get_command(
State(state): State<AppState>,
_user: AuthUser,
Path(id): Path<Uuid>,
) -> Result<Json<Command>, (StatusCode, String)> {
let command = db::get_command_by_id(&state.db, id)

View File

@@ -5,10 +5,11 @@ use axum::{
http::StatusCode,
Json,
};
use chrono::{DateTime, Duration, Utc};
use chrono::{DateTime, Utc};
use serde::Deserialize;
use uuid::Uuid;
use crate::auth::AuthUser;
use crate::db::{self, Metrics, MetricsSummary};
use crate::AppState;
@@ -26,13 +27,15 @@ pub struct MetricsQuery {
}
/// Get metrics for a specific agent
/// Requires authentication.
pub async fn get_agent_metrics(
State(state): State<AppState>,
_user: AuthUser,
Path(id): Path<Uuid>,
Query(query): Query<MetricsQuery>,
) -> Result<Json<Vec<Metrics>>, (StatusCode, String)> {
// First verify the agent exists
let agent = db::get_agent_by_id(&state.db, id)
let _agent = db::get_agent_by_id(&state.db, id)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
.ok_or((StatusCode::NOT_FOUND, "Agent not found".to_string()))?;
@@ -54,8 +57,10 @@ pub async fn get_agent_metrics(
}
/// Get summary metrics across all agents
/// Requires authentication.
pub async fn get_summary(
State(state): State<AppState>,
_user: AuthUser,
) -> Result<Json<MetricsSummary>, (StatusCode, String)> {
let summary = db::get_metrics_summary(&state.db)
.await

View File

@@ -24,7 +24,8 @@ use axum::{
};
use sqlx::postgres::PgPoolOptions;
use tokio::sync::RwLock;
use tower_http::cors::{Any, CorsLayer};
use http::HeaderValue;
use tower_http::cors::{AllowOrigin, CorsLayer};
use tower_http::trace::TraceLayer;
use tracing::info;
@@ -129,11 +130,34 @@ async fn main() -> Result<()> {
/// Build the application router
fn build_router(state: AppState) -> Router {
// CORS configuration (allow dashboard access)
// TODO: Add rate limiting for registration endpoints using tower-governor
// Currently, registration is protected by AuthUser authentication.
// For additional protection against brute-force attacks, consider adding:
// - tower-governor crate for per-IP rate limiting on /api/agents/register
// - Configurable limits via environment variables
// Reference: https://docs.rs/tower-governor/latest/tower_governor/
// CORS configuration - restrict to specific dashboard origin
let dashboard_origin = std::env::var("DASHBOARD_URL")
.unwrap_or_else(|_| "https://rmm.azcomputerguru.com".to_string());
let cors = CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any);
.allow_origin(AllowOrigin::exact(
HeaderValue::from_str(&dashboard_origin).expect("Invalid DASHBOARD_URL"),
))
.allow_methods([
http::Method::GET,
http::Method::POST,
http::Method::PUT,
http::Method::DELETE,
http::Method::OPTIONS,
])
.allow_headers([
http::header::AUTHORIZATION,
http::header::CONTENT_TYPE,
http::header::ACCEPT,
])
.allow_credentials(true);
Router::new()
// Health check