SolverBot: - Inject active project path into agent system prompts so agents know which directory to scope file operations to GuruRMM: - Bump agent version to 0.6.0 - Add serde aliases for PowerShell/ClaudeTask command types - Add typed CommandType enum on server for proper serialization - Support claude_task command type in send_command API Dataforth: - Fix SCP space-escaping in Sync-FromNAS.ps1 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
786 lines
27 KiB
Rust
786 lines
27 KiB
Rust
//! WebSocket handler for agent connections
|
|
//!
|
|
//! Handles real-time communication with agents including:
|
|
//! - Authentication handshake
|
|
//! - Metrics ingestion
|
|
//! - Command dispatching
|
|
//! - Watchdog event handling
|
|
|
|
use std::collections::HashMap;
|
|
|
|
use axum::{
|
|
extract::{
|
|
ws::{Message, WebSocket},
|
|
State, WebSocketUpgrade,
|
|
},
|
|
response::Response,
|
|
};
|
|
use futures_util::{SinkExt, StreamExt};
|
|
use serde::{Deserialize, Serialize};
|
|
use tokio::sync::mpsc;
|
|
use tracing::{debug, error, info, warn};
|
|
use uuid::Uuid;
|
|
|
|
use crate::db;
|
|
use crate::AppState;
|
|
|
|
/// Connected agents manager
|
|
pub struct AgentConnections {
|
|
/// Map of agent ID to sender channel
|
|
connections: HashMap<Uuid, mpsc::Sender<ServerMessage>>,
|
|
}
|
|
|
|
impl AgentConnections {
|
|
pub fn new() -> Self {
|
|
Self {
|
|
connections: HashMap::new(),
|
|
}
|
|
}
|
|
|
|
/// Add a new agent connection
|
|
pub fn add(&mut self, agent_id: Uuid, tx: mpsc::Sender<ServerMessage>) {
|
|
self.connections.insert(agent_id, tx);
|
|
}
|
|
|
|
/// Remove an agent connection
|
|
pub fn remove(&mut self, agent_id: &Uuid) {
|
|
self.connections.remove(agent_id);
|
|
}
|
|
|
|
/// Send a message to a specific agent
|
|
pub async fn send_to(&self, agent_id: &Uuid, msg: ServerMessage) -> bool {
|
|
if let Some(tx) = self.connections.get(agent_id) {
|
|
tx.send(msg).await.is_ok()
|
|
} else {
|
|
false
|
|
}
|
|
}
|
|
|
|
/// Check if an agent is connected
|
|
pub fn is_connected(&self, agent_id: &Uuid) -> bool {
|
|
self.connections.contains_key(agent_id)
|
|
}
|
|
|
|
/// Get count of connected agents
|
|
pub fn count(&self) -> usize {
|
|
self.connections.len()
|
|
}
|
|
}
|
|
|
|
impl Default for AgentConnections {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
/// Messages from agent to server
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
#[serde(tag = "type", content = "payload")]
|
|
#[serde(rename_all = "snake_case")]
|
|
pub enum AgentMessage {
|
|
Auth(AuthPayload),
|
|
Metrics(MetricsPayload),
|
|
NetworkState(NetworkStatePayload),
|
|
CommandResult(CommandResultPayload),
|
|
WatchdogEvent(WatchdogEventPayload),
|
|
UpdateResult(UpdateResultPayload),
|
|
Heartbeat,
|
|
}
|
|
|
|
/// Messages from server to agent
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
#[serde(tag = "type", content = "payload")]
|
|
#[serde(rename_all = "snake_case")]
|
|
pub enum ServerMessage {
|
|
AuthAck(AuthAckPayload),
|
|
Command(CommandPayload),
|
|
ConfigUpdate(serde_json::Value),
|
|
Update(UpdatePayload),
|
|
Ack { message_id: Option<String> },
|
|
Error { code: String, message: String },
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct AuthPayload {
|
|
pub api_key: String,
|
|
/// Hardware-derived device ID (for site-based registration)
|
|
#[serde(default)]
|
|
pub device_id: Option<String>,
|
|
pub hostname: String,
|
|
pub os_type: String,
|
|
pub os_version: String,
|
|
pub agent_version: String,
|
|
/// Architecture (amd64, arm64, etc.)
|
|
#[serde(default = "default_arch")]
|
|
pub architecture: String,
|
|
/// Previous version if reconnecting after update
|
|
#[serde(default)]
|
|
pub previous_version: Option<String>,
|
|
/// Update ID if reconnecting after update
|
|
#[serde(default)]
|
|
pub pending_update_id: Option<Uuid>,
|
|
}
|
|
|
|
fn default_arch() -> String {
|
|
"amd64".to_string()
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct AuthAckPayload {
|
|
pub success: bool,
|
|
pub agent_id: Option<Uuid>,
|
|
pub error: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct MetricsPayload {
|
|
pub timestamp: chrono::DateTime<chrono::Utc>,
|
|
pub cpu_percent: f32,
|
|
pub memory_percent: f32,
|
|
pub memory_used_bytes: u64,
|
|
pub memory_total_bytes: u64,
|
|
pub disk_percent: f32,
|
|
pub disk_used_bytes: u64,
|
|
pub disk_total_bytes: u64,
|
|
pub network_rx_bytes: u64,
|
|
pub network_tx_bytes: u64,
|
|
pub os_type: String,
|
|
pub os_version: String,
|
|
pub hostname: String,
|
|
// Extended metrics (optional for backwards compatibility)
|
|
#[serde(default)]
|
|
pub uptime_seconds: Option<u64>,
|
|
#[serde(default)]
|
|
pub boot_time: Option<i64>,
|
|
#[serde(default)]
|
|
pub logged_in_user: Option<String>,
|
|
#[serde(default)]
|
|
pub user_idle_seconds: Option<u64>,
|
|
#[serde(default)]
|
|
pub public_ip: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct CommandResultPayload {
|
|
pub command_id: Uuid,
|
|
pub exit_code: i32,
|
|
pub stdout: String,
|
|
pub stderr: String,
|
|
pub duration_ms: u64,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct WatchdogEventPayload {
|
|
pub name: String,
|
|
pub event: String,
|
|
pub details: Option<String>,
|
|
}
|
|
|
|
/// Network interface information
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct NetworkInterface {
|
|
pub name: String,
|
|
pub mac_address: Option<String>,
|
|
pub ipv4_addresses: Vec<String>,
|
|
pub ipv6_addresses: Vec<String>,
|
|
}
|
|
|
|
/// Network state payload from agent
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct NetworkStatePayload {
|
|
pub timestamp: chrono::DateTime<chrono::Utc>,
|
|
pub interfaces: Vec<NetworkInterface>,
|
|
pub state_hash: String,
|
|
}
|
|
|
|
/// Types of commands that can be sent to agents.
|
|
/// Must match the agent's CommandType enum serialization format.
|
|
/// Uses snake_case to match the agent's #[serde(rename_all = "snake_case")].
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
#[serde(rename_all = "snake_case")]
|
|
pub enum CommandType {
|
|
/// Shell command (cmd on Windows, sh on Unix)
|
|
Shell,
|
|
/// PowerShell command (Windows)
|
|
PowerShell,
|
|
/// Python script
|
|
Python,
|
|
/// Raw script execution
|
|
Script,
|
|
/// Claude Code task execution
|
|
ClaudeTask {
|
|
/// Task description for Claude Code
|
|
task: String,
|
|
/// Optional working directory
|
|
working_directory: Option<String>,
|
|
/// Optional context files
|
|
context_files: Option<Vec<String>>,
|
|
},
|
|
}
|
|
|
|
impl CommandType {
|
|
/// Parse a command type string from the API into the enum.
|
|
/// Accepts both snake_case ("power_shell") and common formats ("powershell").
|
|
/// Note: ClaudeTask requires additional fields - use `new_claude_task()` instead.
|
|
pub fn from_api_string(s: &str) -> Result<Self, String> {
|
|
match s.to_lowercase().as_str() {
|
|
"shell" => Ok(Self::Shell),
|
|
"powershell" | "power_shell" => Ok(Self::PowerShell),
|
|
"python" => Ok(Self::Python),
|
|
"script" => Ok(Self::Script),
|
|
"claude_task" | "claudetask" => Err(
|
|
"claude_task type requires task field - use the claude_task-specific API fields".to_string()
|
|
),
|
|
_ => Err(format!("Unknown command type: '{}'. Valid types: shell, powershell, python, script, claude_task", s)),
|
|
}
|
|
}
|
|
|
|
/// Check if a command type string represents a claude_task.
|
|
pub fn is_claude_task(s: &str) -> bool {
|
|
matches!(s.to_lowercase().as_str(), "claude_task" | "claudetask")
|
|
}
|
|
|
|
/// Create a ClaudeTask command type with the required fields.
|
|
pub fn new_claude_task(task: String, working_directory: Option<String>, context_files: Option<Vec<String>>) -> Self {
|
|
Self::ClaudeTask { task, working_directory, context_files }
|
|
}
|
|
|
|
/// Convert back to the string format stored in the database.
|
|
pub fn as_db_string(&self) -> &'static str {
|
|
match self {
|
|
Self::Shell => "shell",
|
|
Self::PowerShell => "powershell",
|
|
Self::Python => "python",
|
|
Self::Script => "script",
|
|
Self::ClaudeTask { .. } => "claude_task",
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct CommandPayload {
|
|
pub id: Uuid,
|
|
pub command_type: CommandType,
|
|
pub command: String,
|
|
pub timeout_seconds: Option<u64>,
|
|
pub elevated: bool,
|
|
}
|
|
|
|
/// Update command payload from server to agent
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct UpdatePayload {
|
|
/// Unique update ID for tracking
|
|
pub update_id: Uuid,
|
|
/// Target version to update to
|
|
pub target_version: String,
|
|
/// Download URL for the new binary
|
|
pub download_url: String,
|
|
/// SHA256 checksum of the binary
|
|
pub checksum_sha256: String,
|
|
/// Whether to force update (skip version check)
|
|
#[serde(default)]
|
|
pub force: bool,
|
|
}
|
|
|
|
/// Update result payload from agent to server
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct UpdateResultPayload {
|
|
/// Update ID (from the server)
|
|
pub update_id: Uuid,
|
|
/// Update status
|
|
pub status: String,
|
|
/// Old version before update
|
|
pub old_version: String,
|
|
/// New version after update (if successful)
|
|
pub new_version: Option<String>,
|
|
/// Error message if failed
|
|
pub error: Option<String>,
|
|
}
|
|
|
|
/// Result of successful agent authentication
|
|
struct AuthResult {
|
|
agent_id: Uuid,
|
|
agent_version: String,
|
|
os_type: String,
|
|
architecture: String,
|
|
previous_version: Option<String>,
|
|
pending_update_id: Option<Uuid>,
|
|
}
|
|
|
|
/// WebSocket upgrade handler
|
|
pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<AppState>) -> Response {
|
|
ws.on_upgrade(|socket| handle_socket(socket, state))
|
|
}
|
|
|
|
/// Handle a WebSocket connection
|
|
async fn handle_socket(socket: WebSocket, state: AppState) {
|
|
let (mut sender, mut receiver) = socket.split();
|
|
|
|
// Create channel for outgoing protocol messages (ServerMessage)
|
|
let (tx, mut rx) = mpsc::channel::<ServerMessage>(100);
|
|
|
|
// Create separate channel for raw WebSocket frames (Pong responses)
|
|
// This allows proper WebSocket protocol compliance without changing the public API
|
|
let (pong_tx, mut pong_rx) = mpsc::channel::<Vec<u8>>(16);
|
|
|
|
// Wait for authentication message
|
|
let auth_result = match authenticate(&mut receiver, &mut sender, &state).await {
|
|
Ok(result) => {
|
|
info!("Agent authenticated: {}", result.agent_id);
|
|
|
|
// Send auth success
|
|
let ack = ServerMessage::AuthAck(AuthAckPayload {
|
|
success: true,
|
|
agent_id: Some(result.agent_id),
|
|
error: None,
|
|
});
|
|
|
|
if let Ok(json) = serde_json::to_string(&ack) {
|
|
let _ = sender.send(Message::Text(json)).await;
|
|
}
|
|
|
|
// Register connection
|
|
state.agents.write().await.add(result.agent_id, tx.clone());
|
|
|
|
// Update agent status
|
|
let _ = db::update_agent_status(&state.db, result.agent_id, "online").await;
|
|
|
|
// Check if this is a post-update reconnection
|
|
if let Some(prev_version) = &result.previous_version {
|
|
if prev_version != &result.agent_version {
|
|
info!(
|
|
"Agent {} reconnected after update: {} -> {}",
|
|
result.agent_id, prev_version, result.agent_version
|
|
);
|
|
// Mark update as completed
|
|
let _ = db::complete_update_by_agent(
|
|
&state.db,
|
|
result.agent_id,
|
|
result.pending_update_id,
|
|
prev_version,
|
|
&result.agent_version,
|
|
).await;
|
|
}
|
|
}
|
|
|
|
// Check if agent needs update (auto-update enabled)
|
|
if let Some(available) = state.updates.needs_update(
|
|
&result.agent_version,
|
|
&result.os_type,
|
|
&result.architecture,
|
|
).await {
|
|
info!(
|
|
"Agent {} needs update: {} -> {}",
|
|
result.agent_id, result.agent_version, available.version
|
|
);
|
|
|
|
let update_id = Uuid::new_v4();
|
|
|
|
// Record update in database
|
|
if let Err(e) = db::create_agent_update(
|
|
&state.db,
|
|
result.agent_id,
|
|
update_id,
|
|
&result.agent_version,
|
|
&available.version.to_string(),
|
|
&available.download_url,
|
|
&available.checksum_sha256,
|
|
).await {
|
|
error!("Failed to record update: {}", e);
|
|
} else {
|
|
// Send update command
|
|
let update_msg = ServerMessage::Update(UpdatePayload {
|
|
update_id,
|
|
target_version: available.version.to_string(),
|
|
download_url: available.download_url.clone(),
|
|
checksum_sha256: available.checksum_sha256.clone(),
|
|
force: false,
|
|
});
|
|
|
|
if let Err(e) = tx.send(update_msg).await {
|
|
error!("Failed to send update command: {}", e);
|
|
}
|
|
}
|
|
}
|
|
|
|
result
|
|
}
|
|
Err(e) => {
|
|
error!("Authentication failed: {}", e);
|
|
|
|
let ack = ServerMessage::AuthAck(AuthAckPayload {
|
|
success: false,
|
|
agent_id: None,
|
|
error: Some(e.to_string()),
|
|
});
|
|
|
|
if let Ok(json) = serde_json::to_string(&ack) {
|
|
let _ = sender.send(Message::Text(json)).await;
|
|
}
|
|
|
|
return;
|
|
}
|
|
};
|
|
|
|
let agent_id = auth_result.agent_id;
|
|
|
|
// Spawn task to forward outgoing messages
|
|
// Handles both protocol messages (ServerMessage) and raw Pong frames
|
|
let send_task = tokio::spawn(async move {
|
|
loop {
|
|
tokio::select! {
|
|
// Handle protocol messages (ServerMessage -> JSON text)
|
|
Some(msg) = rx.recv() => {
|
|
if let Ok(json) = serde_json::to_string(&msg) {
|
|
if sender.send(Message::Text(json)).await.is_err() {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
// Handle Pong responses (WebSocket protocol compliance)
|
|
Some(data) = pong_rx.recv() => {
|
|
if sender.send(Message::Pong(data)).await.is_err() {
|
|
break;
|
|
}
|
|
}
|
|
// Both channels closed
|
|
else => break,
|
|
}
|
|
}
|
|
});
|
|
|
|
// Handle incoming messages
|
|
while let Some(msg_result) = receiver.next().await {
|
|
match msg_result {
|
|
Ok(Message::Text(text)) => {
|
|
if let Err(e) = handle_agent_message(&text, agent_id, &state).await {
|
|
error!("Error handling agent message: {}", e);
|
|
}
|
|
}
|
|
Ok(Message::Ping(data)) => {
|
|
// WebSocket protocol requires Pong response with same payload
|
|
// Send via pong channel to the send task
|
|
if pong_tx.send(data).await.is_err() {
|
|
warn!("Failed to send Pong response for agent {}", agent_id);
|
|
break;
|
|
}
|
|
}
|
|
Ok(Message::Close(_)) => {
|
|
info!("Agent {} disconnected", agent_id);
|
|
break;
|
|
}
|
|
Err(e) => {
|
|
error!("WebSocket error for agent {}: {}", agent_id, e);
|
|
break;
|
|
}
|
|
_ => {}
|
|
}
|
|
}
|
|
|
|
// Cleanup
|
|
state.agents.write().await.remove(&agent_id);
|
|
let _ = db::update_agent_status(&state.db, agent_id, "offline").await;
|
|
send_task.abort();
|
|
|
|
info!("Agent {} connection closed", agent_id);
|
|
}
|
|
|
|
/// Authenticate an agent connection
|
|
///
|
|
/// Supports two modes:
|
|
/// 1. Legacy: API key maps directly to an agent (api_key_hash in agents table)
|
|
/// 2. Site-based: API key maps to a site, device_id identifies the specific agent
|
|
async fn authenticate(
|
|
receiver: &mut futures_util::stream::SplitStream<WebSocket>,
|
|
sender: &mut futures_util::stream::SplitSink<WebSocket, Message>,
|
|
state: &AppState,
|
|
) -> anyhow::Result<AuthResult> {
|
|
use tokio::time::{timeout, Duration};
|
|
|
|
// Wait for auth message with timeout
|
|
let msg = timeout(Duration::from_secs(10), receiver.next())
|
|
.await
|
|
.map_err(|_| anyhow::anyhow!("Authentication timeout"))?
|
|
.ok_or_else(|| anyhow::anyhow!("Connection closed before auth"))?
|
|
.map_err(|e| anyhow::anyhow!("WebSocket error: {}", e))?;
|
|
|
|
let text = match msg {
|
|
Message::Text(t) => t,
|
|
_ => return Err(anyhow::anyhow!("Expected text message for auth")),
|
|
};
|
|
|
|
let agent_msg: AgentMessage =
|
|
serde_json::from_str(&text).map_err(|e| anyhow::anyhow!("Invalid auth message: {}", e))?;
|
|
|
|
let auth = match agent_msg {
|
|
AgentMessage::Auth(a) => a,
|
|
_ => return Err(anyhow::anyhow!("Expected auth message")),
|
|
};
|
|
|
|
// Try site-based authentication first (if device_id is provided)
|
|
if let Some(device_id) = &auth.device_id {
|
|
// Check if api_key looks like a site code (WORD-WORD-NUMBER format)
|
|
let site = if is_site_code_format(&auth.api_key) {
|
|
info!("Attempting site code authentication: {}", auth.api_key);
|
|
db::get_site_by_code(&state.db, &auth.api_key)
|
|
.await
|
|
.map_err(|e| anyhow::anyhow!("Database error: {}", e))?
|
|
} else {
|
|
// Hash the API key and look up by hash
|
|
let api_key_hash = hash_api_key(&auth.api_key);
|
|
db::get_site_by_api_key_hash(&state.db, &api_key_hash)
|
|
.await
|
|
.map_err(|e| anyhow::anyhow!("Database error: {}", e))?
|
|
};
|
|
|
|
if let Some(site) = site {
|
|
info!("Site-based auth: site={} ({})", site.name, site.id);
|
|
|
|
// Look up or create agent by site_id + device_id
|
|
let agent = match db::get_agent_by_site_and_device(&state.db, site.id, device_id)
|
|
.await
|
|
.map_err(|e| anyhow::anyhow!("Database error: {}", e))?
|
|
{
|
|
Some(agent) => {
|
|
// Update existing agent info
|
|
let _ = db::update_agent_info_full(
|
|
&state.db,
|
|
agent.id,
|
|
Some(&auth.hostname),
|
|
Some(device_id),
|
|
Some(&auth.os_version),
|
|
Some(&auth.agent_version),
|
|
)
|
|
.await;
|
|
agent
|
|
}
|
|
None => {
|
|
// Auto-register new agent under this site
|
|
info!(
|
|
"Auto-registering new agent: hostname={}, device_id={}, site={}",
|
|
auth.hostname, device_id, site.name
|
|
);
|
|
|
|
db::create_agent_with_site(
|
|
&state.db,
|
|
db::CreateAgentWithSite {
|
|
site_id: site.id,
|
|
device_id: device_id.clone(),
|
|
hostname: auth.hostname.clone(),
|
|
os_type: auth.os_type.clone(),
|
|
os_version: Some(auth.os_version.clone()),
|
|
agent_version: Some(auth.agent_version.clone()),
|
|
},
|
|
)
|
|
.await
|
|
.map_err(|e| anyhow::anyhow!("Failed to create agent: {}", e))?
|
|
}
|
|
};
|
|
|
|
return Ok(AuthResult {
|
|
agent_id: agent.id,
|
|
agent_version: auth.agent_version.clone(),
|
|
os_type: auth.os_type.clone(),
|
|
architecture: auth.architecture.clone(),
|
|
previous_version: auth.previous_version.clone(),
|
|
pending_update_id: auth.pending_update_id,
|
|
});
|
|
}
|
|
}
|
|
|
|
// Fall back to legacy: look up agent directly by API key hash
|
|
let api_key_hash = hash_api_key(&auth.api_key);
|
|
let agent = db::get_agent_by_api_key_hash(&state.db, &api_key_hash)
|
|
.await
|
|
.map_err(|e| anyhow::anyhow!("Database error: {}", e))?
|
|
.ok_or_else(|| anyhow::anyhow!("Invalid API key"))?;
|
|
|
|
// Update agent info (including hostname in case it changed)
|
|
let _ = db::update_agent_info(
|
|
&state.db,
|
|
agent.id,
|
|
Some(&auth.hostname),
|
|
Some(&auth.os_version),
|
|
Some(&auth.agent_version),
|
|
)
|
|
.await;
|
|
|
|
Ok(AuthResult {
|
|
agent_id: agent.id,
|
|
agent_version: auth.agent_version,
|
|
os_type: auth.os_type,
|
|
architecture: auth.architecture,
|
|
previous_version: auth.previous_version,
|
|
pending_update_id: auth.pending_update_id,
|
|
})
|
|
}
|
|
|
|
/// Handle a message from an authenticated agent
|
|
async fn handle_agent_message(
|
|
text: &str,
|
|
agent_id: Uuid,
|
|
state: &AppState,
|
|
) -> anyhow::Result<()> {
|
|
let msg: AgentMessage = serde_json::from_str(text)?;
|
|
|
|
match msg {
|
|
AgentMessage::Metrics(metrics) => {
|
|
debug!("Received metrics from agent {}: CPU={:.1}%", agent_id, metrics.cpu_percent);
|
|
|
|
// Store metrics in database
|
|
let create_metrics = db::CreateMetrics {
|
|
agent_id,
|
|
cpu_percent: Some(metrics.cpu_percent),
|
|
memory_percent: Some(metrics.memory_percent),
|
|
memory_used_bytes: Some(metrics.memory_used_bytes as i64),
|
|
disk_percent: Some(metrics.disk_percent),
|
|
disk_used_bytes: Some(metrics.disk_used_bytes as i64),
|
|
network_rx_bytes: Some(metrics.network_rx_bytes as i64),
|
|
network_tx_bytes: Some(metrics.network_tx_bytes as i64),
|
|
// Extended metrics
|
|
uptime_seconds: metrics.uptime_seconds.map(|v| v as i64),
|
|
boot_time: metrics.boot_time,
|
|
logged_in_user: metrics.logged_in_user.clone(),
|
|
user_idle_seconds: metrics.user_idle_seconds.map(|v| v as i64),
|
|
public_ip: metrics.public_ip.clone(),
|
|
memory_total_bytes: Some(metrics.memory_total_bytes as i64),
|
|
disk_total_bytes: Some(metrics.disk_total_bytes as i64),
|
|
};
|
|
|
|
db::insert_metrics(&state.db, create_metrics).await?;
|
|
|
|
// Also update agent_state for quick access to latest extended info
|
|
let _ = db::upsert_agent_state(
|
|
&state.db,
|
|
agent_id,
|
|
metrics.uptime_seconds.map(|v| v as i64),
|
|
metrics.boot_time,
|
|
metrics.logged_in_user.as_deref(),
|
|
metrics.user_idle_seconds.map(|v| v as i64),
|
|
metrics.public_ip.as_deref(),
|
|
).await;
|
|
|
|
// Update last_seen
|
|
db::update_agent_status(&state.db, agent_id, "online").await?;
|
|
}
|
|
|
|
AgentMessage::CommandResult(result) => {
|
|
info!(
|
|
"Received command result from agent {}: command={}, exit={}",
|
|
agent_id, result.command_id, result.exit_code
|
|
);
|
|
|
|
// Update command in database
|
|
let cmd_result = db::CommandResult {
|
|
exit_code: result.exit_code,
|
|
stdout: result.stdout,
|
|
stderr: result.stderr,
|
|
};
|
|
|
|
db::update_command_result(&state.db, result.command_id, cmd_result).await?;
|
|
}
|
|
|
|
AgentMessage::WatchdogEvent(event) => {
|
|
info!(
|
|
"Received watchdog event from agent {}: {} - {}",
|
|
agent_id, event.name, event.event
|
|
);
|
|
|
|
// Store watchdog event (table exists but we'll add the insert function later)
|
|
// For now, just log it
|
|
}
|
|
|
|
AgentMessage::NetworkState(network_state) => {
|
|
debug!(
|
|
"Received network state from agent {}: {} interfaces",
|
|
agent_id,
|
|
network_state.interfaces.len()
|
|
);
|
|
// Log interface details at trace level
|
|
for iface in &network_state.interfaces {
|
|
tracing::trace!(
|
|
" Interface {}: IPv4={:?}, IPv6={:?}",
|
|
iface.name,
|
|
iface.ipv4_addresses,
|
|
iface.ipv6_addresses
|
|
);
|
|
}
|
|
// Store network state in database
|
|
if let Ok(interfaces_json) = serde_json::to_value(&network_state.interfaces) {
|
|
let _ = db::update_agent_network_state(
|
|
&state.db,
|
|
agent_id,
|
|
&interfaces_json,
|
|
&network_state.state_hash,
|
|
).await;
|
|
}
|
|
// Update last_seen
|
|
db::update_agent_status(&state.db, agent_id, "online").await?;
|
|
}
|
|
|
|
AgentMessage::Heartbeat => {
|
|
debug!("Received heartbeat from agent {}", agent_id);
|
|
db::update_agent_status(&state.db, agent_id, "online").await?;
|
|
}
|
|
|
|
AgentMessage::UpdateResult(result) => {
|
|
info!(
|
|
"Received update result from agent {}: update_id={}, status={}",
|
|
agent_id, result.update_id, result.status
|
|
);
|
|
|
|
// Update the agent_updates record
|
|
match result.status.as_str() {
|
|
"completed" => {
|
|
let _ = db::complete_agent_update(&state.db, result.update_id, result.new_version.as_deref()).await;
|
|
info!("Agent {} successfully updated to {}", agent_id, result.new_version.unwrap_or_default());
|
|
}
|
|
"failed" | "rolled_back" => {
|
|
let _ = db::fail_agent_update(&state.db, result.update_id, result.error.as_deref()).await;
|
|
warn!("Agent {} update failed: {}", agent_id, result.error.unwrap_or_default());
|
|
}
|
|
_ => {
|
|
debug!("Agent {} update status: {}", agent_id, result.status);
|
|
}
|
|
}
|
|
}
|
|
|
|
AgentMessage::Auth(_) => {
|
|
warn!("Received unexpected auth message from already authenticated agent");
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Hash an API key for storage/lookup
|
|
pub fn hash_api_key(api_key: &str) -> String {
|
|
use sha2::{Digest, Sha256};
|
|
let mut hasher = Sha256::new();
|
|
hasher.update(api_key.as_bytes());
|
|
format!("{:x}", hasher.finalize())
|
|
}
|
|
|
|
/// Generate a new API key
|
|
pub fn generate_api_key(prefix: &str) -> String {
|
|
use rand::Rng;
|
|
let random_bytes: [u8; 24] = rand::thread_rng().gen();
|
|
let encoded = base64::Engine::encode(&base64::engine::general_purpose::URL_SAFE_NO_PAD, random_bytes);
|
|
format!("{}{}", prefix, encoded)
|
|
}
|
|
|
|
/// Check if a string looks like a site code (WORD-WORD-NUMBER format)
|
|
/// Examples: SWIFT-CLOUD-6910, APPLE-GREEN-9145
|
|
fn is_site_code_format(s: &str) -> bool {
|
|
let parts: Vec<&str> = s.split('-').collect();
|
|
if parts.len() != 3 {
|
|
return false;
|
|
}
|
|
// First two parts should be alphabetic (words)
|
|
// Third part should be numeric (4 digits)
|
|
parts[0].chars().all(|c| c.is_ascii_alphabetic())
|
|
&& parts[1].chars().all(|c| c.is_ascii_alphabetic())
|
|
&& parts[2].chars().all(|c| c.is_ascii_digit())
|
|
&& parts[2].len() == 4
|
|
}
|