Files
guru-connect/server/src/relay/mod.rs
Mike Swanson 611bc00d06 Add support codes API and portal server changes
- support_codes.rs: 6-digit code management
- main.rs: Portal routes, static file serving, AppState
- relay/mod.rs: Updated for AppState
- Cargo.toml: Added rand, tower-http fs feature

Generated with Claude Code
2025-12-28 17:54:05 +00:00

202 lines
6.2 KiB
Rust

//! WebSocket relay handlers
//!
//! Handles WebSocket connections from agents and viewers,
//! relaying video frames and input events between them.
use axum::{
extract::{
ws::{Message, WebSocket, WebSocketUpgrade},
Query, State,
},
response::IntoResponse,
};
use futures_util::{SinkExt, StreamExt};
use prost::Message as ProstMessage;
use serde::Deserialize;
use tracing::{error, info, warn};
use crate::proto;
use crate::session::SessionManager;
use crate::AppState;
#[derive(Debug, Deserialize)]
pub struct AgentParams {
agent_id: String,
#[serde(default)]
agent_name: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct ViewerParams {
session_id: String,
}
/// WebSocket handler for agent connections
pub async fn agent_ws_handler(
ws: WebSocketUpgrade,
State(state): State<AppState>,
Query(params): Query<AgentParams>,
) -> impl IntoResponse {
let agent_id = params.agent_id;
let agent_name = params.agent_name.unwrap_or_else(|| agent_id.clone());
let sessions = state.sessions.clone();
ws.on_upgrade(move |socket| handle_agent_connection(socket, sessions, agent_id, agent_name))
}
/// WebSocket handler for viewer connections
pub async fn viewer_ws_handler(
ws: WebSocketUpgrade,
State(state): State<AppState>,
Query(params): Query<ViewerParams>,
) -> impl IntoResponse {
let session_id = params.session_id;
let sessions = state.sessions.clone();
ws.on_upgrade(move |socket| handle_viewer_connection(socket, sessions, session_id))
}
/// Handle an agent WebSocket connection
async fn handle_agent_connection(
socket: WebSocket,
sessions: SessionManager,
agent_id: String,
agent_name: String,
) {
info!("Agent connected: {} ({})", agent_name, agent_id);
// Register the agent and get channels
let (session_id, frame_tx, mut input_rx) = sessions.register_agent(agent_id.clone(), agent_name.clone()).await;
info!("Session created: {}", session_id);
let (mut ws_sender, mut ws_receiver) = socket.split();
// Task to forward input events from viewers to agent
let input_forward = tokio::spawn(async move {
while let Some(input_data) = input_rx.recv().await {
if ws_sender.send(Message::Binary(input_data.into())).await.is_err() {
break;
}
}
});
let sessions_cleanup = sessions.clone();
// Main loop: receive frames from agent and broadcast to viewers
while let Some(msg) = ws_receiver.next().await {
match msg {
Ok(Message::Binary(data)) => {
// Try to decode as protobuf message
match proto::Message::decode(data.as_ref()) {
Ok(proto_msg) => {
if let Some(proto::message::Payload::VideoFrame(_)) = &proto_msg.payload {
// Broadcast frame to all viewers
let _ = frame_tx.send(data.to_vec());
}
}
Err(e) => {
warn!("Failed to decode agent message: {}", e);
}
}
}
Ok(Message::Close(_)) => {
info!("Agent disconnected: {}", agent_id);
break;
}
Ok(Message::Ping(data)) => {
// Pong is handled automatically by axum
let _ = data;
}
Ok(_) => {}
Err(e) => {
error!("WebSocket error from agent {}: {}", agent_id, e);
break;
}
}
}
// Cleanup
input_forward.abort();
sessions_cleanup.remove_session(session_id).await;
info!("Session {} ended", session_id);
}
/// Handle a viewer WebSocket connection
async fn handle_viewer_connection(
socket: WebSocket,
sessions: SessionManager,
session_id_str: String,
) {
// Parse session ID
let session_id = match uuid::Uuid::parse_str(&session_id_str) {
Ok(id) => id,
Err(_) => {
warn!("Invalid session ID: {}", session_id_str);
return;
}
};
// Join the session
let (mut frame_rx, input_tx) = match sessions.join_session(session_id).await {
Some(channels) => channels,
None => {
warn!("Session not found: {}", session_id);
return;
}
};
info!("Viewer joined session: {}", session_id);
let (mut ws_sender, mut ws_receiver) = socket.split();
// Task to forward frames from agent to this viewer
let frame_forward = tokio::spawn(async move {
while let Ok(frame_data) = frame_rx.recv().await {
if ws_sender.send(Message::Binary(frame_data.into())).await.is_err() {
break;
}
}
});
let sessions_cleanup = sessions.clone();
// Main loop: receive input from viewer and forward to agent
while let Some(msg) = ws_receiver.next().await {
match msg {
Ok(Message::Binary(data)) => {
// Try to decode as protobuf message
match proto::Message::decode(data.as_ref()) {
Ok(proto_msg) => {
match &proto_msg.payload {
Some(proto::message::Payload::MouseEvent(_)) |
Some(proto::message::Payload::KeyEvent(_)) => {
// Forward input to agent
let _ = input_tx.send(data.to_vec()).await;
}
_ => {}
}
}
Err(e) => {
warn!("Failed to decode viewer message: {}", e);
}
}
}
Ok(Message::Close(_)) => {
info!("Viewer disconnected from session: {}", session_id);
break;
}
Ok(_) => {}
Err(e) => {
error!("WebSocket error from viewer: {}", e);
break;
}
}
}
// Cleanup
frame_forward.abort();
sessions_cleanup.leave_session(session_id).await;
info!("Viewer left session: {}", session_id);
}