fix(security): Implement Phase 1 critical security fixes
CORS: - Restrict CORS to DASHBOARD_URL environment variable - Default to production dashboard domain Authentication: - Add AuthUser requirement to all agent management endpoints - Add AuthUser requirement to all command endpoints - Add AuthUser requirement to all metrics endpoints - Add audit logging for command execution (user_id tracked) Agent Security: - Replace Unicode characters with ASCII markers [OK]/[ERROR]/[WARNING] - Add certificate pinning for update downloads (allowlist domains) - Fix insecure temp file creation (use /var/run/gururmm with 0700 perms) - Fix rollback script backgrounding (use setsid instead of literal &) Dashboard Security: - Move token storage from localStorage to sessionStorage - Add proper TypeScript types (remove 'any' from error handlers) - Centralize token management functions Legacy Agent: - Add -AllowInsecureTLS parameter (opt-in required) - Add Windows Event Log audit trail when insecure mode used - Update documentation with security warnings Closes: Phase 1 items in issue #1 Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -11,6 +11,7 @@ axum = { version = "0.7", features = ["ws", "macros"] }
|
||||
axum-extra = { version = "0.9", features = ["typed-header"] }
|
||||
tower = { version = "0.5", features = ["util", "timeout"] }
|
||||
tower-http = { version = "0.6", features = ["cors", "trace", "compression-gzip"] }
|
||||
http = "1"
|
||||
|
||||
# Async runtime
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
|
||||
@@ -8,6 +8,7 @@ use axum::{
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::auth::AuthUser;
|
||||
use crate::db::{self, AgentResponse, AgentStats};
|
||||
use crate::ws::{generate_api_key, hash_api_key};
|
||||
use crate::AppState;
|
||||
@@ -29,10 +30,20 @@ pub struct RegisterAgentRequest {
|
||||
}
|
||||
|
||||
/// Register a new agent (generates API key)
|
||||
/// Requires authentication to prevent unauthorized agent registration.
|
||||
pub async fn register_agent(
|
||||
State(state): State<AppState>,
|
||||
user: AuthUser,
|
||||
Json(req): Json<RegisterAgentRequest>,
|
||||
) -> Result<Json<RegisterAgentResponse>, (StatusCode, String)> {
|
||||
// Log who is registering the agent
|
||||
tracing::info!(
|
||||
user_id = %user.user_id,
|
||||
hostname = %req.hostname,
|
||||
os_type = %req.os_type,
|
||||
"Agent registration initiated by user"
|
||||
);
|
||||
|
||||
// Generate a new API key
|
||||
let api_key = generate_api_key(&state.config.auth.api_key_prefix);
|
||||
let api_key_hash = hash_api_key(&api_key);
|
||||
@@ -50,6 +61,12 @@ pub async fn register_agent(
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
|
||||
tracing::info!(
|
||||
user_id = %user.user_id,
|
||||
agent_id = %agent.id,
|
||||
"Agent registered successfully"
|
||||
);
|
||||
|
||||
Ok(Json(RegisterAgentResponse {
|
||||
agent_id: agent.id,
|
||||
api_key, // Return the plain API key (only shown once!)
|
||||
@@ -59,8 +76,10 @@ pub async fn register_agent(
|
||||
}
|
||||
|
||||
/// List all agents
|
||||
/// Requires authentication.
|
||||
pub async fn list_agents(
|
||||
State(state): State<AppState>,
|
||||
_user: AuthUser,
|
||||
) -> Result<Json<Vec<AgentResponse>>, (StatusCode, String)> {
|
||||
let agents = db::get_all_agents(&state.db)
|
||||
.await
|
||||
@@ -71,8 +90,10 @@ pub async fn list_agents(
|
||||
}
|
||||
|
||||
/// Get a specific agent
|
||||
/// Requires authentication.
|
||||
pub async fn get_agent(
|
||||
State(state): State<AppState>,
|
||||
_user: AuthUser,
|
||||
Path(id): Path<Uuid>,
|
||||
) -> Result<Json<AgentResponse>, (StatusCode, String)> {
|
||||
let agent = db::get_agent_by_id(&state.db, id)
|
||||
@@ -84,8 +105,10 @@ pub async fn get_agent(
|
||||
}
|
||||
|
||||
/// Delete an agent
|
||||
/// Requires authentication.
|
||||
pub async fn delete_agent(
|
||||
State(state): State<AppState>,
|
||||
_user: AuthUser,
|
||||
Path(id): Path<Uuid>,
|
||||
) -> Result<StatusCode, (StatusCode, String)> {
|
||||
// Check if agent is connected and disconnect it
|
||||
@@ -106,8 +129,10 @@ pub async fn delete_agent(
|
||||
}
|
||||
|
||||
/// Get agent statistics
|
||||
/// Requires authentication.
|
||||
pub async fn get_stats(
|
||||
State(state): State<AppState>,
|
||||
_user: AuthUser,
|
||||
) -> Result<Json<AgentStats>, (StatusCode, String)> {
|
||||
let stats = db::get_agent_stats(&state.db)
|
||||
.await
|
||||
@@ -123,8 +148,10 @@ pub struct MoveAgentRequest {
|
||||
}
|
||||
|
||||
/// Move an agent to a different site
|
||||
/// Requires authentication.
|
||||
pub async fn move_agent(
|
||||
State(state): State<AppState>,
|
||||
_user: AuthUser,
|
||||
Path(id): Path<Uuid>,
|
||||
Json(req): Json<MoveAgentRequest>,
|
||||
) -> Result<Json<AgentResponse>, (StatusCode, String)> {
|
||||
@@ -149,8 +176,10 @@ pub async fn move_agent(
|
||||
}
|
||||
|
||||
/// List all agents with full details (site/client info)
|
||||
/// Requires authentication.
|
||||
pub async fn list_agents_with_details(
|
||||
State(state): State<AppState>,
|
||||
_user: AuthUser,
|
||||
) -> Result<Json<Vec<db::AgentWithDetails>>, (StatusCode, String)> {
|
||||
let agents = db::get_all_agents_with_details(&state.db)
|
||||
.await
|
||||
@@ -160,8 +189,10 @@ pub async fn list_agents_with_details(
|
||||
}
|
||||
|
||||
/// List unassigned agents (not belonging to any site)
|
||||
/// Requires authentication.
|
||||
pub async fn list_unassigned_agents(
|
||||
State(state): State<AppState>,
|
||||
_user: AuthUser,
|
||||
) -> Result<Json<Vec<AgentResponse>>, (StatusCode, String)> {
|
||||
let agents = db::get_unassigned_agents(&state.db)
|
||||
.await
|
||||
@@ -172,8 +203,10 @@ pub async fn list_unassigned_agents(
|
||||
}
|
||||
|
||||
/// Get extended state for an agent (network interfaces, uptime, etc.)
|
||||
/// Requires authentication.
|
||||
pub async fn get_agent_state(
|
||||
State(state): State<AppState>,
|
||||
_user: AuthUser,
|
||||
Path(id): Path<Uuid>,
|
||||
) -> Result<Json<db::AgentState>, (StatusCode, String)> {
|
||||
let agent_state = db::get_agent_state(&state.db, id)
|
||||
|
||||
@@ -8,6 +8,7 @@ use axum::{
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::auth::AuthUser;
|
||||
use crate::db::{self, Command};
|
||||
use crate::ws::{CommandPayload, ServerMessage};
|
||||
use crate::AppState;
|
||||
@@ -43,23 +44,33 @@ pub struct CommandsQuery {
|
||||
}
|
||||
|
||||
/// Send a command to an agent
|
||||
/// Requires authentication. Logs the user who sent the command for audit trail.
|
||||
pub async fn send_command(
|
||||
State(state): State<AppState>,
|
||||
user: AuthUser,
|
||||
Path(agent_id): Path<Uuid>,
|
||||
Json(req): Json<SendCommandRequest>,
|
||||
) -> Result<Json<SendCommandResponse>, (StatusCode, String)> {
|
||||
// Log the command being sent for audit trail
|
||||
tracing::info!(
|
||||
user_id = %user.user_id,
|
||||
agent_id = %agent_id,
|
||||
command_type = %req.command_type,
|
||||
"Command sent by user"
|
||||
);
|
||||
|
||||
// Verify agent exists
|
||||
let agent = db::get_agent_by_id(&state.db, agent_id)
|
||||
let _agent = db::get_agent_by_id(&state.db, agent_id)
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
|
||||
.ok_or((StatusCode::NOT_FOUND, "Agent not found".to_string()))?;
|
||||
|
||||
// Create command record
|
||||
// Create command record with user ID for audit trail
|
||||
let create = db::CreateCommand {
|
||||
agent_id,
|
||||
command_type: req.command_type.clone(),
|
||||
command_text: req.command.clone(),
|
||||
created_by: None, // TODO: Get from JWT
|
||||
created_by: Some(user.user_id),
|
||||
};
|
||||
|
||||
let command = db::create_command(&state.db, create)
|
||||
@@ -100,8 +111,10 @@ pub async fn send_command(
|
||||
}
|
||||
|
||||
/// List recent commands
|
||||
/// Requires authentication.
|
||||
pub async fn list_commands(
|
||||
State(state): State<AppState>,
|
||||
_user: AuthUser,
|
||||
Query(query): Query<CommandsQuery>,
|
||||
) -> Result<Json<Vec<Command>>, (StatusCode, String)> {
|
||||
let limit = query.limit.unwrap_or(50).min(500);
|
||||
@@ -114,8 +127,10 @@ pub async fn list_commands(
|
||||
}
|
||||
|
||||
/// Get a specific command by ID
|
||||
/// Requires authentication.
|
||||
pub async fn get_command(
|
||||
State(state): State<AppState>,
|
||||
_user: AuthUser,
|
||||
Path(id): Path<Uuid>,
|
||||
) -> Result<Json<Command>, (StatusCode, String)> {
|
||||
let command = db::get_command_by_id(&state.db, id)
|
||||
|
||||
@@ -5,10 +5,11 @@ use axum::{
|
||||
http::StatusCode,
|
||||
Json,
|
||||
};
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::Deserialize;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::auth::AuthUser;
|
||||
use crate::db::{self, Metrics, MetricsSummary};
|
||||
use crate::AppState;
|
||||
|
||||
@@ -26,13 +27,15 @@ pub struct MetricsQuery {
|
||||
}
|
||||
|
||||
/// Get metrics for a specific agent
|
||||
/// Requires authentication.
|
||||
pub async fn get_agent_metrics(
|
||||
State(state): State<AppState>,
|
||||
_user: AuthUser,
|
||||
Path(id): Path<Uuid>,
|
||||
Query(query): Query<MetricsQuery>,
|
||||
) -> Result<Json<Vec<Metrics>>, (StatusCode, String)> {
|
||||
// First verify the agent exists
|
||||
let agent = db::get_agent_by_id(&state.db, id)
|
||||
let _agent = db::get_agent_by_id(&state.db, id)
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
|
||||
.ok_or((StatusCode::NOT_FOUND, "Agent not found".to_string()))?;
|
||||
@@ -54,8 +57,10 @@ pub async fn get_agent_metrics(
|
||||
}
|
||||
|
||||
/// Get summary metrics across all agents
|
||||
/// Requires authentication.
|
||||
pub async fn get_summary(
|
||||
State(state): State<AppState>,
|
||||
_user: AuthUser,
|
||||
) -> Result<Json<MetricsSummary>, (StatusCode, String)> {
|
||||
let summary = db::get_metrics_summary(&state.db)
|
||||
.await
|
||||
|
||||
@@ -24,7 +24,8 @@ use axum::{
|
||||
};
|
||||
use sqlx::postgres::PgPoolOptions;
|
||||
use tokio::sync::RwLock;
|
||||
use tower_http::cors::{Any, CorsLayer};
|
||||
use http::HeaderValue;
|
||||
use tower_http::cors::{AllowOrigin, CorsLayer};
|
||||
use tower_http::trace::TraceLayer;
|
||||
use tracing::info;
|
||||
|
||||
@@ -129,11 +130,34 @@ async fn main() -> Result<()> {
|
||||
|
||||
/// Build the application router
|
||||
fn build_router(state: AppState) -> Router {
|
||||
// CORS configuration (allow dashboard access)
|
||||
// TODO: Add rate limiting for registration endpoints using tower-governor
|
||||
// Currently, registration is protected by AuthUser authentication.
|
||||
// For additional protection against brute-force attacks, consider adding:
|
||||
// - tower-governor crate for per-IP rate limiting on /api/agents/register
|
||||
// - Configurable limits via environment variables
|
||||
// Reference: https://docs.rs/tower-governor/latest/tower_governor/
|
||||
|
||||
// CORS configuration - restrict to specific dashboard origin
|
||||
let dashboard_origin = std::env::var("DASHBOARD_URL")
|
||||
.unwrap_or_else(|_| "https://rmm.azcomputerguru.com".to_string());
|
||||
|
||||
let cors = CorsLayer::new()
|
||||
.allow_origin(Any)
|
||||
.allow_methods(Any)
|
||||
.allow_headers(Any);
|
||||
.allow_origin(AllowOrigin::exact(
|
||||
HeaderValue::from_str(&dashboard_origin).expect("Invalid DASHBOARD_URL"),
|
||||
))
|
||||
.allow_methods([
|
||||
http::Method::GET,
|
||||
http::Method::POST,
|
||||
http::Method::PUT,
|
||||
http::Method::DELETE,
|
||||
http::Method::OPTIONS,
|
||||
])
|
||||
.allow_headers([
|
||||
http::header::AUTHORIZATION,
|
||||
http::header::CONTENT_TYPE,
|
||||
http::header::ACCEPT,
|
||||
])
|
||||
.allow_credentials(true);
|
||||
|
||||
Router::new()
|
||||
// Health check
|
||||
|
||||
Reference in New Issue
Block a user