Files
claudetools/projects/msp-tools/guru-rmm/server/src/ws/mod.rs
Mike Swanson 8b6f0bcc96 sync: Multi-project updates - SolverBot, GuruRMM, Dataforth
SolverBot:
- Inject active project path into agent system prompts so agents
  know which directory to scope file operations to

GuruRMM:
- Bump agent version to 0.6.0
- Add serde aliases for PowerShell/ClaudeTask command types
- Add typed CommandType enum on server for proper serialization
- Support claude_task command type in send_command API

Dataforth:
- Fix SCP space-escaping in Sync-FromNAS.ps1

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-18 16:16:18 -07:00

786 lines
27 KiB
Rust

//! WebSocket handler for agent connections
//!
//! Handles real-time communication with agents including:
//! - Authentication handshake
//! - Metrics ingestion
//! - Command dispatching
//! - Watchdog event handling
use std::collections::HashMap;
use axum::{
extract::{
ws::{Message, WebSocket},
State, WebSocketUpgrade,
},
response::Response,
};
use futures_util::{SinkExt, StreamExt};
use serde::{Deserialize, Serialize};
use tokio::sync::mpsc;
use tracing::{debug, error, info, warn};
use uuid::Uuid;
use crate::db;
use crate::AppState;
/// Connected agents manager
pub struct AgentConnections {
/// Map of agent ID to sender channel
connections: HashMap<Uuid, mpsc::Sender<ServerMessage>>,
}
impl AgentConnections {
pub fn new() -> Self {
Self {
connections: HashMap::new(),
}
}
/// Add a new agent connection
pub fn add(&mut self, agent_id: Uuid, tx: mpsc::Sender<ServerMessage>) {
self.connections.insert(agent_id, tx);
}
/// Remove an agent connection
pub fn remove(&mut self, agent_id: &Uuid) {
self.connections.remove(agent_id);
}
/// Send a message to a specific agent
pub async fn send_to(&self, agent_id: &Uuid, msg: ServerMessage) -> bool {
if let Some(tx) = self.connections.get(agent_id) {
tx.send(msg).await.is_ok()
} else {
false
}
}
/// Check if an agent is connected
pub fn is_connected(&self, agent_id: &Uuid) -> bool {
self.connections.contains_key(agent_id)
}
/// Get count of connected agents
pub fn count(&self) -> usize {
self.connections.len()
}
}
impl Default for AgentConnections {
fn default() -> Self {
Self::new()
}
}
/// Messages from agent to server
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", content = "payload")]
#[serde(rename_all = "snake_case")]
pub enum AgentMessage {
Auth(AuthPayload),
Metrics(MetricsPayload),
NetworkState(NetworkStatePayload),
CommandResult(CommandResultPayload),
WatchdogEvent(WatchdogEventPayload),
UpdateResult(UpdateResultPayload),
Heartbeat,
}
/// Messages from server to agent
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", content = "payload")]
#[serde(rename_all = "snake_case")]
pub enum ServerMessage {
AuthAck(AuthAckPayload),
Command(CommandPayload),
ConfigUpdate(serde_json::Value),
Update(UpdatePayload),
Ack { message_id: Option<String> },
Error { code: String, message: String },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthPayload {
pub api_key: String,
/// Hardware-derived device ID (for site-based registration)
#[serde(default)]
pub device_id: Option<String>,
pub hostname: String,
pub os_type: String,
pub os_version: String,
pub agent_version: String,
/// Architecture (amd64, arm64, etc.)
#[serde(default = "default_arch")]
pub architecture: String,
/// Previous version if reconnecting after update
#[serde(default)]
pub previous_version: Option<String>,
/// Update ID if reconnecting after update
#[serde(default)]
pub pending_update_id: Option<Uuid>,
}
fn default_arch() -> String {
"amd64".to_string()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthAckPayload {
pub success: bool,
pub agent_id: Option<Uuid>,
pub error: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MetricsPayload {
pub timestamp: chrono::DateTime<chrono::Utc>,
pub cpu_percent: f32,
pub memory_percent: f32,
pub memory_used_bytes: u64,
pub memory_total_bytes: u64,
pub disk_percent: f32,
pub disk_used_bytes: u64,
pub disk_total_bytes: u64,
pub network_rx_bytes: u64,
pub network_tx_bytes: u64,
pub os_type: String,
pub os_version: String,
pub hostname: String,
// Extended metrics (optional for backwards compatibility)
#[serde(default)]
pub uptime_seconds: Option<u64>,
#[serde(default)]
pub boot_time: Option<i64>,
#[serde(default)]
pub logged_in_user: Option<String>,
#[serde(default)]
pub user_idle_seconds: Option<u64>,
#[serde(default)]
pub public_ip: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CommandResultPayload {
pub command_id: Uuid,
pub exit_code: i32,
pub stdout: String,
pub stderr: String,
pub duration_ms: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WatchdogEventPayload {
pub name: String,
pub event: String,
pub details: Option<String>,
}
/// Network interface information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NetworkInterface {
pub name: String,
pub mac_address: Option<String>,
pub ipv4_addresses: Vec<String>,
pub ipv6_addresses: Vec<String>,
}
/// Network state payload from agent
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NetworkStatePayload {
pub timestamp: chrono::DateTime<chrono::Utc>,
pub interfaces: Vec<NetworkInterface>,
pub state_hash: String,
}
/// Types of commands that can be sent to agents.
/// Must match the agent's CommandType enum serialization format.
/// Uses snake_case to match the agent's #[serde(rename_all = "snake_case")].
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CommandType {
/// Shell command (cmd on Windows, sh on Unix)
Shell,
/// PowerShell command (Windows)
PowerShell,
/// Python script
Python,
/// Raw script execution
Script,
/// Claude Code task execution
ClaudeTask {
/// Task description for Claude Code
task: String,
/// Optional working directory
working_directory: Option<String>,
/// Optional context files
context_files: Option<Vec<String>>,
},
}
impl CommandType {
/// Parse a command type string from the API into the enum.
/// Accepts both snake_case ("power_shell") and common formats ("powershell").
/// Note: ClaudeTask requires additional fields - use `new_claude_task()` instead.
pub fn from_api_string(s: &str) -> Result<Self, String> {
match s.to_lowercase().as_str() {
"shell" => Ok(Self::Shell),
"powershell" | "power_shell" => Ok(Self::PowerShell),
"python" => Ok(Self::Python),
"script" => Ok(Self::Script),
"claude_task" | "claudetask" => Err(
"claude_task type requires task field - use the claude_task-specific API fields".to_string()
),
_ => Err(format!("Unknown command type: '{}'. Valid types: shell, powershell, python, script, claude_task", s)),
}
}
/// Check if a command type string represents a claude_task.
pub fn is_claude_task(s: &str) -> bool {
matches!(s.to_lowercase().as_str(), "claude_task" | "claudetask")
}
/// Create a ClaudeTask command type with the required fields.
pub fn new_claude_task(task: String, working_directory: Option<String>, context_files: Option<Vec<String>>) -> Self {
Self::ClaudeTask { task, working_directory, context_files }
}
/// Convert back to the string format stored in the database.
pub fn as_db_string(&self) -> &'static str {
match self {
Self::Shell => "shell",
Self::PowerShell => "powershell",
Self::Python => "python",
Self::Script => "script",
Self::ClaudeTask { .. } => "claude_task",
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CommandPayload {
pub id: Uuid,
pub command_type: CommandType,
pub command: String,
pub timeout_seconds: Option<u64>,
pub elevated: bool,
}
/// Update command payload from server to agent
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UpdatePayload {
/// Unique update ID for tracking
pub update_id: Uuid,
/// Target version to update to
pub target_version: String,
/// Download URL for the new binary
pub download_url: String,
/// SHA256 checksum of the binary
pub checksum_sha256: String,
/// Whether to force update (skip version check)
#[serde(default)]
pub force: bool,
}
/// Update result payload from agent to server
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UpdateResultPayload {
/// Update ID (from the server)
pub update_id: Uuid,
/// Update status
pub status: String,
/// Old version before update
pub old_version: String,
/// New version after update (if successful)
pub new_version: Option<String>,
/// Error message if failed
pub error: Option<String>,
}
/// Result of successful agent authentication
struct AuthResult {
agent_id: Uuid,
agent_version: String,
os_type: String,
architecture: String,
previous_version: Option<String>,
pending_update_id: Option<Uuid>,
}
/// WebSocket upgrade handler
pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<AppState>) -> Response {
ws.on_upgrade(|socket| handle_socket(socket, state))
}
/// Handle a WebSocket connection
async fn handle_socket(socket: WebSocket, state: AppState) {
let (mut sender, mut receiver) = socket.split();
// Create channel for outgoing protocol messages (ServerMessage)
let (tx, mut rx) = mpsc::channel::<ServerMessage>(100);
// Create separate channel for raw WebSocket frames (Pong responses)
// This allows proper WebSocket protocol compliance without changing the public API
let (pong_tx, mut pong_rx) = mpsc::channel::<Vec<u8>>(16);
// Wait for authentication message
let auth_result = match authenticate(&mut receiver, &mut sender, &state).await {
Ok(result) => {
info!("Agent authenticated: {}", result.agent_id);
// Send auth success
let ack = ServerMessage::AuthAck(AuthAckPayload {
success: true,
agent_id: Some(result.agent_id),
error: None,
});
if let Ok(json) = serde_json::to_string(&ack) {
let _ = sender.send(Message::Text(json)).await;
}
// Register connection
state.agents.write().await.add(result.agent_id, tx.clone());
// Update agent status
let _ = db::update_agent_status(&state.db, result.agent_id, "online").await;
// Check if this is a post-update reconnection
if let Some(prev_version) = &result.previous_version {
if prev_version != &result.agent_version {
info!(
"Agent {} reconnected after update: {} -> {}",
result.agent_id, prev_version, result.agent_version
);
// Mark update as completed
let _ = db::complete_update_by_agent(
&state.db,
result.agent_id,
result.pending_update_id,
prev_version,
&result.agent_version,
).await;
}
}
// Check if agent needs update (auto-update enabled)
if let Some(available) = state.updates.needs_update(
&result.agent_version,
&result.os_type,
&result.architecture,
).await {
info!(
"Agent {} needs update: {} -> {}",
result.agent_id, result.agent_version, available.version
);
let update_id = Uuid::new_v4();
// Record update in database
if let Err(e) = db::create_agent_update(
&state.db,
result.agent_id,
update_id,
&result.agent_version,
&available.version.to_string(),
&available.download_url,
&available.checksum_sha256,
).await {
error!("Failed to record update: {}", e);
} else {
// Send update command
let update_msg = ServerMessage::Update(UpdatePayload {
update_id,
target_version: available.version.to_string(),
download_url: available.download_url.clone(),
checksum_sha256: available.checksum_sha256.clone(),
force: false,
});
if let Err(e) = tx.send(update_msg).await {
error!("Failed to send update command: {}", e);
}
}
}
result
}
Err(e) => {
error!("Authentication failed: {}", e);
let ack = ServerMessage::AuthAck(AuthAckPayload {
success: false,
agent_id: None,
error: Some(e.to_string()),
});
if let Ok(json) = serde_json::to_string(&ack) {
let _ = sender.send(Message::Text(json)).await;
}
return;
}
};
let agent_id = auth_result.agent_id;
// Spawn task to forward outgoing messages
// Handles both protocol messages (ServerMessage) and raw Pong frames
let send_task = tokio::spawn(async move {
loop {
tokio::select! {
// Handle protocol messages (ServerMessage -> JSON text)
Some(msg) = rx.recv() => {
if let Ok(json) = serde_json::to_string(&msg) {
if sender.send(Message::Text(json)).await.is_err() {
break;
}
}
}
// Handle Pong responses (WebSocket protocol compliance)
Some(data) = pong_rx.recv() => {
if sender.send(Message::Pong(data)).await.is_err() {
break;
}
}
// Both channels closed
else => break,
}
}
});
// Handle incoming messages
while let Some(msg_result) = receiver.next().await {
match msg_result {
Ok(Message::Text(text)) => {
if let Err(e) = handle_agent_message(&text, agent_id, &state).await {
error!("Error handling agent message: {}", e);
}
}
Ok(Message::Ping(data)) => {
// WebSocket protocol requires Pong response with same payload
// Send via pong channel to the send task
if pong_tx.send(data).await.is_err() {
warn!("Failed to send Pong response for agent {}", agent_id);
break;
}
}
Ok(Message::Close(_)) => {
info!("Agent {} disconnected", agent_id);
break;
}
Err(e) => {
error!("WebSocket error for agent {}: {}", agent_id, e);
break;
}
_ => {}
}
}
// Cleanup
state.agents.write().await.remove(&agent_id);
let _ = db::update_agent_status(&state.db, agent_id, "offline").await;
send_task.abort();
info!("Agent {} connection closed", agent_id);
}
/// Authenticate an agent connection
///
/// Supports two modes:
/// 1. Legacy: API key maps directly to an agent (api_key_hash in agents table)
/// 2. Site-based: API key maps to a site, device_id identifies the specific agent
async fn authenticate(
receiver: &mut futures_util::stream::SplitStream<WebSocket>,
sender: &mut futures_util::stream::SplitSink<WebSocket, Message>,
state: &AppState,
) -> anyhow::Result<AuthResult> {
use tokio::time::{timeout, Duration};
// Wait for auth message with timeout
let msg = timeout(Duration::from_secs(10), receiver.next())
.await
.map_err(|_| anyhow::anyhow!("Authentication timeout"))?
.ok_or_else(|| anyhow::anyhow!("Connection closed before auth"))?
.map_err(|e| anyhow::anyhow!("WebSocket error: {}", e))?;
let text = match msg {
Message::Text(t) => t,
_ => return Err(anyhow::anyhow!("Expected text message for auth")),
};
let agent_msg: AgentMessage =
serde_json::from_str(&text).map_err(|e| anyhow::anyhow!("Invalid auth message: {}", e))?;
let auth = match agent_msg {
AgentMessage::Auth(a) => a,
_ => return Err(anyhow::anyhow!("Expected auth message")),
};
// Try site-based authentication first (if device_id is provided)
if let Some(device_id) = &auth.device_id {
// Check if api_key looks like a site code (WORD-WORD-NUMBER format)
let site = if is_site_code_format(&auth.api_key) {
info!("Attempting site code authentication: {}", auth.api_key);
db::get_site_by_code(&state.db, &auth.api_key)
.await
.map_err(|e| anyhow::anyhow!("Database error: {}", e))?
} else {
// Hash the API key and look up by hash
let api_key_hash = hash_api_key(&auth.api_key);
db::get_site_by_api_key_hash(&state.db, &api_key_hash)
.await
.map_err(|e| anyhow::anyhow!("Database error: {}", e))?
};
if let Some(site) = site {
info!("Site-based auth: site={} ({})", site.name, site.id);
// Look up or create agent by site_id + device_id
let agent = match db::get_agent_by_site_and_device(&state.db, site.id, device_id)
.await
.map_err(|e| anyhow::anyhow!("Database error: {}", e))?
{
Some(agent) => {
// Update existing agent info
let _ = db::update_agent_info_full(
&state.db,
agent.id,
Some(&auth.hostname),
Some(device_id),
Some(&auth.os_version),
Some(&auth.agent_version),
)
.await;
agent
}
None => {
// Auto-register new agent under this site
info!(
"Auto-registering new agent: hostname={}, device_id={}, site={}",
auth.hostname, device_id, site.name
);
db::create_agent_with_site(
&state.db,
db::CreateAgentWithSite {
site_id: site.id,
device_id: device_id.clone(),
hostname: auth.hostname.clone(),
os_type: auth.os_type.clone(),
os_version: Some(auth.os_version.clone()),
agent_version: Some(auth.agent_version.clone()),
},
)
.await
.map_err(|e| anyhow::anyhow!("Failed to create agent: {}", e))?
}
};
return Ok(AuthResult {
agent_id: agent.id,
agent_version: auth.agent_version.clone(),
os_type: auth.os_type.clone(),
architecture: auth.architecture.clone(),
previous_version: auth.previous_version.clone(),
pending_update_id: auth.pending_update_id,
});
}
}
// Fall back to legacy: look up agent directly by API key hash
let api_key_hash = hash_api_key(&auth.api_key);
let agent = db::get_agent_by_api_key_hash(&state.db, &api_key_hash)
.await
.map_err(|e| anyhow::anyhow!("Database error: {}", e))?
.ok_or_else(|| anyhow::anyhow!("Invalid API key"))?;
// Update agent info (including hostname in case it changed)
let _ = db::update_agent_info(
&state.db,
agent.id,
Some(&auth.hostname),
Some(&auth.os_version),
Some(&auth.agent_version),
)
.await;
Ok(AuthResult {
agent_id: agent.id,
agent_version: auth.agent_version,
os_type: auth.os_type,
architecture: auth.architecture,
previous_version: auth.previous_version,
pending_update_id: auth.pending_update_id,
})
}
/// Handle a message from an authenticated agent
async fn handle_agent_message(
text: &str,
agent_id: Uuid,
state: &AppState,
) -> anyhow::Result<()> {
let msg: AgentMessage = serde_json::from_str(text)?;
match msg {
AgentMessage::Metrics(metrics) => {
debug!("Received metrics from agent {}: CPU={:.1}%", agent_id, metrics.cpu_percent);
// Store metrics in database
let create_metrics = db::CreateMetrics {
agent_id,
cpu_percent: Some(metrics.cpu_percent),
memory_percent: Some(metrics.memory_percent),
memory_used_bytes: Some(metrics.memory_used_bytes as i64),
disk_percent: Some(metrics.disk_percent),
disk_used_bytes: Some(metrics.disk_used_bytes as i64),
network_rx_bytes: Some(metrics.network_rx_bytes as i64),
network_tx_bytes: Some(metrics.network_tx_bytes as i64),
// Extended metrics
uptime_seconds: metrics.uptime_seconds.map(|v| v as i64),
boot_time: metrics.boot_time,
logged_in_user: metrics.logged_in_user.clone(),
user_idle_seconds: metrics.user_idle_seconds.map(|v| v as i64),
public_ip: metrics.public_ip.clone(),
memory_total_bytes: Some(metrics.memory_total_bytes as i64),
disk_total_bytes: Some(metrics.disk_total_bytes as i64),
};
db::insert_metrics(&state.db, create_metrics).await?;
// Also update agent_state for quick access to latest extended info
let _ = db::upsert_agent_state(
&state.db,
agent_id,
metrics.uptime_seconds.map(|v| v as i64),
metrics.boot_time,
metrics.logged_in_user.as_deref(),
metrics.user_idle_seconds.map(|v| v as i64),
metrics.public_ip.as_deref(),
).await;
// Update last_seen
db::update_agent_status(&state.db, agent_id, "online").await?;
}
AgentMessage::CommandResult(result) => {
info!(
"Received command result from agent {}: command={}, exit={}",
agent_id, result.command_id, result.exit_code
);
// Update command in database
let cmd_result = db::CommandResult {
exit_code: result.exit_code,
stdout: result.stdout,
stderr: result.stderr,
};
db::update_command_result(&state.db, result.command_id, cmd_result).await?;
}
AgentMessage::WatchdogEvent(event) => {
info!(
"Received watchdog event from agent {}: {} - {}",
agent_id, event.name, event.event
);
// Store watchdog event (table exists but we'll add the insert function later)
// For now, just log it
}
AgentMessage::NetworkState(network_state) => {
debug!(
"Received network state from agent {}: {} interfaces",
agent_id,
network_state.interfaces.len()
);
// Log interface details at trace level
for iface in &network_state.interfaces {
tracing::trace!(
" Interface {}: IPv4={:?}, IPv6={:?}",
iface.name,
iface.ipv4_addresses,
iface.ipv6_addresses
);
}
// Store network state in database
if let Ok(interfaces_json) = serde_json::to_value(&network_state.interfaces) {
let _ = db::update_agent_network_state(
&state.db,
agent_id,
&interfaces_json,
&network_state.state_hash,
).await;
}
// Update last_seen
db::update_agent_status(&state.db, agent_id, "online").await?;
}
AgentMessage::Heartbeat => {
debug!("Received heartbeat from agent {}", agent_id);
db::update_agent_status(&state.db, agent_id, "online").await?;
}
AgentMessage::UpdateResult(result) => {
info!(
"Received update result from agent {}: update_id={}, status={}",
agent_id, result.update_id, result.status
);
// Update the agent_updates record
match result.status.as_str() {
"completed" => {
let _ = db::complete_agent_update(&state.db, result.update_id, result.new_version.as_deref()).await;
info!("Agent {} successfully updated to {}", agent_id, result.new_version.unwrap_or_default());
}
"failed" | "rolled_back" => {
let _ = db::fail_agent_update(&state.db, result.update_id, result.error.as_deref()).await;
warn!("Agent {} update failed: {}", agent_id, result.error.unwrap_or_default());
}
_ => {
debug!("Agent {} update status: {}", agent_id, result.status);
}
}
}
AgentMessage::Auth(_) => {
warn!("Received unexpected auth message from already authenticated agent");
}
}
Ok(())
}
/// Hash an API key for storage/lookup
pub fn hash_api_key(api_key: &str) -> String {
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(api_key.as_bytes());
format!("{:x}", hasher.finalize())
}
/// Generate a new API key
pub fn generate_api_key(prefix: &str) -> String {
use rand::Rng;
let random_bytes: [u8; 24] = rand::thread_rng().gen();
let encoded = base64::Engine::encode(&base64::engine::general_purpose::URL_SAFE_NO_PAD, random_bytes);
format!("{}{}", prefix, encoded)
}
/// Check if a string looks like a site code (WORD-WORD-NUMBER format)
/// Examples: SWIFT-CLOUD-6910, APPLE-GREEN-9145
fn is_site_code_format(s: &str) -> bool {
let parts: Vec<&str> = s.split('-').collect();
if parts.len() != 3 {
return false;
}
// First two parts should be alphabetic (words)
// Third part should be numeric (4 digits)
parts[0].chars().all(|c| c.is_ascii_alphabetic())
&& parts[1].chars().all(|c| c.is_ascii_alphabetic())
&& parts[2].chars().all(|c| c.is_ascii_digit())
&& parts[2].len() == 4
}