Add PostgreSQL database persistence

- Add connect_machines, connect_sessions, connect_session_events, connect_support_codes tables
- Implement db module with connection pooling (sqlx)
- Add machine persistence across server restarts
- Add audit logging for session/viewer events
- Support codes now persisted to database

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2025-12-28 19:51:01 -07:00
parent 448d3b75ac
commit f6bf0cfd26
10 changed files with 788 additions and 36 deletions

View File

@@ -0,0 +1,88 @@
-- GuruConnect Initial Schema
-- Machine persistence, session audit logging, and support codes
-- Enable UUID generation
CREATE EXTENSION IF NOT EXISTS "pgcrypto";
-- Machines table - persistent agent records that survive server restarts
CREATE TABLE connect_machines (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
agent_id VARCHAR(255) UNIQUE NOT NULL,
hostname VARCHAR(255) NOT NULL,
os_version VARCHAR(255),
is_elevated BOOLEAN DEFAULT FALSE,
is_persistent BOOLEAN DEFAULT TRUE,
first_seen TIMESTAMPTZ DEFAULT NOW(),
last_seen TIMESTAMPTZ DEFAULT NOW(),
last_session_id UUID,
status VARCHAR(20) DEFAULT 'offline',
created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW()
);
CREATE INDEX idx_connect_machines_agent_id ON connect_machines(agent_id);
CREATE INDEX idx_connect_machines_status ON connect_machines(status);
-- Sessions table - connection history
CREATE TABLE connect_sessions (
id UUID PRIMARY KEY,
machine_id UUID REFERENCES connect_machines(id) ON DELETE CASCADE,
started_at TIMESTAMPTZ DEFAULT NOW(),
ended_at TIMESTAMPTZ,
duration_secs INTEGER,
is_support_session BOOLEAN DEFAULT FALSE,
support_code VARCHAR(10),
status VARCHAR(20) DEFAULT 'active'
);
CREATE INDEX idx_connect_sessions_machine ON connect_sessions(machine_id);
CREATE INDEX idx_connect_sessions_started ON connect_sessions(started_at DESC);
CREATE INDEX idx_connect_sessions_support_code ON connect_sessions(support_code);
-- Session events - comprehensive audit log
CREATE TABLE connect_session_events (
id BIGSERIAL PRIMARY KEY,
session_id UUID REFERENCES connect_sessions(id) ON DELETE CASCADE,
event_type VARCHAR(50) NOT NULL,
timestamp TIMESTAMPTZ DEFAULT NOW(),
viewer_id VARCHAR(255),
viewer_name VARCHAR(255),
details JSONB,
ip_address INET
);
CREATE INDEX idx_connect_events_session ON connect_session_events(session_id);
CREATE INDEX idx_connect_events_time ON connect_session_events(timestamp DESC);
CREATE INDEX idx_connect_events_type ON connect_session_events(event_type);
-- Support codes - persistent across restarts
CREATE TABLE connect_support_codes (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
code VARCHAR(10) UNIQUE NOT NULL,
session_id UUID,
created_by VARCHAR(255) NOT NULL,
created_at TIMESTAMPTZ DEFAULT NOW(),
expires_at TIMESTAMPTZ,
status VARCHAR(20) DEFAULT 'pending',
client_name VARCHAR(255),
client_machine VARCHAR(255),
connected_at TIMESTAMPTZ
);
CREATE INDEX idx_support_codes_code ON connect_support_codes(code);
CREATE INDEX idx_support_codes_status ON connect_support_codes(status);
CREATE INDEX idx_support_codes_session ON connect_support_codes(session_id);
-- Trigger to auto-update updated_at on machines
CREATE OR REPLACE FUNCTION update_connect_updated_at()
RETURNS TRIGGER AS $$
BEGIN
NEW.updated_at = NOW();
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
CREATE TRIGGER update_connect_machines_updated_at
BEFORE UPDATE ON connect_machines
FOR EACH ROW
EXECUTE FUNCTION update_connect_updated_at();

View File

@@ -9,9 +9,12 @@ pub struct Config {
/// Address to listen on (e.g., "0.0.0.0:8080") /// Address to listen on (e.g., "0.0.0.0:8080")
pub listen_addr: String, pub listen_addr: String,
/// Database URL (optional for MVP) /// Database URL (optional - server works without it)
pub database_url: Option<String>, pub database_url: Option<String>,
/// Maximum database connections in pool
pub database_max_connections: u32,
/// JWT secret for authentication /// JWT secret for authentication
pub jwt_secret: Option<String>, pub jwt_secret: Option<String>,
@@ -25,6 +28,10 @@ impl Config {
Ok(Self { Ok(Self {
listen_addr: env::var("LISTEN_ADDR").unwrap_or_else(|_| "0.0.0.0:8080".to_string()), listen_addr: env::var("LISTEN_ADDR").unwrap_or_else(|_| "0.0.0.0:8080".to_string()),
database_url: env::var("DATABASE_URL").ok(), database_url: env::var("DATABASE_URL").ok(),
database_max_connections: env::var("DATABASE_MAX_CONNECTIONS")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(5),
jwt_secret: env::var("JWT_SECRET").ok(), jwt_secret: env::var("JWT_SECRET").ok(),
debug: env::var("DEBUG") debug: env::var("DEBUG")
.map(|v| v == "1" || v.to_lowercase() == "true") .map(|v| v == "1" || v.to_lowercase() == "true")
@@ -38,6 +45,7 @@ impl Default for Config {
Self { Self {
listen_addr: "0.0.0.0:8080".to_string(), listen_addr: "0.0.0.0:8080".to_string(),
database_url: None, database_url: None,
database_max_connections: 5,
jwt_secret: None, jwt_secret: None,
debug: false, debug: false,
} }

107
server/src/db/events.rs Normal file
View File

@@ -0,0 +1,107 @@
//! Audit event logging
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use serde_json::Value as JsonValue;
use sqlx::PgPool;
use std::net::IpAddr;
use uuid::Uuid;
/// Session event record from database
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
pub struct SessionEvent {
pub id: i64,
pub session_id: Uuid,
pub event_type: String,
pub timestamp: DateTime<Utc>,
pub viewer_id: Option<String>,
pub viewer_name: Option<String>,
pub details: Option<JsonValue>,
pub ip_address: Option<String>,
}
/// Event types for session audit logging
pub struct EventTypes;
impl EventTypes {
pub const SESSION_STARTED: &'static str = "session_started";
pub const SESSION_ENDED: &'static str = "session_ended";
pub const SESSION_TIMEOUT: &'static str = "session_timeout";
pub const VIEWER_JOINED: &'static str = "viewer_joined";
pub const VIEWER_LEFT: &'static str = "viewer_left";
pub const STREAMING_STARTED: &'static str = "streaming_started";
pub const STREAMING_STOPPED: &'static str = "streaming_stopped";
}
/// Log a session event
pub async fn log_event(
pool: &PgPool,
session_id: Uuid,
event_type: &str,
viewer_id: Option<&str>,
viewer_name: Option<&str>,
details: Option<JsonValue>,
ip_address: Option<IpAddr>,
) -> Result<i64, sqlx::Error> {
let ip_str = ip_address.map(|ip| ip.to_string());
let result = sqlx::query_scalar::<_, i64>(
r#"
INSERT INTO connect_session_events
(session_id, event_type, viewer_id, viewer_name, details, ip_address)
VALUES ($1, $2, $3, $4, $5, $6::inet)
RETURNING id
"#,
)
.bind(session_id)
.bind(event_type)
.bind(viewer_id)
.bind(viewer_name)
.bind(details)
.bind(ip_str)
.fetch_one(pool)
.await?;
Ok(result)
}
/// Get events for a session
pub async fn get_session_events(
pool: &PgPool,
session_id: Uuid,
) -> Result<Vec<SessionEvent>, sqlx::Error> {
sqlx::query_as::<_, SessionEvent>(
"SELECT id, session_id, event_type, timestamp, viewer_id, viewer_name, details, ip_address::text as ip_address FROM connect_session_events WHERE session_id = $1 ORDER BY timestamp"
)
.bind(session_id)
.fetch_all(pool)
.await
}
/// Get recent events (for dashboard)
pub async fn get_recent_events(
pool: &PgPool,
limit: i64,
) -> Result<Vec<SessionEvent>, sqlx::Error> {
sqlx::query_as::<_, SessionEvent>(
"SELECT id, session_id, event_type, timestamp, viewer_id, viewer_name, details, ip_address::text as ip_address FROM connect_session_events ORDER BY timestamp DESC LIMIT $1"
)
.bind(limit)
.fetch_all(pool)
.await
}
/// Get events by type
pub async fn get_events_by_type(
pool: &PgPool,
event_type: &str,
limit: i64,
) -> Result<Vec<SessionEvent>, sqlx::Error> {
sqlx::query_as::<_, SessionEvent>(
"SELECT id, session_id, event_type, timestamp, viewer_id, viewer_name, details, ip_address::text as ip_address FROM connect_session_events WHERE event_type = $1 ORDER BY timestamp DESC LIMIT $2"
)
.bind(event_type)
.bind(limit)
.fetch_all(pool)
.await
}

118
server/src/db/machines.rs Normal file
View File

@@ -0,0 +1,118 @@
//! Machine/Agent database operations
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use sqlx::PgPool;
use uuid::Uuid;
/// Machine record from database
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
pub struct Machine {
pub id: Uuid,
pub agent_id: String,
pub hostname: String,
pub os_version: Option<String>,
pub is_elevated: bool,
pub is_persistent: bool,
pub first_seen: DateTime<Utc>,
pub last_seen: DateTime<Utc>,
pub last_session_id: Option<Uuid>,
pub status: String,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
/// Get or create a machine by agent_id (upsert)
pub async fn upsert_machine(
pool: &PgPool,
agent_id: &str,
hostname: &str,
is_persistent: bool,
) -> Result<Machine, sqlx::Error> {
sqlx::query_as::<_, Machine>(
r#"
INSERT INTO connect_machines (agent_id, hostname, is_persistent, status, last_seen)
VALUES ($1, $2, $3, 'online', NOW())
ON CONFLICT (agent_id) DO UPDATE SET
hostname = EXCLUDED.hostname,
status = 'online',
last_seen = NOW()
RETURNING *
"#,
)
.bind(agent_id)
.bind(hostname)
.bind(is_persistent)
.fetch_one(pool)
.await
}
/// Update machine status and info
pub async fn update_machine_status(
pool: &PgPool,
agent_id: &str,
status: &str,
os_version: Option<&str>,
is_elevated: bool,
session_id: Option<Uuid>,
) -> Result<(), sqlx::Error> {
sqlx::query(
r#"
UPDATE connect_machines SET
status = $1,
os_version = COALESCE($2, os_version),
is_elevated = $3,
last_seen = NOW(),
last_session_id = COALESCE($4, last_session_id)
WHERE agent_id = $5
"#,
)
.bind(status)
.bind(os_version)
.bind(is_elevated)
.bind(session_id)
.bind(agent_id)
.execute(pool)
.await?;
Ok(())
}
/// Get all persistent machines (for restore on startup)
pub async fn get_all_machines(pool: &PgPool) -> Result<Vec<Machine>, sqlx::Error> {
sqlx::query_as::<_, Machine>(
"SELECT * FROM connect_machines WHERE is_persistent = true ORDER BY hostname"
)
.fetch_all(pool)
.await
}
/// Get machine by agent_id
pub async fn get_machine_by_agent_id(
pool: &PgPool,
agent_id: &str,
) -> Result<Option<Machine>, sqlx::Error> {
sqlx::query_as::<_, Machine>(
"SELECT * FROM connect_machines WHERE agent_id = $1"
)
.bind(agent_id)
.fetch_optional(pool)
.await
}
/// Mark machine as offline
pub async fn mark_machine_offline(pool: &PgPool, agent_id: &str) -> Result<(), sqlx::Error> {
sqlx::query("UPDATE connect_machines SET status = 'offline', last_seen = NOW() WHERE agent_id = $1")
.bind(agent_id)
.execute(pool)
.await?;
Ok(())
}
/// Delete a machine record
pub async fn delete_machine(pool: &PgPool, agent_id: &str) -> Result<(), sqlx::Error> {
sqlx::query("DELETE FROM connect_machines WHERE agent_id = $1")
.bind(agent_id)
.execute(pool)
.await?;
Ok(())
}

View File

@@ -1,45 +1,52 @@
//! Database module //! Database module for GuruConnect
//! //!
//! Handles session logging and persistence. //! Handles persistence for machines, sessions, and audit logging.
//! Optional for MVP - sessions are kept in memory only. //! Optional - server works without database if DATABASE_URL not set.
pub mod machines;
pub mod sessions;
pub mod events;
pub mod support_codes;
use anyhow::Result; use anyhow::Result;
use sqlx::postgres::PgPoolOptions;
use sqlx::PgPool;
use tracing::info;
/// Database connection pool (placeholder) pub use machines::*;
pub use sessions::*;
pub use events::*;
pub use support_codes::*;
/// Database connection pool wrapper
#[derive(Clone)] #[derive(Clone)]
pub struct Database { pub struct Database {
// TODO: Add sqlx pool when PostgreSQL is needed pool: PgPool,
_placeholder: (),
} }
impl Database { impl Database {
/// Initialize database connection /// Initialize database connection pool
pub async fn init(_database_url: &str) -> Result<Self> { pub async fn connect(database_url: &str, max_connections: u32) -> Result<Self> {
// TODO: Initialize PostgreSQL connection pool info!("Connecting to database...");
Ok(Self { _placeholder: () }) let pool = PgPoolOptions::new()
.max_connections(max_connections)
.connect(database_url)
.await?;
info!("Database connection established");
Ok(Self { pool })
} }
}
/// Session event for audit logging /// Run database migrations
#[derive(Debug)] pub async fn migrate(&self) -> Result<()> {
pub struct SessionEvent { info!("Running database migrations...");
pub session_id: String, sqlx::migrate!("./migrations").run(&self.pool).await?;
pub event_type: SessionEventType, info!("Migrations complete");
pub details: Option<String>,
}
#[derive(Debug)]
pub enum SessionEventType {
Started,
ViewerJoined,
ViewerLeft,
Ended,
}
impl Database {
/// Log a session event (placeholder)
pub async fn log_session_event(&self, _event: SessionEvent) -> Result<()> {
// TODO: Insert into connect_session_events table
Ok(()) Ok(())
} }
/// Get reference to the connection pool
pub fn pool(&self) -> &PgPool {
&self.pool
}
} }

98
server/src/db/sessions.rs Normal file
View File

@@ -0,0 +1,98 @@
//! Session database operations
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use sqlx::PgPool;
use uuid::Uuid;
/// Session record from database
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
pub struct DbSession {
pub id: Uuid,
pub machine_id: Option<Uuid>,
pub started_at: DateTime<Utc>,
pub ended_at: Option<DateTime<Utc>>,
pub duration_secs: Option<i32>,
pub is_support_session: bool,
pub support_code: Option<String>,
pub status: String,
}
/// Create a new session record
pub async fn create_session(
pool: &PgPool,
session_id: Uuid,
machine_id: Uuid,
is_support_session: bool,
support_code: Option<&str>,
) -> Result<DbSession, sqlx::Error> {
sqlx::query_as::<_, DbSession>(
r#"
INSERT INTO connect_sessions (id, machine_id, is_support_session, support_code, status)
VALUES ($1, $2, $3, $4, 'active')
RETURNING *
"#,
)
.bind(session_id)
.bind(machine_id)
.bind(is_support_session)
.bind(support_code)
.fetch_one(pool)
.await
}
/// End a session
pub async fn end_session(
pool: &PgPool,
session_id: Uuid,
status: &str, // 'ended' or 'disconnected' or 'timeout'
) -> Result<(), sqlx::Error> {
sqlx::query(
r#"
UPDATE connect_sessions SET
ended_at = NOW(),
duration_secs = EXTRACT(EPOCH FROM (NOW() - started_at))::INTEGER,
status = $1
WHERE id = $2
"#,
)
.bind(status)
.bind(session_id)
.execute(pool)
.await?;
Ok(())
}
/// Get session by ID
pub async fn get_session(pool: &PgPool, session_id: Uuid) -> Result<Option<DbSession>, sqlx::Error> {
sqlx::query_as::<_, DbSession>("SELECT * FROM connect_sessions WHERE id = $1")
.bind(session_id)
.fetch_optional(pool)
.await
}
/// Get active sessions for a machine
pub async fn get_active_sessions_for_machine(
pool: &PgPool,
machine_id: Uuid,
) -> Result<Vec<DbSession>, sqlx::Error> {
sqlx::query_as::<_, DbSession>(
"SELECT * FROM connect_sessions WHERE machine_id = $1 AND status = 'active' ORDER BY started_at DESC"
)
.bind(machine_id)
.fetch_all(pool)
.await
}
/// Get recent sessions (for dashboard)
pub async fn get_recent_sessions(
pool: &PgPool,
limit: i64,
) -> Result<Vec<DbSession>, sqlx::Error> {
sqlx::query_as::<_, DbSession>(
"SELECT * FROM connect_sessions ORDER BY started_at DESC LIMIT $1"
)
.bind(limit)
.fetch_all(pool)
.await
}

View File

@@ -0,0 +1,141 @@
//! Support code database operations
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use sqlx::PgPool;
use uuid::Uuid;
/// Support code record from database
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
pub struct DbSupportCode {
pub id: Uuid,
pub code: String,
pub session_id: Option<Uuid>,
pub created_by: String,
pub created_at: DateTime<Utc>,
pub expires_at: Option<DateTime<Utc>>,
pub status: String,
pub client_name: Option<String>,
pub client_machine: Option<String>,
pub connected_at: Option<DateTime<Utc>>,
}
/// Create a new support code
pub async fn create_support_code(
pool: &PgPool,
code: &str,
created_by: &str,
) -> Result<DbSupportCode, sqlx::Error> {
sqlx::query_as::<_, DbSupportCode>(
r#"
INSERT INTO connect_support_codes (code, created_by, status)
VALUES ($1, $2, 'pending')
RETURNING *
"#,
)
.bind(code)
.bind(created_by)
.fetch_one(pool)
.await
}
/// Get support code by code string
pub async fn get_support_code(pool: &PgPool, code: &str) -> Result<Option<DbSupportCode>, sqlx::Error> {
sqlx::query_as::<_, DbSupportCode>(
"SELECT * FROM connect_support_codes WHERE code = $1"
)
.bind(code)
.fetch_optional(pool)
.await
}
/// Update support code when client connects
pub async fn mark_code_connected(
pool: &PgPool,
code: &str,
session_id: Option<Uuid>,
client_name: Option<&str>,
client_machine: Option<&str>,
) -> Result<(), sqlx::Error> {
sqlx::query(
r#"
UPDATE connect_support_codes SET
status = 'connected',
session_id = $1,
client_name = $2,
client_machine = $3,
connected_at = NOW()
WHERE code = $4
"#,
)
.bind(session_id)
.bind(client_name)
.bind(client_machine)
.bind(code)
.execute(pool)
.await?;
Ok(())
}
/// Mark support code as completed
pub async fn mark_code_completed(pool: &PgPool, code: &str) -> Result<(), sqlx::Error> {
sqlx::query("UPDATE connect_support_codes SET status = 'completed' WHERE code = $1")
.bind(code)
.execute(pool)
.await?;
Ok(())
}
/// Mark support code as cancelled
pub async fn mark_code_cancelled(pool: &PgPool, code: &str) -> Result<(), sqlx::Error> {
sqlx::query("UPDATE connect_support_codes SET status = 'cancelled' WHERE code = $1")
.bind(code)
.execute(pool)
.await?;
Ok(())
}
/// Get active support codes (pending or connected)
pub async fn get_active_support_codes(pool: &PgPool) -> Result<Vec<DbSupportCode>, sqlx::Error> {
sqlx::query_as::<_, DbSupportCode>(
"SELECT * FROM connect_support_codes WHERE status IN ('pending', 'connected') ORDER BY created_at DESC"
)
.fetch_all(pool)
.await
}
/// Check if code exists and is valid for connection
pub async fn is_code_valid(pool: &PgPool, code: &str) -> Result<bool, sqlx::Error> {
let result = sqlx::query_scalar::<_, bool>(
"SELECT EXISTS(SELECT 1 FROM connect_support_codes WHERE code = $1 AND status = 'pending')"
)
.bind(code)
.fetch_one(pool)
.await?;
Ok(result)
}
/// Check if code is cancelled
pub async fn is_code_cancelled(pool: &PgPool, code: &str) -> Result<bool, sqlx::Error> {
let result = sqlx::query_scalar::<_, bool>(
"SELECT EXISTS(SELECT 1 FROM connect_support_codes WHERE code = $1 AND status = 'cancelled')"
)
.bind(code)
.fetch_one(pool)
.await?;
Ok(result)
}
/// Link session to support code
pub async fn link_session_to_code(
pool: &PgPool,
code: &str,
session_id: Uuid,
) -> Result<(), sqlx::Error> {
sqlx::query("UPDATE connect_support_codes SET session_id = $1 WHERE code = $2")
.bind(session_id)
.bind(code)
.execute(pool)
.await?;
Ok(())
}

View File

@@ -38,6 +38,7 @@ use support_codes::{SupportCodeManager, CreateCodeRequest, SupportCode, CodeVali
pub struct AppState { pub struct AppState {
sessions: session::SessionManager, sessions: session::SessionManager,
support_codes: SupportCodeManager, support_codes: SupportCodeManager,
db: Option<db::Database>,
} }
#[tokio::main] #[tokio::main]
@@ -52,15 +53,55 @@ async fn main() -> Result<()> {
// Load configuration // Load configuration
let config = config::Config::load()?; let config = config::Config::load()?;
// Use port 3002 for GuruConnect // Use port 3002 for GuruConnect
let listen_addr = std::env::var("LISTEN_ADDR").unwrap_or_else(|_| "0.0.0.0:3002".to_string()); let listen_addr = std::env::var("LISTEN_ADDR").unwrap_or_else(|_| "0.0.0.0:3002".to_string());
info!("Loaded configuration, listening on {}", listen_addr); info!("Loaded configuration, listening on {}", listen_addr);
// Initialize database if configured
let database = if let Some(ref db_url) = config.database_url {
match db::Database::connect(db_url, config.database_max_connections).await {
Ok(db) => {
// Run migrations
if let Err(e) = db.migrate().await {
tracing::error!("Failed to run migrations: {}", e);
return Err(e);
}
Some(db)
}
Err(e) => {
tracing::warn!("Failed to connect to database: {}. Running without persistence.", e);
None
}
}
} else {
info!("No DATABASE_URL set, running without persistence");
None
};
// Create session manager
let sessions = session::SessionManager::new();
// Restore persistent machines from database
if let Some(ref db) = database {
match db::machines::get_all_machines(db.pool()).await {
Ok(machines) => {
info!("Restoring {} persistent machines from database", machines.len());
for machine in machines {
sessions.restore_offline_machine(&machine.agent_id, &machine.hostname).await;
}
}
Err(e) => {
tracing::warn!("Failed to restore machines: {}", e);
}
}
}
// Create application state // Create application state
let state = AppState { let state = AppState {
sessions: session::SessionManager::new(), sessions,
support_codes: SupportCodeManager::new(), support_codes: SupportCodeManager::new(),
db: database,
}; };
// Build router // Build router

View File

@@ -18,6 +18,7 @@ use uuid::Uuid;
use crate::proto; use crate::proto;
use crate::session::SessionManager; use crate::session::SessionManager;
use crate::db::{self, Database};
use crate::AppState; use crate::AppState;
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
@@ -53,8 +54,9 @@ pub async fn agent_ws_handler(
let support_code = params.support_code; let support_code = params.support_code;
let sessions = state.sessions.clone(); let sessions = state.sessions.clone();
let support_codes = state.support_codes.clone(); let support_codes = state.support_codes.clone();
let db = state.db.clone();
ws.on_upgrade(move |socket| handle_agent_connection(socket, sessions, support_codes, agent_id, agent_name, support_code)) ws.on_upgrade(move |socket| handle_agent_connection(socket, sessions, support_codes, db, agent_id, agent_name, support_code))
} }
/// WebSocket handler for viewer connections /// WebSocket handler for viewer connections
@@ -66,8 +68,9 @@ pub async fn viewer_ws_handler(
let session_id = params.session_id; let session_id = params.session_id;
let viewer_name = params.viewer_name; let viewer_name = params.viewer_name;
let sessions = state.sessions.clone(); let sessions = state.sessions.clone();
let db = state.db.clone();
ws.on_upgrade(move |socket| handle_viewer_connection(socket, sessions, session_id, viewer_name)) ws.on_upgrade(move |socket| handle_viewer_connection(socket, sessions, db, session_id, viewer_name))
} }
/// Handle an agent WebSocket connection /// Handle an agent WebSocket connection
@@ -75,6 +78,7 @@ async fn handle_agent_connection(
socket: WebSocket, socket: WebSocket,
sessions: SessionManager, sessions: SessionManager,
support_codes: crate::support_codes::SupportCodeManager, support_codes: crate::support_codes::SupportCodeManager,
db: Option<Database>,
agent_id: String, agent_id: String,
agent_name: String, agent_name: String,
support_code: Option<String>, support_code: Option<String>,
@@ -110,11 +114,54 @@ async fn handle_agent_connection(
info!("Session created: {} (agent in idle mode)", session_id); info!("Session created: {} (agent in idle mode)", session_id);
// Database: upsert machine and create session record
let machine_id = if let Some(ref db) = db {
match db::machines::upsert_machine(db.pool(), &agent_id, &agent_name, is_persistent).await {
Ok(machine) => {
// Create session record
let _ = db::sessions::create_session(
db.pool(),
session_id,
machine.id,
support_code.is_some(),
support_code.as_deref(),
).await;
// Log session started event
let _ = db::events::log_event(
db.pool(),
session_id,
db::events::EventTypes::SESSION_STARTED,
None, None, None, None,
).await;
Some(machine.id)
}
Err(e) => {
warn!("Failed to upsert machine in database: {}", e);
None
}
}
} else {
None
};
// If a support code was provided, mark it as connected // If a support code was provided, mark it as connected
if let Some(ref code) = support_code { if let Some(ref code) = support_code {
info!("Linking support code {} to session {}", code, session_id); info!("Linking support code {} to session {}", code, session_id);
support_codes.mark_connected(code, Some(agent_name.clone()), Some(agent_id.clone())).await; support_codes.mark_connected(code, Some(agent_name.clone()), Some(agent_id.clone())).await;
support_codes.link_session(code, session_id).await; support_codes.link_session(code, session_id).await;
// Database: update support code
if let Some(ref db) = db {
let _ = db::support_codes::mark_code_connected(
db.pool(),
code,
Some(session_id),
Some(&agent_name),
Some(&agent_id),
).await;
}
} }
// Use Arc<Mutex> for sender so we can use it from multiple places // Use Arc<Mutex> for sender so we can use it from multiple places
@@ -233,10 +280,33 @@ async fn handle_agent_connection(
// Mark agent as disconnected (persistent agents stay in list as offline) // Mark agent as disconnected (persistent agents stay in list as offline)
sessions_cleanup.mark_agent_disconnected(session_id).await; sessions_cleanup.mark_agent_disconnected(session_id).await;
// Database: end session and mark machine offline
if let Some(ref db) = db {
// End the session record
let _ = db::sessions::end_session(db.pool(), session_id, "ended").await;
// Mark machine as offline
let _ = db::machines::mark_machine_offline(db.pool(), &agent_id).await;
// Log session ended event
let _ = db::events::log_event(
db.pool(),
session_id,
db::events::EventTypes::SESSION_ENDED,
None, None, None, None,
).await;
}
// Mark support code as completed if one was used (unless cancelled) // Mark support code as completed if one was used (unless cancelled)
if let Some(ref code) = support_code_cleanup { if let Some(ref code) = support_code_cleanup {
if !support_codes_cleanup.is_cancelled(code).await { if !support_codes_cleanup.is_cancelled(code).await {
support_codes_cleanup.mark_completed(code).await; support_codes_cleanup.mark_completed(code).await;
// Database: mark code as completed
if let Some(ref db) = db {
let _ = db::support_codes::mark_code_completed(db.pool(), code).await;
}
info!("Support code {} marked as completed", code); info!("Support code {} marked as completed", code);
} }
} }
@@ -248,6 +318,7 @@ async fn handle_agent_connection(
async fn handle_viewer_connection( async fn handle_viewer_connection(
socket: WebSocket, socket: WebSocket,
sessions: SessionManager, sessions: SessionManager,
db: Option<Database>,
session_id_str: String, session_id_str: String,
viewer_name: String, viewer_name: String,
) { ) {
@@ -274,6 +345,18 @@ async fn handle_viewer_connection(
info!("Viewer {} ({}) joined session: {}", viewer_name, viewer_id, session_id); info!("Viewer {} ({}) joined session: {}", viewer_name, viewer_id, session_id);
// Database: log viewer joined event
if let Some(ref db) = db {
let _ = db::events::log_event(
db.pool(),
session_id,
db::events::EventTypes::VIEWER_JOINED,
Some(&viewer_id),
Some(&viewer_name),
None, None,
).await;
}
let (mut ws_sender, mut ws_receiver) = socket.split(); let (mut ws_sender, mut ws_receiver) = socket.split();
// Task to forward frames from agent to this viewer // Task to forward frames from agent to this viewer
@@ -287,6 +370,7 @@ async fn handle_viewer_connection(
let sessions_cleanup = sessions.clone(); let sessions_cleanup = sessions.clone();
let viewer_id_cleanup = viewer_id.clone(); let viewer_id_cleanup = viewer_id.clone();
let viewer_name_cleanup = viewer_name.clone();
// Main loop: receive input from viewer and forward to agent // Main loop: receive input from viewer and forward to agent
while let Some(msg) = ws_receiver.next().await { while let Some(msg) = ws_receiver.next().await {
@@ -330,5 +414,18 @@ async fn handle_viewer_connection(
// Cleanup (this sends StopStream to agent if last viewer) // Cleanup (this sends StopStream to agent if last viewer)
frame_forward.abort(); frame_forward.abort();
sessions_cleanup.leave_session(session_id, &viewer_id_cleanup).await; sessions_cleanup.leave_session(session_id, &viewer_id_cleanup).await;
// Database: log viewer left event
if let Some(ref db) = db {
let _ = db::events::log_event(
db.pool(),
session_id,
db::events::EventTypes::VIEWER_LEFT,
Some(&viewer_id_cleanup),
Some(&viewer_name_cleanup),
None, None,
).await;
}
info!("Viewer {} left session: {}", viewer_id_cleanup, session_id); info!("Viewer {} left session: {}", viewer_id_cleanup, session_id);
} }

View File

@@ -390,3 +390,50 @@ impl Default for SessionManager {
Self::new() Self::new()
} }
} }
impl SessionManager {
/// Restore a machine as an offline session (called on startup from database)
pub async fn restore_offline_machine(&self, agent_id: &str, hostname: &str) -> SessionId {
let session_id = Uuid::new_v4();
let now = chrono::Utc::now();
let session = Session {
id: session_id,
agent_id: agent_id.to_string(),
agent_name: hostname.to_string(),
started_at: now,
viewer_count: 0,
viewers: Vec::new(),
is_streaming: false,
is_online: false, // Offline until agent reconnects
is_persistent: true,
last_heartbeat: now,
os_version: None,
is_elevated: false,
uptime_secs: 0,
display_count: 1,
};
// Create placeholder channels (will be replaced on reconnect)
let (frame_tx, _) = broadcast::channel(16);
let (input_tx, input_rx) = tokio::sync::mpsc::channel(64);
let session_data = SessionData {
info: session,
frame_tx,
input_tx,
input_rx: Some(input_rx),
viewers: HashMap::new(),
last_heartbeat_instant: Instant::now(),
};
let mut sessions = self.sessions.write().await;
sessions.insert(session_id, session_data);
let mut agents = self.agents.write().await;
agents.insert(agent_id.to_string(), session_id);
tracing::info!("Restored offline machine: {} ({})", hostname, agent_id);
session_id
}
}