//! GuruConnect Server - WebSocket Relay Server //! //! Handles connections from both agents and dashboard viewers, //! relaying video frames and input events between them. mod config; mod relay; mod session; mod auth; mod api; mod db; mod support_codes; mod middleware; mod utils; mod metrics; pub mod proto { include!(concat!(env!("OUT_DIR"), "/guruconnect.rs")); } use anyhow::Result; use axum::{ Router, routing::{get, post, put, delete}, extract::{Path, State, Json, Query, Request}, response::{Html, IntoResponse}, http::StatusCode, middleware::{self as axum_middleware, Next}, }; use std::net::SocketAddr; use std::sync::Arc; use tower_http::cors::{Any, CorsLayer, AllowOrigin}; use axum::http::{Method, HeaderValue}; use tower_http::trace::TraceLayer; use tower_http::services::ServeDir; use tracing::{info, Level}; use tracing_subscriber::FmtSubscriber; use serde::Deserialize; use support_codes::{SupportCodeManager, CreateCodeRequest, SupportCode, CodeValidation}; use auth::{JwtConfig, TokenBlacklist, hash_password, generate_random_password, AuthenticatedUser}; use metrics::SharedMetrics; use prometheus_client::registry::Registry; /// Application state #[derive(Clone)] pub struct AppState { sessions: session::SessionManager, support_codes: SupportCodeManager, db: Option, pub jwt_config: Arc, pub token_blacklist: TokenBlacklist, /// Optional API key for persistent agents (env: AGENT_API_KEY) pub agent_api_key: Option, /// Prometheus metrics pub metrics: SharedMetrics, /// Prometheus registry (for /metrics endpoint) pub registry: Arc>, /// Server start time pub start_time: Arc, } /// Middleware to inject JWT config and token blacklist into request extensions async fn auth_layer( State(state): State, mut request: Request, next: Next, ) -> impl IntoResponse { request.extensions_mut().insert(state.jwt_config.clone()); request.extensions_mut().insert(Arc::new(state.token_blacklist.clone())); next.run(request).await } #[tokio::main] async fn main() -> Result<()> { // Initialize logging let _subscriber = FmtSubscriber::builder() .with_max_level(Level::INFO) .with_target(true) .init(); info!("GuruConnect Server v{}", env!("CARGO_PKG_VERSION")); // Load configuration let config = config::Config::load()?; // Use port 3002 for GuruConnect let listen_addr = std::env::var("LISTEN_ADDR").unwrap_or_else(|_| "0.0.0.0:3002".to_string()); info!("Loaded configuration, listening on {}", listen_addr); // JWT configuration - REQUIRED for security let jwt_secret = std::env::var("JWT_SECRET") .expect("JWT_SECRET environment variable must be set! Generate one with: openssl rand -base64 64"); if jwt_secret.len() < 32 { panic!("JWT_SECRET must be at least 32 characters long for security!"); } let jwt_expiry_hours = std::env::var("JWT_EXPIRY_HOURS") .ok() .and_then(|s| s.parse().ok()) .unwrap_or(24i64); let jwt_config = Arc::new(JwtConfig::new(jwt_secret, jwt_expiry_hours)); // Initialize database if configured let database = if let Some(ref db_url) = config.database_url { match db::Database::connect(db_url, config.database_max_connections).await { Ok(db) => { // Run migrations if let Err(e) = db.migrate().await { tracing::error!("Failed to run migrations: {}", e); return Err(e); } Some(db) } Err(e) => { tracing::warn!("Failed to connect to database: {}. Running without persistence.", e); None } } } else { info!("No DATABASE_URL set, running without persistence"); None }; // Create initial admin user if no users exist if let Some(ref db) = database { match db::count_users(db.pool()).await { Ok(0) => { info!("No users found, creating initial admin user..."); let password = generate_random_password(16); let password_hash = hash_password(&password)?; match db::create_user(db.pool(), "admin", &password_hash, None, "admin").await { Ok(user) => { // Set admin permissions let perms = vec![ "view".to_string(), "control".to_string(), "transfer".to_string(), "manage_users".to_string(), "manage_clients".to_string(), ]; let _ = db::set_user_permissions(db.pool(), user.id, &perms).await; // SEC-6: Write credentials to secure file instead of logging let creds_file = ".admin-credentials"; match std::fs::write(creds_file, format!("Username: admin\nPassword: {}\n\nWARNING: Change this password immediately after first login!\nDelete this file after copying the password.\n", password)) { Ok(_) => { // Set restrictive permissions (Unix only) #[cfg(unix)] { use std::os::unix::fs::PermissionsExt; let _ = std::fs::set_permissions(creds_file, std::fs::Permissions::from_mode(0o600)); } info!("========================================"); info!(" INITIAL ADMIN USER CREATED"); info!(" Credentials written to: {}", creds_file); info!(" (Read file, change password, then delete file)"); info!("========================================"); } Err(e) => { // Fallback to logging if file write fails (but warn about security) tracing::warn!("Could not write credentials file: {}", e); info!("========================================"); info!(" INITIAL ADMIN USER CREATED"); info!(" Username: admin"); info!(" Password: {}", password); info!(" WARNING: Password logged due to file write failure!"); info!(" (Change this password immediately!)"); info!("========================================"); } } } Err(e) => { tracing::error!("Failed to create initial admin user: {}", e); } } } Ok(count) => { info!("{} user(s) in database", count); } Err(e) => { tracing::warn!("Could not check user count: {}", e); } } } // Create session manager let sessions = session::SessionManager::new(); // Restore persistent machines from database if let Some(ref db) = database { match db::machines::get_all_machines(db.pool()).await { Ok(machines) => { info!("Restoring {} persistent machines from database", machines.len()); for machine in machines { sessions.restore_offline_machine(&machine.agent_id, &machine.hostname).await; } } Err(e) => { tracing::warn!("Failed to restore machines: {}", e); } } } // Agent API key for persistent agents (optional) let agent_api_key = std::env::var("AGENT_API_KEY").ok(); if let Some(ref key) = agent_api_key { // Validate API key strength for security utils::validation::validate_api_key_strength(key)?; info!("AGENT_API_KEY configured for persistent agents (validated)"); } else { info!("No AGENT_API_KEY set - persistent agents will need JWT token or support code"); } // Initialize Prometheus metrics let mut registry = Registry::default(); let metrics = Arc::new(metrics::Metrics::new(&mut registry)); let registry = Arc::new(std::sync::Mutex::new(registry)); let start_time = Arc::new(std::time::Instant::now()); // Spawn background task to update uptime metric let metrics_for_uptime = metrics.clone(); let start_time_for_uptime = start_time.clone(); tokio::spawn(async move { let mut interval = tokio::time::interval(std::time::Duration::from_secs(10)); loop { interval.tick().await; let uptime = start_time_for_uptime.elapsed().as_secs() as i64; metrics_for_uptime.update_uptime(uptime); } }); // Create application state let token_blacklist = TokenBlacklist::new(); let state = AppState { sessions, support_codes: SupportCodeManager::new(), db: database, jwt_config, token_blacklist, agent_api_key, metrics, registry, start_time, }; // Build router let app = Router::new() // Health check (no auth required) .route("/health", get(health)) // Prometheus metrics (no auth required - for monitoring) .route("/metrics", get(prometheus_metrics)) // Auth endpoints (TODO: Add rate limiting - see SEC2_RATE_LIMITING_TODO.md) .route("/api/auth/login", post(api::auth::login)) .route("/api/auth/change-password", post(api::auth::change_password)) .route("/api/auth/me", get(api::auth::get_me)) .route("/api/auth/logout", post(api::auth_logout::logout)) .route("/api/auth/revoke-token", post(api::auth_logout::revoke_own_token)) .route("/api/auth/admin/revoke-user", post(api::auth_logout::revoke_user_tokens)) .route("/api/auth/blacklist/stats", get(api::auth_logout::get_blacklist_stats)) .route("/api/auth/blacklist/cleanup", post(api::auth_logout::cleanup_blacklist)) // User management (admin only) .route("/api/users", get(api::users::list_users)) .route("/api/users", post(api::users::create_user)) .route("/api/users/:id", get(api::users::get_user)) .route("/api/users/:id", put(api::users::update_user)) .route("/api/users/:id", delete(api::users::delete_user)) .route("/api/users/:id/permissions", put(api::users::set_permissions)) .route("/api/users/:id/clients", put(api::users::set_client_access)) // Portal API - Support codes (TODO: Add rate limiting) .route("/api/codes", post(create_code)) .route("/api/codes", get(list_codes)) .route("/api/codes/:code/validate", get(validate_code)) .route("/api/codes/:code/cancel", post(cancel_code)) // WebSocket endpoints .route("/ws/agent", get(relay::agent_ws_handler)) .route("/ws/viewer", get(relay::viewer_ws_handler)) // REST API - Sessions .route("/api/sessions", get(list_sessions)) .route("/api/sessions/:id", get(get_session)) .route("/api/sessions/:id", delete(disconnect_session)) // REST API - Machines .route("/api/machines", get(list_machines)) .route("/api/machines/:agent_id", get(get_machine)) .route("/api/machines/:agent_id", delete(delete_machine)) .route("/api/machines/:agent_id/history", get(get_machine_history)) .route("/api/machines/:agent_id/update", post(trigger_machine_update)) // REST API - Releases and Version .route("/api/version", get(api::releases::get_version)) // No auth - for agent polling .route("/api/releases", get(api::releases::list_releases)) .route("/api/releases", post(api::releases::create_release)) .route("/api/releases/:version", get(api::releases::get_release)) .route("/api/releases/:version", put(api::releases::update_release)) .route("/api/releases/:version", delete(api::releases::delete_release)) // Agent downloads (no auth - public download links) .route("/api/download/viewer", get(api::downloads::download_viewer)) .route("/api/download/support", get(api::downloads::download_support)) .route("/api/download/agent", get(api::downloads::download_agent)) // HTML page routes (clean URLs) .route("/login", get(serve_login)) .route("/dashboard", get(serve_dashboard)) .route("/users", get(serve_users)) // State and middleware .with_state(state.clone()) .layer(axum_middleware::from_fn_with_state(state, auth_layer)) // Serve static files for portal (fallback) .fallback_service(ServeDir::new("static").append_index_html_on_directories(true)) // Middleware .layer(axum_middleware::from_fn(middleware::add_security_headers)) // SEC-7 & SEC-12 .layer(TraceLayer::new_for_http()) // SEC-11: Restricted CORS configuration .layer({ let cors = CorsLayer::new() // Allow requests from the production domain and localhost (for development) .allow_origin([ "https://connect.azcomputerguru.com".parse::().unwrap(), "http://localhost:3002".parse::().unwrap(), "http://127.0.0.1:3002".parse::().unwrap(), ]) // Allow only necessary HTTP methods .allow_methods([Method::GET, Method::POST, Method::PUT, Method::DELETE, Method::OPTIONS]) // Allow common headers needed for API requests .allow_headers([ axum::http::header::AUTHORIZATION, axum::http::header::CONTENT_TYPE, axum::http::header::ACCEPT, ]) // Allow credentials (cookies, auth headers) .allow_credentials(true); cors }); // Start server let addr: SocketAddr = listen_addr.parse()?; let listener = tokio::net::TcpListener::bind(addr).await?; info!("Server listening on {}", addr); // Use into_make_service_with_connect_info to enable IP address extraction axum::serve( listener, app.into_make_service_with_connect_info::() ).await?; Ok(()) } async fn health() -> &'static str { "OK" } /// Prometheus metrics endpoint async fn prometheus_metrics( State(state): State, ) -> String { use prometheus_client::encoding::text::encode; let registry = state.registry.lock().unwrap(); let mut buffer = String::new(); encode(&mut buffer, ®istry).unwrap(); buffer } // Support code API handlers async fn create_code( _user: AuthenticatedUser, // Require authentication State(state): State, Json(request): Json, ) -> Json { let code = state.support_codes.create_code(request).await; info!("Created support code: {}", code.code); Json(code) } async fn list_codes( _user: AuthenticatedUser, // Require authentication State(state): State, ) -> Json> { Json(state.support_codes.list_active_codes().await) } #[derive(Deserialize)] struct ValidateParams { code: String, } async fn validate_code( State(state): State, Path(code): Path, ) -> Json { Json(state.support_codes.validate_code(&code).await) } async fn cancel_code( _user: AuthenticatedUser, // Require authentication State(state): State, Path(code): Path, ) -> impl IntoResponse { if state.support_codes.cancel_code(&code).await { (StatusCode::OK, "Code cancelled") } else { (StatusCode::BAD_REQUEST, "Cannot cancel code") } } // Session API handlers (updated to use AppState) async fn list_sessions( _user: AuthenticatedUser, // Require authentication State(state): State, ) -> Json> { let sessions = state.sessions.list_sessions().await; Json(sessions.into_iter().map(api::SessionInfo::from).collect()) } async fn get_session( _user: AuthenticatedUser, // Require authentication State(state): State, Path(id): Path, ) -> Result, (StatusCode, &'static str)> { let session_id = uuid::Uuid::parse_str(&id) .map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID"))?; let session = state.sessions.get_session(session_id).await .ok_or((StatusCode::NOT_FOUND, "Session not found"))?; Ok(Json(api::SessionInfo::from(session))) } async fn disconnect_session( _user: AuthenticatedUser, // Require authentication State(state): State, Path(id): Path, ) -> impl IntoResponse { let session_id = match uuid::Uuid::parse_str(&id) { Ok(id) => id, Err(_) => return (StatusCode::BAD_REQUEST, "Invalid session ID"), }; if state.sessions.disconnect_session(session_id, "Disconnected by administrator").await { info!("Session {} disconnected by admin", session_id); (StatusCode::OK, "Session disconnected") } else { (StatusCode::NOT_FOUND, "Session not found") } } // Machine API handlers async fn list_machines( _user: AuthenticatedUser, // Require authentication State(state): State, ) -> Result>, (StatusCode, &'static str)> { let db = state.db.as_ref() .ok_or((StatusCode::SERVICE_UNAVAILABLE, "Database not available"))?; let machines = db::machines::get_all_machines(db.pool()).await .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Database error"))?; Ok(Json(machines.into_iter().map(api::MachineInfo::from).collect())) } async fn get_machine( _user: AuthenticatedUser, // Require authentication State(state): State, Path(agent_id): Path, ) -> Result, (StatusCode, &'static str)> { let db = state.db.as_ref() .ok_or((StatusCode::SERVICE_UNAVAILABLE, "Database not available"))?; let machine = db::machines::get_machine_by_agent_id(db.pool(), &agent_id).await .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Database error"))? .ok_or((StatusCode::NOT_FOUND, "Machine not found"))?; Ok(Json(api::MachineInfo::from(machine))) } async fn get_machine_history( _user: AuthenticatedUser, // Require authentication State(state): State, Path(agent_id): Path, ) -> Result, (StatusCode, &'static str)> { let db = state.db.as_ref() .ok_or((StatusCode::SERVICE_UNAVAILABLE, "Database not available"))?; // Get machine let machine = db::machines::get_machine_by_agent_id(db.pool(), &agent_id).await .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Database error"))? .ok_or((StatusCode::NOT_FOUND, "Machine not found"))?; // Get sessions for this machine let sessions = db::sessions::get_sessions_for_machine(db.pool(), machine.id).await .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Database error"))?; // Get events for this machine let events = db::events::get_events_for_machine(db.pool(), machine.id).await .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Database error"))?; let history = api::MachineHistory { machine: api::MachineInfo::from(machine), sessions: sessions.into_iter().map(api::SessionRecord::from).collect(), events: events.into_iter().map(api::EventRecord::from).collect(), exported_at: chrono::Utc::now().to_rfc3339(), }; Ok(Json(history)) } async fn delete_machine( _user: AuthenticatedUser, // Require authentication State(state): State, Path(agent_id): Path, Query(params): Query, ) -> Result, (StatusCode, &'static str)> { let db = state.db.as_ref() .ok_or((StatusCode::SERVICE_UNAVAILABLE, "Database not available"))?; // Get machine first let machine = db::machines::get_machine_by_agent_id(db.pool(), &agent_id).await .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Database error"))? .ok_or((StatusCode::NOT_FOUND, "Machine not found"))?; // Export history if requested let history = if params.export { let sessions = db::sessions::get_sessions_for_machine(db.pool(), machine.id).await .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Database error"))?; let events = db::events::get_events_for_machine(db.pool(), machine.id).await .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Database error"))?; Some(api::MachineHistory { machine: api::MachineInfo::from(machine.clone()), sessions: sessions.into_iter().map(api::SessionRecord::from).collect(), events: events.into_iter().map(api::EventRecord::from).collect(), exported_at: chrono::Utc::now().to_rfc3339(), }) } else { None }; // Send uninstall command if requested and agent is online let mut uninstall_sent = false; if params.uninstall { // Find session for this agent if let Some(session) = state.sessions.get_session_by_agent(&agent_id).await { if session.is_online { uninstall_sent = state.sessions.send_admin_command( session.id, proto::AdminCommandType::AdminUninstall, "Deleted by administrator", ).await; if uninstall_sent { info!("Sent uninstall command to agent {}", agent_id); } } } } // Remove from session manager state.sessions.remove_agent(&agent_id).await; // Delete from database (cascades to sessions and events) db::machines::delete_machine(db.pool(), &agent_id).await .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Failed to delete machine"))?; info!("Deleted machine {} (uninstall_sent: {})", agent_id, uninstall_sent); Ok(Json(api::DeleteMachineResponse { success: true, message: format!("Machine {} deleted", machine.hostname), uninstall_sent, history, })) } // Update trigger request #[derive(Deserialize)] struct TriggerUpdateRequest { /// Target version (optional, defaults to latest stable) version: Option, } /// Trigger update on a specific machine async fn trigger_machine_update( _user: AuthenticatedUser, // Require authentication State(state): State, Path(agent_id): Path, Json(request): Json, ) -> Result { let db = state.db.as_ref() .ok_or((StatusCode::SERVICE_UNAVAILABLE, "Database not available"))?; // Get the target release (either specified or latest stable) let release = if let Some(version) = request.version { db::releases::get_release_by_version(db.pool(), &version).await .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Database error"))? .ok_or((StatusCode::NOT_FOUND, "Release version not found"))? } else { db::releases::get_latest_stable_release(db.pool()).await .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Database error"))? .ok_or((StatusCode::NOT_FOUND, "No stable release available"))? }; // Find session for this agent let session = state.sessions.get_session_by_agent(&agent_id).await .ok_or((StatusCode::NOT_FOUND, "Agent not found or offline"))?; if !session.is_online { return Err((StatusCode::BAD_REQUEST, "Agent is offline")); } // Send update command via WebSocket // For now, we send admin command - later we'll include UpdateInfo in the message let sent = state.sessions.send_admin_command( session.id, proto::AdminCommandType::AdminUpdate, &format!("Update to version {}", release.version), ).await; if sent { info!("Sent update command to agent {} (version {})", agent_id, release.version); // Update machine update status in database let _ = db::releases::update_machine_update_status(db.pool(), &agent_id, "downloading").await; Ok((StatusCode::OK, "Update command sent")) } else { Err((StatusCode::INTERNAL_SERVER_ERROR, "Failed to send update command")) } } // Static page handlers async fn serve_login() -> impl IntoResponse { match tokio::fs::read_to_string("static/login.html").await { Ok(content) => Html(content).into_response(), Err(_) => (StatusCode::NOT_FOUND, "Page not found").into_response(), } } async fn serve_dashboard() -> impl IntoResponse { match tokio::fs::read_to_string("static/dashboard.html").await { Ok(content) => Html(content).into_response(), Err(_) => (StatusCode::NOT_FOUND, "Page not found").into_response(), } } async fn serve_users() -> impl IntoResponse { match tokio::fs::read_to_string("static/users.html").await { Ok(content) => Html(content).into_response(), Err(_) => (StatusCode::NOT_FOUND, "Page not found").into_response(), } }