Files
guru-connect/server/src/relay/mod.rs
Mike Swanson f6bf0cfd26 Add PostgreSQL database persistence
- Add connect_machines, connect_sessions, connect_session_events, connect_support_codes tables
- Implement db module with connection pooling (sqlx)
- Add machine persistence across server restarts
- Add audit logging for session/viewer events
- Support codes now persisted to database

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-28 19:51:01 -07:00

432 lines
16 KiB
Rust

//! WebSocket relay handlers
//!
//! Handles WebSocket connections from agents and viewers,
//! relaying video frames and input events between them.
use axum::{
extract::{
ws::{Message, WebSocket, WebSocketUpgrade},
Query, State,
},
response::IntoResponse,
};
use futures_util::{SinkExt, StreamExt};
use prost::Message as ProstMessage;
use serde::Deserialize;
use tracing::{error, info, warn};
use uuid::Uuid;
use crate::proto;
use crate::session::SessionManager;
use crate::db::{self, Database};
use crate::AppState;
#[derive(Debug, Deserialize)]
pub struct AgentParams {
agent_id: String,
#[serde(default)]
agent_name: Option<String>,
#[serde(default)]
support_code: Option<String>,
#[serde(default)]
hostname: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct ViewerParams {
session_id: String,
#[serde(default = "default_viewer_name")]
viewer_name: String,
}
fn default_viewer_name() -> String {
"Technician".to_string()
}
/// WebSocket handler for agent connections
pub async fn agent_ws_handler(
ws: WebSocketUpgrade,
State(state): State<AppState>,
Query(params): Query<AgentParams>,
) -> impl IntoResponse {
let agent_id = params.agent_id;
let agent_name = params.hostname.or(params.agent_name).unwrap_or_else(|| agent_id.clone());
let support_code = params.support_code;
let sessions = state.sessions.clone();
let support_codes = state.support_codes.clone();
let db = state.db.clone();
ws.on_upgrade(move |socket| handle_agent_connection(socket, sessions, support_codes, db, agent_id, agent_name, support_code))
}
/// WebSocket handler for viewer connections
pub async fn viewer_ws_handler(
ws: WebSocketUpgrade,
State(state): State<AppState>,
Query(params): Query<ViewerParams>,
) -> impl IntoResponse {
let session_id = params.session_id;
let viewer_name = params.viewer_name;
let sessions = state.sessions.clone();
let db = state.db.clone();
ws.on_upgrade(move |socket| handle_viewer_connection(socket, sessions, db, session_id, viewer_name))
}
/// Handle an agent WebSocket connection
async fn handle_agent_connection(
socket: WebSocket,
sessions: SessionManager,
support_codes: crate::support_codes::SupportCodeManager,
db: Option<Database>,
agent_id: String,
agent_name: String,
support_code: Option<String>,
) {
info!("Agent connected: {} ({})", agent_name, agent_id);
let (mut ws_sender, mut ws_receiver) = socket.split();
// If a support code was provided, check if it's valid
if let Some(ref code) = support_code {
// Check if the code is cancelled or invalid
if support_codes.is_cancelled(code).await {
warn!("Agent tried to connect with cancelled code: {}", code);
// Send disconnect message to agent
let disconnect_msg = proto::Message {
payload: Some(proto::message::Payload::Disconnect(proto::Disconnect {
reason: "Support session was cancelled by technician".to_string(),
})),
};
let mut buf = Vec::new();
if prost::Message::encode(&disconnect_msg, &mut buf).is_ok() {
let _ = ws_sender.send(Message::Binary(buf.into())).await;
}
let _ = ws_sender.close().await;
return;
}
}
// Register the agent and get channels
// Persistent agents (no support code) keep their session when disconnected
let is_persistent = support_code.is_none();
let (session_id, frame_tx, mut input_rx) = sessions.register_agent(agent_id.clone(), agent_name.clone(), is_persistent).await;
info!("Session created: {} (agent in idle mode)", session_id);
// Database: upsert machine and create session record
let machine_id = if let Some(ref db) = db {
match db::machines::upsert_machine(db.pool(), &agent_id, &agent_name, is_persistent).await {
Ok(machine) => {
// Create session record
let _ = db::sessions::create_session(
db.pool(),
session_id,
machine.id,
support_code.is_some(),
support_code.as_deref(),
).await;
// Log session started event
let _ = db::events::log_event(
db.pool(),
session_id,
db::events::EventTypes::SESSION_STARTED,
None, None, None, None,
).await;
Some(machine.id)
}
Err(e) => {
warn!("Failed to upsert machine in database: {}", e);
None
}
}
} else {
None
};
// If a support code was provided, mark it as connected
if let Some(ref code) = support_code {
info!("Linking support code {} to session {}", code, session_id);
support_codes.mark_connected(code, Some(agent_name.clone()), Some(agent_id.clone())).await;
support_codes.link_session(code, session_id).await;
// Database: update support code
if let Some(ref db) = db {
let _ = db::support_codes::mark_code_connected(
db.pool(),
code,
Some(session_id),
Some(&agent_name),
Some(&agent_id),
).await;
}
}
// Use Arc<Mutex> for sender so we can use it from multiple places
let ws_sender = std::sync::Arc::new(tokio::sync::Mutex::new(ws_sender));
let ws_sender_input = ws_sender.clone();
let ws_sender_cancel = ws_sender.clone();
// Task to forward input events from viewers to agent
let input_forward = tokio::spawn(async move {
while let Some(input_data) = input_rx.recv().await {
let mut sender = ws_sender_input.lock().await;
if sender.send(Message::Binary(input_data.into())).await.is_err() {
break;
}
}
});
let sessions_cleanup = sessions.clone();
let sessions_status = sessions.clone();
let support_codes_cleanup = support_codes.clone();
let support_code_cleanup = support_code.clone();
let support_code_check = support_code.clone();
let support_codes_check = support_codes.clone();
// Task to check for cancellation every 2 seconds
let cancel_check = tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(2));
loop {
interval.tick().await;
if let Some(ref code) = support_code_check {
if support_codes_check.is_cancelled(code).await {
info!("Support code {} was cancelled, disconnecting agent", code);
// Send disconnect message
let disconnect_msg = proto::Message {
payload: Some(proto::message::Payload::Disconnect(proto::Disconnect {
reason: "Support session was cancelled by technician".to_string(),
})),
};
let mut buf = Vec::new();
if prost::Message::encode(&disconnect_msg, &mut buf).is_ok() {
let mut sender = ws_sender_cancel.lock().await;
let _ = sender.send(Message::Binary(buf.into())).await;
let _ = sender.close().await;
}
break;
}
}
}
});
// Main loop: receive messages from agent
while let Some(msg) = ws_receiver.next().await {
match msg {
Ok(Message::Binary(data)) => {
// Try to decode as protobuf message
match proto::Message::decode(data.as_ref()) {
Ok(proto_msg) => {
match &proto_msg.payload {
Some(proto::message::Payload::VideoFrame(_)) => {
// Broadcast frame to all viewers (only sent when streaming)
let _ = frame_tx.send(data.to_vec());
}
Some(proto::message::Payload::ChatMessage(chat)) => {
// Broadcast chat message to all viewers
info!("Chat from client: {}", chat.content);
let _ = frame_tx.send(data.to_vec());
}
Some(proto::message::Payload::AgentStatus(status)) => {
// Update session with agent status
sessions_status.update_agent_status(
session_id,
Some(status.os_version.clone()),
status.is_elevated,
status.uptime_secs,
status.display_count,
status.is_streaming,
).await;
info!("Agent status update: {} - streaming={}, uptime={}s",
status.hostname, status.is_streaming, status.uptime_secs);
}
Some(proto::message::Payload::Heartbeat(_)) => {
// Update heartbeat timestamp
sessions_status.update_heartbeat(session_id).await;
}
Some(proto::message::Payload::HeartbeatAck(_)) => {
// Agent acknowledged our heartbeat
sessions_status.update_heartbeat(session_id).await;
}
_ => {}
}
}
Err(e) => {
warn!("Failed to decode agent message: {}", e);
}
}
}
Ok(Message::Close(_)) => {
info!("Agent disconnected: {}", agent_id);
break;
}
Ok(Message::Ping(data)) => {
// Pong is handled automatically by axum
let _ = data;
}
Ok(_) => {}
Err(e) => {
error!("WebSocket error from agent {}: {}", agent_id, e);
break;
}
}
}
// Cleanup
input_forward.abort();
cancel_check.abort();
// Mark agent as disconnected (persistent agents stay in list as offline)
sessions_cleanup.mark_agent_disconnected(session_id).await;
// Database: end session and mark machine offline
if let Some(ref db) = db {
// End the session record
let _ = db::sessions::end_session(db.pool(), session_id, "ended").await;
// Mark machine as offline
let _ = db::machines::mark_machine_offline(db.pool(), &agent_id).await;
// Log session ended event
let _ = db::events::log_event(
db.pool(),
session_id,
db::events::EventTypes::SESSION_ENDED,
None, None, None, None,
).await;
}
// Mark support code as completed if one was used (unless cancelled)
if let Some(ref code) = support_code_cleanup {
if !support_codes_cleanup.is_cancelled(code).await {
support_codes_cleanup.mark_completed(code).await;
// Database: mark code as completed
if let Some(ref db) = db {
let _ = db::support_codes::mark_code_completed(db.pool(), code).await;
}
info!("Support code {} marked as completed", code);
}
}
info!("Session {} ended", session_id);
}
/// Handle a viewer WebSocket connection
async fn handle_viewer_connection(
socket: WebSocket,
sessions: SessionManager,
db: Option<Database>,
session_id_str: String,
viewer_name: String,
) {
// Parse session ID
let session_id = match uuid::Uuid::parse_str(&session_id_str) {
Ok(id) => id,
Err(_) => {
warn!("Invalid session ID: {}", session_id_str);
return;
}
};
// Generate unique viewer ID
let viewer_id = Uuid::new_v4().to_string();
// Join the session (this sends StartStream to agent if first viewer)
let (mut frame_rx, input_tx) = match sessions.join_session(session_id, viewer_id.clone(), viewer_name.clone()).await {
Some(channels) => channels,
None => {
warn!("Session not found: {}", session_id);
return;
}
};
info!("Viewer {} ({}) joined session: {}", viewer_name, viewer_id, session_id);
// Database: log viewer joined event
if let Some(ref db) = db {
let _ = db::events::log_event(
db.pool(),
session_id,
db::events::EventTypes::VIEWER_JOINED,
Some(&viewer_id),
Some(&viewer_name),
None, None,
).await;
}
let (mut ws_sender, mut ws_receiver) = socket.split();
// Task to forward frames from agent to this viewer
let frame_forward = tokio::spawn(async move {
while let Ok(frame_data) = frame_rx.recv().await {
if ws_sender.send(Message::Binary(frame_data.into())).await.is_err() {
break;
}
}
});
let sessions_cleanup = sessions.clone();
let viewer_id_cleanup = viewer_id.clone();
let viewer_name_cleanup = viewer_name.clone();
// Main loop: receive input from viewer and forward to agent
while let Some(msg) = ws_receiver.next().await {
match msg {
Ok(Message::Binary(data)) => {
// Try to decode as protobuf message
match proto::Message::decode(data.as_ref()) {
Ok(proto_msg) => {
match &proto_msg.payload {
Some(proto::message::Payload::MouseEvent(_)) |
Some(proto::message::Payload::KeyEvent(_)) |
Some(proto::message::Payload::SpecialKey(_)) => {
// Forward input to agent
let _ = input_tx.send(data.to_vec()).await;
}
Some(proto::message::Payload::ChatMessage(chat)) => {
// Forward chat message to agent
info!("Chat from technician: {}", chat.content);
let _ = input_tx.send(data.to_vec()).await;
}
_ => {}
}
}
Err(e) => {
warn!("Failed to decode viewer message: {}", e);
}
}
}
Ok(Message::Close(_)) => {
info!("Viewer {} disconnected from session: {}", viewer_id, session_id);
break;
}
Ok(_) => {}
Err(e) => {
error!("WebSocket error from viewer {}: {}", viewer_id, e);
break;
}
}
}
// Cleanup (this sends StopStream to agent if last viewer)
frame_forward.abort();
sessions_cleanup.leave_session(session_id, &viewer_id_cleanup).await;
// Database: log viewer left event
if let Some(ref db) = db {
let _ = db::events::log_event(
db.pool(),
session_id,
db::events::EventTypes::VIEWER_LEFT,
Some(&viewer_id_cleanup),
Some(&viewer_name_cleanup),
None, None,
).await;
}
info!("Viewer {} left session: {}", viewer_id_cleanup, session_id);
}