From 611bc00d06c4a1a9b121b3c16ed87537538861e2 Mon Sep 17 00:00:00 2001 From: Mike Swanson Date: Sun, 28 Dec 2025 17:54:05 +0000 Subject: [PATCH] Add support codes API and portal server changes - support_codes.rs: 6-digit code management - main.rs: Portal routes, static file serving, AppState - relay/mod.rs: Updated for AppState - Cargo.toml: Added rand, tower-http fs feature Generated with Claude Code --- Cargo.lock | 30 ++++++ server/Cargo.toml | 3 +- server/src/main.rs | 117 ++++++++++++++++++--- server/src/relay/mod.rs | 15 ++- server/src/support_codes.rs | 199 ++++++++++++++++++++++++++++++++++++ 5 files changed, 347 insertions(+), 17 deletions(-) create mode 100644 server/src/support_codes.rs diff --git a/Cargo.lock b/Cargo.lock index 0096198..86cd7d9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -691,6 +691,7 @@ dependencies = [ "prost", "prost-build", "prost-types", + "rand", "ring", "serde", "serde_json", @@ -814,6 +815,12 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "http-range-header" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9171a2ea8a68358193d15dd5d70c1c10a2afc3e7e4c5bc92bc9f025cebd7359c" + [[package]] name = "httparse" version = "1.10.1" @@ -1155,6 +1162,16 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "mime_guess" +version = "2.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e" +dependencies = [ + "mime", + "unicase", +] + [[package]] name = "miniz_oxide" version = "0.8.9" @@ -2419,8 +2436,15 @@ dependencies = [ "bitflags", "bytes", "futures-core", + "futures-util", "http", "http-body", + "http-body-util", + "http-range-header", + "httpdate", + "mime", + "mime_guess", + "percent-encoding", "pin-project-lite", "tokio", "tokio-util", @@ -2528,6 +2552,12 @@ version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" +[[package]] +name = "unicase" +version = "2.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539" + [[package]] name = "unicode-bidi" version = "0.3.18" diff --git a/server/Cargo.toml b/server/Cargo.toml index 15f3932..da4f629 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -12,7 +12,7 @@ tokio = { version = "1", features = ["full", "sync", "time", "rt-multi-thread", # Web framework axum = { version = "0.7", features = ["ws", "macros"] } tower = "0.5" -tower-http = { version = "0.6", features = ["cors", "trace", "compression-gzip"] } +tower-http = { version = "0.6", features = ["cors", "trace", "compression-gzip", "fs"] } # WebSocket futures-util = "0.3" @@ -52,6 +52,7 @@ uuid = { version = "1", features = ["v4", "serde"] } # Time chrono = { version = "0.4", features = ["serde"] } +rand = "0.8" [build-dependencies] prost-build = "0.13" diff --git a/server/src/main.rs b/server/src/main.rs index 183660d..521895a 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -9,6 +9,7 @@ mod session; mod auth; mod api; mod db; +mod support_codes; pub mod proto { include!(concat!(env!("OUT_DIR"), "/guruconnect.rs")); @@ -17,13 +18,27 @@ pub mod proto { use anyhow::Result; use axum::{ Router, - routing::get, + routing::{get, post}, + extract::{Path, State, Json}, + response::{Html, IntoResponse}, + http::StatusCode, }; use std::net::SocketAddr; use tower_http::cors::{Any, CorsLayer}; use tower_http::trace::TraceLayer; +use tower_http::services::ServeDir; use tracing::{info, Level}; use tracing_subscriber::FmtSubscriber; +use serde::Deserialize; + +use support_codes::{SupportCodeManager, CreateCodeRequest, SupportCode, CodeValidation}; + +/// Application state +#[derive(Clone)] +pub struct AppState { + sessions: session::SessionManager, + support_codes: SupportCodeManager, +} #[tokio::main] async fn main() -> Result<()> { @@ -37,26 +52,42 @@ async fn main() -> Result<()> { // Load configuration let config = config::Config::load()?; - info!("Loaded configuration, listening on {}", config.listen_addr); + + // Use port 3002 for GuruConnect + let listen_addr = std::env::var("LISTEN_ADDR").unwrap_or_else(|_| "0.0.0.0:3002".to_string()); + info!("Loaded configuration, listening on {}", listen_addr); - // Initialize database connection (optional for MVP) - // let db = db::init(&config.database_url).await?; - - // Create session manager - let sessions = session::SessionManager::new(); + // Create application state + let state = AppState { + sessions: session::SessionManager::new(), + support_codes: SupportCodeManager::new(), + }; // Build router let app = Router::new() // Health check .route("/health", get(health)) + + // Portal API - Support codes + .route("/api/codes", post(create_code)) + .route("/api/codes", get(list_codes)) + .route("/api/codes/:code/validate", get(validate_code)) + .route("/api/codes/:code/cancel", post(cancel_code)) + // WebSocket endpoints .route("/ws/agent", get(relay::agent_ws_handler)) .route("/ws/viewer", get(relay::viewer_ws_handler)) - // REST API - .route("/api/sessions", get(api::list_sessions)) - .route("/api/sessions/:id", get(api::get_session)) + + // REST API - Sessions + .route("/api/sessions", get(list_sessions)) + .route("/api/sessions/:id", get(get_session)) + // State - .with_state(sessions) + .with_state(state) + + // Serve static files for portal (fallback) + .fallback_service(ServeDir::new("static").append_index_html_on_directories(true)) + // Middleware .layer(TraceLayer::new_for_http()) .layer( @@ -67,7 +98,7 @@ async fn main() -> Result<()> { ); // Start server - let addr: SocketAddr = config.listen_addr.parse()?; + let addr: SocketAddr = listen_addr.parse()?; let listener = tokio::net::TcpListener::bind(addr).await?; info!("Server listening on {}", addr); @@ -80,3 +111,65 @@ async fn main() -> Result<()> { async fn health() -> &'static str { "OK" } + +// Support code API handlers + +async fn create_code( + State(state): State, + Json(request): Json, +) -> Json { + let code = state.support_codes.create_code(request).await; + info!("Created support code: {}", code.code); + Json(code) +} + +async fn list_codes( + State(state): State, +) -> Json> { + Json(state.support_codes.list_active_codes().await) +} + +#[derive(Deserialize)] +struct ValidateParams { + code: String, +} + +async fn validate_code( + State(state): State, + Path(code): Path, +) -> Json { + Json(state.support_codes.validate_code(&code).await) +} + +async fn cancel_code( + State(state): State, + Path(code): Path, +) -> impl IntoResponse { + if state.support_codes.cancel_code(&code).await { + (StatusCode::OK, "Code cancelled") + } else { + (StatusCode::BAD_REQUEST, "Cannot cancel code") + } +} + +// Session API handlers (updated to use AppState) + +async fn list_sessions( + State(state): State, +) -> Json> { + let sessions = state.sessions.list_sessions().await; + Json(sessions.into_iter().map(api::SessionInfo::from).collect()) +} + +async fn get_session( + State(state): State, + Path(id): Path, +) -> Result, (StatusCode, &'static str)> { + let session_id = uuid::Uuid::parse_str(&id) + .map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID"))?; + + let session = state.sessions.get_session(session_id).await + .ok_or((StatusCode::NOT_FOUND, "Session not found"))?; + + Ok(Json(api::SessionInfo::from(session))) +} diff --git a/server/src/relay/mod.rs b/server/src/relay/mod.rs index 5ef72a8..f895dff 100644 --- a/server/src/relay/mod.rs +++ b/server/src/relay/mod.rs @@ -17,6 +17,7 @@ use tracing::{error, info, warn}; use crate::proto; use crate::session::SessionManager; +use crate::AppState; #[derive(Debug, Deserialize)] pub struct AgentParams { @@ -33,11 +34,12 @@ pub struct ViewerParams { /// WebSocket handler for agent connections pub async fn agent_ws_handler( ws: WebSocketUpgrade, - State(sessions): State, + State(state): State, Query(params): Query, ) -> impl IntoResponse { let agent_id = params.agent_id; let agent_name = params.agent_name.unwrap_or_else(|| agent_id.clone()); + let sessions = state.sessions.clone(); ws.on_upgrade(move |socket| handle_agent_connection(socket, sessions, agent_id, agent_name)) } @@ -45,10 +47,11 @@ pub async fn agent_ws_handler( /// WebSocket handler for viewer connections pub async fn viewer_ws_handler( ws: WebSocketUpgrade, - State(sessions): State, + State(state): State, Query(params): Query, ) -> impl IntoResponse { let session_id = params.session_id; + let sessions = state.sessions.clone(); ws.on_upgrade(move |socket| handle_viewer_connection(socket, sessions, session_id)) } @@ -78,6 +81,8 @@ async fn handle_agent_connection( } }); + let sessions_cleanup = sessions.clone(); + // Main loop: receive frames from agent and broadcast to viewers while let Some(msg) = ws_receiver.next().await { match msg { @@ -113,7 +118,7 @@ async fn handle_agent_connection( // Cleanup input_forward.abort(); - sessions.remove_session(session_id).await; + sessions_cleanup.remove_session(session_id).await; info!("Session {} ended", session_id); } @@ -154,6 +159,8 @@ async fn handle_viewer_connection( } }); + let sessions_cleanup = sessions.clone(); + // Main loop: receive input from viewer and forward to agent while let Some(msg) = ws_receiver.next().await { match msg { @@ -189,6 +196,6 @@ async fn handle_viewer_connection( // Cleanup frame_forward.abort(); - sessions.leave_session(session_id).await; + sessions_cleanup.leave_session(session_id).await; info!("Viewer left session: {}", session_id); } diff --git a/server/src/support_codes.rs b/server/src/support_codes.rs new file mode 100644 index 0000000..a8f93d2 --- /dev/null +++ b/server/src/support_codes.rs @@ -0,0 +1,199 @@ +//! Support session codes management +//! +//! Handles generation and validation of 6-digit support codes +//! for one-time remote support sessions. + +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::RwLock; +use chrono::{DateTime, Utc}; +use rand::Rng; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +/// A support session code +#[derive(Debug, Clone, Serialize)] +pub struct SupportCode { + pub code: String, + pub session_id: Uuid, + pub created_by: String, + pub created_at: DateTime, + pub status: CodeStatus, + pub client_name: Option, + pub client_machine: Option, + pub connected_at: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum CodeStatus { + Pending, // Waiting for client to connect + Connected, // Client connected, session active + Completed, // Session ended normally + Cancelled, // Code cancelled by tech +} + +/// Request to create a new support code +#[derive(Debug, Deserialize)] +pub struct CreateCodeRequest { + pub technician_id: Option, + pub technician_name: Option, +} + +/// Response when a code is validated +#[derive(Debug, Serialize)] +pub struct CodeValidation { + pub valid: bool, + pub session_id: Option, + pub server_url: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, +} + +/// Manages support codes +#[derive(Clone)] +pub struct SupportCodeManager { + codes: Arc>>, + session_to_code: Arc>>, +} + +impl SupportCodeManager { + pub fn new() -> Self { + Self { + codes: Arc::new(RwLock::new(HashMap::new())), + session_to_code: Arc::new(RwLock::new(HashMap::new())), + } + } + + /// Generate a unique 6-digit code + async fn generate_unique_code(&self) -> String { + let codes = self.codes.read().await; + let mut rng = rand::thread_rng(); + + loop { + let code: u32 = rng.gen_range(100000..999999); + let code_str = code.to_string(); + + if !codes.contains_key(&code_str) { + return code_str; + } + } + } + + /// Create a new support code + pub async fn create_code(&self, request: CreateCodeRequest) -> SupportCode { + let code = self.generate_unique_code().await; + let session_id = Uuid::new_v4(); + + let support_code = SupportCode { + code: code.clone(), + session_id, + created_by: request.technician_name.unwrap_or_else(|| "Unknown".to_string()), + created_at: Utc::now(), + status: CodeStatus::Pending, + client_name: None, + client_machine: None, + connected_at: None, + }; + + let mut codes = self.codes.write().await; + codes.insert(code.clone(), support_code.clone()); + + let mut session_to_code = self.session_to_code.write().await; + session_to_code.insert(session_id, code); + + support_code + } + + /// Validate a code and return session info + pub async fn validate_code(&self, code: &str) -> CodeValidation { + let codes = self.codes.read().await; + + match codes.get(code) { + Some(support_code) => { + if support_code.status == CodeStatus::Pending || support_code.status == CodeStatus::Connected { + CodeValidation { + valid: true, + session_id: Some(support_code.session_id.to_string()), + server_url: Some("wss://connect.azcomputerguru.com/ws/support".to_string()), + error: None, + } + } else { + CodeValidation { + valid: false, + session_id: None, + server_url: None, + error: Some("This code has expired or been used".to_string()), + } + } + } + None => CodeValidation { + valid: false, + session_id: None, + server_url: None, + error: Some("Invalid code".to_string()), + }, + } + } + + /// Mark a code as connected + pub async fn mark_connected(&self, code: &str, client_name: Option, client_machine: Option) { + let mut codes = self.codes.write().await; + if let Some(support_code) = codes.get_mut(code) { + support_code.status = CodeStatus::Connected; + support_code.client_name = client_name; + support_code.client_machine = client_machine; + support_code.connected_at = Some(Utc::now()); + } + } + + /// Mark a code as completed + pub async fn mark_completed(&self, code: &str) { + let mut codes = self.codes.write().await; + if let Some(support_code) = codes.get_mut(code) { + support_code.status = CodeStatus::Completed; + } + } + + /// Cancel a code + pub async fn cancel_code(&self, code: &str) -> bool { + let mut codes = self.codes.write().await; + if let Some(support_code) = codes.get_mut(code) { + if support_code.status == CodeStatus::Pending { + support_code.status = CodeStatus::Cancelled; + return true; + } + } + false + } + + /// List all codes (for dashboard) + pub async fn list_codes(&self) -> Vec { + let codes = self.codes.read().await; + codes.values().cloned().collect() + } + + /// List active codes only + pub async fn list_active_codes(&self) -> Vec { + let codes = self.codes.read().await; + codes.values() + .filter(|c| c.status == CodeStatus::Pending || c.status == CodeStatus::Connected) + .cloned() + .collect() + } + + /// Get code by session ID + pub async fn get_by_session(&self, session_id: Uuid) -> Option { + let session_to_code = self.session_to_code.read().await; + let code = session_to_code.get(&session_id)?; + + let codes = self.codes.read().await; + codes.get(code).cloned() + } +} + +impl Default for SupportCodeManager { + fn default() -> Self { + Self::new() + } +}