//! WebSocket handler for agent connections //! //! Handles real-time communication with agents including: //! - Authentication handshake //! - Metrics ingestion //! - Command dispatching //! - Watchdog event handling use std::collections::HashMap; use axum::{ extract::{ ws::{Message, WebSocket}, State, WebSocketUpgrade, }, response::Response, }; use futures_util::{SinkExt, StreamExt}; use serde::{Deserialize, Serialize}; use tokio::sync::mpsc; use tracing::{debug, error, info, warn}; use uuid::Uuid; use crate::db; use crate::AppState; /// Connected agents manager pub struct AgentConnections { /// Map of agent ID to sender channel connections: HashMap>, } impl AgentConnections { pub fn new() -> Self { Self { connections: HashMap::new(), } } /// Add a new agent connection pub fn add(&mut self, agent_id: Uuid, tx: mpsc::Sender) { self.connections.insert(agent_id, tx); } /// Remove an agent connection pub fn remove(&mut self, agent_id: &Uuid) { self.connections.remove(agent_id); } /// Send a message to a specific agent pub async fn send_to(&self, agent_id: &Uuid, msg: ServerMessage) -> bool { if let Some(tx) = self.connections.get(agent_id) { tx.send(msg).await.is_ok() } else { false } } /// Check if an agent is connected pub fn is_connected(&self, agent_id: &Uuid) -> bool { self.connections.contains_key(agent_id) } /// Get count of connected agents pub fn count(&self) -> usize { self.connections.len() } } impl Default for AgentConnections { fn default() -> Self { Self::new() } } /// Messages from agent to server #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type", content = "payload")] #[serde(rename_all = "snake_case")] pub enum AgentMessage { Auth(AuthPayload), Metrics(MetricsPayload), NetworkState(NetworkStatePayload), CommandResult(CommandResultPayload), WatchdogEvent(WatchdogEventPayload), UpdateResult(UpdateResultPayload), Heartbeat, } /// Messages from server to agent #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type", content = "payload")] #[serde(rename_all = "snake_case")] pub enum ServerMessage { AuthAck(AuthAckPayload), Command(CommandPayload), ConfigUpdate(serde_json::Value), Update(UpdatePayload), Ack { message_id: Option }, Error { code: String, message: String }, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct AuthPayload { pub api_key: String, /// Hardware-derived device ID (for site-based registration) #[serde(default)] pub device_id: Option, pub hostname: String, pub os_type: String, pub os_version: String, pub agent_version: String, /// Architecture (amd64, arm64, etc.) #[serde(default = "default_arch")] pub architecture: String, /// Previous version if reconnecting after update #[serde(default)] pub previous_version: Option, /// Update ID if reconnecting after update #[serde(default)] pub pending_update_id: Option, } fn default_arch() -> String { "amd64".to_string() } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct AuthAckPayload { pub success: bool, pub agent_id: Option, pub error: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct MetricsPayload { pub timestamp: chrono::DateTime, pub cpu_percent: f32, pub memory_percent: f32, pub memory_used_bytes: u64, pub memory_total_bytes: u64, pub disk_percent: f32, pub disk_used_bytes: u64, pub disk_total_bytes: u64, pub network_rx_bytes: u64, pub network_tx_bytes: u64, pub os_type: String, pub os_version: String, pub hostname: String, // Extended metrics (optional for backwards compatibility) #[serde(default)] pub uptime_seconds: Option, #[serde(default)] pub boot_time: Option, #[serde(default)] pub logged_in_user: Option, #[serde(default)] pub user_idle_seconds: Option, #[serde(default)] pub public_ip: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct CommandResultPayload { pub command_id: Uuid, pub exit_code: i32, pub stdout: String, pub stderr: String, pub duration_ms: u64, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct WatchdogEventPayload { pub name: String, pub event: String, pub details: Option, } /// Network interface information #[derive(Debug, Clone, Serialize, Deserialize)] pub struct NetworkInterface { pub name: String, pub mac_address: Option, pub ipv4_addresses: Vec, pub ipv6_addresses: Vec, } /// Network state payload from agent #[derive(Debug, Clone, Serialize, Deserialize)] pub struct NetworkStatePayload { pub timestamp: chrono::DateTime, pub interfaces: Vec, pub state_hash: String, } /// Types of commands that can be sent to agents. /// Must match the agent's CommandType enum serialization format. /// Uses snake_case to match the agent's #[serde(rename_all = "snake_case")]. #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum CommandType { /// Shell command (cmd on Windows, sh on Unix) Shell, /// PowerShell command (Windows) PowerShell, /// Python script Python, /// Raw script execution Script, /// Claude Code task execution ClaudeTask { /// Task description for Claude Code task: String, /// Optional working directory working_directory: Option, /// Optional context files context_files: Option>, }, } impl CommandType { /// Parse a command type string from the API into the enum. /// Accepts both snake_case ("power_shell") and common formats ("powershell"). /// Note: ClaudeTask requires additional fields - use `new_claude_task()` instead. pub fn from_api_string(s: &str) -> Result { match s.to_lowercase().as_str() { "shell" => Ok(Self::Shell), "powershell" | "power_shell" => Ok(Self::PowerShell), "python" => Ok(Self::Python), "script" => Ok(Self::Script), "claude_task" | "claudetask" => Err( "claude_task type requires task field - use the claude_task-specific API fields".to_string() ), _ => Err(format!("Unknown command type: '{}'. Valid types: shell, powershell, python, script, claude_task", s)), } } /// Check if a command type string represents a claude_task. pub fn is_claude_task(s: &str) -> bool { matches!(s.to_lowercase().as_str(), "claude_task" | "claudetask") } /// Create a ClaudeTask command type with the required fields. pub fn new_claude_task(task: String, working_directory: Option, context_files: Option>) -> Self { Self::ClaudeTask { task, working_directory, context_files } } /// Convert back to the string format stored in the database. pub fn as_db_string(&self) -> &'static str { match self { Self::Shell => "shell", Self::PowerShell => "powershell", Self::Python => "python", Self::Script => "script", Self::ClaudeTask { .. } => "claude_task", } } } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct CommandPayload { pub id: Uuid, pub command_type: CommandType, pub command: String, pub timeout_seconds: Option, pub elevated: bool, } /// Update command payload from server to agent #[derive(Debug, Clone, Serialize, Deserialize)] pub struct UpdatePayload { /// Unique update ID for tracking pub update_id: Uuid, /// Target version to update to pub target_version: String, /// Download URL for the new binary pub download_url: String, /// SHA256 checksum of the binary pub checksum_sha256: String, /// Whether to force update (skip version check) #[serde(default)] pub force: bool, } /// Update result payload from agent to server #[derive(Debug, Clone, Serialize, Deserialize)] pub struct UpdateResultPayload { /// Update ID (from the server) pub update_id: Uuid, /// Update status pub status: String, /// Old version before update pub old_version: String, /// New version after update (if successful) pub new_version: Option, /// Error message if failed pub error: Option, } /// Result of successful agent authentication struct AuthResult { agent_id: Uuid, agent_version: String, os_type: String, architecture: String, previous_version: Option, pending_update_id: Option, } /// WebSocket upgrade handler pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State) -> Response { ws.on_upgrade(|socket| handle_socket(socket, state)) } /// Handle a WebSocket connection async fn handle_socket(socket: WebSocket, state: AppState) { let (mut sender, mut receiver) = socket.split(); // Create channel for outgoing protocol messages (ServerMessage) let (tx, mut rx) = mpsc::channel::(100); // Create separate channel for raw WebSocket frames (Pong responses) // This allows proper WebSocket protocol compliance without changing the public API let (pong_tx, mut pong_rx) = mpsc::channel::>(16); // Wait for authentication message let auth_result = match authenticate(&mut receiver, &mut sender, &state).await { Ok(result) => { info!("Agent authenticated: {}", result.agent_id); // Send auth success let ack = ServerMessage::AuthAck(AuthAckPayload { success: true, agent_id: Some(result.agent_id), error: None, }); if let Ok(json) = serde_json::to_string(&ack) { let _ = sender.send(Message::Text(json)).await; } // Register connection state.agents.write().await.add(result.agent_id, tx.clone()); // Update agent status let _ = db::update_agent_status(&state.db, result.agent_id, "online").await; // Check if this is a post-update reconnection if let Some(prev_version) = &result.previous_version { if prev_version != &result.agent_version { info!( "Agent {} reconnected after update: {} -> {}", result.agent_id, prev_version, result.agent_version ); // Mark update as completed let _ = db::complete_update_by_agent( &state.db, result.agent_id, result.pending_update_id, prev_version, &result.agent_version, ).await; } } // Check if agent needs update (auto-update enabled) if let Some(available) = state.updates.needs_update( &result.agent_version, &result.os_type, &result.architecture, ).await { info!( "Agent {} needs update: {} -> {}", result.agent_id, result.agent_version, available.version ); let update_id = Uuid::new_v4(); // Record update in database if let Err(e) = db::create_agent_update( &state.db, result.agent_id, update_id, &result.agent_version, &available.version.to_string(), &available.download_url, &available.checksum_sha256, ).await { error!("Failed to record update: {}", e); } else { // Send update command let update_msg = ServerMessage::Update(UpdatePayload { update_id, target_version: available.version.to_string(), download_url: available.download_url.clone(), checksum_sha256: available.checksum_sha256.clone(), force: false, }); if let Err(e) = tx.send(update_msg).await { error!("Failed to send update command: {}", e); } } } result } Err(e) => { error!("Authentication failed: {}", e); let ack = ServerMessage::AuthAck(AuthAckPayload { success: false, agent_id: None, error: Some(e.to_string()), }); if let Ok(json) = serde_json::to_string(&ack) { let _ = sender.send(Message::Text(json)).await; } return; } }; let agent_id = auth_result.agent_id; // Spawn task to forward outgoing messages // Handles both protocol messages (ServerMessage) and raw Pong frames let send_task = tokio::spawn(async move { loop { tokio::select! { // Handle protocol messages (ServerMessage -> JSON text) Some(msg) = rx.recv() => { if let Ok(json) = serde_json::to_string(&msg) { if sender.send(Message::Text(json)).await.is_err() { break; } } } // Handle Pong responses (WebSocket protocol compliance) Some(data) = pong_rx.recv() => { if sender.send(Message::Pong(data)).await.is_err() { break; } } // Both channels closed else => break, } } }); // Handle incoming messages while let Some(msg_result) = receiver.next().await { match msg_result { Ok(Message::Text(text)) => { if let Err(e) = handle_agent_message(&text, agent_id, &state).await { error!("Error handling agent message: {}", e); } } Ok(Message::Ping(data)) => { // WebSocket protocol requires Pong response with same payload // Send via pong channel to the send task if pong_tx.send(data).await.is_err() { warn!("Failed to send Pong response for agent {}", agent_id); break; } } Ok(Message::Close(_)) => { info!("Agent {} disconnected", agent_id); break; } Err(e) => { error!("WebSocket error for agent {}: {}", agent_id, e); break; } _ => {} } } // Cleanup state.agents.write().await.remove(&agent_id); let _ = db::update_agent_status(&state.db, agent_id, "offline").await; send_task.abort(); info!("Agent {} connection closed", agent_id); } /// Authenticate an agent connection /// /// Supports two modes: /// 1. Legacy: API key maps directly to an agent (api_key_hash in agents table) /// 2. Site-based: API key maps to a site, device_id identifies the specific agent async fn authenticate( receiver: &mut futures_util::stream::SplitStream, sender: &mut futures_util::stream::SplitSink, state: &AppState, ) -> anyhow::Result { use tokio::time::{timeout, Duration}; // Wait for auth message with timeout let msg = timeout(Duration::from_secs(10), receiver.next()) .await .map_err(|_| anyhow::anyhow!("Authentication timeout"))? .ok_or_else(|| anyhow::anyhow!("Connection closed before auth"))? .map_err(|e| anyhow::anyhow!("WebSocket error: {}", e))?; let text = match msg { Message::Text(t) => t, _ => return Err(anyhow::anyhow!("Expected text message for auth")), }; let agent_msg: AgentMessage = serde_json::from_str(&text).map_err(|e| anyhow::anyhow!("Invalid auth message: {}", e))?; let auth = match agent_msg { AgentMessage::Auth(a) => a, _ => return Err(anyhow::anyhow!("Expected auth message")), }; // Try site-based authentication first (if device_id is provided) if let Some(device_id) = &auth.device_id { // Check if api_key looks like a site code (WORD-WORD-NUMBER format) let site = if is_site_code_format(&auth.api_key) { info!("Attempting site code authentication: {}", auth.api_key); db::get_site_by_code(&state.db, &auth.api_key) .await .map_err(|e| anyhow::anyhow!("Database error: {}", e))? } else { // Hash the API key and look up by hash let api_key_hash = hash_api_key(&auth.api_key); db::get_site_by_api_key_hash(&state.db, &api_key_hash) .await .map_err(|e| anyhow::anyhow!("Database error: {}", e))? }; if let Some(site) = site { info!("Site-based auth: site={} ({})", site.name, site.id); // Look up or create agent by site_id + device_id let agent = match db::get_agent_by_site_and_device(&state.db, site.id, device_id) .await .map_err(|e| anyhow::anyhow!("Database error: {}", e))? { Some(agent) => { // Update existing agent info let _ = db::update_agent_info_full( &state.db, agent.id, Some(&auth.hostname), Some(device_id), Some(&auth.os_version), Some(&auth.agent_version), ) .await; agent } None => { // Auto-register new agent under this site info!( "Auto-registering new agent: hostname={}, device_id={}, site={}", auth.hostname, device_id, site.name ); db::create_agent_with_site( &state.db, db::CreateAgentWithSite { site_id: site.id, device_id: device_id.clone(), hostname: auth.hostname.clone(), os_type: auth.os_type.clone(), os_version: Some(auth.os_version.clone()), agent_version: Some(auth.agent_version.clone()), }, ) .await .map_err(|e| anyhow::anyhow!("Failed to create agent: {}", e))? } }; return Ok(AuthResult { agent_id: agent.id, agent_version: auth.agent_version.clone(), os_type: auth.os_type.clone(), architecture: auth.architecture.clone(), previous_version: auth.previous_version.clone(), pending_update_id: auth.pending_update_id, }); } } // Fall back to legacy: look up agent directly by API key hash let api_key_hash = hash_api_key(&auth.api_key); let agent = db::get_agent_by_api_key_hash(&state.db, &api_key_hash) .await .map_err(|e| anyhow::anyhow!("Database error: {}", e))? .ok_or_else(|| anyhow::anyhow!("Invalid API key"))?; // Update agent info (including hostname in case it changed) let _ = db::update_agent_info( &state.db, agent.id, Some(&auth.hostname), Some(&auth.os_version), Some(&auth.agent_version), ) .await; Ok(AuthResult { agent_id: agent.id, agent_version: auth.agent_version, os_type: auth.os_type, architecture: auth.architecture, previous_version: auth.previous_version, pending_update_id: auth.pending_update_id, }) } /// Handle a message from an authenticated agent async fn handle_agent_message( text: &str, agent_id: Uuid, state: &AppState, ) -> anyhow::Result<()> { let msg: AgentMessage = serde_json::from_str(text)?; match msg { AgentMessage::Metrics(metrics) => { debug!("Received metrics from agent {}: CPU={:.1}%", agent_id, metrics.cpu_percent); // Store metrics in database let create_metrics = db::CreateMetrics { agent_id, cpu_percent: Some(metrics.cpu_percent), memory_percent: Some(metrics.memory_percent), memory_used_bytes: Some(metrics.memory_used_bytes as i64), disk_percent: Some(metrics.disk_percent), disk_used_bytes: Some(metrics.disk_used_bytes as i64), network_rx_bytes: Some(metrics.network_rx_bytes as i64), network_tx_bytes: Some(metrics.network_tx_bytes as i64), // Extended metrics uptime_seconds: metrics.uptime_seconds.map(|v| v as i64), boot_time: metrics.boot_time, logged_in_user: metrics.logged_in_user.clone(), user_idle_seconds: metrics.user_idle_seconds.map(|v| v as i64), public_ip: metrics.public_ip.clone(), memory_total_bytes: Some(metrics.memory_total_bytes as i64), disk_total_bytes: Some(metrics.disk_total_bytes as i64), }; db::insert_metrics(&state.db, create_metrics).await?; // Also update agent_state for quick access to latest extended info let _ = db::upsert_agent_state( &state.db, agent_id, metrics.uptime_seconds.map(|v| v as i64), metrics.boot_time, metrics.logged_in_user.as_deref(), metrics.user_idle_seconds.map(|v| v as i64), metrics.public_ip.as_deref(), ).await; // Update last_seen db::update_agent_status(&state.db, agent_id, "online").await?; } AgentMessage::CommandResult(result) => { info!( "Received command result from agent {}: command={}, exit={}", agent_id, result.command_id, result.exit_code ); // Update command in database let cmd_result = db::CommandResult { exit_code: result.exit_code, stdout: result.stdout, stderr: result.stderr, }; db::update_command_result(&state.db, result.command_id, cmd_result).await?; } AgentMessage::WatchdogEvent(event) => { info!( "Received watchdog event from agent {}: {} - {}", agent_id, event.name, event.event ); // Store watchdog event (table exists but we'll add the insert function later) // For now, just log it } AgentMessage::NetworkState(network_state) => { debug!( "Received network state from agent {}: {} interfaces", agent_id, network_state.interfaces.len() ); // Log interface details at trace level for iface in &network_state.interfaces { tracing::trace!( " Interface {}: IPv4={:?}, IPv6={:?}", iface.name, iface.ipv4_addresses, iface.ipv6_addresses ); } // Store network state in database if let Ok(interfaces_json) = serde_json::to_value(&network_state.interfaces) { let _ = db::update_agent_network_state( &state.db, agent_id, &interfaces_json, &network_state.state_hash, ).await; } // Update last_seen db::update_agent_status(&state.db, agent_id, "online").await?; } AgentMessage::Heartbeat => { debug!("Received heartbeat from agent {}", agent_id); db::update_agent_status(&state.db, agent_id, "online").await?; } AgentMessage::UpdateResult(result) => { info!( "Received update result from agent {}: update_id={}, status={}", agent_id, result.update_id, result.status ); // Update the agent_updates record match result.status.as_str() { "completed" => { let _ = db::complete_agent_update(&state.db, result.update_id, result.new_version.as_deref()).await; info!("Agent {} successfully updated to {}", agent_id, result.new_version.unwrap_or_default()); } "failed" | "rolled_back" => { let _ = db::fail_agent_update(&state.db, result.update_id, result.error.as_deref()).await; warn!("Agent {} update failed: {}", agent_id, result.error.unwrap_or_default()); } _ => { debug!("Agent {} update status: {}", agent_id, result.status); } } } AgentMessage::Auth(_) => { warn!("Received unexpected auth message from already authenticated agent"); } } Ok(()) } /// Hash an API key for storage/lookup pub fn hash_api_key(api_key: &str) -> String { use sha2::{Digest, Sha256}; let mut hasher = Sha256::new(); hasher.update(api_key.as_bytes()); format!("{:x}", hasher.finalize()) } /// Generate a new API key pub fn generate_api_key(prefix: &str) -> String { use rand::Rng; let random_bytes: [u8; 24] = rand::thread_rng().gen(); let encoded = base64::Engine::encode(&base64::engine::general_purpose::URL_SAFE_NO_PAD, random_bytes); format!("{}{}", prefix, encoded) } /// Check if a string looks like a site code (WORD-WORD-NUMBER format) /// Examples: SWIFT-CLOUD-6910, APPLE-GREEN-9145 fn is_site_code_format(s: &str) -> bool { let parts: Vec<&str> = s.split('-').collect(); if parts.len() != 3 { return false; } // First two parts should be alphabetic (words) // Third part should be numeric (4 digits) parts[0].chars().all(|c| c.is_ascii_alphabetic()) && parts[1].chars().all(|c| c.is_ascii_alphabetic()) && parts[2].chars().all(|c| c.is_ascii_digit()) && parts[2].len() == 4 }