Add VPN configuration tools and agent documentation

Created comprehensive VPN setup tooling for Peaceful Spirit L2TP/IPsec connection
and enhanced agent documentation framework.

VPN Configuration (PST-NW-VPN):
- Setup-PST-L2TP-VPN.ps1: Automated L2TP/IPsec setup with split-tunnel and DNS
- Connect-PST-VPN.ps1: Connection helper with PPP adapter detection, DNS (192.168.0.2), and route config (192.168.0.0/24)
- Connect-PST-VPN-Standalone.ps1: Self-contained connection script for remote deployment
- Fix-PST-VPN-Auth.ps1: Authentication troubleshooting for CHAP/MSChapv2
- Diagnose-VPN-Interface.ps1: Comprehensive VPN interface and routing diagnostic
- Quick-Test-VPN.ps1: Fast connectivity verification (DNS/router/routes)
- Add-PST-VPN-Route-Manual.ps1: Manual route configuration helper
- vpn-connect.bat, vpn-disconnect.bat: Simple batch file shortcuts
- OpenVPN config files (Windows-compatible, abandoned for L2TP)

Key VPN Implementation Details:
- L2TP creates PPP adapter with connection name as interface description
- UniFi auto-configures DNS (192.168.0.2) but requires manual route to 192.168.0.0/24
- Split-tunnel enabled (only remote traffic through VPN)
- All-user connection for pre-login auto-connect via scheduled task
- Authentication: CHAP + MSChapv2 for UniFi compatibility

Agent Documentation:
- AGENT_QUICK_REFERENCE.md: Quick reference for all specialized agents
- documentation-squire.md: Documentation and task management specialist agent
- Updated all agent markdown files with standardized formatting

Project Organization:
- Moved conversation logs to dedicated directories (guru-connect-conversation-logs, guru-rmm-conversation-logs)
- Cleaned up old session JSONL files from projects/msp-tools/
- Added guru-connect infrastructure (agent, dashboard, proto, scripts, .gitea workflows)
- Added guru-rmm server components and deployment configs

Technical Notes:
- VPN IP pool: 192.168.4.x (client gets 192.168.4.6)
- Remote network: 192.168.0.0/24 (router at 192.168.0.10)
- PSK: rrClvnmUeXEFo90Ol+z7tfsAZHeSK6w7
- Credentials: pst-admin / 24Hearts$

Files: 15 VPN scripts, 2 agent docs, conversation log reorganization,
guru-connect/guru-rmm infrastructure additions

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-18 11:51:47 -07:00
parent b0a68d89bf
commit 6c316aa701
272 changed files with 37068 additions and 2 deletions

View File

@@ -0,0 +1,75 @@
[package]
name = "gururmm-server"
version = "0.2.0"
edition = "2021"
description = "GuruRMM Server - RMM management server"
authors = ["GuruRMM"]
[dependencies]
# Web framework
axum = { version = "0.7", features = ["ws", "macros"] }
axum-extra = { version = "0.9", features = ["typed-header"] }
tower = { version = "0.5", features = ["util", "timeout"] }
tower-http = { version = "0.6", features = ["cors", "trace", "compression-gzip"] }
# Async runtime
tokio = { version = "1", features = ["full"] }
# Database
sqlx = { version = "0.8", features = [
"runtime-tokio",
"tls-native-tls",
"postgres",
"uuid",
"chrono",
"migrate"
] }
# Serialization
serde = { version = "1", features = ["derive"] }
serde_json = "1"
# Configuration
config = "0.14"
# Authentication
jsonwebtoken = "9"
argon2 = "0.5"
# UUID
uuid = { version = "1", features = ["v4", "serde"] }
# Time
chrono = { version = "0.4", features = ["serde"] }
# Logging
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
# Error handling
anyhow = "1"
thiserror = "1"
# Random for API key generation
rand = "0.8"
base64 = "0.22"
# Hashing for API keys
sha2 = "0.10"
# Semantic versioning for agent updates
semver = "1"
# Environment variables
dotenvy = "0.15"
# Futures for WebSocket
futures-util = "0.3"
# Pin transitive dependencies to stable versions
home = "0.5.9" # 0.5.12 requires Rust 1.88
[profile.release]
opt-level = 3
lto = true
codegen-units = 1

View File

@@ -0,0 +1,69 @@
# GuruRMM Server Dockerfile
# Multi-stage build for minimal image size
# ============================================
# Build Stage
# ============================================
FROM rust:1.85-alpine AS builder
# Install build dependencies
RUN apk add --no-cache musl-dev openssl-dev openssl-libs-static pkgconfig
# Create app directory
WORKDIR /app
# Copy manifests first for better caching
COPY Cargo.toml Cargo.lock* ./
# Create dummy src to build dependencies
RUN mkdir src && echo "fn main() {}" > src/main.rs
# Pin home crate to version compatible with Rust 1.85 (0.5.12 requires Rust 1.88)
RUN cargo update home --precise 0.5.9
# Build dependencies only (this layer will be cached)
RUN cargo build --release && rm -rf src target/release/deps/gururmm*
# Copy actual source code
COPY src ./src
COPY migrations ./migrations
# Build the actual application
RUN cargo build --release
# ============================================
# Runtime Stage
# ============================================
FROM alpine:3.19
# Install runtime dependencies
RUN apk add --no-cache ca-certificates libgcc
# Create non-root user
RUN addgroup -g 1000 gururmm && \
adduser -u 1000 -G gururmm -s /bin/sh -D gururmm
# Create app directory
WORKDIR /app
# Copy binary from builder
COPY --from=builder /app/target/release/gururmm-server /app/gururmm-server
# Copy migrations (for runtime migrations)
COPY --from=builder /app/migrations /app/migrations
# Set ownership
RUN chown -R gururmm:gururmm /app
# Switch to non-root user
USER gururmm
# Expose port
EXPOSE 3001
# Health check (use 127.0.0.1 instead of localhost to avoid IPv6 issues)
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
CMD wget --no-verbose --tries=1 --spider http://127.0.0.1:3001/health || exit 1
# Run the server
CMD ["/app/gururmm-server"]

View File

@@ -0,0 +1,122 @@
-- GuruRMM Initial Schema
-- Creates tables for agents, metrics, commands, watchdog events, and users
-- Enable UUID extension
CREATE EXTENSION IF NOT EXISTS "pgcrypto";
-- Agents table
-- Stores registered agents and their current status
CREATE TABLE agents (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
hostname VARCHAR(255) NOT NULL,
api_key_hash VARCHAR(255) NOT NULL,
os_type VARCHAR(50) NOT NULL,
os_version VARCHAR(100),
agent_version VARCHAR(50),
last_seen TIMESTAMPTZ,
status VARCHAR(20) DEFAULT 'offline',
created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW()
);
-- Index for looking up agents by hostname
CREATE INDEX idx_agents_hostname ON agents(hostname);
-- Index for finding online agents
CREATE INDEX idx_agents_status ON agents(status);
-- Metrics table
-- Time-series data for system metrics from agents
CREATE TABLE metrics (
id BIGSERIAL PRIMARY KEY,
agent_id UUID NOT NULL REFERENCES agents(id) ON DELETE CASCADE,
timestamp TIMESTAMPTZ DEFAULT NOW(),
cpu_percent REAL,
memory_percent REAL,
memory_used_bytes BIGINT,
disk_percent REAL,
disk_used_bytes BIGINT,
network_rx_bytes BIGINT,
network_tx_bytes BIGINT
);
-- Index for querying metrics by agent and time
CREATE INDEX idx_metrics_agent_time ON metrics(agent_id, timestamp DESC);
-- Index for finding recent metrics
CREATE INDEX idx_metrics_timestamp ON metrics(timestamp DESC);
-- Users table
-- Dashboard users for authentication
CREATE TABLE users (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
email VARCHAR(255) UNIQUE NOT NULL,
password_hash VARCHAR(255),
name VARCHAR(255),
role VARCHAR(50) DEFAULT 'user',
sso_provider VARCHAR(50),
sso_id VARCHAR(255),
created_at TIMESTAMPTZ DEFAULT NOW(),
last_login TIMESTAMPTZ
);
-- Index for email lookups during login
CREATE INDEX idx_users_email ON users(email);
-- Index for SSO lookups
CREATE INDEX idx_users_sso ON users(sso_provider, sso_id);
-- Commands table
-- Commands sent to agents and their results
CREATE TABLE commands (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
agent_id UUID NOT NULL REFERENCES agents(id) ON DELETE CASCADE,
command_type VARCHAR(50) NOT NULL,
command_text TEXT NOT NULL,
status VARCHAR(20) DEFAULT 'pending',
exit_code INTEGER,
stdout TEXT,
stderr TEXT,
created_at TIMESTAMPTZ DEFAULT NOW(),
started_at TIMESTAMPTZ,
completed_at TIMESTAMPTZ,
created_by UUID REFERENCES users(id) ON DELETE SET NULL
);
-- Index for finding pending commands for an agent
CREATE INDEX idx_commands_agent_status ON commands(agent_id, status);
-- Index for command history queries
CREATE INDEX idx_commands_created ON commands(created_at DESC);
-- Watchdog events table
-- Events from agent watchdog monitoring
CREATE TABLE watchdog_events (
id BIGSERIAL PRIMARY KEY,
agent_id UUID NOT NULL REFERENCES agents(id) ON DELETE CASCADE,
timestamp TIMESTAMPTZ DEFAULT NOW(),
service_name VARCHAR(255) NOT NULL,
event_type VARCHAR(50) NOT NULL,
details TEXT
);
-- Index for querying events by agent and time
CREATE INDEX idx_watchdog_agent_time ON watchdog_events(agent_id, timestamp DESC);
-- Index for finding recent events
CREATE INDEX idx_watchdog_timestamp ON watchdog_events(timestamp DESC);
-- Function to update updated_at timestamp
CREATE OR REPLACE FUNCTION update_updated_at_column()
RETURNS TRIGGER AS $$
BEGIN
NEW.updated_at = NOW();
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
-- Trigger for agents table
CREATE TRIGGER update_agents_updated_at
BEFORE UPDATE ON agents
FOR EACH ROW
EXECUTE FUNCTION update_updated_at_column();

View File

@@ -0,0 +1,100 @@
-- GuruRMM Clients and Sites Schema
-- Adds multi-tenant support with clients, sites, and site-based agent registration
-- Clients table (organizations/companies)
CREATE TABLE clients (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
name VARCHAR(255) NOT NULL,
code VARCHAR(50) UNIQUE, -- Optional short code like "ACME"
notes TEXT,
is_active BOOLEAN DEFAULT TRUE,
created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW()
);
CREATE INDEX idx_clients_name ON clients(name);
CREATE INDEX idx_clients_code ON clients(code);
-- Trigger for clients updated_at
CREATE TRIGGER update_clients_updated_at
BEFORE UPDATE ON clients
FOR EACH ROW
EXECUTE FUNCTION update_updated_at_column();
-- Sites table (locations under a client)
CREATE TABLE sites (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
client_id UUID NOT NULL REFERENCES clients(id) ON DELETE CASCADE,
name VARCHAR(255) NOT NULL,
-- Site code: human-friendly, used for agent registration (e.g., "BLUE-TIGER-4829")
site_code VARCHAR(50) UNIQUE NOT NULL,
-- API key hash for this site (all agents at site share this key)
api_key_hash VARCHAR(255) NOT NULL,
address TEXT,
notes TEXT,
is_active BOOLEAN DEFAULT TRUE,
created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW()
);
CREATE INDEX idx_sites_client ON sites(client_id);
CREATE INDEX idx_sites_code ON sites(site_code);
CREATE INDEX idx_sites_api_key ON sites(api_key_hash);
-- Trigger for sites updated_at
CREATE TRIGGER update_sites_updated_at
BEFORE UPDATE ON sites
FOR EACH ROW
EXECUTE FUNCTION update_updated_at_column();
-- Add new columns to agents table
-- device_id: unique hardware-derived identifier for the machine
ALTER TABLE agents ADD COLUMN device_id VARCHAR(255);
-- site_id: which site this agent belongs to (nullable for legacy agents)
ALTER TABLE agents ADD COLUMN site_id UUID REFERENCES sites(id) ON DELETE SET NULL;
-- Make api_key_hash nullable (new agents will use site's api_key)
ALTER TABLE agents ALTER COLUMN api_key_hash DROP NOT NULL;
-- Index for looking up agents by device_id within a site
CREATE UNIQUE INDEX idx_agents_site_device ON agents(site_id, device_id) WHERE site_id IS NOT NULL AND device_id IS NOT NULL;
-- Index for site lookups
CREATE INDEX idx_agents_site ON agents(site_id);
-- Registration tokens table (optional: for secure site code distribution)
CREATE TABLE registration_tokens (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
site_id UUID NOT NULL REFERENCES sites(id) ON DELETE CASCADE,
token_hash VARCHAR(255) NOT NULL,
description VARCHAR(255),
uses_remaining INTEGER, -- NULL = unlimited
expires_at TIMESTAMPTZ,
created_at TIMESTAMPTZ DEFAULT NOW(),
created_by UUID REFERENCES users(id) ON DELETE SET NULL
);
CREATE INDEX idx_reg_tokens_site ON registration_tokens(site_id);
CREATE INDEX idx_reg_tokens_hash ON registration_tokens(token_hash);
-- Function to generate a random site code (WORD-WORD-####)
-- This is just a helper; actual generation should be in application code
-- for better word lists
CREATE OR REPLACE FUNCTION generate_site_code() RETURNS VARCHAR(50) AS $$
DECLARE
words TEXT[] := ARRAY['ALPHA', 'BETA', 'GAMMA', 'DELTA', 'ECHO', 'FOXTROT',
'BLUE', 'GREEN', 'RED', 'GOLD', 'SILVER', 'IRON',
'HAWK', 'EAGLE', 'TIGER', 'LION', 'WOLF', 'BEAR',
'NORTH', 'SOUTH', 'EAST', 'WEST', 'PEAK', 'VALLEY',
'RIVER', 'OCEAN', 'STORM', 'CLOUD', 'STAR', 'MOON'];
word1 TEXT;
word2 TEXT;
num INTEGER;
BEGIN
word1 := words[1 + floor(random() * array_length(words, 1))::int];
word2 := words[1 + floor(random() * array_length(words, 1))::int];
num := 1000 + floor(random() * 9000)::int;
RETURN word1 || '-' || word2 || '-' || num::text;
END;
$$ LANGUAGE plpgsql;

View File

@@ -0,0 +1,34 @@
-- Extended metrics and agent state
-- Adds columns for uptime, user info, IPs, and network state storage
-- Add extended columns to metrics table
ALTER TABLE metrics ADD COLUMN IF NOT EXISTS uptime_seconds BIGINT;
ALTER TABLE metrics ADD COLUMN IF NOT EXISTS boot_time BIGINT;
ALTER TABLE metrics ADD COLUMN IF NOT EXISTS logged_in_user VARCHAR(255);
ALTER TABLE metrics ADD COLUMN IF NOT EXISTS user_idle_seconds BIGINT;
ALTER TABLE metrics ADD COLUMN IF NOT EXISTS public_ip VARCHAR(45); -- Supports IPv6
-- Agent state table for current/latest agent information
-- This stores the latest snapshot of extended agent info (not time-series)
CREATE TABLE IF NOT EXISTS agent_state (
agent_id UUID PRIMARY KEY REFERENCES agents(id) ON DELETE CASCADE,
-- Network state
network_interfaces JSONB,
network_state_hash VARCHAR(32),
-- Latest extended metrics (cached for quick access)
uptime_seconds BIGINT,
boot_time BIGINT,
logged_in_user VARCHAR(255),
user_idle_seconds BIGINT,
public_ip VARCHAR(45),
-- Timestamps
network_updated_at TIMESTAMPTZ,
metrics_updated_at TIMESTAMPTZ DEFAULT NOW()
);
-- Index for finding agents by public IP (useful for diagnostics)
CREATE INDEX IF NOT EXISTS idx_agent_state_public_ip ON agent_state(public_ip);
-- Add memory_total_bytes and disk_total_bytes to metrics for completeness
ALTER TABLE metrics ADD COLUMN IF NOT EXISTS memory_total_bytes BIGINT;
ALTER TABLE metrics ADD COLUMN IF NOT EXISTS disk_total_bytes BIGINT;

View File

@@ -0,0 +1,30 @@
-- Agent update tracking
-- Tracks update commands sent to agents and their results
CREATE TABLE agent_updates (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
agent_id UUID NOT NULL REFERENCES agents(id) ON DELETE CASCADE,
update_id UUID NOT NULL UNIQUE,
old_version VARCHAR(50) NOT NULL,
target_version VARCHAR(50) NOT NULL,
status VARCHAR(20) DEFAULT 'pending', -- pending, downloading, installing, completed, failed, rolled_back
download_url TEXT,
checksum_sha256 VARCHAR(64),
started_at TIMESTAMPTZ DEFAULT NOW(),
completed_at TIMESTAMPTZ,
error_message TEXT,
created_at TIMESTAMPTZ DEFAULT NOW()
);
-- Index for finding updates by agent
CREATE INDEX idx_agent_updates_agent ON agent_updates(agent_id);
-- Index for finding updates by status (for monitoring)
CREATE INDEX idx_agent_updates_status ON agent_updates(status);
-- Index for finding pending/in-progress updates (for timeout detection)
CREATE INDEX idx_agent_updates_pending ON agent_updates(agent_id, status)
WHERE status IN ('pending', 'downloading', 'installing');
-- Add architecture column to agents table for proper binary matching
ALTER TABLE agents ADD COLUMN IF NOT EXISTS architecture VARCHAR(20) DEFAULT 'amd64';

View File

@@ -0,0 +1,327 @@
//! Agent management API endpoints
use axum::{
extract::{Path, State},
http::StatusCode,
Json,
};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::db::{self, AgentResponse, AgentStats};
use crate::ws::{generate_api_key, hash_api_key};
use crate::AppState;
/// Response for agent registration
#[derive(Debug, Serialize)]
pub struct RegisterAgentResponse {
pub agent_id: Uuid,
pub api_key: String,
pub message: String,
}
/// Request to register a new agent
#[derive(Debug, Deserialize)]
pub struct RegisterAgentRequest {
pub hostname: String,
pub os_type: String,
pub os_version: Option<String>,
}
/// Register a new agent (generates API key)
pub async fn register_agent(
State(state): State<AppState>,
Json(req): Json<RegisterAgentRequest>,
) -> Result<Json<RegisterAgentResponse>, (StatusCode, String)> {
// Generate a new API key
let api_key = generate_api_key(&state.config.auth.api_key_prefix);
let api_key_hash = hash_api_key(&api_key);
// Create the agent
let create = db::CreateAgent {
hostname: req.hostname,
api_key_hash,
os_type: req.os_type,
os_version: req.os_version,
agent_version: None,
};
let agent = db::create_agent(&state.db, create)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
Ok(Json(RegisterAgentResponse {
agent_id: agent.id,
api_key, // Return the plain API key (only shown once!)
message: "Agent registered successfully. Save the API key - it will not be shown again."
.to_string(),
}))
}
/// List all agents
pub async fn list_agents(
State(state): State<AppState>,
) -> Result<Json<Vec<AgentResponse>>, (StatusCode, String)> {
let agents = db::get_all_agents(&state.db)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
let responses: Vec<AgentResponse> = agents.into_iter().map(|a| a.into()).collect();
Ok(Json(responses))
}
/// Get a specific agent
pub async fn get_agent(
State(state): State<AppState>,
Path(id): Path<Uuid>,
) -> Result<Json<AgentResponse>, (StatusCode, String)> {
let agent = db::get_agent_by_id(&state.db, id)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
.ok_or((StatusCode::NOT_FOUND, "Agent not found".to_string()))?;
Ok(Json(agent.into()))
}
/// Delete an agent
pub async fn delete_agent(
State(state): State<AppState>,
Path(id): Path<Uuid>,
) -> Result<StatusCode, (StatusCode, String)> {
// Check if agent is connected and disconnect it
if state.agents.read().await.is_connected(&id) {
// In a real implementation, we'd send a disconnect message
state.agents.write().await.remove(&id);
}
let deleted = db::delete_agent(&state.db, id)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
if deleted {
Ok(StatusCode::NO_CONTENT)
} else {
Err((StatusCode::NOT_FOUND, "Agent not found".to_string()))
}
}
/// Get agent statistics
pub async fn get_stats(
State(state): State<AppState>,
) -> Result<Json<AgentStats>, (StatusCode, String)> {
let stats = db::get_agent_stats(&state.db)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
Ok(Json(stats))
}
/// Request to move an agent to a different site
#[derive(Debug, Deserialize)]
pub struct MoveAgentRequest {
pub site_id: Option<Uuid>, // None to unassign from site
}
/// Move an agent to a different site
pub async fn move_agent(
State(state): State<AppState>,
Path(id): Path<Uuid>,
Json(req): Json<MoveAgentRequest>,
) -> Result<Json<AgentResponse>, (StatusCode, String)> {
// Verify the site exists if provided
if let Some(site_id) = req.site_id {
let site = db::get_site_by_id(&state.db, site_id)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
if site.is_none() {
return Err((StatusCode::NOT_FOUND, "Site not found".to_string()));
}
}
// Move the agent
let agent = db::move_agent_to_site(&state.db, id, req.site_id)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
.ok_or((StatusCode::NOT_FOUND, "Agent not found".to_string()))?;
Ok(Json(agent.into()))
}
/// List all agents with full details (site/client info)
pub async fn list_agents_with_details(
State(state): State<AppState>,
) -> Result<Json<Vec<db::AgentWithDetails>>, (StatusCode, String)> {
let agents = db::get_all_agents_with_details(&state.db)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
Ok(Json(agents))
}
/// List unassigned agents (not belonging to any site)
pub async fn list_unassigned_agents(
State(state): State<AppState>,
) -> Result<Json<Vec<AgentResponse>>, (StatusCode, String)> {
let agents = db::get_unassigned_agents(&state.db)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
let responses: Vec<AgentResponse> = agents.into_iter().map(|a| a.into()).collect();
Ok(Json(responses))
}
/// Get extended state for an agent (network interfaces, uptime, etc.)
pub async fn get_agent_state(
State(state): State<AppState>,
Path(id): Path<Uuid>,
) -> Result<Json<db::AgentState>, (StatusCode, String)> {
let agent_state = db::get_agent_state(&state.db, id)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
.ok_or((StatusCode::NOT_FOUND, "Agent state not found".to_string()))?;
Ok(Json(agent_state))
}
// ============================================================================
// Legacy Agent Endpoints (PowerShell agent for 2008 R2)
// ============================================================================
/// Request to register a legacy agent with site code
#[derive(Debug, Deserialize)]
pub struct RegisterLegacyRequest {
pub site_code: String,
pub hostname: String,
pub os_type: String,
pub os_version: Option<String>,
pub agent_version: Option<String>,
pub agent_type: Option<String>,
}
/// Response for legacy agent registration
#[derive(Debug, Serialize)]
pub struct RegisterLegacyResponse {
pub agent_id: Uuid,
pub api_key: String,
pub site_name: String,
pub client_name: String,
pub message: String,
}
/// Register a legacy agent using site code
pub async fn register_legacy(
State(state): State<AppState>,
Json(req): Json<RegisterLegacyRequest>,
) -> Result<Json<RegisterLegacyResponse>, (StatusCode, String)> {
// Look up site by code
let site = db::get_site_by_code(&state.db, &req.site_code)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
.ok_or((StatusCode::NOT_FOUND, format!("Site code '{}' not found", req.site_code)))?;
// Get client info
let client = db::get_client_by_id(&state.db, site.client_id)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
.ok_or((StatusCode::NOT_FOUND, "Client not found".to_string()))?;
// Generate API key for this agent
let api_key = generate_api_key(&state.config.auth.api_key_prefix);
let api_key_hash = hash_api_key(&api_key);
// Create the agent
let create = db::CreateAgent {
hostname: req.hostname,
api_key_hash,
os_type: req.os_type,
os_version: req.os_version,
agent_version: req.agent_version,
};
let agent = db::create_agent(&state.db, create)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
// Assign agent to site
db::move_agent_to_site(&state.db, agent.id, Some(site.id))
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
tracing::info!(
"Legacy agent registered: {} ({}) -> Site: {} ({})",
agent.hostname,
agent.id,
site.name,
req.site_code
);
Ok(Json(RegisterLegacyResponse {
agent_id: agent.id,
api_key,
site_name: site.name,
client_name: client.name,
message: "Agent registered successfully".to_string(),
}))
}
/// Heartbeat request from legacy agent
#[derive(Debug, Deserialize)]
pub struct HeartbeatRequest {
pub agent_id: Uuid,
pub timestamp: String,
pub system_info: serde_json::Value,
}
/// Heartbeat response with pending commands
#[derive(Debug, Serialize)]
pub struct HeartbeatResponse {
pub success: bool,
pub pending_commands: Vec<PendingCommand>,
}
#[derive(Debug, Serialize)]
pub struct PendingCommand {
pub id: Uuid,
#[serde(rename = "type")]
pub cmd_type: String,
pub script: String,
}
/// Receive heartbeat from legacy agent
pub async fn heartbeat(
State(state): State<AppState>,
Json(req): Json<HeartbeatRequest>,
) -> Result<Json<HeartbeatResponse>, (StatusCode, String)> {
// Update agent last_seen
db::update_agent_last_seen(&state.db, req.agent_id)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
// TODO: Store system_info metrics, get pending commands
// For now, return empty pending commands
Ok(Json(HeartbeatResponse {
success: true,
pending_commands: vec![],
}))
}
/// Command result from legacy agent
#[derive(Debug, Deserialize)]
pub struct CommandResultRequest {
pub command_id: Uuid,
pub started_at: String,
pub completed_at: String,
pub success: bool,
pub output: String,
pub error: Option<String>,
}
/// Receive command execution result
pub async fn command_result(
State(_state): State<AppState>,
Json(_req): Json<CommandResultRequest>,
) -> Result<StatusCode, (StatusCode, String)> {
// TODO: Store command result in database
Ok(StatusCode::OK)
}

View File

@@ -0,0 +1,152 @@
//! Authentication API endpoints
use axum::{
extract::State,
http::StatusCode,
Json,
};
use serde::{Deserialize, Serialize};
use crate::auth::{create_jwt, verify_password, hash_password, Claims, AuthUser};
use crate::db::{self, UserResponse};
use crate::AppState;
/// Login request
#[derive(Debug, Deserialize)]
pub struct LoginRequest {
pub email: String,
pub password: String,
}
/// Login response
#[derive(Debug, Serialize)]
pub struct LoginResponse {
pub token: String,
pub user: UserResponse,
}
/// Register request (for initial admin setup)
#[derive(Debug, Deserialize)]
pub struct RegisterRequest {
pub email: String,
pub password: String,
pub name: Option<String>,
}
/// Register response
#[derive(Debug, Serialize)]
pub struct RegisterResponse {
pub token: String,
pub user: UserResponse,
pub message: String,
}
/// Login with email and password
pub async fn login(
State(state): State<AppState>,
Json(req): Json<LoginRequest>,
) -> Result<Json<LoginResponse>, (StatusCode, String)> {
// Find user by email
let user = db::get_user_by_email(&state.db, &req.email)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
.ok_or((StatusCode::UNAUTHORIZED, "Invalid credentials".to_string()))?;
// Verify password
let password_hash = user
.password_hash
.as_ref()
.ok_or((StatusCode::UNAUTHORIZED, "Invalid credentials".to_string()))?;
if !verify_password(&req.password, password_hash) {
return Err((StatusCode::UNAUTHORIZED, "Invalid credentials".to_string()));
}
// Update last login
let _ = db::update_last_login(&state.db, user.id).await;
// Generate JWT
let token = create_jwt(
user.id,
&user.role,
&state.config.auth.jwt_secret,
state.config.auth.jwt_expiry_hours,
)
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
Ok(Json(LoginResponse {
token,
user: user.into(),
}))
}
/// Register a new user (only works if no users exist - for initial admin setup)
pub async fn register(
State(state): State<AppState>,
Json(req): Json<RegisterRequest>,
) -> Result<Json<RegisterResponse>, (StatusCode, String)> {
// Check if any users exist
let has_users = db::has_users(&state.db)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
// If users exist, only admins can create new users (would need auth check here)
// For now, only allow registration if no users exist
if has_users {
return Err((
StatusCode::FORBIDDEN,
"Registration is disabled. Contact an administrator.".to_string(),
));
}
// Hash password
let password_hash = hash_password(&req.password)
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
// Create user (first user is admin)
let create = db::CreateUser {
email: req.email,
password_hash,
name: req.name,
role: Some("admin".to_string()),
};
let user = db::create_user(&state.db, create)
.await
.map_err(|e| {
if e.to_string().contains("duplicate key") {
(StatusCode::CONFLICT, "Email already registered".to_string())
} else {
(StatusCode::INTERNAL_SERVER_ERROR, e.to_string())
}
})?;
// Generate JWT
let token = create_jwt(
user.id,
&user.role,
&state.config.auth.jwt_secret,
state.config.auth.jwt_expiry_hours,
)
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
Ok(Json(RegisterResponse {
token,
user: user.into(),
message: "Admin account created successfully".to_string(),
}))
}
/// Get current user info (requires auth)
pub async fn me(
State(state): State<AppState>,
auth: AuthUser,
) -> Result<Json<UserResponse>, (StatusCode, String)> {
// Fetch user from database using authenticated user ID
let user = db::get_user_by_id(&state.db, auth.user_id)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
.ok_or((StatusCode::NOT_FOUND, "User not found".to_string()))?;
Ok(Json(user.into()))
}

View File

@@ -0,0 +1,168 @@
//! Client management API endpoints
use axum::{
extract::{Path, State},
http::StatusCode,
Json,
};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::auth::AuthUser;
use crate::db;
use crate::AppState;
/// Response for client operations
#[derive(Debug, Serialize)]
pub struct ClientResponse {
pub id: Uuid,
pub name: String,
pub code: Option<String>,
pub notes: Option<String>,
pub is_active: bool,
pub created_at: chrono::DateTime<chrono::Utc>,
pub site_count: i64,
}
/// Request to create a new client
#[derive(Debug, Deserialize)]
pub struct CreateClientRequest {
pub name: String,
pub code: Option<String>,
pub notes: Option<String>,
}
/// Request to update a client
#[derive(Debug, Deserialize)]
pub struct UpdateClientRequest {
pub name: Option<String>,
pub code: Option<String>,
pub notes: Option<String>,
pub is_active: Option<bool>,
}
/// List all clients
pub async fn list_clients(
_user: AuthUser,
State(state): State<AppState>,
) -> Result<Json<Vec<ClientResponse>>, (StatusCode, String)> {
let clients = db::get_all_clients_with_counts(&state.db)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
let responses: Vec<ClientResponse> = clients
.into_iter()
.map(|c| ClientResponse {
id: c.id,
name: c.name,
code: c.code,
notes: c.notes,
is_active: c.is_active,
created_at: c.created_at,
site_count: c.site_count,
})
.collect();
Ok(Json(responses))
}
/// Create a new client
pub async fn create_client(
_user: AuthUser,
State(state): State<AppState>,
Json(req): Json<CreateClientRequest>,
) -> Result<Json<ClientResponse>, (StatusCode, String)> {
let create = db::CreateClient {
name: req.name,
code: req.code,
notes: req.notes,
};
let client = db::create_client(&state.db, create)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
Ok(Json(ClientResponse {
id: client.id,
name: client.name,
code: client.code,
notes: client.notes,
is_active: client.is_active,
created_at: client.created_at,
site_count: 0,
}))
}
/// Get a specific client
pub async fn get_client(
_user: AuthUser,
State(state): State<AppState>,
Path(id): Path<Uuid>,
) -> Result<Json<ClientResponse>, (StatusCode, String)> {
let client = db::get_client_by_id(&state.db, id)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
.ok_or((StatusCode::NOT_FOUND, "Client not found".to_string()))?;
// Get site count
let sites = db::get_sites_by_client(&state.db, id)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
Ok(Json(ClientResponse {
id: client.id,
name: client.name,
code: client.code,
notes: client.notes,
is_active: client.is_active,
created_at: client.created_at,
site_count: sites.len() as i64,
}))
}
/// Update a client
pub async fn update_client(
_user: AuthUser,
State(state): State<AppState>,
Path(id): Path<Uuid>,
Json(req): Json<UpdateClientRequest>,
) -> Result<Json<ClientResponse>, (StatusCode, String)> {
let update = db::UpdateClient {
name: req.name,
code: req.code,
notes: req.notes,
is_active: req.is_active,
};
let client = db::update_client(&state.db, id, update)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
.ok_or((StatusCode::NOT_FOUND, "Client not found".to_string()))?;
Ok(Json(ClientResponse {
id: client.id,
name: client.name,
code: client.code,
notes: client.notes,
is_active: client.is_active,
created_at: client.created_at,
site_count: 0, // Would need to query again
}))
}
/// Delete a client
pub async fn delete_client(
_user: AuthUser,
State(state): State<AppState>,
Path(id): Path<Uuid>,
) -> Result<StatusCode, (StatusCode, String)> {
let deleted = db::delete_client(&state.db, id)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
if deleted {
Ok(StatusCode::NO_CONTENT)
} else {
Err((StatusCode::NOT_FOUND, "Client not found".to_string()))
}
}

View File

@@ -0,0 +1,127 @@
//! Commands API endpoints
use axum::{
extract::{Path, Query, State},
http::StatusCode,
Json,
};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::db::{self, Command};
use crate::ws::{CommandPayload, ServerMessage};
use crate::AppState;
/// Request to send a command to an agent
#[derive(Debug, Deserialize)]
pub struct SendCommandRequest {
/// Command type (shell, powershell, python, script)
pub command_type: String,
/// Command text to execute
pub command: String,
/// Timeout in seconds (optional, default 300)
pub timeout_seconds: Option<u64>,
/// Run as elevated/admin (optional, default false)
pub elevated: Option<bool>,
}
/// Response after sending a command
#[derive(Debug, Serialize)]
pub struct SendCommandResponse {
pub command_id: Uuid,
pub status: String,
pub message: String,
}
/// Query parameters for listing commands
#[derive(Debug, Deserialize)]
pub struct CommandsQuery {
pub limit: Option<i64>,
}
/// Send a command to an agent
pub async fn send_command(
State(state): State<AppState>,
Path(agent_id): Path<Uuid>,
Json(req): Json<SendCommandRequest>,
) -> Result<Json<SendCommandResponse>, (StatusCode, String)> {
// Verify agent exists
let agent = db::get_agent_by_id(&state.db, agent_id)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
.ok_or((StatusCode::NOT_FOUND, "Agent not found".to_string()))?;
// Create command record
let create = db::CreateCommand {
agent_id,
command_type: req.command_type.clone(),
command_text: req.command.clone(),
created_by: None, // TODO: Get from JWT
};
let command = db::create_command(&state.db, create)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
// Check if agent is connected
let agents = state.agents.read().await;
if agents.is_connected(&agent_id) {
// Send command via WebSocket
let cmd_msg = ServerMessage::Command(CommandPayload {
id: command.id,
command_type: req.command_type,
command: req.command,
timeout_seconds: req.timeout_seconds,
elevated: req.elevated.unwrap_or(false),
});
if agents.send_to(&agent_id, cmd_msg).await {
// Mark as running
let _ = db::mark_command_running(&state.db, command.id).await;
return Ok(Json(SendCommandResponse {
command_id: command.id,
status: "running".to_string(),
message: "Command sent to agent".to_string(),
}));
}
}
// Agent not connected or send failed - command is queued
Ok(Json(SendCommandResponse {
command_id: command.id,
status: "pending".to_string(),
message: "Agent is offline. Command queued for execution when agent reconnects."
.to_string(),
}))
}
/// List recent commands
pub async fn list_commands(
State(state): State<AppState>,
Query(query): Query<CommandsQuery>,
) -> Result<Json<Vec<Command>>, (StatusCode, String)> {
let limit = query.limit.unwrap_or(50).min(500);
let commands = db::get_recent_commands(&state.db, limit)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
Ok(Json(commands))
}
/// Get a specific command by ID
pub async fn get_command(
State(state): State<AppState>,
Path(id): Path<Uuid>,
) -> Result<Json<Command>, (StatusCode, String)> {
let command = db::get_command_by_id(&state.db, id)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
.ok_or((StatusCode::NOT_FOUND, "Command not found".to_string()))?;
Ok(Json(command))
}

View File

@@ -0,0 +1,65 @@
//! Metrics API endpoints
use axum::{
extract::{Path, Query, State},
http::StatusCode,
Json,
};
use chrono::{DateTime, Duration, Utc};
use serde::Deserialize;
use uuid::Uuid;
use crate::db::{self, Metrics, MetricsSummary};
use crate::AppState;
/// Query parameters for metrics
#[derive(Debug, Deserialize)]
pub struct MetricsQuery {
/// Number of records to return (default: 100)
pub limit: Option<i64>,
/// Start time for range query
pub start: Option<DateTime<Utc>>,
/// End time for range query
pub end: Option<DateTime<Utc>>,
}
/// Get metrics for a specific agent
pub async fn get_agent_metrics(
State(state): State<AppState>,
Path(id): Path<Uuid>,
Query(query): Query<MetricsQuery>,
) -> Result<Json<Vec<Metrics>>, (StatusCode, String)> {
// First verify the agent exists
let agent = db::get_agent_by_id(&state.db, id)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
.ok_or((StatusCode::NOT_FOUND, "Agent not found".to_string()))?;
let metrics = if let (Some(start), Some(end)) = (query.start, query.end) {
// Range query
db::get_agent_metrics_range(&state.db, id, start, end)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
} else {
// Simple limit query
let limit = query.limit.unwrap_or(100).min(1000); // Cap at 1000
db::get_agent_metrics(&state.db, id, limit)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
};
Ok(Json(metrics))
}
/// Get summary metrics across all agents
pub async fn get_summary(
State(state): State<AppState>,
) -> Result<Json<MetricsSummary>, (StatusCode, String)> {
let summary = db::get_metrics_summary(&state.db)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
Ok(Json(summary))
}

View File

@@ -0,0 +1,65 @@
//! REST API routes
//!
//! Provides endpoints for:
//! - Agent management (registration, listing, deletion)
//! - Client and site management
//! - Metrics retrieval
//! - Command execution
//! - User authentication
pub mod agents;
pub mod auth;
pub mod clients;
pub mod commands;
pub mod metrics;
pub mod sites;
use axum::{
routing::{delete, get, post, put},
Router,
};
use crate::AppState;
/// Build all API routes
pub fn routes() -> Router<AppState> {
Router::new()
// Authentication
.route("/auth/login", post(auth::login))
.route("/auth/register", post(auth::register))
.route("/auth/me", get(auth::me))
// Clients
.route("/clients", get(clients::list_clients))
.route("/clients", post(clients::create_client))
.route("/clients/:id", get(clients::get_client))
.route("/clients/:id", put(clients::update_client))
.route("/clients/:id", delete(clients::delete_client))
.route("/clients/:id/sites", get(sites::list_sites_by_client))
// Sites
.route("/sites", get(sites::list_sites))
.route("/sites", post(sites::create_site))
.route("/sites/:id", get(sites::get_site))
.route("/sites/:id", put(sites::update_site))
.route("/sites/:id", delete(sites::delete_site))
.route("/sites/:id/regenerate-key", post(sites::regenerate_api_key))
// Agents
.route("/agents", get(agents::list_agents_with_details))
.route("/agents", post(agents::register_agent))
.route("/agents/stats", get(agents::get_stats))
.route("/agents/unassigned", get(agents::list_unassigned_agents))
.route("/agents/:id", get(agents::get_agent))
.route("/agents/:id", delete(agents::delete_agent))
.route("/agents/:id/move", post(agents::move_agent))
.route("/agents/:id/state", get(agents::get_agent_state))
// Metrics
.route("/agents/:id/metrics", get(metrics::get_agent_metrics))
.route("/metrics/summary", get(metrics::get_summary))
// Commands
.route("/agents/:id/command", post(commands::send_command))
.route("/commands", get(commands::list_commands))
.route("/commands/:id", get(commands::get_command))
// Legacy Agent (PowerShell for 2008 R2)
.route("/agent/register-legacy", post(agents::register_legacy))
.route("/agent/heartbeat", post(agents::heartbeat))
.route("/agent/command-result", post(agents::command_result))
}

View File

@@ -0,0 +1,280 @@
//! Site management API endpoints
use axum::{
extract::{Path, State},
http::StatusCode,
Json,
};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::auth::AuthUser;
use crate::db;
use crate::ws::{generate_api_key, hash_api_key};
use crate::AppState;
/// Response for site operations
#[derive(Debug, Serialize)]
pub struct SiteResponse {
pub id: Uuid,
pub client_id: Uuid,
pub client_name: Option<String>,
pub name: String,
pub site_code: String,
pub address: Option<String>,
pub notes: Option<String>,
pub is_active: bool,
pub created_at: chrono::DateTime<chrono::Utc>,
pub agent_count: i64,
}
/// Response when creating a site (includes one-time API key)
#[derive(Debug, Serialize)]
pub struct CreateSiteResponse {
pub site: SiteResponse,
/// The API key for agents at this site (shown only once!)
pub api_key: String,
pub message: String,
}
/// Request to create a new site
#[derive(Debug, Deserialize)]
pub struct CreateSiteRequest {
pub client_id: Uuid,
pub name: String,
pub address: Option<String>,
pub notes: Option<String>,
}
/// Request to update a site
#[derive(Debug, Deserialize)]
pub struct UpdateSiteRequest {
pub name: Option<String>,
pub address: Option<String>,
pub notes: Option<String>,
pub is_active: Option<bool>,
}
/// List all sites
pub async fn list_sites(
_user: AuthUser,
State(state): State<AppState>,
) -> Result<Json<Vec<SiteResponse>>, (StatusCode, String)> {
let sites = db::get_all_sites_with_details(&state.db)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
let responses: Vec<SiteResponse> = sites
.into_iter()
.map(|s| SiteResponse {
id: s.id,
client_id: s.client_id,
client_name: Some(s.client_name),
name: s.name,
site_code: s.site_code,
address: s.address,
notes: s.notes,
is_active: s.is_active,
created_at: s.created_at,
agent_count: s.agent_count,
})
.collect();
Ok(Json(responses))
}
/// List sites for a specific client
pub async fn list_sites_by_client(
_user: AuthUser,
State(state): State<AppState>,
Path(client_id): Path<Uuid>,
) -> Result<Json<Vec<SiteResponse>>, (StatusCode, String)> {
let sites = db::get_sites_by_client(&state.db, client_id)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
let responses: Vec<SiteResponse> = sites
.into_iter()
.map(|s| SiteResponse {
id: s.id,
client_id: s.client_id,
client_name: None,
name: s.name,
site_code: s.site_code,
address: s.address,
notes: s.notes,
is_active: s.is_active,
created_at: s.created_at,
agent_count: 0, // Would need separate query
})
.collect();
Ok(Json(responses))
}
/// Create a new site
pub async fn create_site(
_user: AuthUser,
State(state): State<AppState>,
Json(req): Json<CreateSiteRequest>,
) -> Result<Json<CreateSiteResponse>, (StatusCode, String)> {
// Verify client exists
let client = db::get_client_by_id(&state.db, req.client_id)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
.ok_or((StatusCode::NOT_FOUND, "Client not found".to_string()))?;
// Generate unique site code and API key
let site_code = db::generate_unique_site_code(&state.db)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
let api_key = generate_api_key(&state.config.auth.api_key_prefix);
let api_key_hash = hash_api_key(&api_key);
let create = db::CreateSiteInternal {
client_id: req.client_id,
name: req.name,
site_code: site_code.clone(),
api_key_hash,
address: req.address,
notes: req.notes,
};
let site = db::create_site(&state.db, create)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
Ok(Json(CreateSiteResponse {
site: SiteResponse {
id: site.id,
client_id: site.client_id,
client_name: Some(client.name),
name: site.name,
site_code: site.site_code,
address: site.address,
notes: site.notes,
is_active: site.is_active,
created_at: site.created_at,
agent_count: 0,
},
api_key,
message: "Site created. Save the API key - it will not be shown again.".to_string(),
}))
}
/// Get a specific site
pub async fn get_site(
_user: AuthUser,
State(state): State<AppState>,
Path(id): Path<Uuid>,
) -> Result<Json<SiteResponse>, (StatusCode, String)> {
let site = db::get_site_by_id(&state.db, id)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
.ok_or((StatusCode::NOT_FOUND, "Site not found".to_string()))?;
// Get client name and agent count
let client = db::get_client_by_id(&state.db, site.client_id)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
let agents = db::get_agents_by_site(&state.db, id)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
Ok(Json(SiteResponse {
id: site.id,
client_id: site.client_id,
client_name: client.map(|c| c.name),
name: site.name,
site_code: site.site_code,
address: site.address,
notes: site.notes,
is_active: site.is_active,
created_at: site.created_at,
agent_count: agents.len() as i64,
}))
}
/// Update a site
pub async fn update_site(
_user: AuthUser,
State(state): State<AppState>,
Path(id): Path<Uuid>,
Json(req): Json<UpdateSiteRequest>,
) -> Result<Json<SiteResponse>, (StatusCode, String)> {
let update = db::UpdateSite {
name: req.name,
address: req.address,
notes: req.notes,
is_active: req.is_active,
};
let site = db::update_site(&state.db, id, update)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
.ok_or((StatusCode::NOT_FOUND, "Site not found".to_string()))?;
Ok(Json(SiteResponse {
id: site.id,
client_id: site.client_id,
client_name: None,
name: site.name,
site_code: site.site_code,
address: site.address,
notes: site.notes,
is_active: site.is_active,
created_at: site.created_at,
agent_count: 0,
}))
}
/// Regenerate API key for a site
#[derive(Debug, Serialize)]
pub struct RegenerateApiKeyResponse {
pub api_key: String,
pub message: String,
}
pub async fn regenerate_api_key(
_user: AuthUser,
State(state): State<AppState>,
Path(id): Path<Uuid>,
) -> Result<Json<RegenerateApiKeyResponse>, (StatusCode, String)> {
// Verify site exists
let _site = db::get_site_by_id(&state.db, id)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?
.ok_or((StatusCode::NOT_FOUND, "Site not found".to_string()))?;
// Generate new API key
let api_key = generate_api_key(&state.config.auth.api_key_prefix);
let api_key_hash = hash_api_key(&api_key);
db::regenerate_site_api_key(&state.db, id, &api_key_hash)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
Ok(Json(RegenerateApiKeyResponse {
api_key,
message: "API key regenerated. Save it - it will not be shown again. Existing agents will need to be reconfigured.".to_string(),
}))
}
/// Delete a site
pub async fn delete_site(
_user: AuthUser,
State(state): State<AppState>,
Path(id): Path<Uuid>,
) -> Result<StatusCode, (StatusCode, String)> {
let deleted = db::delete_site(&state.db, id)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
if deleted {
Ok(StatusCode::NO_CONTENT)
} else {
Err((StatusCode::NOT_FOUND, "Site not found".to_string()))
}
}

View File

@@ -0,0 +1,161 @@
//! Authentication utilities
//!
//! Provides JWT token handling, password hashing, and Axum extractors.
use anyhow::Result;
use axum::{
async_trait,
extract::FromRequestParts,
http::{request::Parts, StatusCode, header::AUTHORIZATION},
};
use chrono::{Duration, Utc};
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::AppState;
/// JWT claims structure
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Claims {
/// Subject (user ID)
pub sub: Uuid,
/// User role
pub role: String,
/// Expiration time (Unix timestamp)
pub exp: i64,
/// Issued at (Unix timestamp)
pub iat: i64,
}
/// Create a new JWT token
pub fn create_jwt(user_id: Uuid, role: &str, secret: &str, expiry_hours: u64) -> Result<String> {
let now = Utc::now();
let exp = now + Duration::hours(expiry_hours as i64);
let claims = Claims {
sub: user_id,
role: role.to_string(),
exp: exp.timestamp(),
iat: now.timestamp(),
};
let token = encode(
&Header::default(),
&claims,
&EncodingKey::from_secret(secret.as_bytes()),
)?;
Ok(token)
}
/// Verify and decode a JWT token
pub fn verify_jwt(token: &str, secret: &str) -> Result<Claims> {
let token_data = decode::<Claims>(
token,
&DecodingKey::from_secret(secret.as_bytes()),
&Validation::default(),
)?;
Ok(token_data.claims)
}
/// Hash a password using Argon2
pub fn hash_password(password: &str) -> Result<String> {
use argon2::{
password_hash::{rand_core::OsRng, PasswordHasher, SaltString},
Argon2,
};
let salt = SaltString::generate(&mut OsRng);
let argon2 = Argon2::default();
let hash = argon2
.hash_password(password.as_bytes(), &salt)
.map_err(|e| anyhow::anyhow!("Failed to hash password: {}", e))?
.to_string();
Ok(hash)
}
/// Verify a password against a hash
pub fn verify_password(password: &str, hash: &str) -> bool {
use argon2::{
password_hash::{PasswordHash, PasswordVerifier},
Argon2,
};
let parsed_hash = match PasswordHash::new(hash) {
Ok(h) => h,
Err(_) => return false,
};
Argon2::default()
.verify_password(password.as_bytes(), &parsed_hash)
.is_ok()
}
/// Authenticated user extractor for Axum
/// Extracts and validates JWT from Authorization header
#[derive(Debug, Clone)]
pub struct AuthUser {
pub user_id: Uuid,
pub role: String,
}
#[async_trait]
impl FromRequestParts<AppState> for AuthUser {
type Rejection = (StatusCode, String);
async fn from_request_parts(parts: &mut Parts, state: &AppState) -> Result<Self, Self::Rejection> {
// Get Authorization header
let auth_header = parts
.headers
.get(AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.ok_or((StatusCode::UNAUTHORIZED, "Missing authorization header".to_string()))?;
// Extract Bearer token
let token = auth_header
.strip_prefix("Bearer ")
.ok_or((StatusCode::UNAUTHORIZED, "Invalid authorization format".to_string()))?;
// Verify JWT
let claims = verify_jwt(token, &state.config.auth.jwt_secret)
.map_err(|_| (StatusCode::UNAUTHORIZED, "Invalid or expired token".to_string()))?;
Ok(AuthUser {
user_id: claims.sub,
role: claims.role,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_password_hash_and_verify() {
let password = "test_password_123";
let hash = hash_password(password).unwrap();
assert!(verify_password(password, &hash));
assert!(!verify_password("wrong_password", &hash));
}
#[test]
fn test_jwt_create_and_verify() {
let user_id = Uuid::new_v4();
let secret = "test_secret_key";
let token = create_jwt(user_id, "admin", secret, 24).unwrap();
let claims = verify_jwt(&token, secret).unwrap();
assert_eq!(claims.sub, user_id);
assert_eq!(claims.role, "admin");
}
}

View File

@@ -0,0 +1,161 @@
//! Server configuration
//!
//! Configuration is loaded from environment variables.
//! Required variables:
//! - DATABASE_URL: PostgreSQL connection string
//! - JWT_SECRET: Secret key for JWT token signing
//!
//! Optional variables:
//! - SERVER_HOST: Host to bind to (default: 0.0.0.0)
//! - SERVER_PORT: Port to bind to (default: 3001)
//! - DB_MAX_CONNECTIONS: Max database connections (default: 10)
//! - DOWNLOADS_DIR: Directory containing agent binaries (default: /var/www/downloads)
//! - DOWNLOADS_BASE_URL: Base URL for downloads (default: http://localhost:3001/downloads)
//! - AUTO_UPDATE_ENABLED: Enable automatic agent updates (default: true)
//! - UPDATE_TIMEOUT_SECS: Timeout for agent updates (default: 180)
use std::path::PathBuf;
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
/// Root server configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerConfig {
/// Server binding configuration
pub server: ServerBindConfig,
/// Database configuration
pub database: DatabaseConfig,
/// Authentication configuration
pub auth: AuthConfig,
/// Agent updates configuration
pub updates: UpdatesConfig,
}
/// Server binding configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerBindConfig {
/// Host to bind to
pub host: String,
/// Port to bind to
pub port: u16,
}
/// Database configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatabaseConfig {
/// PostgreSQL connection URL
pub url: String,
/// Maximum number of connections in the pool
pub max_connections: u32,
}
/// Authentication configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthConfig {
/// JWT signing secret
pub jwt_secret: String,
/// JWT token expiry in hours
pub jwt_expiry_hours: u64,
/// API key prefix for agents
pub api_key_prefix: String,
}
/// Agent updates configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UpdatesConfig {
/// Directory containing agent binaries
pub downloads_dir: PathBuf,
/// Base URL for agent downloads
pub downloads_base_url: String,
/// Enable automatic agent updates
pub auto_update_enabled: bool,
/// Timeout for agent updates in seconds
pub update_timeout_secs: u64,
/// Interval for scanning downloads directory (seconds)
pub scan_interval_secs: u64,
}
impl ServerConfig {
/// Load configuration from environment variables
pub fn from_env() -> Result<Self> {
let database_url =
std::env::var("DATABASE_URL").context("DATABASE_URL environment variable not set")?;
let jwt_secret =
std::env::var("JWT_SECRET").context("JWT_SECRET environment variable not set")?;
let host = std::env::var("SERVER_HOST").unwrap_or_else(|_| "0.0.0.0".to_string());
let port = std::env::var("SERVER_PORT")
.ok()
.and_then(|p| p.parse().ok())
.unwrap_or(3001);
let max_connections = std::env::var("DB_MAX_CONNECTIONS")
.ok()
.and_then(|p| p.parse().ok())
.unwrap_or(10);
let jwt_expiry_hours = std::env::var("JWT_EXPIRY_HOURS")
.ok()
.and_then(|p| p.parse().ok())
.unwrap_or(24);
let api_key_prefix =
std::env::var("API_KEY_PREFIX").unwrap_or_else(|_| "grmm_".to_string());
// Updates configuration
let downloads_dir = std::env::var("DOWNLOADS_DIR")
.map(PathBuf::from)
.unwrap_or_else(|_| PathBuf::from("/var/www/downloads"));
let downloads_base_url = std::env::var("DOWNLOADS_BASE_URL")
.unwrap_or_else(|_| format!("http://{}:{}/downloads", host, port));
let auto_update_enabled = std::env::var("AUTO_UPDATE_ENABLED")
.map(|v| v.to_lowercase() == "true" || v == "1")
.unwrap_or(true);
let update_timeout_secs = std::env::var("UPDATE_TIMEOUT_SECS")
.ok()
.and_then(|p| p.parse().ok())
.unwrap_or(180);
let scan_interval_secs = std::env::var("SCAN_INTERVAL_SECS")
.ok()
.and_then(|p| p.parse().ok())
.unwrap_or(300); // 5 minutes
Ok(Self {
server: ServerBindConfig { host, port },
database: DatabaseConfig {
url: database_url,
max_connections,
},
auth: AuthConfig {
jwt_secret,
jwt_expiry_hours,
api_key_prefix,
},
updates: UpdatesConfig {
downloads_dir,
downloads_base_url,
auto_update_enabled,
update_timeout_secs,
scan_interval_secs,
},
})
}
}

View File

@@ -0,0 +1,375 @@
//! Agent database operations
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use sqlx::PgPool;
use uuid::Uuid;
/// Agent record from database
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
pub struct Agent {
pub id: Uuid,
pub hostname: String,
#[serde(skip_serializing)]
pub api_key_hash: Option<String>, // Nullable: new agents use site's api_key
pub os_type: String,
pub os_version: Option<String>,
pub agent_version: Option<String>,
pub last_seen: Option<DateTime<Utc>>,
pub status: String,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
// New fields for site-based registration
pub device_id: Option<String>, // Hardware-derived unique ID
pub site_id: Option<Uuid>, // Which site this agent belongs to
}
/// Agent without sensitive fields (for API responses)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentResponse {
pub id: Uuid,
pub hostname: String,
pub os_type: String,
pub os_version: Option<String>,
pub agent_version: Option<String>,
pub last_seen: Option<DateTime<Utc>>,
pub status: String,
pub created_at: DateTime<Utc>,
pub device_id: Option<String>,
pub site_id: Option<Uuid>,
pub site_name: Option<String>,
pub client_name: Option<String>,
}
impl From<Agent> for AgentResponse {
fn from(agent: Agent) -> Self {
Self {
id: agent.id,
hostname: agent.hostname,
os_type: agent.os_type,
os_version: agent.os_version,
agent_version: agent.agent_version,
last_seen: agent.last_seen,
status: agent.status,
created_at: agent.created_at,
device_id: agent.device_id,
site_id: agent.site_id,
site_name: None,
client_name: None,
}
}
}
/// Create a new agent registration
#[derive(Debug, Clone, Deserialize)]
pub struct CreateAgent {
pub hostname: String,
pub api_key_hash: String,
pub os_type: String,
pub os_version: Option<String>,
pub agent_version: Option<String>,
}
/// Insert a new agent into the database
pub async fn create_agent(pool: &PgPool, agent: CreateAgent) -> Result<Agent, sqlx::Error> {
sqlx::query_as::<_, Agent>(
r#"
INSERT INTO agents (hostname, api_key_hash, os_type, os_version, agent_version, status)
VALUES ($1, $2, $3, $4, $5, 'offline')
RETURNING *
"#,
)
.bind(&agent.hostname)
.bind(&agent.api_key_hash)
.bind(&agent.os_type)
.bind(&agent.os_version)
.bind(&agent.agent_version)
.fetch_one(pool)
.await
}
/// Get an agent by ID
pub async fn get_agent_by_id(pool: &PgPool, id: Uuid) -> Result<Option<Agent>, sqlx::Error> {
sqlx::query_as::<_, Agent>("SELECT * FROM agents WHERE id = $1")
.bind(id)
.fetch_optional(pool)
.await
}
/// Get an agent by API key hash
pub async fn get_agent_by_api_key_hash(
pool: &PgPool,
api_key_hash: &str,
) -> Result<Option<Agent>, sqlx::Error> {
sqlx::query_as::<_, Agent>("SELECT * FROM agents WHERE api_key_hash = $1")
.bind(api_key_hash)
.fetch_optional(pool)
.await
}
/// Get all agents
pub async fn get_all_agents(pool: &PgPool) -> Result<Vec<Agent>, sqlx::Error> {
sqlx::query_as::<_, Agent>("SELECT * FROM agents ORDER BY hostname")
.fetch_all(pool)
.await
}
/// Get agents by status
pub async fn get_agents_by_status(
pool: &PgPool,
status: &str,
) -> Result<Vec<Agent>, sqlx::Error> {
sqlx::query_as::<_, Agent>("SELECT * FROM agents WHERE status = $1 ORDER BY hostname")
.bind(status)
.fetch_all(pool)
.await
}
/// Update agent status and last_seen
pub async fn update_agent_status(
pool: &PgPool,
id: Uuid,
status: &str,
) -> Result<(), sqlx::Error> {
sqlx::query("UPDATE agents SET status = $1, last_seen = NOW() WHERE id = $2")
.bind(status)
.bind(id)
.execute(pool)
.await?;
Ok(())
}
/// Update agent info (on connection)
pub async fn update_agent_info(
pool: &PgPool,
id: Uuid,
hostname: Option<&str>,
os_version: Option<&str>,
agent_version: Option<&str>,
) -> Result<(), sqlx::Error> {
sqlx::query(
r#"
UPDATE agents
SET hostname = COALESCE($1, hostname),
os_version = COALESCE($2, os_version),
agent_version = COALESCE($3, agent_version),
last_seen = NOW(),
status = 'online'
WHERE id = $4
"#,
)
.bind(hostname)
.bind(os_version)
.bind(agent_version)
.bind(id)
.execute(pool)
.await?;
Ok(())
}
/// Delete an agent
pub async fn delete_agent(pool: &PgPool, id: Uuid) -> Result<bool, sqlx::Error> {
let result = sqlx::query("DELETE FROM agents WHERE id = $1")
.bind(id)
.execute(pool)
.await?;
Ok(result.rows_affected() > 0)
}
/// Get agent count statistics
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentStats {
pub total: i64,
pub online: i64,
pub offline: i64,
}
pub async fn get_agent_stats(pool: &PgPool) -> Result<AgentStats, sqlx::Error> {
let total: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM agents")
.fetch_one(pool)
.await?;
let online: (i64,) =
sqlx::query_as("SELECT COUNT(*) FROM agents WHERE status = 'online'")
.fetch_one(pool)
.await?;
Ok(AgentStats {
total: total.0,
online: online.0,
offline: total.0 - online.0,
})
}
// ============================================================================
// Site-based agent operations
// ============================================================================
/// Data for creating an agent via site registration
#[derive(Debug, Clone)]
pub struct CreateAgentWithSite {
pub site_id: Uuid,
pub device_id: String,
pub hostname: String,
pub os_type: String,
pub os_version: Option<String>,
pub agent_version: Option<String>,
}
/// Create a new agent under a site (site-based registration)
pub async fn create_agent_with_site(
pool: &PgPool,
agent: CreateAgentWithSite,
) -> Result<Agent, sqlx::Error> {
sqlx::query_as::<_, Agent>(
r#"
INSERT INTO agents (site_id, device_id, hostname, os_type, os_version, agent_version, status)
VALUES ($1, $2, $3, $4, $5, $6, 'offline')
RETURNING *
"#,
)
.bind(&agent.site_id)
.bind(&agent.device_id)
.bind(&agent.hostname)
.bind(&agent.os_type)
.bind(&agent.os_version)
.bind(&agent.agent_version)
.fetch_one(pool)
.await
}
/// Get an agent by site_id and device_id (for site-based auth)
pub async fn get_agent_by_site_and_device(
pool: &PgPool,
site_id: Uuid,
device_id: &str,
) -> Result<Option<Agent>, sqlx::Error> {
sqlx::query_as::<_, Agent>(
"SELECT * FROM agents WHERE site_id = $1 AND device_id = $2"
)
.bind(site_id)
.bind(device_id)
.fetch_optional(pool)
.await
}
/// Get all agents for a site
pub async fn get_agents_by_site(pool: &PgPool, site_id: Uuid) -> Result<Vec<Agent>, sqlx::Error> {
sqlx::query_as::<_, Agent>(
"SELECT * FROM agents WHERE site_id = $1 ORDER BY hostname"
)
.bind(site_id)
.fetch_all(pool)
.await
}
/// Agent with site and client details for API responses
#[derive(Debug, Clone, sqlx::FromRow, Serialize)]
pub struct AgentWithDetails {
pub id: Uuid,
pub hostname: String,
pub os_type: String,
pub os_version: Option<String>,
pub agent_version: Option<String>,
pub last_seen: Option<DateTime<Utc>>,
pub status: String,
pub created_at: DateTime<Utc>,
pub device_id: Option<String>,
pub site_id: Option<Uuid>,
pub site_name: Option<String>,
pub client_id: Option<Uuid>,
pub client_name: Option<String>,
}
/// Get all agents with site/client details
pub async fn get_all_agents_with_details(pool: &PgPool) -> Result<Vec<AgentWithDetails>, sqlx::Error> {
sqlx::query_as::<_, AgentWithDetails>(
r#"
SELECT
a.id, a.hostname, a.os_type, a.os_version, a.agent_version,
a.last_seen, a.status, a.created_at, a.device_id, a.site_id,
s.name as site_name,
c.id as client_id,
c.name as client_name
FROM agents a
LEFT JOIN sites s ON a.site_id = s.id
LEFT JOIN clients c ON s.client_id = c.id
ORDER BY c.name, s.name, a.hostname
"#,
)
.fetch_all(pool)
.await
}
/// Update agent info including device_id (on connection)
pub async fn update_agent_info_full(
pool: &PgPool,
id: Uuid,
hostname: Option<&str>,
device_id: Option<&str>,
os_version: Option<&str>,
agent_version: Option<&str>,
) -> Result<(), sqlx::Error> {
sqlx::query(
r#"
UPDATE agents
SET hostname = COALESCE($1, hostname),
device_id = COALESCE($2, device_id),
os_version = COALESCE($3, os_version),
agent_version = COALESCE($4, agent_version),
last_seen = NOW(),
status = 'online'
WHERE id = $5
"#,
)
.bind(hostname)
.bind(device_id)
.bind(os_version)
.bind(agent_version)
.bind(id)
.execute(pool)
.await?;
Ok(())
}
/// Move an agent to a different site (or remove from site if site_id is None)
pub async fn move_agent_to_site(
pool: &PgPool,
agent_id: Uuid,
site_id: Option<Uuid>,
) -> Result<Option<Agent>, sqlx::Error> {
sqlx::query_as::<_, Agent>(
r#"
UPDATE agents
SET site_id = $1,
updated_at = NOW()
WHERE id = $2
RETURNING *
"#,
)
.bind(site_id)
.bind(agent_id)
.fetch_optional(pool)
.await
}
/// Get agents that are not assigned to any site
pub async fn get_unassigned_agents(pool: &PgPool) -> Result<Vec<Agent>, sqlx::Error> {
sqlx::query_as::<_, Agent>(
"SELECT * FROM agents WHERE site_id IS NULL ORDER BY hostname"
)
.fetch_all(pool)
.await
}
/// Update agent last_seen timestamp (for heartbeat)
pub async fn update_agent_last_seen(
pool: &PgPool,
id: Uuid,
) -> Result<(), sqlx::Error> {
sqlx::query("UPDATE agents SET last_seen = NOW(), status = 'online' WHERE id = $1")
.bind(id)
.execute(pool)
.await?;
Ok(())
}

View File

@@ -0,0 +1,157 @@
//! Client (organization) database operations
use serde::{Deserialize, Serialize};
use sqlx::PgPool;
use uuid::Uuid;
/// Client record from database
#[derive(Debug, Clone, sqlx::FromRow)]
pub struct Client {
pub id: Uuid,
pub name: String,
pub code: Option<String>,
pub notes: Option<String>,
pub is_active: bool,
pub created_at: chrono::DateTime<chrono::Utc>,
pub updated_at: chrono::DateTime<chrono::Utc>,
}
/// Client response for API
#[derive(Debug, Clone, Serialize)]
pub struct ClientResponse {
pub id: Uuid,
pub name: String,
pub code: Option<String>,
pub notes: Option<String>,
pub is_active: bool,
pub created_at: chrono::DateTime<chrono::Utc>,
pub site_count: Option<i64>,
pub agent_count: Option<i64>,
}
impl From<Client> for ClientResponse {
fn from(c: Client) -> Self {
ClientResponse {
id: c.id,
name: c.name,
code: c.code,
notes: c.notes,
is_active: c.is_active,
created_at: c.created_at,
site_count: None,
agent_count: None,
}
}
}
/// Data for creating a new client
#[derive(Debug, Deserialize)]
pub struct CreateClient {
pub name: String,
pub code: Option<String>,
pub notes: Option<String>,
}
/// Data for updating a client
#[derive(Debug, Deserialize)]
pub struct UpdateClient {
pub name: Option<String>,
pub code: Option<String>,
pub notes: Option<String>,
pub is_active: Option<bool>,
}
/// Create a new client
pub async fn create_client(pool: &PgPool, client: CreateClient) -> Result<Client, sqlx::Error> {
sqlx::query_as::<_, Client>(
r#"
INSERT INTO clients (name, code, notes)
VALUES ($1, $2, $3)
RETURNING *
"#,
)
.bind(&client.name)
.bind(&client.code)
.bind(&client.notes)
.fetch_one(pool)
.await
}
/// Get a client by ID
pub async fn get_client_by_id(pool: &PgPool, id: Uuid) -> Result<Option<Client>, sqlx::Error> {
sqlx::query_as::<_, Client>("SELECT * FROM clients WHERE id = $1")
.bind(id)
.fetch_optional(pool)
.await
}
/// Get all clients
pub async fn get_all_clients(pool: &PgPool) -> Result<Vec<Client>, sqlx::Error> {
sqlx::query_as::<_, Client>("SELECT * FROM clients ORDER BY name")
.fetch_all(pool)
.await
}
/// Get all clients with counts
#[derive(Debug, Clone, sqlx::FromRow)]
pub struct ClientWithCounts {
pub id: Uuid,
pub name: String,
pub code: Option<String>,
pub notes: Option<String>,
pub is_active: bool,
pub created_at: chrono::DateTime<chrono::Utc>,
pub updated_at: chrono::DateTime<chrono::Utc>,
pub site_count: i64,
pub agent_count: i64,
}
pub async fn get_all_clients_with_counts(pool: &PgPool) -> Result<Vec<ClientWithCounts>, sqlx::Error> {
sqlx::query_as::<_, ClientWithCounts>(
r#"
SELECT
c.*,
COALESCE((SELECT COUNT(*) FROM sites WHERE client_id = c.id), 0) as site_count,
COALESCE((SELECT COUNT(*) FROM agents a JOIN sites s ON a.site_id = s.id WHERE s.client_id = c.id), 0) as agent_count
FROM clients c
ORDER BY c.name
"#,
)
.fetch_all(pool)
.await
}
/// Update a client
pub async fn update_client(
pool: &PgPool,
id: Uuid,
update: UpdateClient,
) -> Result<Option<Client>, sqlx::Error> {
sqlx::query_as::<_, Client>(
r#"
UPDATE clients
SET name = COALESCE($1, name),
code = COALESCE($2, code),
notes = COALESCE($3, notes),
is_active = COALESCE($4, is_active)
WHERE id = $5
RETURNING *
"#,
)
.bind(&update.name)
.bind(&update.code)
.bind(&update.notes)
.bind(&update.is_active)
.bind(id)
.fetch_optional(pool)
.await
}
/// Delete a client
pub async fn delete_client(pool: &PgPool, id: Uuid) -> Result<bool, sqlx::Error> {
let result = sqlx::query("DELETE FROM clients WHERE id = $1")
.bind(id)
.execute(pool)
.await?;
Ok(result.rows_affected() > 0)
}

View File

@@ -0,0 +1,163 @@
//! Commands database operations
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use sqlx::PgPool;
use uuid::Uuid;
/// Command record from database
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
pub struct Command {
pub id: Uuid,
pub agent_id: Uuid,
pub command_type: String,
pub command_text: String,
pub status: String,
pub exit_code: Option<i32>,
pub stdout: Option<String>,
pub stderr: Option<String>,
pub created_at: DateTime<Utc>,
pub started_at: Option<DateTime<Utc>>,
pub completed_at: Option<DateTime<Utc>>,
pub created_by: Option<Uuid>,
}
/// Create a new command
#[derive(Debug, Clone, Deserialize)]
pub struct CreateCommand {
pub agent_id: Uuid,
pub command_type: String,
pub command_text: String,
pub created_by: Option<Uuid>,
}
/// Insert a new command
pub async fn create_command(pool: &PgPool, cmd: CreateCommand) -> Result<Command, sqlx::Error> {
sqlx::query_as::<_, Command>(
r#"
INSERT INTO commands (agent_id, command_type, command_text, status, created_by)
VALUES ($1, $2, $3, 'pending', $4)
RETURNING *
"#,
)
.bind(cmd.agent_id)
.bind(&cmd.command_type)
.bind(&cmd.command_text)
.bind(cmd.created_by)
.fetch_one(pool)
.await
}
/// Get a command by ID
pub async fn get_command_by_id(pool: &PgPool, id: Uuid) -> Result<Option<Command>, sqlx::Error> {
sqlx::query_as::<_, Command>("SELECT * FROM commands WHERE id = $1")
.bind(id)
.fetch_optional(pool)
.await
}
/// Get pending commands for an agent
pub async fn get_pending_commands(
pool: &PgPool,
agent_id: Uuid,
) -> Result<Vec<Command>, sqlx::Error> {
sqlx::query_as::<_, Command>(
r#"
SELECT * FROM commands
WHERE agent_id = $1 AND status = 'pending'
ORDER BY created_at ASC
"#,
)
.bind(agent_id)
.fetch_all(pool)
.await
}
/// Get command history for an agent
pub async fn get_agent_commands(
pool: &PgPool,
agent_id: Uuid,
limit: i64,
) -> Result<Vec<Command>, sqlx::Error> {
sqlx::query_as::<_, Command>(
r#"
SELECT * FROM commands
WHERE agent_id = $1
ORDER BY created_at DESC
LIMIT $2
"#,
)
.bind(agent_id)
.bind(limit)
.fetch_all(pool)
.await
}
/// Get all recent commands
pub async fn get_recent_commands(pool: &PgPool, limit: i64) -> Result<Vec<Command>, sqlx::Error> {
sqlx::query_as::<_, Command>(
r#"
SELECT * FROM commands
ORDER BY created_at DESC
LIMIT $1
"#,
)
.bind(limit)
.fetch_all(pool)
.await
}
/// Update command status to running
pub async fn mark_command_running(pool: &PgPool, id: Uuid) -> Result<(), sqlx::Error> {
sqlx::query("UPDATE commands SET status = 'running', started_at = NOW() WHERE id = $1")
.bind(id)
.execute(pool)
.await?;
Ok(())
}
/// Update command result
#[derive(Debug, Clone, Deserialize)]
pub struct CommandResult {
pub exit_code: i32,
pub stdout: String,
pub stderr: String,
}
pub async fn update_command_result(
pool: &PgPool,
id: Uuid,
result: CommandResult,
) -> Result<(), sqlx::Error> {
let status = if result.exit_code == 0 {
"completed"
} else {
"failed"
};
sqlx::query(
r#"
UPDATE commands
SET status = $1, exit_code = $2, stdout = $3, stderr = $4, completed_at = NOW()
WHERE id = $5
"#,
)
.bind(status)
.bind(result.exit_code)
.bind(&result.stdout)
.bind(&result.stderr)
.bind(id)
.execute(pool)
.await?;
Ok(())
}
/// Delete a command
pub async fn delete_command(pool: &PgPool, id: Uuid) -> Result<bool, sqlx::Error> {
let result = sqlx::query("DELETE FROM commands WHERE id = $1")
.bind(id)
.execute(pool)
.await?;
Ok(result.rows_affected() > 0)
}

View File

@@ -0,0 +1,284 @@
//! Metrics database operations
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use sqlx::PgPool;
use uuid::Uuid;
/// Metrics record from database
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
pub struct Metrics {
pub id: i64,
pub agent_id: Uuid,
pub timestamp: DateTime<Utc>,
pub cpu_percent: Option<f32>,
pub memory_percent: Option<f32>,
pub memory_used_bytes: Option<i64>,
pub disk_percent: Option<f32>,
pub disk_used_bytes: Option<i64>,
pub network_rx_bytes: Option<i64>,
pub network_tx_bytes: Option<i64>,
// Extended metrics
pub uptime_seconds: Option<i64>,
pub boot_time: Option<i64>,
pub logged_in_user: Option<String>,
pub user_idle_seconds: Option<i64>,
pub public_ip: Option<String>,
pub memory_total_bytes: Option<i64>,
pub disk_total_bytes: Option<i64>,
}
/// Create metrics data from agent
#[derive(Debug, Clone, Deserialize)]
pub struct CreateMetrics {
pub agent_id: Uuid,
pub cpu_percent: Option<f32>,
pub memory_percent: Option<f32>,
pub memory_used_bytes: Option<i64>,
pub disk_percent: Option<f32>,
pub disk_used_bytes: Option<i64>,
pub network_rx_bytes: Option<i64>,
pub network_tx_bytes: Option<i64>,
// Extended metrics
pub uptime_seconds: Option<i64>,
pub boot_time: Option<i64>,
pub logged_in_user: Option<String>,
pub user_idle_seconds: Option<i64>,
pub public_ip: Option<String>,
pub memory_total_bytes: Option<i64>,
pub disk_total_bytes: Option<i64>,
}
/// Insert metrics into the database
pub async fn insert_metrics(pool: &PgPool, metrics: CreateMetrics) -> Result<i64, sqlx::Error> {
let result: (i64,) = sqlx::query_as(
r#"
INSERT INTO metrics (
agent_id, cpu_percent, memory_percent, memory_used_bytes,
disk_percent, disk_used_bytes, network_rx_bytes, network_tx_bytes,
uptime_seconds, boot_time, logged_in_user, user_idle_seconds,
public_ip, memory_total_bytes, disk_total_bytes
)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)
RETURNING id
"#,
)
.bind(metrics.agent_id)
.bind(metrics.cpu_percent)
.bind(metrics.memory_percent)
.bind(metrics.memory_used_bytes)
.bind(metrics.disk_percent)
.bind(metrics.disk_used_bytes)
.bind(metrics.network_rx_bytes)
.bind(metrics.network_tx_bytes)
.bind(metrics.uptime_seconds)
.bind(metrics.boot_time)
.bind(&metrics.logged_in_user)
.bind(metrics.user_idle_seconds)
.bind(&metrics.public_ip)
.bind(metrics.memory_total_bytes)
.bind(metrics.disk_total_bytes)
.fetch_one(pool)
.await?;
Ok(result.0)
}
/// Get recent metrics for an agent
pub async fn get_agent_metrics(
pool: &PgPool,
agent_id: Uuid,
limit: i64,
) -> Result<Vec<Metrics>, sqlx::Error> {
sqlx::query_as::<_, Metrics>(
r#"
SELECT * FROM metrics
WHERE agent_id = $1
ORDER BY timestamp DESC
LIMIT $2
"#,
)
.bind(agent_id)
.bind(limit)
.fetch_all(pool)
.await
}
/// Get metrics for an agent within a time range
pub async fn get_agent_metrics_range(
pool: &PgPool,
agent_id: Uuid,
start: DateTime<Utc>,
end: DateTime<Utc>,
) -> Result<Vec<Metrics>, sqlx::Error> {
sqlx::query_as::<_, Metrics>(
r#"
SELECT * FROM metrics
WHERE agent_id = $1 AND timestamp >= $2 AND timestamp <= $3
ORDER BY timestamp ASC
"#,
)
.bind(agent_id)
.bind(start)
.bind(end)
.fetch_all(pool)
.await
}
/// Get latest metrics for an agent
pub async fn get_latest_metrics(
pool: &PgPool,
agent_id: Uuid,
) -> Result<Option<Metrics>, sqlx::Error> {
sqlx::query_as::<_, Metrics>(
r#"
SELECT * FROM metrics
WHERE agent_id = $1
ORDER BY timestamp DESC
LIMIT 1
"#,
)
.bind(agent_id)
.fetch_optional(pool)
.await
}
/// Delete old metrics (for cleanup jobs)
pub async fn delete_old_metrics(
pool: &PgPool,
older_than: DateTime<Utc>,
) -> Result<u64, sqlx::Error> {
let result = sqlx::query("DELETE FROM metrics WHERE timestamp < $1")
.bind(older_than)
.execute(pool)
.await?;
Ok(result.rows_affected())
}
/// Summary statistics for dashboard
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MetricsSummary {
pub avg_cpu: f32,
pub avg_memory: f32,
pub avg_disk: f32,
pub total_network_rx: i64,
pub total_network_tx: i64,
}
/// Get summary metrics across all agents (last hour)
pub async fn get_metrics_summary(pool: &PgPool) -> Result<MetricsSummary, sqlx::Error> {
let result: (Option<f64>, Option<f64>, Option<f64>, Option<i64>, Option<i64>) = sqlx::query_as(
r#"
SELECT
AVG(cpu_percent)::float8,
AVG(memory_percent)::float8,
AVG(disk_percent)::float8,
SUM(network_rx_bytes),
SUM(network_tx_bytes)
FROM metrics
WHERE timestamp > NOW() - INTERVAL '1 hour'
"#,
)
.fetch_one(pool)
.await?;
Ok(MetricsSummary {
avg_cpu: result.0.unwrap_or(0.0) as f32,
avg_memory: result.1.unwrap_or(0.0) as f32,
avg_disk: result.2.unwrap_or(0.0) as f32,
total_network_rx: result.3.unwrap_or(0),
total_network_tx: result.4.unwrap_or(0),
})
}
// ============================================================================
// Agent State (network interfaces, extended info snapshot)
// ============================================================================
/// Agent state record from database
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
pub struct AgentState {
pub agent_id: Uuid,
pub network_interfaces: Option<serde_json::Value>,
pub network_state_hash: Option<String>,
pub uptime_seconds: Option<i64>,
pub boot_time: Option<i64>,
pub logged_in_user: Option<String>,
pub user_idle_seconds: Option<i64>,
pub public_ip: Option<String>,
pub network_updated_at: Option<DateTime<Utc>>,
pub metrics_updated_at: Option<DateTime<Utc>>,
}
/// Update or insert agent state (upsert)
pub async fn upsert_agent_state(
pool: &PgPool,
agent_id: Uuid,
uptime_seconds: Option<i64>,
boot_time: Option<i64>,
logged_in_user: Option<&str>,
user_idle_seconds: Option<i64>,
public_ip: Option<&str>,
) -> Result<(), sqlx::Error> {
sqlx::query(
r#"
INSERT INTO agent_state (
agent_id, uptime_seconds, boot_time, logged_in_user,
user_idle_seconds, public_ip, metrics_updated_at
)
VALUES ($1, $2, $3, $4, $5, $6, NOW())
ON CONFLICT (agent_id) DO UPDATE SET
uptime_seconds = EXCLUDED.uptime_seconds,
boot_time = EXCLUDED.boot_time,
logged_in_user = EXCLUDED.logged_in_user,
user_idle_seconds = EXCLUDED.user_idle_seconds,
public_ip = EXCLUDED.public_ip,
metrics_updated_at = NOW()
"#,
)
.bind(agent_id)
.bind(uptime_seconds)
.bind(boot_time)
.bind(logged_in_user)
.bind(user_idle_seconds)
.bind(public_ip)
.execute(pool)
.await?;
Ok(())
}
/// Update network state for an agent
pub async fn update_agent_network_state(
pool: &PgPool,
agent_id: Uuid,
interfaces: &serde_json::Value,
state_hash: &str,
) -> Result<(), sqlx::Error> {
sqlx::query(
r#"
INSERT INTO agent_state (agent_id, network_interfaces, network_state_hash, network_updated_at)
VALUES ($1, $2, $3, NOW())
ON CONFLICT (agent_id) DO UPDATE SET
network_interfaces = EXCLUDED.network_interfaces,
network_state_hash = EXCLUDED.network_state_hash,
network_updated_at = NOW()
"#,
)
.bind(agent_id)
.bind(interfaces)
.bind(state_hash)
.execute(pool)
.await?;
Ok(())
}
/// Get agent state by agent ID
pub async fn get_agent_state(pool: &PgPool, agent_id: Uuid) -> Result<Option<AgentState>, sqlx::Error> {
sqlx::query_as::<_, AgentState>("SELECT * FROM agent_state WHERE agent_id = $1")
.bind(agent_id)
.fetch_optional(pool)
.await
}

View File

@@ -0,0 +1,19 @@
//! Database models and queries
//!
//! Provides database access for clients, sites, agents, metrics, commands, users, and updates.
pub mod agents;
pub mod clients;
pub mod commands;
pub mod metrics;
pub mod sites;
pub mod updates;
pub mod users;
pub use agents::*;
pub use clients::*;
pub use commands::*;
pub use metrics::*;
pub use sites::*;
pub use updates::*;
pub use users::*;

View File

@@ -0,0 +1,264 @@
//! Site database operations
use rand::Rng;
use serde::{Deserialize, Serialize};
use sqlx::PgPool;
use uuid::Uuid;
/// Word lists for generating site codes
const ADJECTIVES: &[&str] = &[
"BLUE", "GREEN", "RED", "GOLD", "SILVER", "IRON", "COPPER", "BRONZE",
"SWIFT", "BRIGHT", "DARK", "LIGHT", "BOLD", "CALM", "WILD", "WARM",
"NORTH", "SOUTH", "EAST", "WEST", "UPPER", "LOWER", "INNER", "OUTER",
];
const NOUNS: &[&str] = &[
"HAWK", "EAGLE", "TIGER", "LION", "WOLF", "BEAR", "FALCON", "PHOENIX",
"PEAK", "VALLEY", "RIVER", "OCEAN", "STORM", "CLOUD", "STAR", "MOON",
"TOWER", "BRIDGE", "GATE", "FORGE", "CASTLE", "HARBOR", "MEADOW", "GROVE",
];
/// Generate a human-friendly site code (e.g., "BLUE-TIGER-4829")
pub fn generate_site_code() -> String {
let mut rng = rand::thread_rng();
let adj = ADJECTIVES[rng.gen_range(0..ADJECTIVES.len())];
let noun = NOUNS[rng.gen_range(0..NOUNS.len())];
let num: u16 = rng.gen_range(1000..9999);
format!("{}-{}-{}", adj, noun, num)
}
/// Site record from database
#[derive(Debug, Clone, sqlx::FromRow)]
pub struct Site {
pub id: Uuid,
pub client_id: Uuid,
pub name: String,
pub site_code: String,
pub api_key_hash: String,
pub address: Option<String>,
pub notes: Option<String>,
pub is_active: bool,
pub created_at: chrono::DateTime<chrono::Utc>,
pub updated_at: chrono::DateTime<chrono::Utc>,
}
/// Site response for API
#[derive(Debug, Clone, Serialize)]
pub struct SiteResponse {
pub id: Uuid,
pub client_id: Uuid,
pub client_name: Option<String>,
pub name: String,
pub site_code: String,
pub address: Option<String>,
pub notes: Option<String>,
pub is_active: bool,
pub created_at: chrono::DateTime<chrono::Utc>,
pub agent_count: Option<i64>,
}
impl From<Site> for SiteResponse {
fn from(s: Site) -> Self {
SiteResponse {
id: s.id,
client_id: s.client_id,
client_name: None,
name: s.name,
site_code: s.site_code,
address: s.address,
notes: s.notes,
is_active: s.is_active,
created_at: s.created_at,
agent_count: None,
}
}
}
/// Data for creating a new site
#[derive(Debug, Deserialize)]
pub struct CreateSite {
pub client_id: Uuid,
pub name: String,
pub address: Option<String>,
pub notes: Option<String>,
}
/// Internal create with all fields
pub struct CreateSiteInternal {
pub client_id: Uuid,
pub name: String,
pub site_code: String,
pub api_key_hash: String,
pub address: Option<String>,
pub notes: Option<String>,
}
/// Data for updating a site
#[derive(Debug, Deserialize)]
pub struct UpdateSite {
pub name: Option<String>,
pub address: Option<String>,
pub notes: Option<String>,
pub is_active: Option<bool>,
}
/// Create a new site
pub async fn create_site(pool: &PgPool, site: CreateSiteInternal) -> Result<Site, sqlx::Error> {
sqlx::query_as::<_, Site>(
r#"
INSERT INTO sites (client_id, name, site_code, api_key_hash, address, notes)
VALUES ($1, $2, $3, $4, $5, $6)
RETURNING *
"#,
)
.bind(&site.client_id)
.bind(&site.name)
.bind(&site.site_code)
.bind(&site.api_key_hash)
.bind(&site.address)
.bind(&site.notes)
.fetch_one(pool)
.await
}
/// Get a site by ID
pub async fn get_site_by_id(pool: &PgPool, id: Uuid) -> Result<Option<Site>, sqlx::Error> {
sqlx::query_as::<_, Site>("SELECT * FROM sites WHERE id = $1")
.bind(id)
.fetch_optional(pool)
.await
}
/// Get a site by site code
pub async fn get_site_by_code(pool: &PgPool, site_code: &str) -> Result<Option<Site>, sqlx::Error> {
sqlx::query_as::<_, Site>("SELECT * FROM sites WHERE site_code = $1 AND is_active = true")
.bind(site_code.to_uppercase())
.fetch_optional(pool)
.await
}
/// Get a site by API key hash
pub async fn get_site_by_api_key_hash(pool: &PgPool, api_key_hash: &str) -> Result<Option<Site>, sqlx::Error> {
sqlx::query_as::<_, Site>("SELECT * FROM sites WHERE api_key_hash = $1 AND is_active = true")
.bind(api_key_hash)
.fetch_optional(pool)
.await
}
/// Get all sites for a client
pub async fn get_sites_by_client(pool: &PgPool, client_id: Uuid) -> Result<Vec<Site>, sqlx::Error> {
sqlx::query_as::<_, Site>("SELECT * FROM sites WHERE client_id = $1 ORDER BY name")
.bind(client_id)
.fetch_all(pool)
.await
}
/// Get all sites
pub async fn get_all_sites(pool: &PgPool) -> Result<Vec<Site>, sqlx::Error> {
sqlx::query_as::<_, Site>("SELECT * FROM sites ORDER BY name")
.fetch_all(pool)
.await
}
/// Site with client name and agent count
#[derive(Debug, Clone, sqlx::FromRow)]
pub struct SiteWithDetails {
pub id: Uuid,
pub client_id: Uuid,
pub name: String,
pub site_code: String,
pub api_key_hash: String,
pub address: Option<String>,
pub notes: Option<String>,
pub is_active: bool,
pub created_at: chrono::DateTime<chrono::Utc>,
pub updated_at: chrono::DateTime<chrono::Utc>,
pub client_name: String,
pub agent_count: i64,
}
pub async fn get_all_sites_with_details(pool: &PgPool) -> Result<Vec<SiteWithDetails>, sqlx::Error> {
sqlx::query_as::<_, SiteWithDetails>(
r#"
SELECT
s.*,
c.name as client_name,
COALESCE((SELECT COUNT(*) FROM agents WHERE site_id = s.id), 0) as agent_count
FROM sites s
JOIN clients c ON s.client_id = c.id
ORDER BY c.name, s.name
"#,
)
.fetch_all(pool)
.await
}
/// Update a site
pub async fn update_site(
pool: &PgPool,
id: Uuid,
update: UpdateSite,
) -> Result<Option<Site>, sqlx::Error> {
sqlx::query_as::<_, Site>(
r#"
UPDATE sites
SET name = COALESCE($1, name),
address = COALESCE($2, address),
notes = COALESCE($3, notes),
is_active = COALESCE($4, is_active)
WHERE id = $5
RETURNING *
"#,
)
.bind(&update.name)
.bind(&update.address)
.bind(&update.notes)
.bind(&update.is_active)
.bind(id)
.fetch_optional(pool)
.await
}
/// Regenerate API key for a site
pub async fn regenerate_site_api_key(
pool: &PgPool,
id: Uuid,
new_api_key_hash: &str,
) -> Result<bool, sqlx::Error> {
let result = sqlx::query("UPDATE sites SET api_key_hash = $1 WHERE id = $2")
.bind(new_api_key_hash)
.bind(id)
.execute(pool)
.await?;
Ok(result.rows_affected() > 0)
}
/// Delete a site
pub async fn delete_site(pool: &PgPool, id: Uuid) -> Result<bool, sqlx::Error> {
let result = sqlx::query("DELETE FROM sites WHERE id = $1")
.bind(id)
.execute(pool)
.await?;
Ok(result.rows_affected() > 0)
}
/// Check if a site code is unique
pub async fn is_site_code_unique(pool: &PgPool, site_code: &str) -> Result<bool, sqlx::Error> {
let result: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM sites WHERE site_code = $1")
.bind(site_code.to_uppercase())
.fetch_one(pool)
.await?;
Ok(result.0 == 0)
}
/// Generate a unique site code (tries up to 10 times)
pub async fn generate_unique_site_code(pool: &PgPool) -> Result<String, sqlx::Error> {
for _ in 0..10 {
let code = generate_site_code();
if is_site_code_unique(pool, &code).await? {
return Ok(code);
}
}
// Fallback: add random suffix
Ok(format!("{}-{}", generate_site_code(), rand::thread_rng().gen_range(100..999)))
}

View File

@@ -0,0 +1,217 @@
//! Database operations for agent updates
use anyhow::Result;
use sqlx::{PgPool, Row};
use uuid::Uuid;
/// Agent update record
#[derive(Debug, Clone, sqlx::FromRow)]
pub struct AgentUpdateRecord {
pub id: Uuid,
pub agent_id: Uuid,
pub update_id: Uuid,
pub old_version: String,
pub target_version: String,
pub download_url: Option<String>,
pub checksum_sha256: Option<String>,
pub status: Option<String>,
pub started_at: Option<chrono::DateTime<chrono::Utc>>,
pub completed_at: Option<chrono::DateTime<chrono::Utc>>,
pub error_message: Option<String>,
}
/// Create a new agent update record
pub async fn create_agent_update(
pool: &PgPool,
agent_id: Uuid,
update_id: Uuid,
old_version: &str,
target_version: &str,
download_url: &str,
checksum_sha256: &str,
) -> Result<()> {
sqlx::query(
r#"
INSERT INTO agent_updates (agent_id, update_id, old_version, target_version, download_url, checksum_sha256, status)
VALUES ($1, $2, $3, $4, $5, $6, 'pending')
"#,
)
.bind(agent_id)
.bind(update_id)
.bind(old_version)
.bind(target_version)
.bind(download_url)
.bind(checksum_sha256)
.execute(pool)
.await?;
Ok(())
}
/// Mark an agent update as completed
pub async fn complete_agent_update(
pool: &PgPool,
update_id: Uuid,
new_version: Option<&str>,
) -> Result<()> {
sqlx::query(
r#"
UPDATE agent_updates
SET status = 'completed', completed_at = NOW()
WHERE update_id = $1
"#,
)
.bind(update_id)
.execute(pool)
.await?;
// If new_version provided, update the agent's version
if let Some(version) = new_version {
sqlx::query(
r#"
UPDATE agents
SET agent_version = $2, updated_at = NOW()
WHERE id = (SELECT agent_id FROM agent_updates WHERE update_id = $1)
"#,
)
.bind(update_id)
.bind(version)
.execute(pool)
.await?;
}
Ok(())
}
/// Mark an agent update as failed
pub async fn fail_agent_update(
pool: &PgPool,
update_id: Uuid,
error_message: Option<&str>,
) -> Result<()> {
sqlx::query(
r#"
UPDATE agent_updates
SET status = 'failed', completed_at = NOW(), error_message = $2
WHERE update_id = $1
"#,
)
.bind(update_id)
.bind(error_message)
.execute(pool)
.await?;
Ok(())
}
/// Update the status of an agent update (for progress tracking)
pub async fn update_agent_update_status(
pool: &PgPool,
update_id: Uuid,
status: &str,
) -> Result<()> {
sqlx::query(
r#"
UPDATE agent_updates
SET status = $2
WHERE update_id = $1
"#,
)
.bind(update_id)
.bind(status)
.execute(pool)
.await?;
Ok(())
}
/// Get pending update for an agent (if any)
pub async fn get_pending_update(
pool: &PgPool,
agent_id: Uuid,
) -> Result<Option<AgentUpdateRecord>> {
let record = sqlx::query_as::<_, AgentUpdateRecord>(
r#"
SELECT id, agent_id, update_id, old_version, target_version,
download_url, checksum_sha256, status, started_at, completed_at, error_message
FROM agent_updates
WHERE agent_id = $1 AND status IN ('pending', 'downloading', 'installing')
ORDER BY started_at DESC
LIMIT 1
"#,
)
.bind(agent_id)
.fetch_optional(pool)
.await?;
Ok(record)
}
/// Get stale updates (started but not completed within timeout)
pub async fn get_stale_updates(
pool: &PgPool,
timeout_secs: i64,
) -> Result<Vec<AgentUpdateRecord>> {
let records = sqlx::query_as::<_, AgentUpdateRecord>(
r#"
SELECT id, agent_id, update_id, old_version, target_version,
download_url, checksum_sha256, status, started_at, completed_at, error_message
FROM agent_updates
WHERE status IN ('pending', 'downloading', 'installing')
AND started_at < NOW() - INTERVAL '1 second' * $1
"#,
)
.bind(timeout_secs as f64)
.fetch_all(pool)
.await?;
Ok(records)
}
/// Complete a pending update by matching agent reconnection
/// Called when an agent reconnects with a previous_version different from agent_version
pub async fn complete_update_by_agent(
pool: &PgPool,
agent_id: Uuid,
pending_update_id: Option<Uuid>,
old_version: &str,
new_version: &str,
) -> Result<bool> {
// First try by update_id if provided
if let Some(update_id) = pending_update_id {
let result = sqlx::query(
r#"
UPDATE agent_updates
SET status = 'completed', completed_at = NOW()
WHERE update_id = $1 AND agent_id = $2 AND status IN ('pending', 'downloading', 'installing')
"#,
)
.bind(update_id)
.bind(agent_id)
.execute(pool)
.await?;
if result.rows_affected() > 0 {
return Ok(true);
}
}
// Fall back to finding by old_version match
let result = sqlx::query(
r#"
UPDATE agent_updates
SET status = 'completed', completed_at = NOW()
WHERE agent_id = $1
AND old_version = $2
AND target_version = $3
AND status IN ('pending', 'downloading', 'installing')
"#,
)
.bind(agent_id)
.bind(old_version)
.bind(new_version)
.execute(pool)
.await?;
Ok(result.rows_affected() > 0)
}

View File

@@ -0,0 +1,177 @@
//! Users database operations
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use sqlx::PgPool;
use uuid::Uuid;
/// User record from database
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
pub struct User {
pub id: Uuid,
pub email: String,
#[serde(skip_serializing)]
pub password_hash: Option<String>,
pub name: Option<String>,
pub role: String,
pub sso_provider: Option<String>,
pub sso_id: Option<String>,
pub created_at: DateTime<Utc>,
pub last_login: Option<DateTime<Utc>>,
}
/// User response without sensitive fields
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UserResponse {
pub id: Uuid,
pub email: String,
pub name: Option<String>,
pub role: String,
pub sso_provider: Option<String>,
pub created_at: DateTime<Utc>,
pub last_login: Option<DateTime<Utc>>,
}
impl From<User> for UserResponse {
fn from(user: User) -> Self {
Self {
id: user.id,
email: user.email,
name: user.name,
role: user.role,
sso_provider: user.sso_provider,
created_at: user.created_at,
last_login: user.last_login,
}
}
}
/// Create a new user (local auth)
#[derive(Debug, Clone, Deserialize)]
pub struct CreateUser {
pub email: String,
pub password_hash: String,
pub name: Option<String>,
pub role: Option<String>,
}
/// Create a new local user
pub async fn create_user(pool: &PgPool, user: CreateUser) -> Result<User, sqlx::Error> {
sqlx::query_as::<_, User>(
r#"
INSERT INTO users (email, password_hash, name, role)
VALUES ($1, $2, $3, COALESCE($4, 'user'))
RETURNING *
"#,
)
.bind(&user.email)
.bind(&user.password_hash)
.bind(&user.name)
.bind(&user.role)
.fetch_one(pool)
.await
}
/// Create or update SSO user
#[derive(Debug, Clone, Deserialize)]
pub struct UpsertSsoUser {
pub email: String,
pub name: Option<String>,
pub sso_provider: String,
pub sso_id: String,
}
pub async fn upsert_sso_user(pool: &PgPool, user: UpsertSsoUser) -> Result<User, sqlx::Error> {
sqlx::query_as::<_, User>(
r#"
INSERT INTO users (email, name, sso_provider, sso_id, role)
VALUES ($1, $2, $3, $4, 'user')
ON CONFLICT (email)
DO UPDATE SET
name = COALESCE(EXCLUDED.name, users.name),
sso_provider = EXCLUDED.sso_provider,
sso_id = EXCLUDED.sso_id,
last_login = NOW()
RETURNING *
"#,
)
.bind(&user.email)
.bind(&user.name)
.bind(&user.sso_provider)
.bind(&user.sso_id)
.fetch_one(pool)
.await
}
/// Get user by ID
pub async fn get_user_by_id(pool: &PgPool, id: Uuid) -> Result<Option<User>, sqlx::Error> {
sqlx::query_as::<_, User>("SELECT * FROM users WHERE id = $1")
.bind(id)
.fetch_optional(pool)
.await
}
/// Get user by email
pub async fn get_user_by_email(pool: &PgPool, email: &str) -> Result<Option<User>, sqlx::Error> {
sqlx::query_as::<_, User>("SELECT * FROM users WHERE email = $1")
.bind(email)
.fetch_optional(pool)
.await
}
/// Get all users
pub async fn get_all_users(pool: &PgPool) -> Result<Vec<User>, sqlx::Error> {
sqlx::query_as::<_, User>("SELECT * FROM users ORDER BY email")
.fetch_all(pool)
.await
}
/// Update last login timestamp
pub async fn update_last_login(pool: &PgPool, id: Uuid) -> Result<(), sqlx::Error> {
sqlx::query("UPDATE users SET last_login = NOW() WHERE id = $1")
.bind(id)
.execute(pool)
.await?;
Ok(())
}
/// Update user role
pub async fn update_user_role(pool: &PgPool, id: Uuid, role: &str) -> Result<(), sqlx::Error> {
sqlx::query("UPDATE users SET role = $1 WHERE id = $2")
.bind(role)
.bind(id)
.execute(pool)
.await?;
Ok(())
}
/// Update user password
pub async fn update_user_password(
pool: &PgPool,
id: Uuid,
password_hash: &str,
) -> Result<(), sqlx::Error> {
sqlx::query("UPDATE users SET password_hash = $1 WHERE id = $2")
.bind(password_hash)
.bind(id)
.execute(pool)
.await?;
Ok(())
}
/// Delete a user
pub async fn delete_user(pool: &PgPool, id: Uuid) -> Result<bool, sqlx::Error> {
let result = sqlx::query("DELETE FROM users WHERE id = $1")
.bind(id)
.execute(pool)
.await?;
Ok(result.rows_affected() > 0)
}
/// Check if any users exist (for initial setup)
pub async fn has_users(pool: &PgPool) -> Result<bool, sqlx::Error> {
let result: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM users")
.fetch_one(pool)
.await?;
Ok(result.0 > 0)
}

View File

@@ -0,0 +1,155 @@
//! GuruRMM Server - RMM Management Server
//!
//! Provides the backend API and WebSocket endpoint for GuruRMM agents.
//! Features:
//! - Agent registration and management
//! - Real-time WebSocket communication with agents
//! - Metrics storage and retrieval
//! - Command execution via agents
//! - Dashboard authentication
mod api;
mod auth;
mod config;
mod db;
mod updates;
mod ws;
use std::sync::Arc;
use anyhow::Result;
use axum::{
routing::{get, post},
Router,
};
use sqlx::postgres::PgPoolOptions;
use tokio::sync::RwLock;
use tower_http::cors::{Any, CorsLayer};
use tower_http::trace::TraceLayer;
use tracing::info;
use crate::config::ServerConfig;
use crate::updates::UpdateManager;
use crate::ws::AgentConnections;
/// Shared application state
#[derive(Clone)]
pub struct AppState {
/// Database connection pool
pub db: sqlx::PgPool,
/// Server configuration
pub config: Arc<ServerConfig>,
/// Connected agents (WebSocket connections)
pub agents: Arc<RwLock<AgentConnections>>,
/// Agent update manager
pub updates: Arc<UpdateManager>,
}
#[tokio::main]
async fn main() -> Result<()> {
// Load environment variables from .env file
dotenvy::dotenv().ok();
// Initialize logging
tracing_subscriber::fmt()
.with_env_filter(
tracing_subscriber::EnvFilter::from_default_env()
.add_directive("gururmm_server=info".parse()?)
.add_directive("tower_http=debug".parse()?)
.add_directive("info".parse()?),
)
.init();
info!("GuruRMM Server starting...");
// Load configuration
let config = ServerConfig::from_env()?;
info!("Server configuration loaded");
// Connect to database
info!("Connecting to database...");
let db = PgPoolOptions::new()
.max_connections(config.database.max_connections)
.connect(&config.database.url)
.await?;
info!("Database connected");
// Run migrations
info!("Running database migrations...");
sqlx::migrate!("./migrations").run(&db).await?;
info!("Migrations complete");
// Initialize update manager
let update_manager = UpdateManager::new(
config.updates.downloads_dir.clone(),
config.updates.downloads_base_url.clone(),
config.updates.auto_update_enabled,
config.updates.update_timeout_secs,
);
// Initial scan for available versions
info!("Scanning for available agent versions...");
if let Err(e) = update_manager.scan_versions().await {
tracing::warn!("Failed to scan agent versions: {} (continuing without auto-update)", e);
}
let update_manager = Arc::new(update_manager);
// Spawn background scanner (handle is intentionally not awaited - runs until server shutdown)
let _scanner_handle = update_manager.spawn_scanner(config.updates.scan_interval_secs);
info!(
"Auto-update: {} (scan interval: {}s)",
if config.updates.auto_update_enabled { "enabled" } else { "disabled" },
config.updates.scan_interval_secs
);
// Create shared state
let state = AppState {
db,
config: Arc::new(config.clone()),
agents: Arc::new(RwLock::new(AgentConnections::new())),
updates: update_manager,
};
// Build router
let app = build_router(state);
// Start server
let addr = format!("{}:{}", config.server.host, config.server.port);
info!("Starting server on {}", addr);
let listener = tokio::net::TcpListener::bind(&addr).await?;
axum::serve(listener, app).await?;
Ok(())
}
/// Build the application router
fn build_router(state: AppState) -> Router {
// CORS configuration (allow dashboard access)
let cors = CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any);
Router::new()
// Health check
.route("/health", get(health_check))
// WebSocket endpoint for agents
.route("/ws", get(ws::ws_handler))
// API routes
.nest("/api", api::routes())
// Middleware
.layer(TraceLayer::new_for_http())
.layer(cors)
// State
.with_state(state)
}
/// Health check endpoint
async fn health_check() -> &'static str {
"OK"
}

View File

@@ -0,0 +1,10 @@
//! Agent update management
//!
//! Handles:
//! - Scanning downloads directory for available agent versions
//! - Version comparison to determine if agents need updates
//! - Update tracking and timeout monitoring
mod scanner;
pub use scanner::{UpdateManager, AvailableVersion};

View File

@@ -0,0 +1,311 @@
//! Version scanner for available agent binaries
//!
//! Scans a downloads directory for agent binaries and parses version info
//! from filenames in the format: gururmm-agent-{os}-{arch}-{version}[.exe]
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use anyhow::Result;
use semver::Version;
use tokio::sync::RwLock;
use tracing::{debug, error, info, warn};
/// Information about an available agent version
#[derive(Debug, Clone)]
pub struct AvailableVersion {
/// Semantic version
pub version: Version,
/// Operating system (linux, windows)
pub os: String,
/// Architecture (amd64, arm64)
pub arch: String,
/// Filename on disk
pub filename: String,
/// Full download URL
pub download_url: String,
/// SHA256 checksum
pub checksum_sha256: String,
/// File size in bytes
pub file_size: u64,
}
/// Manages available agent versions
pub struct UpdateManager {
/// Directory containing agent binaries
downloads_dir: PathBuf,
/// Base URL for downloads
base_url: String,
/// Cached available versions, keyed by "os-arch"
versions: Arc<RwLock<HashMap<String, Vec<AvailableVersion>>>>,
/// Whether auto-updates are enabled
pub auto_update_enabled: bool,
/// Update timeout in seconds
pub update_timeout_secs: u64,
}
impl UpdateManager {
/// Create a new UpdateManager
pub fn new(
downloads_dir: PathBuf,
base_url: String,
auto_update_enabled: bool,
update_timeout_secs: u64,
) -> Self {
Self {
downloads_dir,
base_url,
versions: Arc::new(RwLock::new(HashMap::new())),
auto_update_enabled,
update_timeout_secs,
}
}
/// Scan the downloads directory for available agent binaries
pub async fn scan_versions(&self) -> Result<()> {
let mut versions: HashMap<String, Vec<AvailableVersion>> = HashMap::new();
if !self.downloads_dir.exists() {
warn!("Downloads directory does not exist: {:?}", self.downloads_dir);
return Ok(());
}
let entries = std::fs::read_dir(&self.downloads_dir)?;
for entry in entries.flatten() {
let path = entry.path();
if !path.is_file() {
continue;
}
let filename = match path.file_name().and_then(|n| n.to_str()) {
Some(name) => name.to_string(),
None => continue,
};
// Skip checksum files
if filename.ends_with(".sha256") {
continue;
}
// Try to parse as agent binary
if let Some((os, arch, version)) = Self::parse_filename(&filename) {
// Read checksum from companion file
let checksum = self.read_checksum(&path).await.unwrap_or_default();
if checksum.is_empty() {
warn!("No checksum found for {}, skipping", filename);
continue;
}
// Get file size
let file_size = std::fs::metadata(&path).map(|m| m.len()).unwrap_or(0);
let download_url = format!("{}/{}", self.base_url.trim_end_matches('/'), filename);
let available = AvailableVersion {
version,
os: os.clone(),
arch: arch.clone(),
filename: filename.clone(),
download_url,
checksum_sha256: checksum,
file_size,
};
let key = format!("{}-{}", os, arch);
debug!("Found agent binary: {} (v{})", filename, available.version);
versions.entry(key).or_default().push(available);
}
}
// Sort each list by version descending (newest first)
for list in versions.values_mut() {
list.sort_by(|a, b| b.version.cmp(&a.version));
}
let total: usize = versions.values().map(|v| v.len()).sum();
info!("Scanned {} agent binaries across {} platform/arch combinations", total, versions.len());
*self.versions.write().await = versions;
Ok(())
}
/// Parse a filename to extract OS, architecture, and version
///
/// Expected format: gururmm-agent-{os}-{arch}-{version}[.exe]
/// Examples:
/// - gururmm-agent-linux-amd64-0.2.0
/// - gururmm-agent-windows-amd64-0.2.0.exe
fn parse_filename(filename: &str) -> Option<(String, String, Version)> {
// Remove .exe extension if present
let name = filename.strip_suffix(".exe").unwrap_or(filename);
// Split by dashes
let parts: Vec<&str> = name.split('-').collect();
// Expected: ["gururmm", "agent", "linux", "amd64", "0", "2", "0"]
// or: ["gururmm", "agent", "linux", "amd64", "0.2.0"]
if parts.len() < 5 || parts[0] != "gururmm" || parts[1] != "agent" {
return None;
}
let os = parts[2].to_string();
let arch = parts[3].to_string();
// Version could be either:
// - A single part with dots: "0.2.0"
// - Multiple parts joined: "0", "2", "0"
let version_str = if parts.len() == 5 {
// Single part with dots
parts[4].to_string()
} else {
// Multiple parts, join with dots
parts[4..].join(".")
};
let version = Version::parse(&version_str).ok()?;
Some((os, arch, version))
}
/// Read checksum from companion .sha256 file
async fn read_checksum(&self, binary_path: &Path) -> Result<String> {
let checksum_path = PathBuf::from(format!("{}.sha256", binary_path.display()));
if !checksum_path.exists() {
return Err(anyhow::anyhow!("Checksum file not found"));
}
let content = tokio::fs::read_to_string(&checksum_path).await?;
// Checksum file format: "<hash> <filename>" or just "<hash>"
let checksum = content
.split_whitespace()
.next()
.ok_or_else(|| anyhow::anyhow!("Empty checksum file"))?
.to_lowercase();
// Validate it looks like a SHA256 hash (64 hex chars)
if checksum.len() != 64 || !checksum.chars().all(|c| c.is_ascii_hexdigit()) {
return Err(anyhow::anyhow!("Invalid checksum format"));
}
Ok(checksum)
}
/// Get the latest version available for a given OS/arch
pub async fn get_latest_version(&self, os: &str, arch: &str) -> Option<AvailableVersion> {
let versions = self.versions.read().await;
let key = format!("{}-{}", os, arch);
versions.get(&key).and_then(|list| list.first().cloned())
}
/// Check if an agent with the given version needs an update
/// Returns the available update if one exists
pub async fn needs_update(
&self,
current_version: &str,
os: &str,
arch: &str,
) -> Option<AvailableVersion> {
if !self.auto_update_enabled {
return None;
}
let current = match Version::parse(current_version) {
Ok(v) => v,
Err(e) => {
warn!("Failed to parse current version '{}': {}", current_version, e);
return None;
}
};
let latest = self.get_latest_version(os, arch).await?;
if latest.version > current {
info!(
"Agent needs update: {} -> {} ({}-{})",
current, latest.version, os, arch
);
Some(latest)
} else {
None
}
}
/// Get all available versions (for dashboard display)
pub async fn get_all_versions(&self) -> HashMap<String, Vec<AvailableVersion>> {
self.versions.read().await.clone()
}
/// Spawn a background task to periodically rescan versions
pub fn spawn_scanner(&self, interval_secs: u64) -> tokio::task::JoinHandle<()> {
let downloads_dir = self.downloads_dir.clone();
let base_url = self.base_url.clone();
let versions = self.versions.clone();
let auto_update_enabled = self.auto_update_enabled;
let update_timeout_secs = self.update_timeout_secs;
tokio::spawn(async move {
let manager = UpdateManager {
downloads_dir,
base_url,
versions,
auto_update_enabled,
update_timeout_secs,
};
let mut interval = tokio::time::interval(std::time::Duration::from_secs(interval_secs));
loop {
interval.tick().await;
if let Err(e) = manager.scan_versions().await {
error!("Failed to scan versions: {}", e);
}
}
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_filename_linux() {
let result = UpdateManager::parse_filename("gururmm-agent-linux-amd64-0.2.0");
assert!(result.is_some());
let (os, arch, version) = result.unwrap();
assert_eq!(os, "linux");
assert_eq!(arch, "amd64");
assert_eq!(version, Version::new(0, 2, 0));
}
#[test]
fn test_parse_filename_windows() {
let result = UpdateManager::parse_filename("gururmm-agent-windows-amd64-0.2.0.exe");
assert!(result.is_some());
let (os, arch, version) = result.unwrap();
assert_eq!(os, "windows");
assert_eq!(arch, "amd64");
assert_eq!(version, Version::new(0, 2, 0));
}
#[test]
fn test_parse_filename_arm64() {
let result = UpdateManager::parse_filename("gururmm-agent-linux-arm64-1.0.0");
assert!(result.is_some());
let (os, arch, version) = result.unwrap();
assert_eq!(os, "linux");
assert_eq!(arch, "arm64");
assert_eq!(version, Version::new(1, 0, 0));
}
#[test]
fn test_parse_filename_invalid() {
assert!(UpdateManager::parse_filename("random-file.txt").is_none());
assert!(UpdateManager::parse_filename("gururmm-server-linux-amd64-0.1.0").is_none());
assert!(UpdateManager::parse_filename("gururmm-agent-linux").is_none());
}
}

View File

@@ -0,0 +1,705 @@
//! WebSocket handler for agent connections
//!
//! Handles real-time communication with agents including:
//! - Authentication handshake
//! - Metrics ingestion
//! - Command dispatching
//! - Watchdog event handling
use std::collections::HashMap;
use std::sync::Arc;
use axum::{
extract::{
ws::{Message, WebSocket},
State, WebSocketUpgrade,
},
response::Response,
};
use futures_util::{SinkExt, StreamExt};
use serde::{Deserialize, Serialize};
use tokio::sync::{mpsc, RwLock};
use tracing::{debug, error, info, warn};
use uuid::Uuid;
use crate::db;
use crate::AppState;
/// Connected agents manager
pub struct AgentConnections {
/// Map of agent ID to sender channel
connections: HashMap<Uuid, mpsc::Sender<ServerMessage>>,
}
impl AgentConnections {
pub fn new() -> Self {
Self {
connections: HashMap::new(),
}
}
/// Add a new agent connection
pub fn add(&mut self, agent_id: Uuid, tx: mpsc::Sender<ServerMessage>) {
self.connections.insert(agent_id, tx);
}
/// Remove an agent connection
pub fn remove(&mut self, agent_id: &Uuid) {
self.connections.remove(agent_id);
}
/// Send a message to a specific agent
pub async fn send_to(&self, agent_id: &Uuid, msg: ServerMessage) -> bool {
if let Some(tx) = self.connections.get(agent_id) {
tx.send(msg).await.is_ok()
} else {
false
}
}
/// Check if an agent is connected
pub fn is_connected(&self, agent_id: &Uuid) -> bool {
self.connections.contains_key(agent_id)
}
/// Get count of connected agents
pub fn count(&self) -> usize {
self.connections.len()
}
}
impl Default for AgentConnections {
fn default() -> Self {
Self::new()
}
}
/// Messages from agent to server
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", content = "payload")]
#[serde(rename_all = "snake_case")]
pub enum AgentMessage {
Auth(AuthPayload),
Metrics(MetricsPayload),
NetworkState(NetworkStatePayload),
CommandResult(CommandResultPayload),
WatchdogEvent(WatchdogEventPayload),
UpdateResult(UpdateResultPayload),
Heartbeat,
}
/// Messages from server to agent
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", content = "payload")]
#[serde(rename_all = "snake_case")]
pub enum ServerMessage {
AuthAck(AuthAckPayload),
Command(CommandPayload),
ConfigUpdate(serde_json::Value),
Update(UpdatePayload),
Ack { message_id: Option<String> },
Error { code: String, message: String },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthPayload {
pub api_key: String,
/// Hardware-derived device ID (for site-based registration)
#[serde(default)]
pub device_id: Option<String>,
pub hostname: String,
pub os_type: String,
pub os_version: String,
pub agent_version: String,
/// Architecture (amd64, arm64, etc.)
#[serde(default = "default_arch")]
pub architecture: String,
/// Previous version if reconnecting after update
#[serde(default)]
pub previous_version: Option<String>,
/// Update ID if reconnecting after update
#[serde(default)]
pub pending_update_id: Option<Uuid>,
}
fn default_arch() -> String {
"amd64".to_string()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthAckPayload {
pub success: bool,
pub agent_id: Option<Uuid>,
pub error: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MetricsPayload {
pub timestamp: chrono::DateTime<chrono::Utc>,
pub cpu_percent: f32,
pub memory_percent: f32,
pub memory_used_bytes: u64,
pub memory_total_bytes: u64,
pub disk_percent: f32,
pub disk_used_bytes: u64,
pub disk_total_bytes: u64,
pub network_rx_bytes: u64,
pub network_tx_bytes: u64,
pub os_type: String,
pub os_version: String,
pub hostname: String,
// Extended metrics (optional for backwards compatibility)
#[serde(default)]
pub uptime_seconds: Option<u64>,
#[serde(default)]
pub boot_time: Option<i64>,
#[serde(default)]
pub logged_in_user: Option<String>,
#[serde(default)]
pub user_idle_seconds: Option<u64>,
#[serde(default)]
pub public_ip: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CommandResultPayload {
pub command_id: Uuid,
pub exit_code: i32,
pub stdout: String,
pub stderr: String,
pub duration_ms: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WatchdogEventPayload {
pub name: String,
pub event: String,
pub details: Option<String>,
}
/// Network interface information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NetworkInterface {
pub name: String,
pub mac_address: Option<String>,
pub ipv4_addresses: Vec<String>,
pub ipv6_addresses: Vec<String>,
}
/// Network state payload from agent
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NetworkStatePayload {
pub timestamp: chrono::DateTime<chrono::Utc>,
pub interfaces: Vec<NetworkInterface>,
pub state_hash: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CommandPayload {
pub id: Uuid,
pub command_type: String,
pub command: String,
pub timeout_seconds: Option<u64>,
pub elevated: bool,
}
/// Update command payload from server to agent
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UpdatePayload {
/// Unique update ID for tracking
pub update_id: Uuid,
/// Target version to update to
pub target_version: String,
/// Download URL for the new binary
pub download_url: String,
/// SHA256 checksum of the binary
pub checksum_sha256: String,
/// Whether to force update (skip version check)
#[serde(default)]
pub force: bool,
}
/// Update result payload from agent to server
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UpdateResultPayload {
/// Update ID (from the server)
pub update_id: Uuid,
/// Update status
pub status: String,
/// Old version before update
pub old_version: String,
/// New version after update (if successful)
pub new_version: Option<String>,
/// Error message if failed
pub error: Option<String>,
}
/// Result of successful agent authentication
struct AuthResult {
agent_id: Uuid,
agent_version: String,
os_type: String,
architecture: String,
previous_version: Option<String>,
pending_update_id: Option<Uuid>,
}
/// WebSocket upgrade handler
pub async fn ws_handler(ws: WebSocketUpgrade, State(state): State<AppState>) -> Response {
ws.on_upgrade(|socket| handle_socket(socket, state))
}
/// Handle a WebSocket connection
async fn handle_socket(socket: WebSocket, state: AppState) {
let (mut sender, mut receiver) = socket.split();
// Create channel for outgoing messages
let (tx, mut rx) = mpsc::channel::<ServerMessage>(100);
// Wait for authentication message
let auth_result = match authenticate(&mut receiver, &mut sender, &state).await {
Ok(result) => {
info!("Agent authenticated: {}", result.agent_id);
// Send auth success
let ack = ServerMessage::AuthAck(AuthAckPayload {
success: true,
agent_id: Some(result.agent_id),
error: None,
});
if let Ok(json) = serde_json::to_string(&ack) {
let _ = sender.send(Message::Text(json)).await;
}
// Register connection
state.agents.write().await.add(result.agent_id, tx.clone());
// Update agent status
let _ = db::update_agent_status(&state.db, result.agent_id, "online").await;
// Check if this is a post-update reconnection
if let Some(prev_version) = &result.previous_version {
if prev_version != &result.agent_version {
info!(
"Agent {} reconnected after update: {} -> {}",
result.agent_id, prev_version, result.agent_version
);
// Mark update as completed
let _ = db::complete_update_by_agent(
&state.db,
result.agent_id,
result.pending_update_id,
prev_version,
&result.agent_version,
).await;
}
}
// Check if agent needs update (auto-update enabled)
if let Some(available) = state.updates.needs_update(
&result.agent_version,
&result.os_type,
&result.architecture,
).await {
info!(
"Agent {} needs update: {} -> {}",
result.agent_id, result.agent_version, available.version
);
let update_id = Uuid::new_v4();
// Record update in database
if let Err(e) = db::create_agent_update(
&state.db,
result.agent_id,
update_id,
&result.agent_version,
&available.version.to_string(),
&available.download_url,
&available.checksum_sha256,
).await {
error!("Failed to record update: {}", e);
} else {
// Send update command
let update_msg = ServerMessage::Update(UpdatePayload {
update_id,
target_version: available.version.to_string(),
download_url: available.download_url.clone(),
checksum_sha256: available.checksum_sha256.clone(),
force: false,
});
if let Err(e) = tx.send(update_msg).await {
error!("Failed to send update command: {}", e);
}
}
}
result
}
Err(e) => {
error!("Authentication failed: {}", e);
let ack = ServerMessage::AuthAck(AuthAckPayload {
success: false,
agent_id: None,
error: Some(e.to_string()),
});
if let Ok(json) = serde_json::to_string(&ack) {
let _ = sender.send(Message::Text(json)).await;
}
return;
}
};
let agent_id = auth_result.agent_id;
// Spawn task to forward outgoing messages
let send_task = tokio::spawn(async move {
while let Some(msg) = rx.recv().await {
if let Ok(json) = serde_json::to_string(&msg) {
if sender.send(Message::Text(json)).await.is_err() {
break;
}
}
}
});
// Handle incoming messages
while let Some(msg_result) = receiver.next().await {
match msg_result {
Ok(Message::Text(text)) => {
if let Err(e) = handle_agent_message(&text, agent_id, &state).await {
error!("Error handling agent message: {}", e);
}
}
Ok(Message::Ping(data)) => {
if tx
.send(ServerMessage::Ack { message_id: None })
.await
.is_err()
{
break;
}
}
Ok(Message::Close(_)) => {
info!("Agent {} disconnected", agent_id);
break;
}
Err(e) => {
error!("WebSocket error for agent {}: {}", agent_id, e);
break;
}
_ => {}
}
}
// Cleanup
state.agents.write().await.remove(&agent_id);
let _ = db::update_agent_status(&state.db, agent_id, "offline").await;
send_task.abort();
info!("Agent {} connection closed", agent_id);
}
/// Authenticate an agent connection
///
/// Supports two modes:
/// 1. Legacy: API key maps directly to an agent (api_key_hash in agents table)
/// 2. Site-based: API key maps to a site, device_id identifies the specific agent
async fn authenticate(
receiver: &mut futures_util::stream::SplitStream<WebSocket>,
sender: &mut futures_util::stream::SplitSink<WebSocket, Message>,
state: &AppState,
) -> anyhow::Result<AuthResult> {
use tokio::time::{timeout, Duration};
// Wait for auth message with timeout
let msg = timeout(Duration::from_secs(10), receiver.next())
.await
.map_err(|_| anyhow::anyhow!("Authentication timeout"))?
.ok_or_else(|| anyhow::anyhow!("Connection closed before auth"))?
.map_err(|e| anyhow::anyhow!("WebSocket error: {}", e))?;
let text = match msg {
Message::Text(t) => t,
_ => return Err(anyhow::anyhow!("Expected text message for auth")),
};
let agent_msg: AgentMessage =
serde_json::from_str(&text).map_err(|e| anyhow::anyhow!("Invalid auth message: {}", e))?;
let auth = match agent_msg {
AgentMessage::Auth(a) => a,
_ => return Err(anyhow::anyhow!("Expected auth message")),
};
// Try site-based authentication first (if device_id is provided)
if let Some(device_id) = &auth.device_id {
// Check if api_key looks like a site code (WORD-WORD-NUMBER format)
let site = if is_site_code_format(&auth.api_key) {
info!("Attempting site code authentication: {}", auth.api_key);
db::get_site_by_code(&state.db, &auth.api_key)
.await
.map_err(|e| anyhow::anyhow!("Database error: {}", e))?
} else {
// Hash the API key and look up by hash
let api_key_hash = hash_api_key(&auth.api_key);
db::get_site_by_api_key_hash(&state.db, &api_key_hash)
.await
.map_err(|e| anyhow::anyhow!("Database error: {}", e))?
};
if let Some(site) = site {
info!("Site-based auth: site={} ({})", site.name, site.id);
// Look up or create agent by site_id + device_id
let agent = match db::get_agent_by_site_and_device(&state.db, site.id, device_id)
.await
.map_err(|e| anyhow::anyhow!("Database error: {}", e))?
{
Some(agent) => {
// Update existing agent info
let _ = db::update_agent_info_full(
&state.db,
agent.id,
Some(&auth.hostname),
Some(device_id),
Some(&auth.os_version),
Some(&auth.agent_version),
)
.await;
agent
}
None => {
// Auto-register new agent under this site
info!(
"Auto-registering new agent: hostname={}, device_id={}, site={}",
auth.hostname, device_id, site.name
);
db::create_agent_with_site(
&state.db,
db::CreateAgentWithSite {
site_id: site.id,
device_id: device_id.clone(),
hostname: auth.hostname.clone(),
os_type: auth.os_type.clone(),
os_version: Some(auth.os_version.clone()),
agent_version: Some(auth.agent_version.clone()),
},
)
.await
.map_err(|e| anyhow::anyhow!("Failed to create agent: {}", e))?
}
};
return Ok(AuthResult {
agent_id: agent.id,
agent_version: auth.agent_version.clone(),
os_type: auth.os_type.clone(),
architecture: auth.architecture.clone(),
previous_version: auth.previous_version.clone(),
pending_update_id: auth.pending_update_id,
});
}
}
// Fall back to legacy: look up agent directly by API key hash
let api_key_hash = hash_api_key(&auth.api_key);
let agent = db::get_agent_by_api_key_hash(&state.db, &api_key_hash)
.await
.map_err(|e| anyhow::anyhow!("Database error: {}", e))?
.ok_or_else(|| anyhow::anyhow!("Invalid API key"))?;
// Update agent info (including hostname in case it changed)
let _ = db::update_agent_info(
&state.db,
agent.id,
Some(&auth.hostname),
Some(&auth.os_version),
Some(&auth.agent_version),
)
.await;
Ok(AuthResult {
agent_id: agent.id,
agent_version: auth.agent_version,
os_type: auth.os_type,
architecture: auth.architecture,
previous_version: auth.previous_version,
pending_update_id: auth.pending_update_id,
})
}
/// Handle a message from an authenticated agent
async fn handle_agent_message(
text: &str,
agent_id: Uuid,
state: &AppState,
) -> anyhow::Result<()> {
let msg: AgentMessage = serde_json::from_str(text)?;
match msg {
AgentMessage::Metrics(metrics) => {
debug!("Received metrics from agent {}: CPU={:.1}%", agent_id, metrics.cpu_percent);
// Store metrics in database
let create_metrics = db::CreateMetrics {
agent_id,
cpu_percent: Some(metrics.cpu_percent),
memory_percent: Some(metrics.memory_percent),
memory_used_bytes: Some(metrics.memory_used_bytes as i64),
disk_percent: Some(metrics.disk_percent),
disk_used_bytes: Some(metrics.disk_used_bytes as i64),
network_rx_bytes: Some(metrics.network_rx_bytes as i64),
network_tx_bytes: Some(metrics.network_tx_bytes as i64),
// Extended metrics
uptime_seconds: metrics.uptime_seconds.map(|v| v as i64),
boot_time: metrics.boot_time,
logged_in_user: metrics.logged_in_user.clone(),
user_idle_seconds: metrics.user_idle_seconds.map(|v| v as i64),
public_ip: metrics.public_ip.clone(),
memory_total_bytes: Some(metrics.memory_total_bytes as i64),
disk_total_bytes: Some(metrics.disk_total_bytes as i64),
};
db::insert_metrics(&state.db, create_metrics).await?;
// Also update agent_state for quick access to latest extended info
let _ = db::upsert_agent_state(
&state.db,
agent_id,
metrics.uptime_seconds.map(|v| v as i64),
metrics.boot_time,
metrics.logged_in_user.as_deref(),
metrics.user_idle_seconds.map(|v| v as i64),
metrics.public_ip.as_deref(),
).await;
// Update last_seen
db::update_agent_status(&state.db, agent_id, "online").await?;
}
AgentMessage::CommandResult(result) => {
info!(
"Received command result from agent {}: command={}, exit={}",
agent_id, result.command_id, result.exit_code
);
// Update command in database
let cmd_result = db::CommandResult {
exit_code: result.exit_code,
stdout: result.stdout,
stderr: result.stderr,
};
db::update_command_result(&state.db, result.command_id, cmd_result).await?;
}
AgentMessage::WatchdogEvent(event) => {
info!(
"Received watchdog event from agent {}: {} - {}",
agent_id, event.name, event.event
);
// Store watchdog event (table exists but we'll add the insert function later)
// For now, just log it
}
AgentMessage::NetworkState(network_state) => {
debug!(
"Received network state from agent {}: {} interfaces",
agent_id,
network_state.interfaces.len()
);
// Log interface details at trace level
for iface in &network_state.interfaces {
tracing::trace!(
" Interface {}: IPv4={:?}, IPv6={:?}",
iface.name,
iface.ipv4_addresses,
iface.ipv6_addresses
);
}
// Store network state in database
if let Ok(interfaces_json) = serde_json::to_value(&network_state.interfaces) {
let _ = db::update_agent_network_state(
&state.db,
agent_id,
&interfaces_json,
&network_state.state_hash,
).await;
}
// Update last_seen
db::update_agent_status(&state.db, agent_id, "online").await?;
}
AgentMessage::Heartbeat => {
debug!("Received heartbeat from agent {}", agent_id);
db::update_agent_status(&state.db, agent_id, "online").await?;
}
AgentMessage::UpdateResult(result) => {
info!(
"Received update result from agent {}: update_id={}, status={}",
agent_id, result.update_id, result.status
);
// Update the agent_updates record
match result.status.as_str() {
"completed" => {
let _ = db::complete_agent_update(&state.db, result.update_id, result.new_version.as_deref()).await;
info!("Agent {} successfully updated to {}", agent_id, result.new_version.unwrap_or_default());
}
"failed" | "rolled_back" => {
let _ = db::fail_agent_update(&state.db, result.update_id, result.error.as_deref()).await;
warn!("Agent {} update failed: {}", agent_id, result.error.unwrap_or_default());
}
_ => {
debug!("Agent {} update status: {}", agent_id, result.status);
}
}
}
AgentMessage::Auth(_) => {
warn!("Received unexpected auth message from already authenticated agent");
}
}
Ok(())
}
/// Hash an API key for storage/lookup
pub fn hash_api_key(api_key: &str) -> String {
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(api_key.as_bytes());
format!("{:x}", hasher.finalize())
}
/// Generate a new API key
pub fn generate_api_key(prefix: &str) -> String {
use rand::Rng;
let random_bytes: [u8; 24] = rand::thread_rng().gen();
let encoded = base64::Engine::encode(&base64::engine::general_purpose::URL_SAFE_NO_PAD, random_bytes);
format!("{}{}", prefix, encoded)
}
/// Check if a string looks like a site code (WORD-WORD-NUMBER format)
/// Examples: SWIFT-CLOUD-6910, APPLE-GREEN-9145
fn is_site_code_format(s: &str) -> bool {
let parts: Vec<&str> = s.split('-').collect();
if parts.len() != 3 {
return false;
}
// First two parts should be alphabetic (words)
// Third part should be numeric (4 digits)
parts[0].chars().all(|c| c.is_ascii_alphabetic())
&& parts[1].chars().all(|c| c.is_ascii_alphabetic())
&& parts[2].chars().all(|c| c.is_ascii_digit())
&& parts[2].len() == 4
}