style: cargo fmt --all — make codebase rustfmt-clean
Some checks failed
Build and Test / Build Server (Linux) (push) Failing after 2m59s
Build and Test / Build Agent (Windows) (push) Has started running
Build and Test / Security Audit (push) Has been cancelled
Build and Test / Build Summary (push) Has been cancelled
Run Tests / Test Server (push) Has been cancelled
Run Tests / Test Agent (push) Has been cancelled
Run Tests / Code Coverage (push) Has been cancelled
Run Tests / Lint and Format Check (push) Has been cancelled

First run of the build-and-test CI gate (cargo fmt --all -- --check) surfaced
pre-existing formatting drift across the agent and server crates. Apply rustfmt
across the workspace so the codebase meets its own CI gate. Pure formatting; no
logic changes.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-29 15:02:12 +00:00
parent f2e0456f8d
commit 1c5c1e78e7
48 changed files with 1174 additions and 797 deletions

View File

@@ -1,15 +1,13 @@
//! Authentication API endpoints
use axum::{
extract::{State, Request},
extract::{Request, State},
http::StatusCode,
Json,
};
use serde::{Deserialize, Serialize};
use crate::auth::{
verify_password, AuthenticatedUser, JwtConfig,
};
use crate::auth::{verify_password, AuthenticatedUser, JwtConfig};
use crate::db;
use crate::AppState;
@@ -89,16 +87,15 @@ pub async fn login(
}
// Verify password
let password_valid = verify_password(&request.password, &user.password_hash)
.map_err(|e| {
tracing::error!("Password verification error: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: "Internal server error".to_string(),
}),
)
})?;
let password_valid = verify_password(&request.password, &user.password_hash).map_err(|e| {
tracing::error!("Password verification error: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: "Internal server error".to_string(),
}),
)
})?;
if !password_valid {
return Err((
@@ -118,21 +115,18 @@ pub async fn login(
let _ = db::update_last_login(db.pool(), user.id).await;
// Create JWT token
let token = state.jwt_config.create_token(
user.id,
&user.username,
&user.role,
permissions.clone(),
)
.map_err(|e| {
tracing::error!("Token creation error: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: "Failed to create token".to_string(),
}),
)
})?;
let token = state
.jwt_config
.create_token(user.id, &user.username, &user.role, permissions.clone())
.map_err(|e| {
tracing::error!("Token creation error: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: "Failed to create token".to_string(),
}),
)
})?;
tracing::info!("User {} logged in successfully", user.username);
@@ -288,16 +282,15 @@ pub async fn change_password(
}
// Hash new password
let new_hash = crate::auth::hash_password(&request.new_password)
.map_err(|e| {
tracing::error!("Password hashing error: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: "Failed to hash password".to_string(),
}),
)
})?;
let new_hash = crate::auth::hash_password(&request.new_password).map_err(|e| {
tracing::error!("Password hashing error: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: "Failed to hash password".to_string(),
}),
)
})?;
// Update password
db::update_user_password(db.pool(), user_id, &new_hash)

View File

@@ -1,13 +1,13 @@
//! Logout and token revocation endpoints
use axum::{
extract::{Request, State, Path},
http::{StatusCode, HeaderMap},
extract::{Path, Request, State},
http::{HeaderMap, StatusCode},
Json,
};
use uuid::Uuid;
use serde::Serialize;
use tracing::{info, warn};
use uuid::Uuid;
use crate::auth::AuthenticatedUser;
use crate::AppState;
@@ -15,7 +15,9 @@ use crate::AppState;
use super::auth::ErrorResponse;
/// Extract JWT token from Authorization header
fn extract_token_from_headers(headers: &HeaderMap) -> Result<String, (StatusCode, Json<ErrorResponse>)> {
fn extract_token_from_headers(
headers: &HeaderMap,
) -> Result<String, (StatusCode, Json<ErrorResponse>)> {
let auth_header = headers
.get("Authorization")
.and_then(|v| v.to_str().ok())
@@ -28,16 +30,14 @@ fn extract_token_from_headers(headers: &HeaderMap) -> Result<String, (StatusCode
)
})?;
let token = auth_header
.strip_prefix("Bearer ")
.ok_or_else(|| {
(
StatusCode::UNAUTHORIZED,
Json(ErrorResponse {
error: "Invalid Authorization format".to_string(),
}),
)
})?;
let token = auth_header.strip_prefix("Bearer ").ok_or_else(|| {
(
StatusCode::UNAUTHORIZED,
Json(ErrorResponse {
error: "Invalid Authorization format".to_string(),
}),
)
})?;
Ok(token.to_string())
}
@@ -124,7 +124,8 @@ pub async fn revoke_user_tokens(
Err((
StatusCode::NOT_IMPLEMENTED,
Json(ErrorResponse {
error: "User token revocation not yet implemented - requires session tracking table".to_string(),
error: "User token revocation not yet implemented - requires session tracking table"
.to_string(),
}),
))
}
@@ -179,10 +180,16 @@ pub async fn cleanup_blacklist(
));
}
let removed = state.token_blacklist.cleanup_expired(&state.jwt_config).await;
let removed = state
.token_blacklist
.cleanup_expired(&state.jwt_config)
.await;
let remaining = state.token_blacklist.len().await;
info!("Admin {} cleaned up blacklist: {} tokens removed, {} remaining", admin.username, removed, remaining);
info!(
"Admin {} cleaned up blacklist: {} tokens removed, {} remaining",
admin.username, removed, remaining
);
Ok(Json(CleanupResponse {
removed_count: removed,

View File

@@ -13,7 +13,7 @@ use axum::{
};
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use tracing::{info, warn, error};
use tracing::{error, info, warn};
/// Magic marker for embedded configuration (must match agent)
const MAGIC_MARKER: &[u8] = b"GURUCONFIG";
@@ -87,7 +87,7 @@ pub async fn download_viewer() -> impl IntoResponse {
.header(header::CONTENT_TYPE, "application/octet-stream")
.header(
header::CONTENT_DISPOSITION,
"attachment; filename=\"GuruConnect-Viewer.exe\""
"attachment; filename=\"GuruConnect-Viewer.exe\"",
)
.header(header::CONTENT_LENGTH, binary_data.len())
.body(Body::from(binary_data))
@@ -104,9 +104,7 @@ pub async fn download_viewer() -> impl IntoResponse {
}
/// Download support session binary (code embedded in filename)
pub async fn download_support(
Query(params): Query<SupportDownloadParams>,
) -> impl IntoResponse {
pub async fn download_support(Query(params): Query<SupportDownloadParams>) -> impl IntoResponse {
// Validate support code (must be 6 digits)
let code = params.code.trim();
if code.len() != 6 || !code.chars().all(|c| c.is_ascii_digit()) {
@@ -120,7 +118,11 @@ pub async fn download_support(
match std::fs::read(&binary_path) {
Ok(binary_data) => {
info!("Serving support session download for code {} ({} bytes)", code, binary_data.len());
info!(
"Serving support session download for code {} ({} bytes)",
code,
binary_data.len()
);
// Filename includes the support code
let filename = format!("GuruConnect-{}.exe", code);
@@ -130,7 +132,7 @@ pub async fn download_support(
.header(header::CONTENT_TYPE, "application/octet-stream")
.header(
header::CONTENT_DISPOSITION,
format!("attachment; filename=\"{}\"", filename)
format!("attachment; filename=\"{}\"", filename),
)
.header(header::CONTENT_LENGTH, binary_data.len())
.body(Body::from(binary_data))
@@ -147,9 +149,7 @@ pub async fn download_support(
}
/// Download permanent agent binary with embedded configuration
pub async fn download_agent(
Query(params): Query<AgentDownloadParams>,
) -> impl IntoResponse {
pub async fn download_agent(Query(params): Query<AgentDownloadParams>) -> impl IntoResponse {
let binary_path = get_base_binary_path();
// Read base binary
@@ -167,10 +167,13 @@ pub async fn download_agent(
// Build embedded config
let config = EmbeddedConfig {
server_url: "wss://connect.azcomputerguru.com/ws/agent".to_string(),
api_key: params.api_key.unwrap_or_else(|| "managed-agent".to_string()),
api_key: params
.api_key
.unwrap_or_else(|| "managed-agent".to_string()),
company: params.company.clone(),
site: params.site.clone(),
tags: params.tags
tags: params
.tags
.as_ref()
.map(|t| t.split(',').map(|s| s.trim().to_string()).collect())
.unwrap_or_default(),
@@ -196,18 +199,25 @@ pub async fn download_agent(
info!(
"Serving permanent agent download: company={:?}, site={:?}, tags={:?} ({} bytes)",
config.company, config.site, config.tags, binary_data.len()
config.company,
config.site,
config.tags,
binary_data.len()
);
// Generate filename based on company/site
let filename = match (&params.company, &params.site) {
(Some(company), Some(site)) => {
format!("GuruConnect-{}-{}-Setup.exe", sanitize_filename(company), sanitize_filename(site))
format!(
"GuruConnect-{}-{}-Setup.exe",
sanitize_filename(company),
sanitize_filename(site)
)
}
(Some(company), None) => {
format!("GuruConnect-{}-Setup.exe", sanitize_filename(company))
}
_ => "GuruConnect-Setup.exe".to_string()
_ => "GuruConnect-Setup.exe".to_string(),
};
Response::builder()
@@ -215,7 +225,7 @@ pub async fn download_agent(
.header(header::CONTENT_TYPE, "application/octet-stream")
.header(
header::CONTENT_DISPOSITION,
format!("attachment; filename=\"{}\"", filename)
format!("attachment; filename=\"{}\"", filename),
)
.header(header::CONTENT_LENGTH, binary_data.len())
.body(Body::from(binary_data))

View File

@@ -3,19 +3,19 @@
pub mod auth;
pub mod auth_logout;
pub mod changelog;
pub mod users;
pub mod releases;
pub mod downloads;
pub mod releases;
pub mod users;
use axum::{
extract::{Path, State, Query},
extract::{Path, Query, State},
Json,
};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::session::SessionManager;
use crate::db;
use crate::session::SessionManager;
/// Viewer info returned by API
#[derive(Debug, Serialize)]
@@ -78,9 +78,7 @@ impl From<crate::session::Session> for SessionInfo {
}
/// List all active sessions
pub async fn list_sessions(
State(sessions): State<SessionManager>,
) -> Json<Vec<SessionInfo>> {
pub async fn list_sessions(State(sessions): State<SessionManager>) -> Json<Vec<SessionInfo>> {
let sessions = sessions.list_sessions().await;
Json(sessions.into_iter().map(SessionInfo::from).collect())
}
@@ -93,7 +91,9 @@ pub async fn get_session(
let session_id = Uuid::parse_str(&id)
.map_err(|_| (axum::http::StatusCode::BAD_REQUEST, "Invalid session ID"))?;
let session = sessions.get_session(session_id).await
let session = sessions
.get_session(session_id)
.await
.ok_or((axum::http::StatusCode::NOT_FOUND, "Session not found"))?;
Ok(Json(SessionInfo::from(session)))

View File

@@ -129,17 +129,15 @@ pub async fn list_releases(
)
})?;
let releases = db::get_all_releases(db.pool())
.await
.map_err(|e| {
tracing::error!("Database error: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: "Failed to fetch releases".to_string(),
}),
)
})?;
let releases = db::get_all_releases(db.pool()).await.map_err(|e| {
tracing::error!("Database error: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: "Failed to fetch releases".to_string(),
}),
)
})?;
Ok(Json(releases.into_iter().map(ReleaseInfo::from).collect()))
}
@@ -171,7 +169,10 @@ pub async fn create_release(
// Validate checksum format (64 hex chars for SHA-256)
if request.checksum_sha256.len() != 64
|| !request.checksum_sha256.chars().all(|c| c.is_ascii_hexdigit())
|| !request
.checksum_sha256
.chars()
.all(|c| c.is_ascii_hexdigit())
{
return Err((
StatusCode::BAD_REQUEST,
@@ -349,17 +350,15 @@ pub async fn delete_release(
)
})?;
let deleted = db::delete_release(db.pool(), &version)
.await
.map_err(|e| {
tracing::error!("Database error: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: "Failed to delete release".to_string(),
}),
)
})?;
let deleted = db::delete_release(db.pool(), &version).await.map_err(|e| {
tracing::error!("Database error: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: "Failed to delete release".to_string(),
}),
)
})?;
if deleted {
tracing::info!("Deleted release: {}", version);

View File

@@ -72,17 +72,15 @@ pub async fn list_users(
)
})?;
let users = db::get_all_users(db.pool())
.await
.map_err(|e| {
tracing::error!("Database error: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: "Failed to fetch users".to_string(),
}),
)
})?;
let users = db::get_all_users(db.pool()).await.map_err(|e| {
tracing::error!("Database error: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: "Failed to fetch users".to_string(),
}),
)
})?;
let mut result = Vec::new();
for user in users {
@@ -210,7 +208,13 @@ pub async fn create_user(
} else {
// Default permissions based on role
let default_perms = match request.role.as_str() {
"admin" => vec!["view", "control", "transfer", "manage_users", "manage_clients"],
"admin" => vec![
"view",
"control",
"transfer",
"manage_users",
"manage_clients",
],
"operator" => vec!["view", "control", "transfer"],
"viewer" => vec!["view"],
_ => vec!["view"],
@@ -455,17 +459,15 @@ pub async fn delete_user(
));
}
let deleted = db::delete_user(db.pool(), user_id)
.await
.map_err(|e| {
tracing::error!("Database error: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: "Failed to delete user".to_string(),
}),
)
})?;
let deleted = db::delete_user(db.pool(), user_id).await.map_err(|e| {
tracing::error!("Database error: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: "Failed to delete user".to_string(),
}),
)
})?;
if deleted {
tracing::info!("Deleted user: {}", id);
@@ -506,13 +508,22 @@ pub async fn set_permissions(
})?;
// Validate permissions
let valid_permissions = ["view", "control", "transfer", "manage_users", "manage_clients"];
let valid_permissions = [
"view",
"control",
"transfer",
"manage_users",
"manage_clients",
];
for perm in &request.permissions {
if !valid_permissions.contains(&perm.as_str()) {
return Err((
StatusCode::BAD_REQUEST,
Json(ErrorResponse {
error: format!("Invalid permission: {}. Valid: {:?}", perm, valid_permissions),
error: format!(
"Invalid permission: {}. Valid: {:?}",
perm, valid_permissions
),
}),
));
}

View File

@@ -54,7 +54,10 @@ pub struct JwtConfig {
impl JwtConfig {
/// Create new JWT config
pub fn new(secret: String, expiry_hours: i64) -> Self {
Self { secret, expiry_hours }
Self {
secret,
expiry_hours,
}
}
/// Create a JWT token for a user
@@ -97,9 +100,9 @@ impl JwtConfig {
pub fn validate_token(&self, token: &str) -> Result<Claims> {
// SEC-13: Explicit validation configuration
let mut validation = Validation::default();
validation.validate_exp = true; // Enforce expiration check
validation.validate_exp = true; // Enforce expiration check
validation.validate_nbf = false; // Not using "not before" claim
validation.leeway = 0; // No clock skew tolerance
validation.leeway = 0; // No clock skew tolerance
let token_data = decode::<Claims>(
token,
@@ -129,12 +132,14 @@ mod tests {
let config = JwtConfig::new("test-secret".to_string(), 24);
let user_id = Uuid::new_v4();
let token = config.create_token(
user_id,
"testuser",
"admin",
vec!["view".to_string(), "control".to_string()],
).unwrap();
let token = config
.create_token(
user_id,
"testuser",
"admin",
vec!["view".to_string(), "control".to_string()],
)
.unwrap();
let claims = config.validate_token(&token).unwrap();
assert_eq!(claims.username, "testuser");

View File

@@ -8,7 +8,7 @@ pub mod password;
pub mod token_blacklist;
pub use jwt::{Claims, JwtConfig};
pub use password::{hash_password, verify_password, generate_random_password};
pub use password::{generate_random_password, hash_password, verify_password};
pub use token_blacklist::TokenBlacklist;
use axum::{

View File

@@ -6,7 +6,7 @@
use anyhow::{anyhow, Result};
use argon2::{
password_hash::{rand_core::OsRng, PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
Argon2, Algorithm, Version, Params,
Algorithm, Argon2, Params, Version,
};
/// Hash a password using Argon2id
@@ -22,9 +22,9 @@ pub fn hash_password(password: &str) -> Result<String> {
// Explicitly use Argon2id (Algorithm::Argon2id)
let argon2 = Argon2::new(
Algorithm::Argon2id, // SEC-9: Explicit Argon2id variant
Version::V0x13, // Latest version
Params::default(), // Default params (19456 KiB, 2 iterations, 1 parallelism)
Algorithm::Argon2id, // SEC-9: Explicit Argon2id variant
Version::V0x13, // Latest version
Params::default(), // Default params (19456 KiB, 2 iterations, 1 parallelism)
);
let hash = argon2
@@ -35,12 +35,14 @@ pub fn hash_password(password: &str) -> Result<String> {
/// Verify a password against a stored hash
pub fn verify_password(password: &str, hash: &str) -> Result<bool> {
let parsed_hash = PasswordHash::new(hash)
.map_err(|e| anyhow!("Invalid password hash format: {}", e))?;
let parsed_hash =
PasswordHash::new(hash).map_err(|e| anyhow!("Invalid password hash format: {}", e))?;
// Argon2::default() uses Argon2id, but we verify against the hash's embedded algorithm
let argon2 = Argon2::default();
Ok(argon2.verify_password(password.as_bytes(), &parsed_hash).is_ok())
Ok(argon2
.verify_password(password.as_bytes(), &parsed_hash)
.is_ok())
}
/// Generate a random password (for initial admin)

View File

@@ -6,7 +6,7 @@
use std::collections::HashSet;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{info, debug};
use tracing::{debug, info};
/// Token blacklist for revocation
///
@@ -41,7 +41,10 @@ impl TokenBlacklist {
let was_new = tokens.insert(token.to_string());
if was_new {
debug!("Token revoked and added to blacklist (length: {})", token.len());
debug!(
"Token revoked and added to blacklist (length: {})",
token.len()
);
}
}
@@ -92,7 +95,11 @@ impl TokenBlacklist {
let removed = original_len - tokens.len();
if removed > 0 {
info!("Cleaned {} expired tokens from blacklist ({} remaining)", removed, tokens.len());
info!(
"Cleaned {} expired tokens from blacklist ({} remaining)",
removed,
tokens.len()
);
}
removed

View File

@@ -36,8 +36,10 @@ impl EventTypes {
pub const CONNECTION_REJECTED_NO_AUTH: &'static str = "connection_rejected_no_auth";
pub const CONNECTION_REJECTED_INVALID_CODE: &'static str = "connection_rejected_invalid_code";
pub const CONNECTION_REJECTED_EXPIRED_CODE: &'static str = "connection_rejected_expired_code";
pub const CONNECTION_REJECTED_INVALID_API_KEY: &'static str = "connection_rejected_invalid_api_key";
pub const CONNECTION_REJECTED_CANCELLED_CODE: &'static str = "connection_rejected_cancelled_code";
pub const CONNECTION_REJECTED_INVALID_API_KEY: &'static str =
"connection_rejected_invalid_api_key";
pub const CONNECTION_REJECTED_CANCELLED_CODE: &'static str =
"connection_rejected_cancelled_code";
}
/// Log a session event

View File

@@ -80,7 +80,7 @@ pub async fn update_machine_status(
/// Get all persistent machines (for restore on startup)
pub async fn get_all_machines(pool: &PgPool) -> Result<Vec<Machine>, sqlx::Error> {
sqlx::query_as::<_, Machine>(
"SELECT * FROM connect_machines WHERE is_persistent = true ORDER BY hostname"
"SELECT * FROM connect_machines WHERE is_persistent = true ORDER BY hostname",
)
.fetch_all(pool)
.await
@@ -91,20 +91,20 @@ pub async fn get_machine_by_agent_id(
pool: &PgPool,
agent_id: &str,
) -> Result<Option<Machine>, sqlx::Error> {
sqlx::query_as::<_, Machine>(
"SELECT * FROM connect_machines WHERE agent_id = $1"
)
.bind(agent_id)
.fetch_optional(pool)
.await
sqlx::query_as::<_, Machine>("SELECT * FROM connect_machines WHERE agent_id = $1")
.bind(agent_id)
.fetch_optional(pool)
.await
}
/// Mark machine as offline
pub async fn mark_machine_offline(pool: &PgPool, agent_id: &str) -> Result<(), sqlx::Error> {
sqlx::query("UPDATE connect_machines SET status = 'offline', last_seen = NOW() WHERE agent_id = $1")
.bind(agent_id)
.execute(pool)
.await?;
sqlx::query(
"UPDATE connect_machines SET status = 'offline', last_seen = NOW() WHERE agent_id = $1",
)
.bind(agent_id)
.execute(pool)
.await?;
Ok(())
}

View File

@@ -3,24 +3,24 @@
//! Handles persistence for machines, sessions, and audit logging.
//! Optional - server works without database if DATABASE_URL not set.
pub mod machines;
pub mod sessions;
pub mod events;
pub mod machines;
pub mod releases;
pub mod sessions;
pub mod support_codes;
pub mod users;
pub mod releases;
use anyhow::Result;
use sqlx::postgres::PgPoolOptions;
use sqlx::PgPool;
use tracing::info;
pub use machines::*;
pub use sessions::*;
pub use events::*;
pub use machines::*;
pub use releases::*;
pub use sessions::*;
pub use support_codes::*;
pub use users::*;
pub use releases::*;
/// Database connection pool wrapper
#[derive(Clone)]

View File

@@ -45,7 +45,7 @@ pub async fn create_session(
pub async fn end_session(
pool: &PgPool,
session_id: Uuid,
status: &str, // 'ended' or 'disconnected' or 'timeout'
status: &str, // 'ended' or 'disconnected' or 'timeout'
) -> Result<(), sqlx::Error> {
sqlx::query(
r#"
@@ -64,7 +64,10 @@ pub async fn end_session(
}
/// Get session by ID
pub async fn get_session(pool: &PgPool, session_id: Uuid) -> Result<Option<DbSession>, sqlx::Error> {
pub async fn get_session(
pool: &PgPool,
session_id: Uuid,
) -> Result<Option<DbSession>, sqlx::Error> {
sqlx::query_as::<_, DbSession>("SELECT * FROM connect_sessions WHERE id = $1")
.bind(session_id)
.fetch_optional(pool)
@@ -85,12 +88,9 @@ pub async fn get_active_sessions_for_machine(
}
/// Get recent sessions (for dashboard)
pub async fn get_recent_sessions(
pool: &PgPool,
limit: i64,
) -> Result<Vec<DbSession>, sqlx::Error> {
pub async fn get_recent_sessions(pool: &PgPool, limit: i64) -> Result<Vec<DbSession>, sqlx::Error> {
sqlx::query_as::<_, DbSession>(
"SELECT * FROM connect_sessions ORDER BY started_at DESC LIMIT $1"
"SELECT * FROM connect_sessions ORDER BY started_at DESC LIMIT $1",
)
.bind(limit)
.fetch_all(pool)
@@ -103,7 +103,7 @@ pub async fn get_sessions_for_machine(
machine_id: Uuid,
) -> Result<Vec<DbSession>, sqlx::Error> {
sqlx::query_as::<_, DbSession>(
"SELECT * FROM connect_sessions WHERE machine_id = $1 ORDER BY started_at DESC"
"SELECT * FROM connect_sessions WHERE machine_id = $1 ORDER BY started_at DESC",
)
.bind(machine_id)
.fetch_all(pool)

View File

@@ -40,13 +40,14 @@ pub async fn create_support_code(
}
/// Get support code by code string
pub async fn get_support_code(pool: &PgPool, code: &str) -> Result<Option<DbSupportCode>, sqlx::Error> {
sqlx::query_as::<_, DbSupportCode>(
"SELECT * FROM connect_support_codes WHERE code = $1"
)
.bind(code)
.fetch_optional(pool)
.await
pub async fn get_support_code(
pool: &PgPool,
code: &str,
) -> Result<Option<DbSupportCode>, sqlx::Error> {
sqlx::query_as::<_, DbSupportCode>("SELECT * FROM connect_support_codes WHERE code = $1")
.bind(code)
.fetch_optional(pool)
.await
}
/// Update support code when client connects
@@ -107,7 +108,7 @@ pub async fn get_active_support_codes(pool: &PgPool) -> Result<Vec<DbSupportCode
/// Check if code exists and is valid for connection
pub async fn is_code_valid(pool: &PgPool, code: &str) -> Result<bool, sqlx::Error> {
let result = sqlx::query_scalar::<_, bool>(
"SELECT EXISTS(SELECT 1 FROM connect_support_codes WHERE code = $1 AND status = 'pending')"
"SELECT EXISTS(SELECT 1 FROM connect_support_codes WHERE code = $1 AND status = 'pending')",
)
.bind(code)
.fetch_one(pool)

View File

@@ -49,33 +49,27 @@ impl From<User> for UserInfo {
/// Get user by username
pub async fn get_user_by_username(pool: &PgPool, username: &str) -> Result<Option<User>> {
let user = sqlx::query_as::<_, User>(
"SELECT * FROM users WHERE username = $1"
)
.bind(username)
.fetch_optional(pool)
.await?;
let user = sqlx::query_as::<_, User>("SELECT * FROM users WHERE username = $1")
.bind(username)
.fetch_optional(pool)
.await?;
Ok(user)
}
/// Get user by ID
pub async fn get_user_by_id(pool: &PgPool, id: Uuid) -> Result<Option<User>> {
let user = sqlx::query_as::<_, User>(
"SELECT * FROM users WHERE id = $1"
)
.bind(id)
.fetch_optional(pool)
.await?;
let user = sqlx::query_as::<_, User>("SELECT * FROM users WHERE id = $1")
.bind(id)
.fetch_optional(pool)
.await?;
Ok(user)
}
/// Get all users
pub async fn get_all_users(pool: &PgPool) -> Result<Vec<User>> {
let users = sqlx::query_as::<_, User>(
"SELECT * FROM users ORDER BY username"
)
.fetch_all(pool)
.await?;
let users = sqlx::query_as::<_, User>("SELECT * FROM users ORDER BY username")
.fetch_all(pool)
.await?;
Ok(users)
}
@@ -92,7 +86,7 @@ pub async fn create_user(
INSERT INTO users (username, password_hash, email, role)
VALUES ($1, $2, $3, $4)
RETURNING *
"#
"#,
)
.bind(username)
.bind(password_hash)
@@ -117,7 +111,7 @@ pub async fn update_user(
SET email = $2, role = $3, enabled = $4, updated_at = NOW()
WHERE id = $1
RETURNING *
"#
"#,
)
.bind(id)
.bind(email)
@@ -129,18 +123,13 @@ pub async fn update_user(
}
/// Update user password
pub async fn update_user_password(
pool: &PgPool,
id: Uuid,
password_hash: &str,
) -> Result<bool> {
let result = sqlx::query(
"UPDATE users SET password_hash = $2, updated_at = NOW() WHERE id = $1"
)
.bind(id)
.bind(password_hash)
.execute(pool)
.await?;
pub async fn update_user_password(pool: &PgPool, id: Uuid, password_hash: &str) -> Result<bool> {
let result =
sqlx::query("UPDATE users SET password_hash = $2, updated_at = NOW() WHERE id = $1")
.bind(id)
.bind(password_hash)
.execute(pool)
.await?;
Ok(result.rows_affected() > 0)
}
@@ -172,12 +161,11 @@ pub async fn count_users(pool: &PgPool) -> Result<i64> {
/// Get user permissions
pub async fn get_user_permissions(pool: &PgPool, user_id: Uuid) -> Result<Vec<String>> {
let perms: Vec<(String,)> = sqlx::query_as(
"SELECT permission FROM user_permissions WHERE user_id = $1"
)
.bind(user_id)
.fetch_all(pool)
.await?;
let perms: Vec<(String,)> =
sqlx::query_as("SELECT permission FROM user_permissions WHERE user_id = $1")
.bind(user_id)
.fetch_all(pool)
.await?;
Ok(perms.into_iter().map(|p| p.0).collect())
}
@@ -195,25 +183,22 @@ pub async fn set_user_permissions(
// Insert new
for perm in permissions {
sqlx::query(
"INSERT INTO user_permissions (user_id, permission) VALUES ($1, $2)"
)
.bind(user_id)
.bind(perm)
.execute(pool)
.await?;
sqlx::query("INSERT INTO user_permissions (user_id, permission) VALUES ($1, $2)")
.bind(user_id)
.bind(perm)
.execute(pool)
.await?;
}
Ok(())
}
/// Get user's accessible client IDs (empty = all access)
pub async fn get_user_client_access(pool: &PgPool, user_id: Uuid) -> Result<Vec<Uuid>> {
let clients: Vec<(Uuid,)> = sqlx::query_as(
"SELECT client_id FROM user_client_access WHERE user_id = $1"
)
.bind(user_id)
.fetch_all(pool)
.await?;
let clients: Vec<(Uuid,)> =
sqlx::query_as("SELECT client_id FROM user_client_access WHERE user_id = $1")
.bind(user_id)
.fetch_all(pool)
.await?;
Ok(clients.into_iter().map(|c| c.0).collect())
}
@@ -231,23 +216,17 @@ pub async fn set_user_client_access(
// Insert new
for client_id in client_ids {
sqlx::query(
"INSERT INTO user_client_access (user_id, client_id) VALUES ($1, $2)"
)
.bind(user_id)
.bind(client_id)
.execute(pool)
.await?;
sqlx::query("INSERT INTO user_client_access (user_id, client_id) VALUES ($1, $2)")
.bind(user_id)
.bind(client_id)
.execute(pool)
.await?;
}
Ok(())
}
/// Check if user has access to a specific client
pub async fn user_has_client_access(
pool: &PgPool,
user_id: Uuid,
client_id: Uuid,
) -> Result<bool> {
pub async fn user_has_client_access(pool: &PgPool, user_id: Uuid, client_id: Uuid) -> Result<bool> {
// Admins have access to all
let user = get_user_by_id(pool, user_id).await?;
if let Some(u) = user {
@@ -258,7 +237,7 @@ pub async fn user_has_client_access(
// Check explicit access
let access: Option<(Uuid,)> = sqlx::query_as(
"SELECT client_id FROM user_client_access WHERE user_id = $1 AND client_id = $2"
"SELECT client_id FROM user_client_access WHERE user_id = $1 AND client_id = $2",
)
.bind(user_id)
.bind(client_id)
@@ -271,12 +250,11 @@ pub async fn user_has_client_access(
}
// Check if user has ANY access restrictions
let count: (i64,) = sqlx::query_as(
"SELECT COUNT(*) FROM user_client_access WHERE user_id = $1"
)
.bind(user_id)
.fetch_one(pool)
.await?;
let count: (i64,) =
sqlx::query_as("SELECT COUNT(*) FROM user_client_access WHERE user_id = $1")
.bind(user_id)
.fetch_one(pool)
.await?;
// No restrictions means access to all
Ok(count.0 == 0)

View File

@@ -3,44 +3,44 @@
//! Handles connections from both agents and dashboard viewers,
//! relaying video frames and input events between them.
mod api;
mod auth;
mod config;
mod db;
mod metrics;
mod middleware;
mod relay;
mod session;
mod auth;
mod api;
mod db;
mod support_codes;
mod middleware;
mod utils;
mod metrics;
pub mod proto {
include!(concat!(env!("OUT_DIR"), "/guruconnect.rs"));
}
use anyhow::Result;
use axum::http::{HeaderValue, Method};
use axum::{
Router,
routing::{get, post, put, delete},
extract::{Path, State, Json, Query, Request},
response::{Html, IntoResponse},
extract::{Json, Path, Query, Request, State},
http::StatusCode,
middleware::{self as axum_middleware, Next},
response::{Html, IntoResponse},
routing::{delete, get, post, put},
Router,
};
use serde::Deserialize;
use std::net::SocketAddr;
use std::sync::Arc;
use tower_http::cors::{Any, CorsLayer, AllowOrigin};
use axum::http::{Method, HeaderValue};
use tower_http::trace::TraceLayer;
use tower_http::cors::{AllowOrigin, Any, CorsLayer};
use tower_http::services::ServeDir;
use tower_http::trace::TraceLayer;
use tracing::{info, Level};
use tracing_subscriber::FmtSubscriber;
use serde::Deserialize;
use support_codes::{SupportCodeManager, CreateCodeRequest, SupportCode, CodeValidation};
use auth::{JwtConfig, TokenBlacklist, hash_password, generate_random_password, AuthenticatedUser};
use auth::{generate_random_password, hash_password, AuthenticatedUser, JwtConfig, TokenBlacklist};
use metrics::SharedMetrics;
use prometheus_client::registry::Registry;
use support_codes::{CodeValidation, CreateCodeRequest, SupportCode, SupportCodeManager};
/// Application state
#[derive(Clone)]
@@ -67,7 +67,9 @@ async fn auth_layer(
next: Next,
) -> impl IntoResponse {
request.extensions_mut().insert(state.jwt_config.clone());
request.extensions_mut().insert(Arc::new(state.token_blacklist.clone()));
request
.extensions_mut()
.insert(Arc::new(state.token_blacklist.clone()));
next.run(request).await
}
@@ -89,8 +91,9 @@ async fn main() -> Result<()> {
info!("Loaded configuration, listening on {}", listen_addr);
// JWT configuration - REQUIRED for security
let jwt_secret = std::env::var("JWT_SECRET")
.expect("JWT_SECRET environment variable must be set! Generate one with: openssl rand -base64 64");
let jwt_secret = std::env::var("JWT_SECRET").expect(
"JWT_SECRET environment variable must be set! Generate one with: openssl rand -base64 64",
);
if jwt_secret.len() < 32 {
panic!("JWT_SECRET must be at least 32 characters long for security!");
@@ -114,7 +117,10 @@ async fn main() -> Result<()> {
Some(db)
}
Err(e) => {
tracing::warn!("Failed to connect to database: {}. Running without persistence.", e);
tracing::warn!(
"Failed to connect to database: {}. Running without persistence.",
e
);
None
}
}
@@ -194,9 +200,14 @@ async fn main() -> Result<()> {
if let Some(ref db) = database {
match db::machines::get_all_machines(db.pool()).await {
Ok(machines) => {
info!("Restoring {} persistent machines from database", machines.len());
info!(
"Restoring {} persistent machines from database",
machines.len()
);
for machine in machines {
sessions.restore_offline_machine(&machine.agent_id, &machine.hostname).await;
sessions
.restore_offline_machine(&machine.agent_id, &machine.hostname)
.await;
}
}
Err(e) => {
@@ -254,92 +265,117 @@ async fn main() -> Result<()> {
.route("/health", get(health))
// Prometheus metrics (no auth required - for monitoring)
.route("/metrics", get(prometheus_metrics))
// Auth endpoints (TODO: Add rate limiting - see SEC2_RATE_LIMITING_TODO.md)
.route("/api/auth/login", post(api::auth::login))
.route("/api/auth/change-password", post(api::auth::change_password))
.route(
"/api/auth/change-password",
post(api::auth::change_password),
)
.route("/api/auth/me", get(api::auth::get_me))
.route("/api/auth/logout", post(api::auth_logout::logout))
.route("/api/auth/revoke-token", post(api::auth_logout::revoke_own_token))
.route("/api/auth/admin/revoke-user", post(api::auth_logout::revoke_user_tokens))
.route("/api/auth/blacklist/stats", get(api::auth_logout::get_blacklist_stats))
.route("/api/auth/blacklist/cleanup", post(api::auth_logout::cleanup_blacklist))
.route(
"/api/auth/revoke-token",
post(api::auth_logout::revoke_own_token),
)
.route(
"/api/auth/admin/revoke-user",
post(api::auth_logout::revoke_user_tokens),
)
.route(
"/api/auth/blacklist/stats",
get(api::auth_logout::get_blacklist_stats),
)
.route(
"/api/auth/blacklist/cleanup",
post(api::auth_logout::cleanup_blacklist),
)
// User management (admin only)
.route("/api/users", get(api::users::list_users))
.route("/api/users", post(api::users::create_user))
.route("/api/users/:id", get(api::users::get_user))
.route("/api/users/:id", put(api::users::update_user))
.route("/api/users/:id", delete(api::users::delete_user))
.route("/api/users/:id/permissions", put(api::users::set_permissions))
.route(
"/api/users/:id/permissions",
put(api::users::set_permissions),
)
.route("/api/users/:id/clients", put(api::users::set_client_access))
// Portal API - Support codes (TODO: Add rate limiting)
.route("/api/codes", post(create_code))
.route("/api/codes", get(list_codes))
.route("/api/codes/:code/validate", get(validate_code))
.route("/api/codes/:code/cancel", post(cancel_code))
// WebSocket endpoints
.route("/ws/agent", get(relay::agent_ws_handler))
.route("/ws/viewer", get(relay::viewer_ws_handler))
// REST API - Sessions
.route("/api/sessions", get(list_sessions))
.route("/api/sessions/:id", get(get_session))
.route("/api/sessions/:id", delete(disconnect_session))
// REST API - Machines
.route("/api/machines", get(list_machines))
.route("/api/machines/:agent_id", get(get_machine))
.route("/api/machines/:agent_id", delete(delete_machine))
.route("/api/machines/:agent_id/history", get(get_machine_history))
.route("/api/machines/:agent_id/update", post(trigger_machine_update))
.route(
"/api/machines/:agent_id/update",
post(trigger_machine_update),
)
// REST API - Releases and Version
.route("/api/version", get(api::releases::get_version)) // No auth - for agent polling
.route("/api/version", get(api::releases::get_version)) // No auth - for agent polling
.route("/api/releases", get(api::releases::list_releases))
.route("/api/releases", post(api::releases::create_release))
.route("/api/releases/:version", get(api::releases::get_release))
.route("/api/releases/:version", put(api::releases::update_release))
.route("/api/releases/:version", delete(api::releases::delete_release))
.route(
"/api/releases/:version",
delete(api::releases::delete_release),
)
// Changelog (no auth - public, like /api/version)
// Single route: version == "latest" selects the latest file; axum 0.7 / matchit 0.7
// panics if a static segment and a path param share this position, so do not split it.
.route("/api/changelog/:component/:version", get(api::changelog::get))
.route(
"/api/changelog/:component/:version",
get(api::changelog::get),
)
// Agent downloads (no auth - public download links)
.route("/api/download/viewer", get(api::downloads::download_viewer))
.route("/api/download/support", get(api::downloads::download_support))
.route(
"/api/download/support",
get(api::downloads::download_support),
)
.route("/api/download/agent", get(api::downloads::download_agent))
// HTML page routes (clean URLs)
.route("/login", get(serve_login))
.route("/dashboard", get(serve_dashboard))
.route("/users", get(serve_users))
// State and middleware
.with_state(state.clone())
.layer(axum_middleware::from_fn_with_state(state, auth_layer))
// Serve static files for portal (fallback)
.fallback_service(ServeDir::new("static").append_index_html_on_directories(true))
// Middleware
.layer(axum_middleware::from_fn(middleware::add_security_headers)) // SEC-7 & SEC-12
.layer(axum_middleware::from_fn(middleware::add_security_headers)) // SEC-7 & SEC-12
.layer(TraceLayer::new_for_http())
// SEC-11: Restricted CORS configuration
.layer({
let cors = CorsLayer::new()
// Allow requests from the production domain and localhost (for development)
.allow_origin([
"https://connect.azcomputerguru.com".parse::<HeaderValue>().unwrap(),
"https://connect.azcomputerguru.com"
.parse::<HeaderValue>()
.unwrap(),
"http://localhost:3002".parse::<HeaderValue>().unwrap(),
"http://127.0.0.1:3002".parse::<HeaderValue>().unwrap(),
])
// Allow only necessary HTTP methods
.allow_methods([Method::GET, Method::POST, Method::PUT, Method::DELETE, Method::OPTIONS])
.allow_methods([
Method::GET,
Method::POST,
Method::PUT,
Method::DELETE,
Method::OPTIONS,
])
// Allow common headers needed for API requests
.allow_headers([
axum::http::header::AUTHORIZATION,
@@ -360,8 +396,9 @@ async fn main() -> Result<()> {
// Use into_make_service_with_connect_info to enable IP address extraction
axum::serve(
listener,
app.into_make_service_with_connect_info::<SocketAddr>()
).await?;
app.into_make_service_with_connect_info::<SocketAddr>(),
)
.await?;
Ok(())
}
@@ -371,9 +408,7 @@ async fn health() -> &'static str {
}
/// Prometheus metrics endpoint
async fn prometheus_metrics(
State(state): State<AppState>,
) -> String {
async fn prometheus_metrics(State(state): State<AppState>) -> String {
use prometheus_client::encoding::text::encode;
let registry = state.registry.lock().unwrap();
@@ -385,7 +420,7 @@ async fn prometheus_metrics(
// Support code API handlers
async fn create_code(
_user: AuthenticatedUser, // Require authentication
_user: AuthenticatedUser, // Require authentication
State(state): State<AppState>,
Json(request): Json<CreateCodeRequest>,
) -> Json<SupportCode> {
@@ -395,7 +430,7 @@ async fn create_code(
}
async fn list_codes(
_user: AuthenticatedUser, // Require authentication
_user: AuthenticatedUser, // Require authentication
State(state): State<AppState>,
) -> Json<Vec<SupportCode>> {
Json(state.support_codes.list_active_codes().await)
@@ -414,7 +449,7 @@ async fn validate_code(
}
async fn cancel_code(
_user: AuthenticatedUser, // Require authentication
_user: AuthenticatedUser, // Require authentication
State(state): State<AppState>,
Path(code): Path<String>,
) -> impl IntoResponse {
@@ -428,7 +463,7 @@ async fn cancel_code(
// Session API handlers (updated to use AppState)
async fn list_sessions(
_user: AuthenticatedUser, // Require authentication
_user: AuthenticatedUser, // Require authentication
State(state): State<AppState>,
) -> Json<Vec<api::SessionInfo>> {
let sessions = state.sessions.list_sessions().await;
@@ -436,21 +471,24 @@ async fn list_sessions(
}
async fn get_session(
_user: AuthenticatedUser, // Require authentication
_user: AuthenticatedUser, // Require authentication
State(state): State<AppState>,
Path(id): Path<String>,
) -> Result<Json<api::SessionInfo>, (StatusCode, &'static str)> {
let session_id = uuid::Uuid::parse_str(&id)
.map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID"))?;
let session_id =
uuid::Uuid::parse_str(&id).map_err(|_| (StatusCode::BAD_REQUEST, "Invalid session ID"))?;
let session = state.sessions.get_session(session_id).await
let session = state
.sessions
.get_session(session_id)
.await
.ok_or((StatusCode::NOT_FOUND, "Session not found"))?;
Ok(Json(api::SessionInfo::from(session)))
}
async fn disconnect_session(
_user: AuthenticatedUser, // Require authentication
_user: AuthenticatedUser, // Require authentication
State(state): State<AppState>,
Path(id): Path<String>,
) -> impl IntoResponse {
@@ -459,7 +497,11 @@ async fn disconnect_session(
Err(_) => return (StatusCode::BAD_REQUEST, "Invalid session ID"),
};
if state.sessions.disconnect_session(session_id, "Disconnected by administrator").await {
if state
.sessions
.disconnect_session(session_id, "Disconnected by administrator")
.await
{
info!("Session {} disconnected by admin", session_id);
(StatusCode::OK, "Session disconnected")
} else {
@@ -470,27 +512,35 @@ async fn disconnect_session(
// Machine API handlers
async fn list_machines(
_user: AuthenticatedUser, // Require authentication
_user: AuthenticatedUser, // Require authentication
State(state): State<AppState>,
) -> Result<Json<Vec<api::MachineInfo>>, (StatusCode, &'static str)> {
let db = state.db.as_ref()
let db = state
.db
.as_ref()
.ok_or((StatusCode::SERVICE_UNAVAILABLE, "Database not available"))?;
let machines = db::machines::get_all_machines(db.pool()).await
let machines = db::machines::get_all_machines(db.pool())
.await
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Database error"))?;
Ok(Json(machines.into_iter().map(api::MachineInfo::from).collect()))
Ok(Json(
machines.into_iter().map(api::MachineInfo::from).collect(),
))
}
async fn get_machine(
_user: AuthenticatedUser, // Require authentication
_user: AuthenticatedUser, // Require authentication
State(state): State<AppState>,
Path(agent_id): Path<String>,
) -> Result<Json<api::MachineInfo>, (StatusCode, &'static str)> {
let db = state.db.as_ref()
let db = state
.db
.as_ref()
.ok_or((StatusCode::SERVICE_UNAVAILABLE, "Database not available"))?;
let machine = db::machines::get_machine_by_agent_id(db.pool(), &agent_id).await
let machine = db::machines::get_machine_by_agent_id(db.pool(), &agent_id)
.await
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Database error"))?
.ok_or((StatusCode::NOT_FOUND, "Machine not found"))?;
@@ -498,24 +548,29 @@ async fn get_machine(
}
async fn get_machine_history(
_user: AuthenticatedUser, // Require authentication
_user: AuthenticatedUser, // Require authentication
State(state): State<AppState>,
Path(agent_id): Path<String>,
) -> Result<Json<api::MachineHistory>, (StatusCode, &'static str)> {
let db = state.db.as_ref()
let db = state
.db
.as_ref()
.ok_or((StatusCode::SERVICE_UNAVAILABLE, "Database not available"))?;
// Get machine
let machine = db::machines::get_machine_by_agent_id(db.pool(), &agent_id).await
let machine = db::machines::get_machine_by_agent_id(db.pool(), &agent_id)
.await
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Database error"))?
.ok_or((StatusCode::NOT_FOUND, "Machine not found"))?;
// Get sessions for this machine
let sessions = db::sessions::get_sessions_for_machine(db.pool(), machine.id).await
let sessions = db::sessions::get_sessions_for_machine(db.pool(), machine.id)
.await
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Database error"))?;
// Get events for this machine
let events = db::events::get_events_for_machine(db.pool(), machine.id).await
let events = db::events::get_events_for_machine(db.pool(), machine.id)
.await
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Database error"))?;
let history = api::MachineHistory {
@@ -529,24 +584,29 @@ async fn get_machine_history(
}
async fn delete_machine(
_user: AuthenticatedUser, // Require authentication
_user: AuthenticatedUser, // Require authentication
State(state): State<AppState>,
Path(agent_id): Path<String>,
Query(params): Query<api::DeleteMachineParams>,
) -> Result<Json<api::DeleteMachineResponse>, (StatusCode, &'static str)> {
let db = state.db.as_ref()
let db = state
.db
.as_ref()
.ok_or((StatusCode::SERVICE_UNAVAILABLE, "Database not available"))?;
// Get machine first
let machine = db::machines::get_machine_by_agent_id(db.pool(), &agent_id).await
let machine = db::machines::get_machine_by_agent_id(db.pool(), &agent_id)
.await
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Database error"))?
.ok_or((StatusCode::NOT_FOUND, "Machine not found"))?;
// Export history if requested
let history = if params.export {
let sessions = db::sessions::get_sessions_for_machine(db.pool(), machine.id).await
let sessions = db::sessions::get_sessions_for_machine(db.pool(), machine.id)
.await
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Database error"))?;
let events = db::events::get_events_for_machine(db.pool(), machine.id).await
let events = db::events::get_events_for_machine(db.pool(), machine.id)
.await
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Database error"))?;
Some(api::MachineHistory {
@@ -565,11 +625,14 @@ async fn delete_machine(
// Find session for this agent
if let Some(session) = state.sessions.get_session_by_agent(&agent_id).await {
if session.is_online {
uninstall_sent = state.sessions.send_admin_command(
session.id,
proto::AdminCommandType::AdminUninstall,
"Deleted by administrator",
).await;
uninstall_sent = state
.sessions
.send_admin_command(
session.id,
proto::AdminCommandType::AdminUninstall,
"Deleted by administrator",
)
.await;
if uninstall_sent {
info!("Sent uninstall command to agent {}", agent_id);
}
@@ -581,10 +644,19 @@ async fn delete_machine(
state.sessions.remove_agent(&agent_id).await;
// Delete from database (cascades to sessions and events)
db::machines::delete_machine(db.pool(), &agent_id).await
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Failed to delete machine"))?;
db::machines::delete_machine(db.pool(), &agent_id)
.await
.map_err(|_| {
(
StatusCode::INTERNAL_SERVER_ERROR,
"Failed to delete machine",
)
})?;
info!("Deleted machine {} (uninstall_sent: {})", agent_id, uninstall_sent);
info!(
"Deleted machine {} (uninstall_sent: {})",
agent_id, uninstall_sent
);
Ok(Json(api::DeleteMachineResponse {
success: true,
@@ -603,27 +675,34 @@ struct TriggerUpdateRequest {
/// Trigger update on a specific machine
async fn trigger_machine_update(
_user: AuthenticatedUser, // Require authentication
_user: AuthenticatedUser, // Require authentication
State(state): State<AppState>,
Path(agent_id): Path<String>,
Json(request): Json<TriggerUpdateRequest>,
) -> Result<impl IntoResponse, (StatusCode, &'static str)> {
let db = state.db.as_ref()
let db = state
.db
.as_ref()
.ok_or((StatusCode::SERVICE_UNAVAILABLE, "Database not available"))?;
// Get the target release (either specified or latest stable)
let release = if let Some(version) = request.version {
db::releases::get_release_by_version(db.pool(), &version).await
db::releases::get_release_by_version(db.pool(), &version)
.await
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Database error"))?
.ok_or((StatusCode::NOT_FOUND, "Release version not found"))?
} else {
db::releases::get_latest_stable_release(db.pool()).await
db::releases::get_latest_stable_release(db.pool())
.await
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "Database error"))?
.ok_or((StatusCode::NOT_FOUND, "No stable release available"))?
};
// Find session for this agent
let session = state.sessions.get_session_by_agent(&agent_id).await
let session = state
.sessions
.get_session_by_agent(&agent_id)
.await
.ok_or((StatusCode::NOT_FOUND, "Agent not found or offline"))?;
if !session.is_online {
@@ -632,21 +711,31 @@ async fn trigger_machine_update(
// Send update command via WebSocket
// For now, we send admin command - later we'll include UpdateInfo in the message
let sent = state.sessions.send_admin_command(
session.id,
proto::AdminCommandType::AdminUpdate,
&format!("Update to version {}", release.version),
).await;
let sent = state
.sessions
.send_admin_command(
session.id,
proto::AdminCommandType::AdminUpdate,
&format!("Update to version {}", release.version),
)
.await;
if sent {
info!("Sent update command to agent {} (version {})", agent_id, release.version);
info!(
"Sent update command to agent {} (version {})",
agent_id, release.version
);
// Update machine update status in database
let _ = db::releases::update_machine_update_status(db.pool(), &agent_id, "downloading").await;
let _ =
db::releases::update_machine_update_status(db.pool(), &agent_id, "downloading").await;
Ok((StatusCode::OK, "Update command sent"))
} else {
Err((StatusCode::INTERNAL_SERVER_ERROR, "Failed to send update command"))
Err((
StatusCode::INTERNAL_SERVER_ERROR,
"Failed to send update command",
))
}
}

View File

@@ -22,26 +22,26 @@ pub struct RequestLabels {
/// Metrics labels for session events
#[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelSet)]
pub struct SessionLabels {
pub status: String, // created, closed, failed, expired
pub status: String, // created, closed, failed, expired
}
/// Metrics labels for connection events
#[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelSet)]
pub struct ConnectionLabels {
pub conn_type: String, // agent, viewer, dashboard
pub conn_type: String, // agent, viewer, dashboard
}
/// Metrics labels for error tracking
#[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelSet)]
pub struct ErrorLabels {
pub error_type: String, // auth, database, websocket, protocol, internal
pub error_type: String, // auth, database, websocket, protocol, internal
}
/// Metrics labels for database operations
#[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelSet)]
pub struct DatabaseLabels {
pub operation: String, // select, insert, update, delete
pub status: String, // success, error
pub operation: String, // select, insert, update, delete
pub status: String, // success, error
}
/// GuruConnect server metrics
@@ -82,9 +82,10 @@ impl Metrics {
requests_total.clone(),
);
let request_duration_seconds = Family::<RequestLabels, Histogram>::new_with_constructor(|| {
Histogram::new(exponential_buckets(0.001, 2.0, 10)) // 1ms to ~1s
});
let request_duration_seconds =
Family::<RequestLabels, Histogram>::new_with_constructor(|| {
Histogram::new(exponential_buckets(0.001, 2.0, 10)) // 1ms to ~1s
});
registry.register(
"guruconnect_request_duration_seconds",
"HTTP request duration in seconds",
@@ -106,7 +107,7 @@ impl Metrics {
active_sessions.clone(),
);
let session_duration_seconds = Histogram::new(exponential_buckets(1.0, 2.0, 15)); // 1s to ~9 hours
let session_duration_seconds = Histogram::new(exponential_buckets(1.0, 2.0, 15)); // 1s to ~9 hours
registry.register(
"guruconnect_session_duration_seconds",
"Session duration in seconds",
@@ -144,9 +145,10 @@ impl Metrics {
db_operations_total.clone(),
);
let db_query_duration_seconds = Family::<DatabaseLabels, Histogram>::new_with_constructor(|| {
Histogram::new(exponential_buckets(0.0001, 2.0, 12)) // 0.1ms to ~400ms
});
let db_query_duration_seconds =
Family::<DatabaseLabels, Histogram>::new_with_constructor(|| {
Histogram::new(exponential_buckets(0.0001, 2.0, 12)) // 0.1ms to ~400ms
});
registry.register(
"guruconnect_db_query_duration_seconds",
"Database query duration in seconds",
@@ -188,7 +190,13 @@ impl Metrics {
}
/// Record request duration
pub fn record_request_duration(&self, method: &str, path: &str, status: u16, duration_secs: f64) {
pub fn record_request_duration(
&self,
method: &str,
path: &str,
status: u16,
duration_secs: f64,
) {
self.request_duration_seconds
.get_or_create(&RequestLabels {
method: method.to_string(),

View File

@@ -3,17 +3,10 @@
//! SEC-7: XSS Prevention via Content-Security-Policy
//! SEC-12: Additional security headers
use axum::{
extract::Request,
middleware::Next,
response::Response,
};
use axum::{extract::Request, middleware::Next, response::Response};
/// Add security headers to all responses
pub async fn add_security_headers(
request: Request,
next: Next,
) -> Response {
pub async fn add_security_headers(request: Request, next: Next) -> Response {
let mut response = next.run(request).await;
let headers = response.headers_mut();
@@ -35,22 +28,13 @@ pub async fn add_security_headers(
);
// SEC-12: X-Frame-Options (Clickjacking protection)
headers.insert(
"X-Frame-Options",
"DENY".parse().unwrap(),
);
headers.insert("X-Frame-Options", "DENY".parse().unwrap());
// SEC-12: X-Content-Type-Options (MIME sniffing protection)
headers.insert(
"X-Content-Type-Options",
"nosniff".parse().unwrap(),
);
headers.insert("X-Content-Type-Options", "nosniff".parse().unwrap());
// SEC-12: X-XSS-Protection (Legacy XSS filter - deprecated but still useful)
headers.insert(
"X-XSS-Protection",
"1; mode=block".parse().unwrap(),
);
headers.insert("X-XSS-Protection", "1; mode=block".parse().unwrap());
// SEC-12: Referrer-Policy (Control referrer information)
headers.insert(

View File

@@ -6,21 +6,21 @@
use axum::{
extract::{
ws::{Message, WebSocket, WebSocketUpgrade},
Query, State, ConnectInfo,
ConnectInfo, Query, State,
},
response::IntoResponse,
http::StatusCode,
response::IntoResponse,
};
use std::net::SocketAddr;
use futures_util::{SinkExt, StreamExt};
use prost::Message as ProstMessage;
use serde::Deserialize;
use std::net::SocketAddr;
use tracing::{error, info, warn};
use uuid::Uuid;
use crate::db::{self, Database};
use crate::proto;
use crate::session::SessionManager;
use crate::db::{self, Database};
use crate::AppState;
#[derive(Debug, Deserialize)]
@@ -59,7 +59,11 @@ pub async fn agent_ws_handler(
Query(params): Query<AgentParams>,
) -> Result<impl IntoResponse, StatusCode> {
let agent_id = params.agent_id.clone();
let agent_name = params.hostname.clone().or(params.agent_name.clone()).unwrap_or_else(|| agent_id.clone());
let agent_name = params
.hostname
.clone()
.or(params.agent_name.clone())
.unwrap_or_else(|| agent_id.clone());
let support_code = params.support_code.clone();
let api_key = params.api_key.clone();
let client_ip = addr.ip();
@@ -69,7 +73,10 @@ pub async fn agent_ws_handler(
// API key = persistent managed agent
if support_code.is_none() && api_key.is_none() {
warn!("Agent connection rejected: {} from {} - no support code or API key", agent_id, client_ip);
warn!(
"Agent connection rejected: {} from {} - no support code or API key",
agent_id, client_ip
);
// Log failed connection attempt to database
if let Some(ref db) = state.db {
@@ -84,7 +91,8 @@ pub async fn agent_ws_handler(
"agent_id": agent_id
})),
Some(client_ip),
).await;
)
.await;
}
return Err(StatusCode::UNAUTHORIZED);
@@ -95,7 +103,10 @@ pub async fn agent_ws_handler(
// Check if it's a valid, pending support code
let code_info = state.support_codes.get_status(code).await;
if code_info.is_none() {
warn!("Agent connection rejected: {} from {} - invalid support code {}", agent_id, client_ip, code);
warn!(
"Agent connection rejected: {} from {} - invalid support code {}",
agent_id, client_ip, code
);
// Log failed connection attempt
if let Some(ref db) = state.db {
@@ -111,14 +122,18 @@ pub async fn agent_ws_handler(
"agent_id": agent_id
})),
Some(client_ip),
).await;
)
.await;
}
return Err(StatusCode::UNAUTHORIZED);
}
let status = code_info.unwrap();
if status != "pending" && status != "connected" {
warn!("Agent connection rejected: {} from {} - support code {} has status {}", agent_id, client_ip, code, status);
warn!(
"Agent connection rejected: {} from {} - support code {} has status {}",
agent_id, client_ip, code, status
);
// Log failed connection attempt (expired/cancelled code)
if let Some(ref db) = state.db {
@@ -140,12 +155,16 @@ pub async fn agent_ws_handler(
"agent_id": agent_id
})),
Some(client_ip),
).await;
)
.await;
}
return Err(StatusCode::UNAUTHORIZED);
}
info!("Agent {} from {} authenticated via support code {}", agent_id, client_ip, code);
info!(
"Agent {} from {} authenticated via support code {}",
agent_id, client_ip, code
);
}
// Validate API key if provided (for persistent agents)
@@ -153,7 +172,10 @@ pub async fn agent_ws_handler(
// For now, we'll accept API keys that match the JWT secret or a configured agent key
// In production, this should validate against a database of registered agents
if !validate_agent_api_key(&state, key).await {
warn!("Agent connection rejected: {} from {} - invalid API key", agent_id, client_ip);
warn!(
"Agent connection rejected: {} from {} - invalid API key",
agent_id, client_ip
);
// Log failed connection attempt
if let Some(ref db) = state.db {
@@ -168,19 +190,34 @@ pub async fn agent_ws_handler(
"agent_id": agent_id
})),
Some(client_ip),
).await;
)
.await;
}
return Err(StatusCode::UNAUTHORIZED);
}
info!("Agent {} from {} authenticated via API key", agent_id, client_ip);
info!(
"Agent {} from {} authenticated via API key",
agent_id, client_ip
);
}
let sessions = state.sessions.clone();
let support_codes = state.support_codes.clone();
let db = state.db.clone();
Ok(ws.on_upgrade(move |socket| handle_agent_connection(socket, sessions, support_codes, db, agent_id, agent_name, support_code, Some(client_ip))))
Ok(ws.on_upgrade(move |socket| {
handle_agent_connection(
socket,
sessions,
support_codes,
db,
agent_id,
agent_name,
support_code,
Some(client_ip),
)
}))
}
/// Validate an agent API key
@@ -212,24 +249,42 @@ pub async fn viewer_ws_handler(
// Require JWT token for viewers
let token = params.token.ok_or_else(|| {
warn!("Viewer connection rejected from {}: missing token", client_ip);
warn!(
"Viewer connection rejected from {}: missing token",
client_ip
);
StatusCode::UNAUTHORIZED
})?;
// Validate the token
let claims = state.jwt_config.validate_token(&token).map_err(|e| {
warn!("Viewer connection rejected from {}: invalid token: {}", client_ip, e);
warn!(
"Viewer connection rejected from {}: invalid token: {}",
client_ip, e
);
StatusCode::UNAUTHORIZED
})?;
info!("Viewer {} authenticated via JWT from {}", claims.username, client_ip);
info!(
"Viewer {} authenticated via JWT from {}",
claims.username, client_ip
);
let session_id = params.session_id;
let viewer_name = params.viewer_name;
let sessions = state.sessions.clone();
let db = state.db.clone();
Ok(ws.on_upgrade(move |socket| handle_viewer_connection(socket, sessions, db, session_id, viewer_name, Some(client_ip))))
Ok(ws.on_upgrade(move |socket| {
handle_viewer_connection(
socket,
sessions,
db,
session_id,
viewer_name,
Some(client_ip),
)
}))
}
/// Handle an agent WebSocket connection
@@ -243,7 +298,10 @@ async fn handle_agent_connection(
support_code: Option<String>,
client_ip: Option<std::net::IpAddr>,
) {
info!("Agent connected: {} ({}) from {:?}", agent_name, agent_id, client_ip);
info!(
"Agent connected: {} ({}) from {:?}",
agent_name, agent_id, client_ip
);
let (mut ws_sender, mut ws_receiver) = socket.split();
@@ -270,7 +328,9 @@ async fn handle_agent_connection(
// Register the agent and get channels
// Persistent agents (no support code) keep their session when disconnected
let is_persistent = support_code.is_none();
let (session_id, frame_tx, mut input_rx) = sessions.register_agent(agent_id.clone(), agent_name.clone(), is_persistent).await;
let (session_id, frame_tx, mut input_rx) = sessions
.register_agent(agent_id.clone(), agent_name.clone(), is_persistent)
.await;
info!("Session created: {} (agent in idle mode)", session_id);
@@ -285,15 +345,20 @@ async fn handle_agent_connection(
machine.id,
support_code.is_some(),
support_code.as_deref(),
).await;
)
.await;
// Log session started event
let _ = db::events::log_event(
db.pool(),
session_id,
db::events::EventTypes::SESSION_STARTED,
None, None, None, client_ip,
).await;
None,
None,
None,
client_ip,
)
.await;
Some(machine.id)
}
@@ -309,7 +374,9 @@ async fn handle_agent_connection(
// If a support code was provided, mark it as connected
if let Some(ref code) = support_code {
info!("Linking support code {} to session {}", code, session_id);
support_codes.mark_connected(code, Some(agent_name.clone()), Some(agent_id.clone())).await;
support_codes
.mark_connected(code, Some(agent_name.clone()), Some(agent_id.clone()))
.await;
support_codes.link_session(code, session_id).await;
// Database: update support code
@@ -320,7 +387,8 @@ async fn handle_agent_connection(
Some(session_id),
Some(&agent_name),
Some(&agent_id),
).await;
)
.await;
}
}
@@ -333,7 +401,11 @@ async fn handle_agent_connection(
let input_forward = tokio::spawn(async move {
while let Some(input_data) = input_rx.recv().await {
let mut sender = ws_sender_input.lock().await;
if sender.send(Message::Binary(input_data.into())).await.is_err() {
if sender
.send(Message::Binary(input_data.into()))
.await
.is_err()
{
break;
}
}
@@ -406,22 +478,29 @@ async fn handle_agent_connection(
} else {
Some(status.site.clone())
};
sessions_status.update_agent_status(
session_id,
Some(status.os_version.clone()),
status.is_elevated,
status.uptime_secs,
status.display_count,
status.is_streaming,
agent_version.clone(),
organization.clone(),
site.clone(),
status.tags.clone(),
).await;
sessions_status
.update_agent_status(
session_id,
Some(status.os_version.clone()),
status.is_elevated,
status.uptime_secs,
status.display_count,
status.is_streaming,
agent_version.clone(),
organization.clone(),
site.clone(),
status.tags.clone(),
)
.await;
// Update version in database if present
if let (Some(ref db), Some(ref version)) = (&db, &agent_version) {
let _ = crate::db::releases::update_machine_version(db.pool(), &agent_id, version).await;
let _ = crate::db::releases::update_machine_version(
db.pool(),
&agent_id,
version,
)
.await;
}
// Update organization/site/tags in database if present
@@ -432,7 +511,8 @@ async fn handle_agent_connection(
organization.as_deref(),
site.as_deref(),
&status.tags,
).await;
)
.await;
}
info!("Agent status update: {} - streaming={}, uptime={}s, version={:?}, org={:?}, site={:?}",
@@ -489,8 +569,12 @@ async fn handle_agent_connection(
db.pool(),
session_id,
db::events::EventTypes::SESSION_ENDED,
None, None, None, client_ip,
).await;
None,
None,
None,
client_ip,
)
.await;
}
// Mark support code as completed if one was used (unless cancelled)
@@ -532,7 +616,10 @@ async fn handle_viewer_connection(
let viewer_id = Uuid::new_v4().to_string();
// Join the session (this sends StartStream to agent if first viewer)
let (mut frame_rx, input_tx) = match sessions.join_session(session_id, viewer_id.clone(), viewer_name.clone()).await {
let (mut frame_rx, input_tx) = match sessions
.join_session(session_id, viewer_id.clone(), viewer_name.clone())
.await
{
Some(channels) => channels,
None => {
warn!("Session not found: {}", session_id);
@@ -540,7 +627,10 @@ async fn handle_viewer_connection(
}
};
info!("Viewer {} ({}) joined session: {} from {:?}", viewer_name, viewer_id, session_id, client_ip);
info!(
"Viewer {} ({}) joined session: {} from {:?}",
viewer_name, viewer_id, session_id, client_ip
);
// Database: log viewer joined event
if let Some(ref db) = db {
@@ -550,8 +640,10 @@ async fn handle_viewer_connection(
db::events::EventTypes::VIEWER_JOINED,
Some(&viewer_id),
Some(&viewer_name),
None, client_ip,
).await;
None,
client_ip,
)
.await;
}
let (mut ws_sender, mut ws_receiver) = socket.split();
@@ -559,7 +651,11 @@ async fn handle_viewer_connection(
// Task to forward frames from agent to this viewer
let frame_forward = tokio::spawn(async move {
while let Ok(frame_data) = frame_rx.recv().await {
if ws_sender.send(Message::Binary(frame_data.into())).await.is_err() {
if ws_sender
.send(Message::Binary(frame_data.into()))
.await
.is_err()
{
break;
}
}
@@ -577,9 +673,9 @@ async fn handle_viewer_connection(
match proto::Message::decode(data.as_ref()) {
Ok(proto_msg) => {
match &proto_msg.payload {
Some(proto::message::Payload::MouseEvent(_)) |
Some(proto::message::Payload::KeyEvent(_)) |
Some(proto::message::Payload::SpecialKey(_)) => {
Some(proto::message::Payload::MouseEvent(_))
| Some(proto::message::Payload::KeyEvent(_))
| Some(proto::message::Payload::SpecialKey(_)) => {
// Forward input to agent
let _ = input_tx.send(data.to_vec()).await;
}
@@ -597,7 +693,10 @@ async fn handle_viewer_connection(
}
}
Ok(Message::Close(_)) => {
info!("Viewer {} disconnected from session: {}", viewer_id, session_id);
info!(
"Viewer {} disconnected from session: {}",
viewer_id, session_id
);
break;
}
Ok(_) => {}
@@ -610,7 +709,9 @@ async fn handle_viewer_connection(
// Cleanup (this sends StopStream to agent if last viewer)
frame_forward.abort();
sessions_cleanup.leave_session(session_id, &viewer_id_cleanup).await;
sessions_cleanup
.leave_session(session_id, &viewer_id_cleanup)
.await;
// Database: log viewer left event
if let Some(ref db) = db {
@@ -620,8 +721,10 @@ async fn handle_viewer_connection(
db::events::EventTypes::VIEWER_LEFT,
Some(&viewer_id_cleanup),
Some(&viewer_name_cleanup),
None, client_ip,
).await;
None,
client_ip,
)
.await;
}
info!("Viewer {} left session: {}", viewer_id_cleanup, session_id);

View File

@@ -37,20 +37,20 @@ pub struct Session {
pub agent_name: String,
pub started_at: chrono::DateTime<chrono::Utc>,
pub viewer_count: usize,
pub viewers: Vec<ViewerInfo>, // List of connected technicians
pub viewers: Vec<ViewerInfo>, // List of connected technicians
pub is_streaming: bool,
pub is_online: bool, // Whether agent is currently connected
pub is_persistent: bool, // Persistent agent (no support code) vs support session
pub is_online: bool, // Whether agent is currently connected
pub is_persistent: bool, // Persistent agent (no support code) vs support session
pub last_heartbeat: chrono::DateTime<chrono::Utc>,
// Agent status info
pub os_version: Option<String>,
pub is_elevated: bool,
pub uptime_secs: i64,
pub display_count: i32,
pub agent_version: Option<String>, // Agent software version
pub organization: Option<String>, // Company/organization name
pub site: Option<String>, // Site/location name
pub tags: Vec<String>, // Tags for categorization
pub agent_version: Option<String>, // Agent software version
pub organization: Option<String>, // Company/organization name
pub site: Option<String>, // Site/location name
pub tags: Vec<String>, // Tags for categorization
}
/// Channel for sending frames from agent to viewers
@@ -92,7 +92,12 @@ impl SessionManager {
/// Register a new agent and create a session
/// If agent was previously connected (offline session exists), reuse that session
pub async fn register_agent(&self, agent_id: AgentId, agent_name: String, is_persistent: bool) -> (SessionId, FrameSender, InputReceiver) {
pub async fn register_agent(
&self,
agent_id: AgentId,
agent_name: String,
is_persistent: bool,
) -> (SessionId, FrameSender, InputReceiver) {
// Check if this agent already has an offline session (reconnecting)
{
let agents = self.agents.read().await;
@@ -101,7 +106,11 @@ impl SessionManager {
if let Some(session_data) = sessions.get_mut(&existing_session_id) {
if !session_data.info.is_online {
// Reuse existing session - mark as online and create new channels
tracing::info!("Agent {} reconnecting to existing session {}", agent_id, existing_session_id);
tracing::info!(
"Agent {} reconnecting to existing session {}",
agent_id,
existing_session_id
);
let (frame_tx, _) = broadcast::channel(16);
let (input_tx, input_rx) = tokio::sync::mpsc::channel(64);
@@ -230,7 +239,9 @@ impl SessionManager {
let sessions = self.sessions.read().await;
sessions
.iter()
.filter(|(_, data)| data.last_heartbeat_instant.elapsed().as_secs() > HEARTBEAT_TIMEOUT_SECS)
.filter(|(_, data)| {
data.last_heartbeat_instant.elapsed().as_secs() > HEARTBEAT_TIMEOUT_SECS
})
.map(|(id, _)| *id)
.collect()
}
@@ -251,7 +262,12 @@ impl SessionManager {
}
/// Join a session as a viewer, returns channels and sends StartStream to agent
pub async fn join_session(&self, session_id: SessionId, viewer_id: ViewerId, viewer_name: String) -> Option<(FrameReceiver, InputSender)> {
pub async fn join_session(
&self,
session_id: SessionId,
viewer_id: ViewerId,
viewer_name: String,
) -> Option<(FrameReceiver, InputSender)> {
let mut sessions = self.sessions.write().await;
let session_data = sessions.get_mut(&session_id)?;
@@ -274,10 +290,20 @@ impl SessionManager {
// If this is the first viewer, send StartStream to agent
if was_empty {
tracing::info!("Viewer {} ({}) joined session {}, sending StartStream", viewer_name, viewer_id, session_id);
tracing::info!(
"Viewer {} ({}) joined session {}, sending StartStream",
viewer_name,
viewer_id,
session_id
);
Self::send_start_stream_internal(session_data, &viewer_id).await;
} else {
tracing::info!("Viewer {} ({}) joined session {}", viewer_name, viewer_id, session_id);
tracing::info!(
"Viewer {} ({}) joined session {}",
viewer_name,
viewer_id,
session_id
);
}
Some((frame_rx, input_tx))
@@ -312,12 +338,20 @@ impl SessionManager {
// If no more viewers, send StopStream to agent
if session_data.viewers.is_empty() {
tracing::info!("Last viewer {} ({}) left session {}, sending StopStream",
viewer_name.as_deref().unwrap_or("unknown"), viewer_id, session_id);
tracing::info!(
"Last viewer {} ({}) left session {}, sending StopStream",
viewer_name.as_deref().unwrap_or("unknown"),
viewer_id,
session_id
);
Self::send_stop_stream_internal(session_data, viewer_id).await;
} else {
tracing::info!("Viewer {} ({}) left session {}",
viewer_name.as_deref().unwrap_or("unknown"), viewer_id, session_id);
tracing::info!(
"Viewer {} ({}) left session {}",
viewer_name.as_deref().unwrap_or("unknown"),
viewer_id,
session_id
);
}
}
}
@@ -347,8 +381,11 @@ impl SessionManager {
if let Some(session_data) = sessions.get_mut(&session_id) {
if session_data.info.is_persistent {
// Persistent agent - keep session but mark as offline
tracing::info!("Persistent agent {} marked offline (session {} preserved)",
session_data.info.agent_id, session_id);
tracing::info!(
"Persistent agent {} marked offline (session {} preserved)",
session_data.info.agent_id,
session_id
);
session_data.info.is_online = false;
session_data.info.is_streaming = false;
session_data.info.viewer_count = 0;
@@ -410,7 +447,12 @@ impl SessionManager {
/// Send an admin command to an agent (uninstall, restart, etc.)
/// Returns true if the message was sent successfully
pub async fn send_admin_command(&self, session_id: SessionId, command: crate::proto::AdminCommandType, reason: &str) -> bool {
pub async fn send_admin_command(
&self,
session_id: SessionId,
command: crate::proto::AdminCommandType,
reason: &str,
) -> bool {
let sessions = self.sessions.read().await;
if let Some(session_data) = sessions.get(&session_id) {
if !session_data.info.is_online {
@@ -471,7 +513,7 @@ impl SessionManager {
viewer_count: 0,
viewers: Vec::new(),
is_streaming: false,
is_online: false, // Offline until agent reconnects
is_online: false, // Offline until agent reconnects
is_persistent: true,
last_heartbeat: now,
os_version: None,

View File

@@ -3,12 +3,12 @@
//! Handles generation and validation of 6-digit support codes
//! for one-time remote support sessions.
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use chrono::{DateTime, Utc};
use rand::Rng;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use uuid::Uuid;
/// A support session code
@@ -27,10 +27,10 @@ pub struct SupportCode {
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum CodeStatus {
Pending, // Waiting for client to connect
Connected, // Client connected, session active
Completed, // Session ended normally
Cancelled, // Code cancelled by tech
Pending, // Waiting for client to connect
Connected, // Client connected, session active
Completed, // Session ended normally
Cancelled, // Code cancelled by tech
}
/// Request to create a new support code
@@ -69,11 +69,11 @@ impl SupportCodeManager {
async fn generate_unique_code(&self) -> String {
let codes = self.codes.read().await;
let mut rng = rand::thread_rng();
loop {
let code: u32 = rng.gen_range(100000..999999);
let code_str = code.to_string();
if !codes.contains_key(&code_str) {
return code_str;
}
@@ -84,11 +84,13 @@ impl SupportCodeManager {
pub async fn create_code(&self, request: CreateCodeRequest) -> SupportCode {
let code = self.generate_unique_code().await;
let session_id = Uuid::new_v4();
let support_code = SupportCode {
code: code.clone(),
session_id,
created_by: request.technician_name.unwrap_or_else(|| "Unknown".to_string()),
created_by: request
.technician_name
.unwrap_or_else(|| "Unknown".to_string()),
created_at: Utc::now(),
status: CodeStatus::Pending,
client_name: None,
@@ -108,10 +110,12 @@ impl SupportCodeManager {
/// Validate a code and return session info
pub async fn validate_code(&self, code: &str) -> CodeValidation {
let codes = self.codes.read().await;
match codes.get(code) {
Some(support_code) => {
if support_code.status == CodeStatus::Pending || support_code.status == CodeStatus::Connected {
if support_code.status == CodeStatus::Pending
|| support_code.status == CodeStatus::Connected
{
CodeValidation {
valid: true,
session_id: Some(support_code.session_id.to_string()),
@@ -137,7 +141,12 @@ impl SupportCodeManager {
}
/// Mark a code as connected
pub async fn mark_connected(&self, code: &str, client_name: Option<String>, client_machine: Option<String>) {
pub async fn mark_connected(
&self,
code: &str,
client_name: Option<String>,
client_machine: Option<String>,
) {
let mut codes = self.codes.write().await;
if let Some(support_code) = codes.get_mut(code) {
support_code.status = CodeStatus::Connected;
@@ -180,7 +189,9 @@ impl SupportCodeManager {
pub async fn cancel_code(&self, code: &str) -> bool {
let mut codes = self.codes.write().await;
if let Some(support_code) = codes.get_mut(code) {
if support_code.status == CodeStatus::Pending || support_code.status == CodeStatus::Connected {
if support_code.status == CodeStatus::Pending
|| support_code.status == CodeStatus::Connected
{
support_code.status = CodeStatus::Cancelled;
return true;
}
@@ -191,13 +202,19 @@ impl SupportCodeManager {
/// Check if a code is cancelled
pub async fn is_cancelled(&self, code: &str) -> bool {
let codes = self.codes.read().await;
codes.get(code).map(|c| c.status == CodeStatus::Cancelled).unwrap_or(false)
codes
.get(code)
.map(|c| c.status == CodeStatus::Cancelled)
.unwrap_or(false)
}
/// Check if a code is valid for connection (exists and is pending)
pub async fn is_valid_for_connection(&self, code: &str) -> bool {
let codes = self.codes.read().await;
codes.get(code).map(|c| c.status == CodeStatus::Pending).unwrap_or(false)
codes
.get(code)
.map(|c| c.status == CodeStatus::Pending)
.unwrap_or(false)
}
/// List all codes (for dashboard)
@@ -209,7 +226,8 @@ impl SupportCodeManager {
/// List active codes only
pub async fn list_active_codes(&self) -> Vec<SupportCode> {
let codes = self.codes.read().await;
codes.values()
codes
.values()
.filter(|c| c.status == CodeStatus::Pending || c.status == CodeStatus::Connected)
.cloned()
.collect()

View File

@@ -11,18 +11,29 @@ use anyhow::{anyhow, Result};
pub fn validate_api_key_strength(api_key: &str) -> Result<()> {
// Minimum length check
if api_key.len() < 32 {
return Err(anyhow!("API key must be at least 32 characters long for security"));
return Err(anyhow!(
"API key must be at least 32 characters long for security"
));
}
// Check for common weak keys
let weak_keys = [
"password", "12345", "admin", "test", "api_key",
"secret", "changeme", "default", "guruconnect"
"password",
"12345",
"admin",
"test",
"api_key",
"secret",
"changeme",
"default",
"guruconnect",
];
let lowercase_key = api_key.to_lowercase();
for weak in &weak_keys {
if lowercase_key.contains(weak) {
return Err(anyhow!("API key contains weak/common patterns and is not secure"));
return Err(anyhow!(
"API key contains weak/common patterns and is not secure"
));
}
}
@@ -53,6 +64,9 @@ mod tests {
assert!(validate_api_key_strength("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa").is_err());
// Good key
assert!(validate_api_key_strength("KfPrjjC3J6YMx9q1yjPxZAYkHLM2JdFy1XRxHJ9oPnw0NU3xH074ufHk7fj").is_ok());
assert!(validate_api_key_strength(
"KfPrjjC3J6YMx9q1yjPxZAYkHLM2JdFy1XRxHJ9oPnw0NU3xH074ufHk7fj"
)
.is_ok());
}
}