Files
guru-connect/server/src/main.rs
Mike Swanson 1c5c1e78e7
Some checks failed
Build and Test / Build Server (Linux) (push) Failing after 2m59s
Build and Test / Build Agent (Windows) (push) Has started running
Build and Test / Security Audit (push) Has been cancelled
Build and Test / Build Summary (push) Has been cancelled
Run Tests / Test Server (push) Has been cancelled
Run Tests / Test Agent (push) Has been cancelled
Run Tests / Code Coverage (push) Has been cancelled
Run Tests / Lint and Format Check (push) Has been cancelled
style: cargo fmt --all — make codebase rustfmt-clean
First run of the build-and-test CI gate (cargo fmt --all -- --check) surfaced
pre-existing formatting drift across the agent and server crates. Apply rustfmt
across the workspace so the codebase meets its own CI gate. Pure formatting; no
logic changes.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-05-29 15:02:12 +00:00

763 lines
27 KiB
Rust

//! GuruConnect Server - WebSocket Relay Server
//!
//! Handles connections from both agents and dashboard viewers,
//! relaying video frames and input events between them.
mod api;
mod auth;
mod config;
mod db;
mod metrics;
mod middleware;
mod relay;
mod session;
mod support_codes;
mod utils;
pub mod proto {
include!(concat!(env!("OUT_DIR"), "/guruconnect.rs"));
}
use anyhow::Result;
use axum::http::{HeaderValue, Method};
use axum::{
extract::{Json, Path, Query, Request, State},
http::StatusCode,
middleware::{self as axum_middleware, Next},
response::{Html, IntoResponse},
routing::{delete, get, post, put},
Router,
};
use serde::Deserialize;
use std::net::SocketAddr;
use std::sync::Arc;
use tower_http::cors::{AllowOrigin, Any, CorsLayer};
use tower_http::services::ServeDir;
use tower_http::trace::TraceLayer;
use tracing::{info, Level};
use tracing_subscriber::FmtSubscriber;
use auth::{generate_random_password, hash_password, AuthenticatedUser, JwtConfig, TokenBlacklist};
use metrics::SharedMetrics;
use prometheus_client::registry::Registry;
use support_codes::{CodeValidation, CreateCodeRequest, SupportCode, SupportCodeManager};
/// Application state
#[derive(Clone)]
pub struct AppState {
sessions: session::SessionManager,
support_codes: SupportCodeManager,
db: Option<db::Database>,
pub jwt_config: Arc<JwtConfig>,
pub token_blacklist: TokenBlacklist,
/// Optional API key for persistent agents (env: AGENT_API_KEY)
pub agent_api_key: Option<String>,
/// Prometheus metrics
pub metrics: SharedMetrics,
/// Prometheus registry (for /metrics endpoint)
pub registry: Arc<std::sync::Mutex<Registry>>,
/// Server start time
pub start_time: Arc<std::time::Instant>,
}
/// Middleware to inject JWT config and token blacklist into request extensions
async fn auth_layer(
State(state): State<AppState>,
mut request: Request,
next: Next,
) -> impl IntoResponse {
request.extensions_mut().insert(state.jwt_config.clone());
request
.extensions_mut()
.insert(Arc::new(state.token_blacklist.clone()));
next.run(request).await
}
#[tokio::main]
async fn main() -> Result<()> {
// Initialize logging
let _subscriber = FmtSubscriber::builder()
.with_max_level(Level::INFO)
.with_target(true)
.init();
info!("GuruConnect Server v{}", env!("CARGO_PKG_VERSION"));
// Load configuration
let config = config::Config::load()?;
// Use port 3002 for GuruConnect
let listen_addr = std::env::var("LISTEN_ADDR").unwrap_or_else(|_| "0.0.0.0:3002".to_string());
info!("Loaded configuration, listening on {}", listen_addr);
// JWT configuration - REQUIRED for security
let jwt_secret = std::env::var("JWT_SECRET").expect(
"JWT_SECRET environment variable must be set! Generate one with: openssl rand -base64 64",
);
if jwt_secret.len() < 32 {
panic!("JWT_SECRET must be at least 32 characters long for security!");
}
let jwt_expiry_hours = std::env::var("JWT_EXPIRY_HOURS")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(24i64);
let jwt_config = Arc::new(JwtConfig::new(jwt_secret, jwt_expiry_hours));
// Initialize database if configured
let database = if let Some(ref db_url) = config.database_url {
match db::Database::connect(db_url, config.database_max_connections).await {
Ok(db) => {
// Run migrations
if let Err(e) = db.migrate().await {
tracing::error!("Failed to run migrations: {}", e);
return Err(e);
}
Some(db)
}
Err(e) => {
tracing::warn!(
"Failed to connect to database: {}. Running without persistence.",
e
);
None
}
}
} else {
info!("No DATABASE_URL set, running without persistence");
None
};
// Create initial admin user if no users exist
if let Some(ref db) = database {
match db::count_users(db.pool()).await {
Ok(0) => {
info!("No users found, creating initial admin user...");
let password = generate_random_password(16);
let password_hash = hash_password(&password)?;
match db::create_user(db.pool(), "admin", &password_hash, None, "admin").await {
Ok(user) => {
// Set admin permissions
let perms = vec![
"view".to_string(),
"control".to_string(),
"transfer".to_string(),
"manage_users".to_string(),
"manage_clients".to_string(),
];
let _ = db::set_user_permissions(db.pool(), user.id, &perms).await;
// SEC-6: Write credentials to secure file instead of logging
let creds_file = ".admin-credentials";
match std::fs::write(creds_file, format!("Username: admin\nPassword: {}\n\nWARNING: Change this password immediately after first login!\nDelete this file after copying the password.\n", password)) {
Ok(_) => {
// Set restrictive permissions (Unix only)
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let _ = std::fs::set_permissions(creds_file, std::fs::Permissions::from_mode(0o600));
}
info!("========================================");
info!(" INITIAL ADMIN USER CREATED");
info!(" Credentials written to: {}", creds_file);
info!(" (Read file, change password, then delete file)");
info!("========================================");
}
Err(e) => {
// Fallback to logging if file write fails (but warn about security)
tracing::warn!("Could not write credentials file: {}", e);
info!("========================================");
info!(" INITIAL ADMIN USER CREATED");
info!(" Username: admin");
info!(" Password: {}", password);
info!(" WARNING: Password logged due to file write failure!");
info!(" (Change this password immediately!)");
info!("========================================");
}
}
}
Err(e) => {
tracing::error!("Failed to create initial admin user: {}", e);
}
}
}
Ok(count) => {
info!("{} user(s) in database", count);
}
Err(e) => {
tracing::warn!("Could not check user count: {}", e);
}
}
}
// Create session manager
let sessions = session::SessionManager::new();
// Restore persistent machines from database
if let Some(ref db) = database {
match db::machines::get_all_machines(db.pool()).await {
Ok(machines) => {
info!(
"Restoring {} persistent machines from database",
machines.len()
);
for machine in machines {
sessions
.restore_offline_machine(&machine.agent_id, &machine.hostname)
.await;
}
}
Err(e) => {
tracing::warn!("Failed to restore machines: {}", e);
}
}
}
// Agent API key for persistent agents (optional)
let agent_api_key = std::env::var("AGENT_API_KEY").ok();
if let Some(ref key) = agent_api_key {
// Validate API key strength for security
utils::validation::validate_api_key_strength(key)?;
info!("AGENT_API_KEY configured for persistent agents (validated)");
} else {
info!("No AGENT_API_KEY set - persistent agents will need JWT token or support code");
}
// Initialize Prometheus metrics
let mut registry = Registry::default();
let metrics = Arc::new(metrics::Metrics::new(&mut registry));
let registry = Arc::new(std::sync::Mutex::new(registry));
let start_time = Arc::new(std::time::Instant::now());
// Spawn background task to update uptime metric
let metrics_for_uptime = metrics.clone();
let start_time_for_uptime = start_time.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(10));
loop {
interval.tick().await;
let uptime = start_time_for_uptime.elapsed().as_secs() as i64;
metrics_for_uptime.update_uptime(uptime);
}
});
// Create application state
let token_blacklist = TokenBlacklist::new();
let state = AppState {
sessions,
support_codes: SupportCodeManager::new(),
db: database,
jwt_config,
token_blacklist,
agent_api_key,
metrics,
registry,
start_time,
};
// Build router
let app = Router::new()
// Health check (no auth required)
.route("/health", get(health))
// Prometheus metrics (no auth required - for monitoring)
.route("/metrics", get(prometheus_metrics))
// Auth endpoints (TODO: Add rate limiting - see SEC2_RATE_LIMITING_TODO.md)
.route("/api/auth/login", post(api::auth::login))
.route(
"/api/auth/change-password",
post(api::auth::change_password),
)
.route("/api/auth/me", get(api::auth::get_me))
.route("/api/auth/logout", post(api::auth_logout::logout))
.route(
"/api/auth/revoke-token",
post(api::auth_logout::revoke_own_token),
)
.route(
"/api/auth/admin/revoke-user",
post(api::auth_logout::revoke_user_tokens),
)
.route(
"/api/auth/blacklist/stats",
get(api::auth_logout::get_blacklist_stats),
)
.route(
"/api/auth/blacklist/cleanup",
post(api::auth_logout::cleanup_blacklist),
)
// User management (admin only)
.route("/api/users", get(api::users::list_users))
.route("/api/users", post(api::users::create_user))
.route("/api/users/:id", get(api::users::get_user))
.route("/api/users/:id", put(api::users::update_user))
.route("/api/users/:id", delete(api::users::delete_user))
.route(
"/api/users/:id/permissions",
put(api::users::set_permissions),
)
.route("/api/users/:id/clients", put(api::users::set_client_access))
// Portal API - Support codes (TODO: Add rate limiting)
.route("/api/codes", post(create_code))
.route("/api/codes", get(list_codes))
.route("/api/codes/:code/validate", get(validate_code))
.route("/api/codes/:code/cancel", post(cancel_code))
// WebSocket endpoints
.route("/ws/agent", get(relay::agent_ws_handler))
.route("/ws/viewer", get(relay::viewer_ws_handler))
// REST API - Sessions
.route("/api/sessions", get(list_sessions))
.route("/api/sessions/:id", get(get_session))
.route("/api/sessions/:id", delete(disconnect_session))
// REST API - Machines
.route("/api/machines", get(list_machines))
.route("/api/machines/:agent_id", get(get_machine))
.route("/api/machines/:agent_id", delete(delete_machine))
.route("/api/machines/:agent_id/history", get(get_machine_history))
.route(
"/api/machines/:agent_id/update",
post(trigger_machine_update),
)
// REST API - Releases and Version
.route("/api/version", get(api::releases::get_version)) // No auth - for agent polling
.route("/api/releases", get(api::releases::list_releases))
.route("/api/releases", post(api::releases::create_release))
.route("/api/releases/:version", get(api::releases::get_release))
.route("/api/releases/:version", put(api::releases::update_release))
.route(
"/api/releases/:version",
delete(api::releases::delete_release),
)
// Changelog (no auth - public, like /api/version)
// Single route: version == "latest" selects the latest file; axum 0.7 / matchit 0.7
// panics if a static segment and a path param share this position, so do not split it.
.route(
"/api/changelog/:component/:version",
get(api::changelog::get),
)
// Agent downloads (no auth - public download links)
.route("/api/download/viewer", get(api::downloads::download_viewer))
.route(
"/api/download/support",
get(api::downloads::download_support),
)
.route("/api/download/agent", get(api::downloads::download_agent))
// HTML page routes (clean URLs)
.route("/login", get(serve_login))
.route("/dashboard", get(serve_dashboard))
.route("/users", get(serve_users))
// State and middleware
.with_state(state.clone())
.layer(axum_middleware::from_fn_with_state(state, auth_layer))
// Serve static files for portal (fallback)
.fallback_service(ServeDir::new("static").append_index_html_on_directories(true))
// Middleware
.layer(axum_middleware::from_fn(middleware::add_security_headers)) // SEC-7 & SEC-12
.layer(TraceLayer::new_for_http())
// SEC-11: Restricted CORS configuration
.layer({
let cors = CorsLayer::new()
// Allow requests from the production domain and localhost (for development)
.allow_origin([
"https://connect.azcomputerguru.com"
.parse::<HeaderValue>()
.unwrap(),
"http://localhost:3002".parse::<HeaderValue>().unwrap(),
"http://127.0.0.1:3002".parse::<HeaderValue>().unwrap(),
])
// Allow only necessary HTTP methods
.allow_methods([
Method::GET,
Method::POST,
Method::PUT,
Method::DELETE,
Method::OPTIONS,
])
// Allow common headers needed for API requests
.allow_headers([
axum::http::header::AUTHORIZATION,
axum::http::header::CONTENT_TYPE,
axum::http::header::ACCEPT,
])
// Allow credentials (cookies, auth headers)
.allow_credentials(true);
cors
});
// Start server
let addr: SocketAddr = listen_addr.parse()?;
let listener = tokio::net::TcpListener::bind(addr).await?;
info!("Server listening on {}", addr);
// Use into_make_service_with_connect_info to enable IP address extraction
axum::serve(
listener,
app.into_make_service_with_connect_info::<SocketAddr>(),
)
.await?;
Ok(())
}
async fn health() -> &'static str {
"OK"
}
/// Prometheus metrics endpoint
async fn prometheus_metrics(State(state): State<AppState>) -> String {
use prometheus_client::encoding::text::encode;
let registry = state.registry.lock().unwrap();
let mut buffer = String::new();
encode(&mut buffer, &registry).unwrap();
buffer
}
// Support code API handlers
async fn create_code(
_user: AuthenticatedUser, // Require authentication
State(state): State<AppState>,
Json(request): Json<CreateCodeRequest>,
) -> Json<SupportCode> {
let code = state.support_codes.create_code(request).await;
info!("Created support code: {}", code.code);
Json(code)
}
async fn list_codes(
_user: AuthenticatedUser, // Require authentication
State(state): State<AppState>,
) -> Json<Vec<SupportCode>> {
Json(state.support_codes.list_active_codes().await)
}
#[derive(Deserialize)]
struct ValidateParams {
code: String,
}
async fn validate_code(
State(state): State<AppState>,
Path(code): Path<String>,
) -> Json<CodeValidation> {
Json(state.support_codes.validate_code(&code).await)
}
async fn cancel_code(
_user: AuthenticatedUser, // Require authentication
State(state): State<AppState>,
Path(code): Path<String>,
) -> impl IntoResponse {
if state.support_codes.cancel_code(&code).await {
(StatusCode::OK, "Code cancelled")
} else {
(StatusCode::BAD_REQUEST, "Cannot cancel code")
}
}
// Session API handlers (updated to use AppState)
async fn list_sessions(
_user: AuthenticatedUser, // Require authentication
State(state): State<AppState>,
) -> Json<Vec<api::SessionInfo>> {
let sessions = state.sessions.list_sessions().await;
Json(sessions.into_iter().map(api::SessionInfo::from).collect())
}
async fn get_session(
_user: AuthenticatedUser, // Require authentication
State(state): State<AppState>,
Path(id): Path<String>,
) -> Result<Json<api::SessionInfo>, (StatusCode, &'static str)> {
let session_id =
uuid::Uuid::parse_str(&id).map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID"))?;
let session = state
.sessions
.get_session(session_id)
.await
.ok_or((StatusCode::NOT_FOUND, "Session not found"))?;
Ok(Json(api::SessionInfo::from(session)))
}
async fn disconnect_session(
_user: AuthenticatedUser, // Require authentication
State(state): State<AppState>,
Path(id): Path<String>,
) -> impl IntoResponse {
let session_id = match uuid::Uuid::parse_str(&id) {
Ok(id) => id,
Err(_) => return (StatusCode::BAD_REQUEST, "Invalid session ID"),
};
if state
.sessions
.disconnect_session(session_id, "Disconnected by administrator")
.await
{
info!("Session {} disconnected by admin", session_id);
(StatusCode::OK, "Session disconnected")
} else {
(StatusCode::NOT_FOUND, "Session not found")
}
}
// Machine API handlers
async fn list_machines(
_user: AuthenticatedUser, // Require authentication
State(state): State<AppState>,
) -> Result<Json<Vec<api::MachineInfo>>, (StatusCode, &'static str)> {
let db = state
.db
.as_ref()
.ok_or((StatusCode::SERVICE_UNAVAILABLE, "Database not available"))?;
let machines = db::machines::get_all_machines(db.pool())
.await
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Database error"))?;
Ok(Json(
machines.into_iter().map(api::MachineInfo::from).collect(),
))
}
async fn get_machine(
_user: AuthenticatedUser, // Require authentication
State(state): State<AppState>,
Path(agent_id): Path<String>,
) -> Result<Json<api::MachineInfo>, (StatusCode, &'static str)> {
let db = state
.db
.as_ref()
.ok_or((StatusCode::SERVICE_UNAVAILABLE, "Database not available"))?;
let machine = db::machines::get_machine_by_agent_id(db.pool(), &agent_id)
.await
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Database error"))?
.ok_or((StatusCode::NOT_FOUND, "Machine not found"))?;
Ok(Json(api::MachineInfo::from(machine)))
}
async fn get_machine_history(
_user: AuthenticatedUser, // Require authentication
State(state): State<AppState>,
Path(agent_id): Path<String>,
) -> Result<Json<api::MachineHistory>, (StatusCode, &'static str)> {
let db = state
.db
.as_ref()
.ok_or((StatusCode::SERVICE_UNAVAILABLE, "Database not available"))?;
// Get machine
let machine = db::machines::get_machine_by_agent_id(db.pool(), &agent_id)
.await
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Database error"))?
.ok_or((StatusCode::NOT_FOUND, "Machine not found"))?;
// Get sessions for this machine
let sessions = db::sessions::get_sessions_for_machine(db.pool(), machine.id)
.await
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Database error"))?;
// Get events for this machine
let events = db::events::get_events_for_machine(db.pool(), machine.id)
.await
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Database error"))?;
let history = api::MachineHistory {
machine: api::MachineInfo::from(machine),
sessions: sessions.into_iter().map(api::SessionRecord::from).collect(),
events: events.into_iter().map(api::EventRecord::from).collect(),
exported_at: chrono::Utc::now().to_rfc3339(),
};
Ok(Json(history))
}
async fn delete_machine(
_user: AuthenticatedUser, // Require authentication
State(state): State<AppState>,
Path(agent_id): Path<String>,
Query(params): Query<api::DeleteMachineParams>,
) -> Result<Json<api::DeleteMachineResponse>, (StatusCode, &'static str)> {
let db = state
.db
.as_ref()
.ok_or((StatusCode::SERVICE_UNAVAILABLE, "Database not available"))?;
// Get machine first
let machine = db::machines::get_machine_by_agent_id(db.pool(), &agent_id)
.await
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Database error"))?
.ok_or((StatusCode::NOT_FOUND, "Machine not found"))?;
// Export history if requested
let history = if params.export {
let sessions = db::sessions::get_sessions_for_machine(db.pool(), machine.id)
.await
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Database error"))?;
let events = db::events::get_events_for_machine(db.pool(), machine.id)
.await
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Database error"))?;
Some(api::MachineHistory {
machine: api::MachineInfo::from(machine.clone()),
sessions: sessions.into_iter().map(api::SessionRecord::from).collect(),
events: events.into_iter().map(api::EventRecord::from).collect(),
exported_at: chrono::Utc::now().to_rfc3339(),
})
} else {
None
};
// Send uninstall command if requested and agent is online
let mut uninstall_sent = false;
if params.uninstall {
// Find session for this agent
if let Some(session) = state.sessions.get_session_by_agent(&agent_id).await {
if session.is_online {
uninstall_sent = state
.sessions
.send_admin_command(
session.id,
proto::AdminCommandType::AdminUninstall,
"Deleted by administrator",
)
.await;
if uninstall_sent {
info!("Sent uninstall command to agent {}", agent_id);
}
}
}
}
// Remove from session manager
state.sessions.remove_agent(&agent_id).await;
// Delete from database (cascades to sessions and events)
db::machines::delete_machine(db.pool(), &agent_id)
.await
.map_err(|_| {
(
StatusCode::INTERNAL_SERVER_ERROR,
"Failed to delete machine",
)
})?;
info!(
"Deleted machine {} (uninstall_sent: {})",
agent_id, uninstall_sent
);
Ok(Json(api::DeleteMachineResponse {
success: true,
message: format!("Machine {} deleted", machine.hostname),
uninstall_sent,
history,
}))
}
// Update trigger request
#[derive(Deserialize)]
struct TriggerUpdateRequest {
/// Target version (optional, defaults to latest stable)
version: Option<String>,
}
/// Trigger update on a specific machine
async fn trigger_machine_update(
_user: AuthenticatedUser, // Require authentication
State(state): State<AppState>,
Path(agent_id): Path<String>,
Json(request): Json<TriggerUpdateRequest>,
) -> Result<impl IntoResponse, (StatusCode, &'static str)> {
let db = state
.db
.as_ref()
.ok_or((StatusCode::SERVICE_UNAVAILABLE, "Database not available"))?;
// Get the target release (either specified or latest stable)
let release = if let Some(version) = request.version {
db::releases::get_release_by_version(db.pool(), &version)
.await
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Database error"))?
.ok_or((StatusCode::NOT_FOUND, "Release version not found"))?
} else {
db::releases::get_latest_stable_release(db.pool())
.await
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Database error"))?
.ok_or((StatusCode::NOT_FOUND, "No stable release available"))?
};
// Find session for this agent
let session = state
.sessions
.get_session_by_agent(&agent_id)
.await
.ok_or((StatusCode::NOT_FOUND, "Agent not found or offline"))?;
if !session.is_online {
return Err((StatusCode::BAD_REQUEST, "Agent is offline"));
}
// Send update command via WebSocket
// For now, we send admin command - later we'll include UpdateInfo in the message
let sent = state
.sessions
.send_admin_command(
session.id,
proto::AdminCommandType::AdminUpdate,
&format!("Update to version {}", release.version),
)
.await;
if sent {
info!(
"Sent update command to agent {} (version {})",
agent_id, release.version
);
// Update machine update status in database
let _ =
db::releases::update_machine_update_status(db.pool(), &agent_id, "downloading").await;
Ok((StatusCode::OK, "Update command sent"))
} else {
Err((
StatusCode::INTERNAL_SERVER_ERROR,
"Failed to send update command",
))
}
}
// Static page handlers
async fn serve_login() -> impl IntoResponse {
match tokio::fs::read_to_string("static/login.html").await {
Ok(content) => Html(content).into_response(),
Err(_) => (StatusCode::NOT_FOUND, "Page not found").into_response(),
}
}
async fn serve_dashboard() -> impl IntoResponse {
match tokio::fs::read_to_string("static/dashboard.html").await {
Ok(content) => Html(content).into_response(),
Err(_) => (StatusCode::NOT_FOUND, "Page not found").into_response(),
}
}
async fn serve_users() -> impl IntoResponse {
match tokio::fs::read_to_string("static/users.html").await {
Ok(content) => Html(content).into_response(),
Err(_) => (StatusCode::NOT_FOUND, "Page not found").into_response(),
}
}