chore: sync repository to current working state
Some checks failed
Build and Test / Build Server (Linux) (push) Has been cancelled
Build and Test / Build Agent (Windows) (push) Has been cancelled
Build and Test / Security Audit (push) Has been cancelled
Build and Test / Build Summary (push) Has been cancelled
Run Tests / Test Server (push) Has been cancelled
Run Tests / Test Agent (push) Has been cancelled
Run Tests / Code Coverage (push) Has been cancelled
Run Tests / Lint and Format Check (push) Has been cancelled
Some checks failed
Build and Test / Build Server (Linux) (push) Has been cancelled
Build and Test / Build Agent (Windows) (push) Has been cancelled
Build and Test / Security Audit (push) Has been cancelled
Build and Test / Build Summary (push) Has been cancelled
Run Tests / Test Server (push) Has been cancelled
Run Tests / Test Agent (push) Has been cancelled
Run Tests / Code Coverage (push) Has been cancelled
Run Tests / Lint and Format Check (push) Has been cancelled
Brings azcomputerguru/guru-connect up to the authoritative working copy that had been maintained in the claudetools monorepo: Phase 1 security and infrastructure (middleware, metrics, utils, token blacklist, deployment scripts, security audits) plus the native-remote-control integration spec. Preserves the repo .gitignore, .cargo, and server/static/downloads. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
//! Authentication API endpoints
|
||||
|
||||
use axum::{
|
||||
extract::State,
|
||||
extract::{State, Request},
|
||||
http::StatusCode,
|
||||
Json,
|
||||
};
|
||||
|
||||
191
server/src/api/auth_logout.rs
Normal file
191
server/src/api/auth_logout.rs
Normal file
@@ -0,0 +1,191 @@
|
||||
//! Logout and token revocation endpoints
|
||||
|
||||
use axum::{
|
||||
extract::{Request, State, Path},
|
||||
http::{StatusCode, HeaderMap},
|
||||
Json,
|
||||
};
|
||||
use uuid::Uuid;
|
||||
use serde::Serialize;
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::auth::AuthenticatedUser;
|
||||
use crate::AppState;
|
||||
|
||||
use super::auth::ErrorResponse;
|
||||
|
||||
/// Extract JWT token from Authorization header
|
||||
fn extract_token_from_headers(headers: &HeaderMap) -> Result<String, (StatusCode, Json<ErrorResponse>)> {
|
||||
let auth_header = headers
|
||||
.get("Authorization")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.ok_or_else(|| {
|
||||
(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(ErrorResponse {
|
||||
error: "Missing Authorization header".to_string(),
|
||||
}),
|
||||
)
|
||||
})?;
|
||||
|
||||
let token = auth_header
|
||||
.strip_prefix("Bearer ")
|
||||
.ok_or_else(|| {
|
||||
(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(ErrorResponse {
|
||||
error: "Invalid Authorization format".to_string(),
|
||||
}),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(token.to_string())
|
||||
}
|
||||
|
||||
/// Logout response
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct LogoutResponse {
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
/// POST /api/auth/logout - Revoke current token (logout)
|
||||
///
|
||||
/// Adds the user's current JWT token to the blacklist, effectively logging them out.
|
||||
/// The token will no longer be valid for any requests.
|
||||
pub async fn logout(
|
||||
State(state): State<AppState>,
|
||||
user: AuthenticatedUser,
|
||||
request: Request,
|
||||
) -> Result<Json<LogoutResponse>, (StatusCode, Json<ErrorResponse>)> {
|
||||
// Extract token from headers
|
||||
let token = extract_token_from_headers(request.headers())?;
|
||||
|
||||
// Add token to blacklist
|
||||
state.token_blacklist.revoke(&token).await;
|
||||
|
||||
info!("User {} logged out (token revoked)", user.username);
|
||||
|
||||
Ok(Json(LogoutResponse {
|
||||
message: "Logged out successfully".to_string(),
|
||||
}))
|
||||
}
|
||||
|
||||
/// POST /api/auth/revoke-token - Revoke own token (same as logout)
|
||||
///
|
||||
/// Alias for logout endpoint for consistency with revocation terminology.
|
||||
pub async fn revoke_own_token(
|
||||
State(state): State<AppState>,
|
||||
user: AuthenticatedUser,
|
||||
request: Request,
|
||||
) -> Result<Json<LogoutResponse>, (StatusCode, Json<ErrorResponse>)> {
|
||||
logout(State(state), user, request).await
|
||||
}
|
||||
|
||||
/// Revoke user request
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
pub struct RevokeUserRequest {
|
||||
pub user_id: Uuid,
|
||||
}
|
||||
|
||||
/// POST /api/auth/admin/revoke-user - Admin endpoint to revoke all tokens for a user
|
||||
///
|
||||
/// WARNING: This currently only revokes the admin's own token as a demonstration.
|
||||
/// Full implementation would require:
|
||||
/// 1. Session tracking table to store active JWT tokens
|
||||
/// 2. Query to find all tokens for the target user
|
||||
/// 3. Add all found tokens to blacklist
|
||||
///
|
||||
/// For MVP, we're implementing the foundation but not the full user tracking.
|
||||
pub async fn revoke_user_tokens(
|
||||
State(state): State<AppState>,
|
||||
admin: AuthenticatedUser,
|
||||
Json(req): Json<RevokeUserRequest>,
|
||||
) -> Result<Json<LogoutResponse>, (StatusCode, Json<ErrorResponse>)> {
|
||||
// Verify admin permission
|
||||
if !admin.is_admin() {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(ErrorResponse {
|
||||
error: "Admin access required".to_string(),
|
||||
}),
|
||||
));
|
||||
}
|
||||
|
||||
warn!(
|
||||
"Admin {} attempted to revoke tokens for user {} - NOT IMPLEMENTED (requires session tracking)",
|
||||
admin.username, req.user_id
|
||||
);
|
||||
|
||||
// TODO: Implement session tracking
|
||||
// 1. Query active_sessions table for all tokens belonging to user_id
|
||||
// 2. Add each token to blacklist
|
||||
// 3. Delete session records from database
|
||||
|
||||
Err((
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
Json(ErrorResponse {
|
||||
error: "User token revocation not yet implemented - requires session tracking table".to_string(),
|
||||
}),
|
||||
))
|
||||
}
|
||||
|
||||
/// GET /api/auth/blacklist/stats - Get blacklist statistics (admin only)
|
||||
///
|
||||
/// Returns information about the current token blacklist for monitoring.
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct BlacklistStatsResponse {
|
||||
pub revoked_tokens_count: usize,
|
||||
}
|
||||
|
||||
pub async fn get_blacklist_stats(
|
||||
State(state): State<AppState>,
|
||||
admin: AuthenticatedUser,
|
||||
) -> Result<Json<BlacklistStatsResponse>, (StatusCode, Json<ErrorResponse>)> {
|
||||
if !admin.is_admin() {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(ErrorResponse {
|
||||
error: "Admin access required".to_string(),
|
||||
}),
|
||||
));
|
||||
}
|
||||
|
||||
let count = state.token_blacklist.len().await;
|
||||
|
||||
Ok(Json(BlacklistStatsResponse {
|
||||
revoked_tokens_count: count,
|
||||
}))
|
||||
}
|
||||
|
||||
/// POST /api/auth/blacklist/cleanup - Clean up expired tokens from blacklist (admin only)
|
||||
///
|
||||
/// Removes expired tokens from the blacklist to prevent memory buildup.
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct CleanupResponse {
|
||||
pub removed_count: usize,
|
||||
pub remaining_count: usize,
|
||||
}
|
||||
|
||||
pub async fn cleanup_blacklist(
|
||||
State(state): State<AppState>,
|
||||
admin: AuthenticatedUser,
|
||||
) -> Result<Json<CleanupResponse>, (StatusCode, Json<ErrorResponse>)> {
|
||||
if !admin.is_admin() {
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
Json(ErrorResponse {
|
||||
error: "Admin access required".to_string(),
|
||||
}),
|
||||
));
|
||||
}
|
||||
|
||||
let removed = state.token_blacklist.cleanup_expired(&state.jwt_config).await;
|
||||
let remaining = state.token_blacklist.len().await;
|
||||
|
||||
info!("Admin {} cleaned up blacklist: {} tokens removed, {} remaining", admin.username, removed, remaining);
|
||||
|
||||
Ok(Json(CleanupResponse {
|
||||
removed_count: removed,
|
||||
remaining_count: remaining,
|
||||
}))
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
//! REST API endpoints
|
||||
|
||||
pub mod auth;
|
||||
pub mod auth_logout;
|
||||
pub mod users;
|
||||
pub mod releases;
|
||||
pub mod downloads;
|
||||
|
||||
@@ -88,26 +88,37 @@ impl JwtConfig {
|
||||
}
|
||||
|
||||
/// Validate and decode a JWT token
|
||||
///
|
||||
/// SEC-13: Explicitly enforces token expiration
|
||||
/// - Validates signature against secret
|
||||
/// - Checks exp claim (expiration time)
|
||||
/// - Checks iat claim (issued at time)
|
||||
/// - Rejects expired tokens
|
||||
pub fn validate_token(&self, token: &str) -> Result<Claims> {
|
||||
// SEC-13: Explicit validation configuration
|
||||
let mut validation = Validation::default();
|
||||
validation.validate_exp = true; // Enforce expiration check
|
||||
validation.validate_nbf = false; // Not using "not before" claim
|
||||
validation.leeway = 0; // No clock skew tolerance
|
||||
|
||||
let token_data = decode::<Claims>(
|
||||
token,
|
||||
&DecodingKey::from_secret(self.secret.as_bytes()),
|
||||
&Validation::default(),
|
||||
&validation,
|
||||
)
|
||||
.map_err(|e| anyhow!("Invalid token: {}", e))?;
|
||||
|
||||
// Additional check: Ensure token hasn't expired (redundant but explicit)
|
||||
let now = Utc::now().timestamp();
|
||||
if token_data.claims.exp < now {
|
||||
return Err(anyhow!("Token has expired"));
|
||||
}
|
||||
|
||||
Ok(token_data.claims)
|
||||
}
|
||||
}
|
||||
|
||||
/// Default JWT secret if not configured (NOT for production!)
|
||||
pub fn default_jwt_secret() -> String {
|
||||
// In production, this should come from environment variable
|
||||
std::env::var("JWT_SECRET").unwrap_or_else(|_| {
|
||||
tracing::warn!("JWT_SECRET not set, using default (INSECURE!)");
|
||||
"guruconnect-dev-secret-change-me-in-production".to_string()
|
||||
})
|
||||
}
|
||||
// Removed insecure default_jwt_secret() function - JWT_SECRET must be set via environment variable
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
@@ -5,9 +5,11 @@
|
||||
|
||||
pub mod jwt;
|
||||
pub mod password;
|
||||
pub mod token_blacklist;
|
||||
|
||||
pub use jwt::{Claims, JwtConfig};
|
||||
pub use password::{hash_password, verify_password, generate_random_password};
|
||||
pub use token_blacklist::TokenBlacklist;
|
||||
|
||||
use axum::{
|
||||
extract::FromRequestParts,
|
||||
@@ -98,6 +100,17 @@ where
|
||||
.get::<Arc<JwtConfig>>()
|
||||
.ok_or((StatusCode::INTERNAL_SERVER_ERROR, "Auth not configured"))?;
|
||||
|
||||
// Get token blacklist from extensions (set by middleware)
|
||||
let blacklist = parts
|
||||
.extensions
|
||||
.get::<Arc<TokenBlacklist>>()
|
||||
.ok_or((StatusCode::INTERNAL_SERVER_ERROR, "Auth not configured"))?;
|
||||
|
||||
// Check if token is revoked
|
||||
if blacklist.is_revoked(token).await {
|
||||
return Err((StatusCode::UNAUTHORIZED, "Token has been revoked"));
|
||||
}
|
||||
|
||||
// Validate token
|
||||
let claims = jwt_config
|
||||
.validate_token(token)
|
||||
|
||||
@@ -1,15 +1,32 @@
|
||||
//! Password hashing using Argon2id
|
||||
//!
|
||||
//! SEC-9: Explicitly uses Argon2id (hybrid variant) for password hashing
|
||||
//! Argon2id provides resistance against both side-channel and GPU attacks
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use argon2::{
|
||||
password_hash::{rand_core::OsRng, PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
|
||||
Argon2,
|
||||
Argon2, Algorithm, Version, Params,
|
||||
};
|
||||
|
||||
/// Hash a password using Argon2id
|
||||
///
|
||||
/// SEC-9: Explicitly configured to use Argon2id variant
|
||||
/// - Algorithm: Argon2id (hybrid of Argon2i and Argon2d)
|
||||
/// - Version: 0x13 (latest version)
|
||||
/// - Memory: 19456 KiB (default)
|
||||
/// - Iterations: 2 (default)
|
||||
/// - Parallelism: 1 (default)
|
||||
pub fn hash_password(password: &str) -> Result<String> {
|
||||
let salt = SaltString::generate(&mut OsRng);
|
||||
let argon2 = Argon2::default();
|
||||
|
||||
// Explicitly use Argon2id (Algorithm::Argon2id)
|
||||
let argon2 = Argon2::new(
|
||||
Algorithm::Argon2id, // SEC-9: Explicit Argon2id variant
|
||||
Version::V0x13, // Latest version
|
||||
Params::default(), // Default params (19456 KiB, 2 iterations, 1 parallelism)
|
||||
);
|
||||
|
||||
let hash = argon2
|
||||
.hash_password(password.as_bytes(), &salt)
|
||||
.map_err(|e| anyhow!("Failed to hash password: {}", e))?;
|
||||
@@ -20,6 +37,8 @@ pub fn hash_password(password: &str) -> Result<String> {
|
||||
pub fn verify_password(password: &str, hash: &str) -> Result<bool> {
|
||||
let parsed_hash = PasswordHash::new(hash)
|
||||
.map_err(|e| anyhow!("Invalid password hash format: {}", e))?;
|
||||
|
||||
// Argon2::default() uses Argon2id, but we verify against the hash's embedded algorithm
|
||||
let argon2 = Argon2::default();
|
||||
Ok(argon2.verify_password(password.as_bytes(), &parsed_hash).is_ok())
|
||||
}
|
||||
|
||||
164
server/src/auth/token_blacklist.rs
Normal file
164
server/src/auth/token_blacklist.rs
Normal file
@@ -0,0 +1,164 @@
|
||||
//! Token blacklist for JWT revocation
|
||||
//!
|
||||
//! Provides in-memory token blacklist for immediate revocation of JWTs.
|
||||
//! Tokens are automatically cleaned up after expiration.
|
||||
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::{info, debug};
|
||||
|
||||
/// Token blacklist for revocation
|
||||
///
|
||||
/// Maintains a set of revoked token signatures. When a token is revoked
|
||||
/// (e.g., on logout or admin action), it's added to this blacklist and
|
||||
/// all subsequent validation attempts will fail.
|
||||
#[derive(Clone)]
|
||||
pub struct TokenBlacklist {
|
||||
/// Set of revoked token strings
|
||||
tokens: Arc<RwLock<HashSet<String>>>,
|
||||
}
|
||||
|
||||
impl TokenBlacklist {
|
||||
/// Create a new empty blacklist
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
tokens: Arc::new(RwLock::new(HashSet::new())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a token to the blacklist (revoke it)
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `token` - The full JWT token string to revoke
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust
|
||||
/// blacklist.revoke("eyJ...").await;
|
||||
/// ```
|
||||
pub async fn revoke(&self, token: &str) {
|
||||
let mut tokens = self.tokens.write().await;
|
||||
let was_new = tokens.insert(token.to_string());
|
||||
|
||||
if was_new {
|
||||
debug!("Token revoked and added to blacklist (length: {})", token.len());
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a token has been revoked
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `token` - The JWT token string to check
|
||||
///
|
||||
/// # Returns
|
||||
/// `true` if the token is in the blacklist (revoked), `false` otherwise
|
||||
pub async fn is_revoked(&self, token: &str) -> bool {
|
||||
let tokens = self.tokens.read().await;
|
||||
tokens.contains(token)
|
||||
}
|
||||
|
||||
/// Get the number of tokens currently in the blacklist
|
||||
pub async fn len(&self) -> usize {
|
||||
let tokens = self.tokens.read().await;
|
||||
tokens.len()
|
||||
}
|
||||
|
||||
/// Check if the blacklist is empty
|
||||
pub async fn is_empty(&self) -> bool {
|
||||
let tokens = self.tokens.read().await;
|
||||
tokens.is_empty()
|
||||
}
|
||||
|
||||
/// Remove expired tokens from blacklist (cleanup)
|
||||
///
|
||||
/// This should be called periodically to prevent memory buildup.
|
||||
/// Tokens that can no longer be validated (expired) are removed.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `jwt_config` - JWT configuration for validating token expiration
|
||||
///
|
||||
/// # Returns
|
||||
/// Number of tokens removed from blacklist
|
||||
pub async fn cleanup_expired(&self, jwt_config: &super::JwtConfig) -> usize {
|
||||
let mut tokens = self.tokens.write().await;
|
||||
let original_len = tokens.len();
|
||||
|
||||
// Remove tokens that fail validation (expired)
|
||||
tokens.retain(|token| {
|
||||
// If token is expired (validation fails), remove it from blacklist
|
||||
jwt_config.validate_token(token).is_ok()
|
||||
});
|
||||
|
||||
let removed = original_len - tokens.len();
|
||||
|
||||
if removed > 0 {
|
||||
info!("Cleaned {} expired tokens from blacklist ({} remaining)", removed, tokens.len());
|
||||
}
|
||||
|
||||
removed
|
||||
}
|
||||
|
||||
/// Clear all tokens from the blacklist
|
||||
///
|
||||
/// WARNING: This removes all revoked tokens. Use with caution.
|
||||
pub async fn clear(&self) {
|
||||
let mut tokens = self.tokens.write().await;
|
||||
let count = tokens.len();
|
||||
tokens.clear();
|
||||
info!("Cleared {} tokens from blacklist", count);
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for TokenBlacklist {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_revoke_and_check() {
|
||||
let blacklist = TokenBlacklist::new();
|
||||
let token = "test.token.here";
|
||||
|
||||
assert!(!blacklist.is_revoked(token).await);
|
||||
|
||||
blacklist.revoke(token).await;
|
||||
|
||||
assert!(blacklist.is_revoked(token).await);
|
||||
assert_eq!(blacklist.len().await, 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multiple_revocations() {
|
||||
let blacklist = TokenBlacklist::new();
|
||||
|
||||
blacklist.revoke("token1").await;
|
||||
blacklist.revoke("token2").await;
|
||||
blacklist.revoke("token3").await;
|
||||
|
||||
assert_eq!(blacklist.len().await, 3);
|
||||
assert!(blacklist.is_revoked("token1").await);
|
||||
assert!(blacklist.is_revoked("token2").await);
|
||||
assert!(blacklist.is_revoked("token3").await);
|
||||
assert!(!blacklist.is_revoked("token4").await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_clear() {
|
||||
let blacklist = TokenBlacklist::new();
|
||||
|
||||
blacklist.revoke("token1").await;
|
||||
blacklist.revoke("token2").await;
|
||||
|
||||
assert_eq!(blacklist.len().await, 2);
|
||||
|
||||
blacklist.clear().await;
|
||||
|
||||
assert_eq!(blacklist.len().await, 0);
|
||||
assert!(blacklist.is_empty().await);
|
||||
}
|
||||
}
|
||||
@@ -31,6 +31,13 @@ impl EventTypes {
|
||||
pub const VIEWER_LEFT: &'static str = "viewer_left";
|
||||
pub const STREAMING_STARTED: &'static str = "streaming_started";
|
||||
pub const STREAMING_STOPPED: &'static str = "streaming_stopped";
|
||||
|
||||
// Failed connection events (security audit trail)
|
||||
pub const CONNECTION_REJECTED_NO_AUTH: &'static str = "connection_rejected_no_auth";
|
||||
pub const CONNECTION_REJECTED_INVALID_CODE: &'static str = "connection_rejected_invalid_code";
|
||||
pub const CONNECTION_REJECTED_EXPIRED_CODE: &'static str = "connection_rejected_expired_code";
|
||||
pub const CONNECTION_REJECTED_INVALID_API_KEY: &'static str = "connection_rejected_invalid_api_key";
|
||||
pub const CONNECTION_REJECTED_CANCELLED_CODE: &'static str = "connection_rejected_cancelled_code";
|
||||
}
|
||||
|
||||
/// Log a session event
|
||||
|
||||
@@ -10,6 +10,9 @@ mod auth;
|
||||
mod api;
|
||||
mod db;
|
||||
mod support_codes;
|
||||
mod middleware;
|
||||
mod utils;
|
||||
mod metrics;
|
||||
|
||||
pub mod proto {
|
||||
include!(concat!(env!("OUT_DIR"), "/guruconnect.rs"));
|
||||
@@ -22,11 +25,12 @@ use axum::{
|
||||
extract::{Path, State, Json, Query, Request},
|
||||
response::{Html, IntoResponse},
|
||||
http::StatusCode,
|
||||
middleware::{self, Next},
|
||||
middleware::{self as axum_middleware, Next},
|
||||
};
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use tower_http::cors::{Any, CorsLayer};
|
||||
use tower_http::cors::{Any, CorsLayer, AllowOrigin};
|
||||
use axum::http::{Method, HeaderValue};
|
||||
use tower_http::trace::TraceLayer;
|
||||
use tower_http::services::ServeDir;
|
||||
use tracing::{info, Level};
|
||||
@@ -34,7 +38,9 @@ use tracing_subscriber::FmtSubscriber;
|
||||
use serde::Deserialize;
|
||||
|
||||
use support_codes::{SupportCodeManager, CreateCodeRequest, SupportCode, CodeValidation};
|
||||
use auth::{JwtConfig, hash_password, generate_random_password, AuthenticatedUser};
|
||||
use auth::{JwtConfig, TokenBlacklist, hash_password, generate_random_password, AuthenticatedUser};
|
||||
use metrics::SharedMetrics;
|
||||
use prometheus_client::registry::Registry;
|
||||
|
||||
/// Application state
|
||||
#[derive(Clone)]
|
||||
@@ -43,17 +49,25 @@ pub struct AppState {
|
||||
support_codes: SupportCodeManager,
|
||||
db: Option<db::Database>,
|
||||
pub jwt_config: Arc<JwtConfig>,
|
||||
pub token_blacklist: TokenBlacklist,
|
||||
/// Optional API key for persistent agents (env: AGENT_API_KEY)
|
||||
pub agent_api_key: Option<String>,
|
||||
/// Prometheus metrics
|
||||
pub metrics: SharedMetrics,
|
||||
/// Prometheus registry (for /metrics endpoint)
|
||||
pub registry: Arc<std::sync::Mutex<Registry>>,
|
||||
/// Server start time
|
||||
pub start_time: Arc<std::time::Instant>,
|
||||
}
|
||||
|
||||
/// Middleware to inject JWT config into request extensions
|
||||
/// Middleware to inject JWT config and token blacklist into request extensions
|
||||
async fn auth_layer(
|
||||
State(state): State<AppState>,
|
||||
mut request: Request,
|
||||
next: Next,
|
||||
) -> impl IntoResponse {
|
||||
request.extensions_mut().insert(state.jwt_config.clone());
|
||||
request.extensions_mut().insert(Arc::new(state.token_blacklist.clone()));
|
||||
next.run(request).await
|
||||
}
|
||||
|
||||
@@ -74,11 +88,14 @@ async fn main() -> Result<()> {
|
||||
let listen_addr = std::env::var("LISTEN_ADDR").unwrap_or_else(|_| "0.0.0.0:3002".to_string());
|
||||
info!("Loaded configuration, listening on {}", listen_addr);
|
||||
|
||||
// JWT configuration
|
||||
let jwt_secret = std::env::var("JWT_SECRET").unwrap_or_else(|_| {
|
||||
tracing::warn!("JWT_SECRET not set, using default (INSECURE for production!)");
|
||||
"guruconnect-dev-secret-change-me-in-production".to_string()
|
||||
});
|
||||
// JWT configuration - REQUIRED for security
|
||||
let jwt_secret = std::env::var("JWT_SECRET")
|
||||
.expect("JWT_SECRET environment variable must be set! Generate one with: openssl rand -base64 64");
|
||||
|
||||
if jwt_secret.len() < 32 {
|
||||
panic!("JWT_SECRET must be at least 32 characters long for security!");
|
||||
}
|
||||
|
||||
let jwt_expiry_hours = std::env::var("JWT_EXPIRY_HOURS")
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
@@ -126,12 +143,35 @@ async fn main() -> Result<()> {
|
||||
];
|
||||
let _ = db::set_user_permissions(db.pool(), user.id, &perms).await;
|
||||
|
||||
info!("========================================");
|
||||
info!(" INITIAL ADMIN USER CREATED");
|
||||
info!(" Username: admin");
|
||||
info!(" Password: {}", password);
|
||||
info!(" (Change this password after first login!)");
|
||||
info!("========================================");
|
||||
// SEC-6: Write credentials to secure file instead of logging
|
||||
let creds_file = ".admin-credentials";
|
||||
match std::fs::write(creds_file, format!("Username: admin\nPassword: {}\n\nWARNING: Change this password immediately after first login!\nDelete this file after copying the password.\n", password)) {
|
||||
Ok(_) => {
|
||||
// Set restrictive permissions (Unix only)
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
let _ = std::fs::set_permissions(creds_file, std::fs::Permissions::from_mode(0o600));
|
||||
}
|
||||
|
||||
info!("========================================");
|
||||
info!(" INITIAL ADMIN USER CREATED");
|
||||
info!(" Credentials written to: {}", creds_file);
|
||||
info!(" (Read file, change password, then delete file)");
|
||||
info!("========================================");
|
||||
}
|
||||
Err(e) => {
|
||||
// Fallback to logging if file write fails (but warn about security)
|
||||
tracing::warn!("Could not write credentials file: {}", e);
|
||||
info!("========================================");
|
||||
info!(" INITIAL ADMIN USER CREATED");
|
||||
info!(" Username: admin");
|
||||
info!(" Password: {}", password);
|
||||
info!(" WARNING: Password logged due to file write failure!");
|
||||
info!(" (Change this password immediately!)");
|
||||
info!("========================================");
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to create initial admin user: {}", e);
|
||||
@@ -167,32 +207,63 @@ async fn main() -> Result<()> {
|
||||
|
||||
// Agent API key for persistent agents (optional)
|
||||
let agent_api_key = std::env::var("AGENT_API_KEY").ok();
|
||||
if agent_api_key.is_some() {
|
||||
info!("AGENT_API_KEY configured for persistent agents");
|
||||
if let Some(ref key) = agent_api_key {
|
||||
// Validate API key strength for security
|
||||
utils::validation::validate_api_key_strength(key)?;
|
||||
info!("AGENT_API_KEY configured for persistent agents (validated)");
|
||||
} else {
|
||||
info!("No AGENT_API_KEY set - persistent agents will need JWT token or support code");
|
||||
}
|
||||
|
||||
// Initialize Prometheus metrics
|
||||
let mut registry = Registry::default();
|
||||
let metrics = Arc::new(metrics::Metrics::new(&mut registry));
|
||||
let registry = Arc::new(std::sync::Mutex::new(registry));
|
||||
let start_time = Arc::new(std::time::Instant::now());
|
||||
|
||||
// Spawn background task to update uptime metric
|
||||
let metrics_for_uptime = metrics.clone();
|
||||
let start_time_for_uptime = start_time.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(std::time::Duration::from_secs(10));
|
||||
loop {
|
||||
interval.tick().await;
|
||||
let uptime = start_time_for_uptime.elapsed().as_secs() as i64;
|
||||
metrics_for_uptime.update_uptime(uptime);
|
||||
}
|
||||
});
|
||||
|
||||
// Create application state
|
||||
let token_blacklist = TokenBlacklist::new();
|
||||
|
||||
let state = AppState {
|
||||
sessions,
|
||||
support_codes: SupportCodeManager::new(),
|
||||
db: database,
|
||||
jwt_config,
|
||||
token_blacklist,
|
||||
agent_api_key,
|
||||
metrics,
|
||||
registry,
|
||||
start_time,
|
||||
};
|
||||
|
||||
// Build router
|
||||
let app = Router::new()
|
||||
// Health check (no auth required)
|
||||
.route("/health", get(health))
|
||||
// Prometheus metrics (no auth required - for monitoring)
|
||||
.route("/metrics", get(prometheus_metrics))
|
||||
|
||||
// Auth endpoints (no auth required for login)
|
||||
// Auth endpoints (TODO: Add rate limiting - see SEC2_RATE_LIMITING_TODO.md)
|
||||
.route("/api/auth/login", post(api::auth::login))
|
||||
|
||||
// Auth endpoints (auth required)
|
||||
.route("/api/auth/me", get(api::auth::get_me))
|
||||
.route("/api/auth/change-password", post(api::auth::change_password))
|
||||
.route("/api/auth/me", get(api::auth::get_me))
|
||||
.route("/api/auth/logout", post(api::auth_logout::logout))
|
||||
.route("/api/auth/revoke-token", post(api::auth_logout::revoke_own_token))
|
||||
.route("/api/auth/admin/revoke-user", post(api::auth_logout::revoke_user_tokens))
|
||||
.route("/api/auth/blacklist/stats", get(api::auth_logout::get_blacklist_stats))
|
||||
.route("/api/auth/blacklist/cleanup", post(api::auth_logout::cleanup_blacklist))
|
||||
|
||||
// User management (admin only)
|
||||
.route("/api/users", get(api::users::list_users))
|
||||
@@ -203,7 +274,7 @@ async fn main() -> Result<()> {
|
||||
.route("/api/users/:id/permissions", put(api::users::set_permissions))
|
||||
.route("/api/users/:id/clients", put(api::users::set_client_access))
|
||||
|
||||
// Portal API - Support codes
|
||||
// Portal API - Support codes (TODO: Add rate limiting)
|
||||
.route("/api/codes", post(create_code))
|
||||
.route("/api/codes", get(list_codes))
|
||||
.route("/api/codes/:code/validate", get(validate_code))
|
||||
@@ -245,19 +316,35 @@ async fn main() -> Result<()> {
|
||||
|
||||
// State and middleware
|
||||
.with_state(state.clone())
|
||||
.layer(middleware::from_fn_with_state(state, auth_layer))
|
||||
.layer(axum_middleware::from_fn_with_state(state, auth_layer))
|
||||
|
||||
// Serve static files for portal (fallback)
|
||||
.fallback_service(ServeDir::new("static").append_index_html_on_directories(true))
|
||||
|
||||
// Middleware
|
||||
.layer(axum_middleware::from_fn(middleware::add_security_headers)) // SEC-7 & SEC-12
|
||||
.layer(TraceLayer::new_for_http())
|
||||
.layer(
|
||||
CorsLayer::new()
|
||||
.allow_origin(Any)
|
||||
.allow_methods(Any)
|
||||
.allow_headers(Any),
|
||||
);
|
||||
// SEC-11: Restricted CORS configuration
|
||||
.layer({
|
||||
let cors = CorsLayer::new()
|
||||
// Allow requests from the production domain and localhost (for development)
|
||||
.allow_origin([
|
||||
"https://connect.azcomputerguru.com".parse::<HeaderValue>().unwrap(),
|
||||
"http://localhost:3002".parse::<HeaderValue>().unwrap(),
|
||||
"http://127.0.0.1:3002".parse::<HeaderValue>().unwrap(),
|
||||
])
|
||||
// Allow only necessary HTTP methods
|
||||
.allow_methods([Method::GET, Method::POST, Method::PUT, Method::DELETE, Method::OPTIONS])
|
||||
// Allow common headers needed for API requests
|
||||
.allow_headers([
|
||||
axum::http::header::AUTHORIZATION,
|
||||
axum::http::header::CONTENT_TYPE,
|
||||
axum::http::header::ACCEPT,
|
||||
])
|
||||
// Allow credentials (cookies, auth headers)
|
||||
.allow_credentials(true);
|
||||
cors
|
||||
});
|
||||
|
||||
// Start server
|
||||
let addr: SocketAddr = listen_addr.parse()?;
|
||||
@@ -265,7 +352,11 @@ async fn main() -> Result<()> {
|
||||
|
||||
info!("Server listening on {}", addr);
|
||||
|
||||
axum::serve(listener, app).await?;
|
||||
// Use into_make_service_with_connect_info to enable IP address extraction
|
||||
axum::serve(
|
||||
listener,
|
||||
app.into_make_service_with_connect_info::<SocketAddr>()
|
||||
).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -274,6 +365,18 @@ async fn health() -> &'static str {
|
||||
"OK"
|
||||
}
|
||||
|
||||
/// Prometheus metrics endpoint
|
||||
async fn prometheus_metrics(
|
||||
State(state): State<AppState>,
|
||||
) -> String {
|
||||
use prometheus_client::encoding::text::encode;
|
||||
|
||||
let registry = state.registry.lock().unwrap();
|
||||
let mut buffer = String::new();
|
||||
encode(&mut buffer, ®istry).unwrap();
|
||||
buffer
|
||||
}
|
||||
|
||||
// Support code API handlers
|
||||
|
||||
async fn create_code(
|
||||
|
||||
290
server/src/metrics/mod.rs
Normal file
290
server/src/metrics/mod.rs
Normal file
@@ -0,0 +1,290 @@
|
||||
//! Prometheus metrics for GuruConnect server
|
||||
//!
|
||||
//! This module exposes metrics for monitoring server health, performance, and usage.
|
||||
//! Metrics are exposed at the `/metrics` endpoint in Prometheus format.
|
||||
|
||||
use prometheus_client::encoding::EncodeLabelSet;
|
||||
use prometheus_client::metrics::counter::Counter;
|
||||
use prometheus_client::metrics::family::Family;
|
||||
use prometheus_client::metrics::gauge::Gauge;
|
||||
use prometheus_client::metrics::histogram::{exponential_buckets, Histogram};
|
||||
use prometheus_client::registry::Registry;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Metrics labels for HTTP requests
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelSet)]
|
||||
pub struct RequestLabels {
|
||||
pub method: String,
|
||||
pub path: String,
|
||||
pub status: u16,
|
||||
}
|
||||
|
||||
/// Metrics labels for session events
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelSet)]
|
||||
pub struct SessionLabels {
|
||||
pub status: String, // created, closed, failed, expired
|
||||
}
|
||||
|
||||
/// Metrics labels for connection events
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelSet)]
|
||||
pub struct ConnectionLabels {
|
||||
pub conn_type: String, // agent, viewer, dashboard
|
||||
}
|
||||
|
||||
/// Metrics labels for error tracking
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelSet)]
|
||||
pub struct ErrorLabels {
|
||||
pub error_type: String, // auth, database, websocket, protocol, internal
|
||||
}
|
||||
|
||||
/// Metrics labels for database operations
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelSet)]
|
||||
pub struct DatabaseLabels {
|
||||
pub operation: String, // select, insert, update, delete
|
||||
pub status: String, // success, error
|
||||
}
|
||||
|
||||
/// GuruConnect server metrics
|
||||
#[derive(Clone)]
|
||||
pub struct Metrics {
|
||||
// Request metrics
|
||||
pub requests_total: Family<RequestLabels, Counter>,
|
||||
pub request_duration_seconds: Family<RequestLabels, Histogram>,
|
||||
|
||||
// Session metrics
|
||||
pub sessions_total: Family<SessionLabels, Counter>,
|
||||
pub active_sessions: Gauge,
|
||||
pub session_duration_seconds: Histogram,
|
||||
|
||||
// Connection metrics
|
||||
pub connections_total: Family<ConnectionLabels, Counter>,
|
||||
pub active_connections: Family<ConnectionLabels, Gauge>,
|
||||
|
||||
// Error metrics
|
||||
pub errors_total: Family<ErrorLabels, Counter>,
|
||||
|
||||
// Database metrics
|
||||
pub db_operations_total: Family<DatabaseLabels, Counter>,
|
||||
pub db_query_duration_seconds: Family<DatabaseLabels, Histogram>,
|
||||
|
||||
// System metrics
|
||||
pub uptime_seconds: Gauge,
|
||||
}
|
||||
|
||||
impl Metrics {
|
||||
/// Create a new metrics instance and register all metrics
|
||||
pub fn new(registry: &mut Registry) -> Self {
|
||||
// Request metrics
|
||||
let requests_total = Family::<RequestLabels, Counter>::default();
|
||||
registry.register(
|
||||
"guruconnect_requests_total",
|
||||
"Total number of HTTP requests",
|
||||
requests_total.clone(),
|
||||
);
|
||||
|
||||
let request_duration_seconds = Family::<RequestLabels, Histogram>::new_with_constructor(|| {
|
||||
Histogram::new(exponential_buckets(0.001, 2.0, 10)) // 1ms to ~1s
|
||||
});
|
||||
registry.register(
|
||||
"guruconnect_request_duration_seconds",
|
||||
"HTTP request duration in seconds",
|
||||
request_duration_seconds.clone(),
|
||||
);
|
||||
|
||||
// Session metrics
|
||||
let sessions_total = Family::<SessionLabels, Counter>::default();
|
||||
registry.register(
|
||||
"guruconnect_sessions_total",
|
||||
"Total number of sessions",
|
||||
sessions_total.clone(),
|
||||
);
|
||||
|
||||
let active_sessions = Gauge::default();
|
||||
registry.register(
|
||||
"guruconnect_active_sessions",
|
||||
"Number of currently active sessions",
|
||||
active_sessions.clone(),
|
||||
);
|
||||
|
||||
let session_duration_seconds = Histogram::new(exponential_buckets(1.0, 2.0, 15)); // 1s to ~9 hours
|
||||
registry.register(
|
||||
"guruconnect_session_duration_seconds",
|
||||
"Session duration in seconds",
|
||||
session_duration_seconds.clone(),
|
||||
);
|
||||
|
||||
// Connection metrics
|
||||
let connections_total = Family::<ConnectionLabels, Counter>::default();
|
||||
registry.register(
|
||||
"guruconnect_connections_total",
|
||||
"Total number of WebSocket connections",
|
||||
connections_total.clone(),
|
||||
);
|
||||
|
||||
let active_connections = Family::<ConnectionLabels, Gauge>::default();
|
||||
registry.register(
|
||||
"guruconnect_active_connections",
|
||||
"Number of active WebSocket connections by type",
|
||||
active_connections.clone(),
|
||||
);
|
||||
|
||||
// Error metrics
|
||||
let errors_total = Family::<ErrorLabels, Counter>::default();
|
||||
registry.register(
|
||||
"guruconnect_errors_total",
|
||||
"Total number of errors by type",
|
||||
errors_total.clone(),
|
||||
);
|
||||
|
||||
// Database metrics
|
||||
let db_operations_total = Family::<DatabaseLabels, Counter>::default();
|
||||
registry.register(
|
||||
"guruconnect_db_operations_total",
|
||||
"Total number of database operations",
|
||||
db_operations_total.clone(),
|
||||
);
|
||||
|
||||
let db_query_duration_seconds = Family::<DatabaseLabels, Histogram>::new_with_constructor(|| {
|
||||
Histogram::new(exponential_buckets(0.0001, 2.0, 12)) // 0.1ms to ~400ms
|
||||
});
|
||||
registry.register(
|
||||
"guruconnect_db_query_duration_seconds",
|
||||
"Database query duration in seconds",
|
||||
db_query_duration_seconds.clone(),
|
||||
);
|
||||
|
||||
// System metrics
|
||||
let uptime_seconds = Gauge::default();
|
||||
registry.register(
|
||||
"guruconnect_uptime_seconds",
|
||||
"Server uptime in seconds",
|
||||
uptime_seconds.clone(),
|
||||
);
|
||||
|
||||
Self {
|
||||
requests_total,
|
||||
request_duration_seconds,
|
||||
sessions_total,
|
||||
active_sessions,
|
||||
session_duration_seconds,
|
||||
connections_total,
|
||||
active_connections,
|
||||
errors_total,
|
||||
db_operations_total,
|
||||
db_query_duration_seconds,
|
||||
uptime_seconds,
|
||||
}
|
||||
}
|
||||
|
||||
/// Increment request counter
|
||||
pub fn record_request(&self, method: &str, path: &str, status: u16) {
|
||||
self.requests_total
|
||||
.get_or_create(&RequestLabels {
|
||||
method: method.to_string(),
|
||||
path: path.to_string(),
|
||||
status,
|
||||
})
|
||||
.inc();
|
||||
}
|
||||
|
||||
/// Record request duration
|
||||
pub fn record_request_duration(&self, method: &str, path: &str, status: u16, duration_secs: f64) {
|
||||
self.request_duration_seconds
|
||||
.get_or_create(&RequestLabels {
|
||||
method: method.to_string(),
|
||||
path: path.to_string(),
|
||||
status,
|
||||
})
|
||||
.observe(duration_secs);
|
||||
}
|
||||
|
||||
/// Record session creation
|
||||
pub fn record_session_created(&self) {
|
||||
self.sessions_total
|
||||
.get_or_create(&SessionLabels {
|
||||
status: "created".to_string(),
|
||||
})
|
||||
.inc();
|
||||
self.active_sessions.inc();
|
||||
}
|
||||
|
||||
/// Record session closure
|
||||
pub fn record_session_closed(&self) {
|
||||
self.sessions_total
|
||||
.get_or_create(&SessionLabels {
|
||||
status: "closed".to_string(),
|
||||
})
|
||||
.inc();
|
||||
self.active_sessions.dec();
|
||||
}
|
||||
|
||||
/// Record session failure
|
||||
pub fn record_session_failed(&self) {
|
||||
self.sessions_total
|
||||
.get_or_create(&SessionLabels {
|
||||
status: "failed".to_string(),
|
||||
})
|
||||
.inc();
|
||||
}
|
||||
|
||||
/// Record session duration
|
||||
pub fn record_session_duration(&self, duration_secs: f64) {
|
||||
self.session_duration_seconds.observe(duration_secs);
|
||||
}
|
||||
|
||||
/// Record connection created
|
||||
pub fn record_connection_created(&self, conn_type: &str) {
|
||||
self.connections_total
|
||||
.get_or_create(&ConnectionLabels {
|
||||
conn_type: conn_type.to_string(),
|
||||
})
|
||||
.inc();
|
||||
self.active_connections
|
||||
.get_or_create(&ConnectionLabels {
|
||||
conn_type: conn_type.to_string(),
|
||||
})
|
||||
.inc();
|
||||
}
|
||||
|
||||
/// Record connection closed
|
||||
pub fn record_connection_closed(&self, conn_type: &str) {
|
||||
self.active_connections
|
||||
.get_or_create(&ConnectionLabels {
|
||||
conn_type: conn_type.to_string(),
|
||||
})
|
||||
.dec();
|
||||
}
|
||||
|
||||
/// Record an error
|
||||
pub fn record_error(&self, error_type: &str) {
|
||||
self.errors_total
|
||||
.get_or_create(&ErrorLabels {
|
||||
error_type: error_type.to_string(),
|
||||
})
|
||||
.inc();
|
||||
}
|
||||
|
||||
/// Record database operation
|
||||
pub fn record_db_operation(&self, operation: &str, status: &str, duration_secs: f64) {
|
||||
let labels = DatabaseLabels {
|
||||
operation: operation.to_string(),
|
||||
status: status.to_string(),
|
||||
};
|
||||
|
||||
self.db_operations_total
|
||||
.get_or_create(&labels.clone())
|
||||
.inc();
|
||||
|
||||
self.db_query_duration_seconds
|
||||
.get_or_create(&labels)
|
||||
.observe(duration_secs);
|
||||
}
|
||||
|
||||
/// Update uptime metric
|
||||
pub fn update_uptime(&self, uptime_secs: i64) {
|
||||
self.uptime_seconds.set(uptime_secs);
|
||||
}
|
||||
}
|
||||
|
||||
/// Global metrics state wrapped in Arc for sharing across threads
|
||||
pub type SharedMetrics = Arc<Metrics>;
|
||||
16
server/src/middleware/mod.rs
Normal file
16
server/src/middleware/mod.rs
Normal file
@@ -0,0 +1,16 @@
|
||||
//! Middleware modules
|
||||
|
||||
// DISABLED: Rate limiting not yet functional due to type signature issues
|
||||
// See SEC2_RATE_LIMITING_TODO.md
|
||||
// pub mod rate_limit;
|
||||
//
|
||||
// pub use rate_limit::{
|
||||
// auth_rate_limiter,
|
||||
// support_code_rate_limiter,
|
||||
// api_rate_limiter,
|
||||
// };
|
||||
|
||||
// SEC-7 & SEC-12: Security headers middleware
|
||||
pub mod security_headers;
|
||||
|
||||
pub use security_headers::add_security_headers;
|
||||
59
server/src/middleware/rate_limit.rs
Normal file
59
server/src/middleware/rate_limit.rs
Normal file
@@ -0,0 +1,59 @@
|
||||
//! Rate limiting middleware using tower-governor
|
||||
//!
|
||||
//! Protects against brute force attacks on authentication endpoints.
|
||||
|
||||
use tower_governor::{
|
||||
governor::GovernorConfigBuilder,
|
||||
GovernorLayer,
|
||||
};
|
||||
|
||||
/// Create rate limiting layer for authentication endpoints
|
||||
///
|
||||
/// Allows 5 requests per minute per IP address
|
||||
pub fn auth_rate_limiter() -> impl tower::Layer<tower::service_fn::ServiceFn<impl Fn(axum::http::Request<axum::body::Body>) -> std::future::Future<Output = Result<axum::http::Response<axum::body::Body>, std::convert::Infallible>>>> {
|
||||
let governor_conf = Box::new(
|
||||
GovernorConfigBuilder::default()
|
||||
.per_millisecond(60000 / 5) // 5 requests per minute
|
||||
.burst_size(5)
|
||||
.finish()
|
||||
.unwrap()
|
||||
);
|
||||
|
||||
GovernorLayer {
|
||||
config: Box::leak(governor_conf),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create rate limiting layer for support code validation
|
||||
///
|
||||
/// Allows 10 requests per minute per IP address
|
||||
pub fn support_code_rate_limiter() -> impl tower::Layer<tower::service_fn::ServiceFn<impl Fn(axum::http::Request<axum::body::Body>) -> std::future::Future<Output = Result<axum::http::Response<axum::body::Body>, std::convert::Infallible>>>> {
|
||||
let governor_conf = Box::new(
|
||||
GovernorConfigBuilder::default()
|
||||
.per_millisecond(60000 / 10) // 10 requests per minute
|
||||
.burst_size(10)
|
||||
.finish()
|
||||
.unwrap()
|
||||
);
|
||||
|
||||
GovernorLayer {
|
||||
config: Box::leak(governor_conf),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create rate limiting layer for API endpoints
|
||||
///
|
||||
/// Allows 60 requests per minute per IP address
|
||||
pub fn api_rate_limiter() -> impl tower::Layer<tower::service_fn::ServiceFn<impl Fn(axum::http::Request<axum::body::Body>) -> std::future::Future<Output = Result<axum::http::Response<axum::body::Body>, std::convert::Infallible>>>> {
|
||||
let governor_conf = Box::new(
|
||||
GovernorConfigBuilder::default()
|
||||
.per_millisecond(1000) // 1 request per second
|
||||
.burst_size(60)
|
||||
.finish()
|
||||
.unwrap()
|
||||
);
|
||||
|
||||
GovernorLayer {
|
||||
config: Box::leak(governor_conf),
|
||||
}
|
||||
}
|
||||
75
server/src/middleware/security_headers.rs
Normal file
75
server/src/middleware/security_headers.rs
Normal file
@@ -0,0 +1,75 @@
|
||||
//! Security headers middleware
|
||||
//!
|
||||
//! SEC-7: XSS Prevention via Content-Security-Policy
|
||||
//! SEC-12: Additional security headers
|
||||
|
||||
use axum::{
|
||||
extract::Request,
|
||||
middleware::Next,
|
||||
response::Response,
|
||||
};
|
||||
|
||||
/// Add security headers to all responses
|
||||
pub async fn add_security_headers(
|
||||
request: Request,
|
||||
next: Next,
|
||||
) -> Response {
|
||||
let mut response = next.run(request).await;
|
||||
let headers = response.headers_mut();
|
||||
|
||||
// SEC-7: Content Security Policy (XSS Prevention)
|
||||
// This CSP allows inline scripts/styles (needed for dashboard) but blocks external resources
|
||||
headers.insert(
|
||||
"Content-Security-Policy",
|
||||
"default-src 'self'; \
|
||||
script-src 'self' 'unsafe-inline'; \
|
||||
style-src 'self' 'unsafe-inline'; \
|
||||
img-src 'self' data:; \
|
||||
font-src 'self'; \
|
||||
connect-src 'self' ws: wss:; \
|
||||
frame-ancestors 'none'; \
|
||||
base-uri 'self'; \
|
||||
form-action 'self'"
|
||||
.parse()
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
// SEC-12: X-Frame-Options (Clickjacking protection)
|
||||
headers.insert(
|
||||
"X-Frame-Options",
|
||||
"DENY".parse().unwrap(),
|
||||
);
|
||||
|
||||
// SEC-12: X-Content-Type-Options (MIME sniffing protection)
|
||||
headers.insert(
|
||||
"X-Content-Type-Options",
|
||||
"nosniff".parse().unwrap(),
|
||||
);
|
||||
|
||||
// SEC-12: X-XSS-Protection (Legacy XSS filter - deprecated but still useful)
|
||||
headers.insert(
|
||||
"X-XSS-Protection",
|
||||
"1; mode=block".parse().unwrap(),
|
||||
);
|
||||
|
||||
// SEC-12: Referrer-Policy (Control referrer information)
|
||||
headers.insert(
|
||||
"Referrer-Policy",
|
||||
"strict-origin-when-cross-origin".parse().unwrap(),
|
||||
);
|
||||
|
||||
// SEC-12: Permissions-Policy (Feature policy)
|
||||
headers.insert(
|
||||
"Permissions-Policy",
|
||||
"geolocation=(), microphone=(), camera=()".parse().unwrap(),
|
||||
);
|
||||
|
||||
// SEC-10: Strict-Transport-Security (HSTS - only when using HTTPS)
|
||||
// Uncomment when HTTPS is enabled:
|
||||
// headers.insert(
|
||||
// "Strict-Transport-Security",
|
||||
// "max-age=31536000; includeSubDomains; preload".parse().unwrap(),
|
||||
// );
|
||||
|
||||
response
|
||||
}
|
||||
@@ -6,11 +6,12 @@
|
||||
use axum::{
|
||||
extract::{
|
||||
ws::{Message, WebSocket, WebSocketUpgrade},
|
||||
Query, State,
|
||||
Query, State, ConnectInfo,
|
||||
},
|
||||
response::IntoResponse,
|
||||
http::StatusCode,
|
||||
};
|
||||
use std::net::SocketAddr;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use prost::Message as ProstMessage;
|
||||
use serde::Deserialize;
|
||||
@@ -54,19 +55,38 @@ fn default_viewer_name() -> String {
|
||||
pub async fn agent_ws_handler(
|
||||
ws: WebSocketUpgrade,
|
||||
State(state): State<AppState>,
|
||||
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||||
Query(params): Query<AgentParams>,
|
||||
) -> Result<impl IntoResponse, StatusCode> {
|
||||
let agent_id = params.agent_id.clone();
|
||||
let agent_name = params.hostname.clone().or(params.agent_name.clone()).unwrap_or_else(|| agent_id.clone());
|
||||
let support_code = params.support_code.clone();
|
||||
let api_key = params.api_key.clone();
|
||||
let client_ip = addr.ip();
|
||||
|
||||
// SECURITY: Agent must provide either a support code OR an API key
|
||||
// Support code = ad-hoc support session (technician generated code)
|
||||
// API key = persistent managed agent
|
||||
|
||||
if support_code.is_none() && api_key.is_none() {
|
||||
warn!("Agent connection rejected: {} - no support code or API key", agent_id);
|
||||
warn!("Agent connection rejected: {} from {} - no support code or API key", agent_id, client_ip);
|
||||
|
||||
// Log failed connection attempt to database
|
||||
if let Some(ref db) = state.db {
|
||||
let _ = db::events::log_event(
|
||||
db.pool(),
|
||||
Uuid::new_v4(), // Temporary UUID for failed attempt
|
||||
db::events::EventTypes::CONNECTION_REJECTED_NO_AUTH,
|
||||
None,
|
||||
Some(&agent_id),
|
||||
Some(serde_json::json!({
|
||||
"reason": "no_auth_method",
|
||||
"agent_id": agent_id
|
||||
})),
|
||||
Some(client_ip),
|
||||
).await;
|
||||
}
|
||||
|
||||
return Err(StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
@@ -75,15 +95,57 @@ pub async fn agent_ws_handler(
|
||||
// Check if it's a valid, pending support code
|
||||
let code_info = state.support_codes.get_status(code).await;
|
||||
if code_info.is_none() {
|
||||
warn!("Agent connection rejected: {} - invalid support code {}", agent_id, code);
|
||||
warn!("Agent connection rejected: {} from {} - invalid support code {}", agent_id, client_ip, code);
|
||||
|
||||
// Log failed connection attempt
|
||||
if let Some(ref db) = state.db {
|
||||
let _ = db::events::log_event(
|
||||
db.pool(),
|
||||
Uuid::new_v4(),
|
||||
db::events::EventTypes::CONNECTION_REJECTED_INVALID_CODE,
|
||||
None,
|
||||
Some(&agent_id),
|
||||
Some(serde_json::json!({
|
||||
"reason": "invalid_code",
|
||||
"support_code": code,
|
||||
"agent_id": agent_id
|
||||
})),
|
||||
Some(client_ip),
|
||||
).await;
|
||||
}
|
||||
|
||||
return Err(StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
let status = code_info.unwrap();
|
||||
if status != "pending" && status != "connected" {
|
||||
warn!("Agent connection rejected: {} - support code {} has status {}", agent_id, code, status);
|
||||
warn!("Agent connection rejected: {} from {} - support code {} has status {}", agent_id, client_ip, code, status);
|
||||
|
||||
// Log failed connection attempt (expired/cancelled code)
|
||||
if let Some(ref db) = state.db {
|
||||
let event_type = if status == "cancelled" {
|
||||
db::events::EventTypes::CONNECTION_REJECTED_CANCELLED_CODE
|
||||
} else {
|
||||
db::events::EventTypes::CONNECTION_REJECTED_EXPIRED_CODE
|
||||
};
|
||||
|
||||
let _ = db::events::log_event(
|
||||
db.pool(),
|
||||
Uuid::new_v4(),
|
||||
event_type,
|
||||
None,
|
||||
Some(&agent_id),
|
||||
Some(serde_json::json!({
|
||||
"reason": status,
|
||||
"support_code": code,
|
||||
"agent_id": agent_id
|
||||
})),
|
||||
Some(client_ip),
|
||||
).await;
|
||||
}
|
||||
|
||||
return Err(StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
info!("Agent {} authenticated via support code {}", agent_id, code);
|
||||
info!("Agent {} from {} authenticated via support code {}", agent_id, client_ip, code);
|
||||
}
|
||||
|
||||
// Validate API key if provided (for persistent agents)
|
||||
@@ -91,17 +153,34 @@ pub async fn agent_ws_handler(
|
||||
// For now, we'll accept API keys that match the JWT secret or a configured agent key
|
||||
// In production, this should validate against a database of registered agents
|
||||
if !validate_agent_api_key(&state, key).await {
|
||||
warn!("Agent connection rejected: {} - invalid API key", agent_id);
|
||||
warn!("Agent connection rejected: {} from {} - invalid API key", agent_id, client_ip);
|
||||
|
||||
// Log failed connection attempt
|
||||
if let Some(ref db) = state.db {
|
||||
let _ = db::events::log_event(
|
||||
db.pool(),
|
||||
Uuid::new_v4(),
|
||||
db::events::EventTypes::CONNECTION_REJECTED_INVALID_API_KEY,
|
||||
None,
|
||||
Some(&agent_id),
|
||||
Some(serde_json::json!({
|
||||
"reason": "invalid_api_key",
|
||||
"agent_id": agent_id
|
||||
})),
|
||||
Some(client_ip),
|
||||
).await;
|
||||
}
|
||||
|
||||
return Err(StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
info!("Agent {} authenticated via API key", agent_id);
|
||||
info!("Agent {} from {} authenticated via API key", agent_id, client_ip);
|
||||
}
|
||||
|
||||
let sessions = state.sessions.clone();
|
||||
let support_codes = state.support_codes.clone();
|
||||
let db = state.db.clone();
|
||||
|
||||
Ok(ws.on_upgrade(move |socket| handle_agent_connection(socket, sessions, support_codes, db, agent_id, agent_name, support_code)))
|
||||
Ok(ws.on_upgrade(move |socket| handle_agent_connection(socket, sessions, support_codes, db, agent_id, agent_name, support_code, Some(client_ip))))
|
||||
}
|
||||
|
||||
/// Validate an agent API key
|
||||
@@ -126,28 +205,31 @@ async fn validate_agent_api_key(state: &AppState, api_key: &str) -> bool {
|
||||
pub async fn viewer_ws_handler(
|
||||
ws: WebSocketUpgrade,
|
||||
State(state): State<AppState>,
|
||||
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||||
Query(params): Query<ViewerParams>,
|
||||
) -> Result<impl IntoResponse, StatusCode> {
|
||||
let client_ip = addr.ip();
|
||||
|
||||
// Require JWT token for viewers
|
||||
let token = params.token.ok_or_else(|| {
|
||||
warn!("Viewer connection rejected: missing token");
|
||||
warn!("Viewer connection rejected from {}: missing token", client_ip);
|
||||
StatusCode::UNAUTHORIZED
|
||||
})?;
|
||||
|
||||
// Validate the token
|
||||
let claims = state.jwt_config.validate_token(&token).map_err(|e| {
|
||||
warn!("Viewer connection rejected: invalid token: {}", e);
|
||||
warn!("Viewer connection rejected from {}: invalid token: {}", client_ip, e);
|
||||
StatusCode::UNAUTHORIZED
|
||||
})?;
|
||||
|
||||
info!("Viewer {} authenticated via JWT", claims.username);
|
||||
info!("Viewer {} authenticated via JWT from {}", claims.username, client_ip);
|
||||
|
||||
let session_id = params.session_id;
|
||||
let viewer_name = params.viewer_name;
|
||||
let sessions = state.sessions.clone();
|
||||
let db = state.db.clone();
|
||||
|
||||
Ok(ws.on_upgrade(move |socket| handle_viewer_connection(socket, sessions, db, session_id, viewer_name)))
|
||||
Ok(ws.on_upgrade(move |socket| handle_viewer_connection(socket, sessions, db, session_id, viewer_name, Some(client_ip))))
|
||||
}
|
||||
|
||||
/// Handle an agent WebSocket connection
|
||||
@@ -159,8 +241,9 @@ async fn handle_agent_connection(
|
||||
agent_id: String,
|
||||
agent_name: String,
|
||||
support_code: Option<String>,
|
||||
client_ip: Option<std::net::IpAddr>,
|
||||
) {
|
||||
info!("Agent connected: {} ({})", agent_name, agent_id);
|
||||
info!("Agent connected: {} ({}) from {:?}", agent_name, agent_id, client_ip);
|
||||
|
||||
let (mut ws_sender, mut ws_receiver) = socket.split();
|
||||
|
||||
@@ -209,7 +292,7 @@ async fn handle_agent_connection(
|
||||
db.pool(),
|
||||
session_id,
|
||||
db::events::EventTypes::SESSION_STARTED,
|
||||
None, None, None, None,
|
||||
None, None, None, client_ip,
|
||||
).await;
|
||||
|
||||
Some(machine.id)
|
||||
@@ -406,7 +489,7 @@ async fn handle_agent_connection(
|
||||
db.pool(),
|
||||
session_id,
|
||||
db::events::EventTypes::SESSION_ENDED,
|
||||
None, None, None, None,
|
||||
None, None, None, client_ip,
|
||||
).await;
|
||||
}
|
||||
|
||||
@@ -434,6 +517,7 @@ async fn handle_viewer_connection(
|
||||
db: Option<Database>,
|
||||
session_id_str: String,
|
||||
viewer_name: String,
|
||||
client_ip: Option<std::net::IpAddr>,
|
||||
) {
|
||||
// Parse session ID
|
||||
let session_id = match uuid::Uuid::parse_str(&session_id_str) {
|
||||
@@ -456,7 +540,7 @@ async fn handle_viewer_connection(
|
||||
}
|
||||
};
|
||||
|
||||
info!("Viewer {} ({}) joined session: {}", viewer_name, viewer_id, session_id);
|
||||
info!("Viewer {} ({}) joined session: {} from {:?}", viewer_name, viewer_id, session_id, client_ip);
|
||||
|
||||
// Database: log viewer joined event
|
||||
if let Some(ref db) = db {
|
||||
@@ -466,7 +550,7 @@ async fn handle_viewer_connection(
|
||||
db::events::EventTypes::VIEWER_JOINED,
|
||||
Some(&viewer_id),
|
||||
Some(&viewer_name),
|
||||
None, None,
|
||||
None, client_ip,
|
||||
).await;
|
||||
}
|
||||
|
||||
@@ -536,7 +620,7 @@ async fn handle_viewer_connection(
|
||||
db::events::EventTypes::VIEWER_LEFT,
|
||||
Some(&viewer_id_cleanup),
|
||||
Some(&viewer_name_cleanup),
|
||||
None, None,
|
||||
None, client_ip,
|
||||
).await;
|
||||
}
|
||||
|
||||
|
||||
22
server/src/utils/ip_extract.rs
Normal file
22
server/src/utils/ip_extract.rs
Normal file
@@ -0,0 +1,22 @@
|
||||
//! IP address extraction from WebSocket connections
|
||||
|
||||
use axum::extract::ConnectInfo;
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
|
||||
/// Extract IP address from Axum ConnectInfo
|
||||
///
|
||||
/// # Example
|
||||
/// ```rust
|
||||
/// pub async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) {
|
||||
/// let ip = extract_ip(&addr);
|
||||
/// // Use ip for logging
|
||||
/// }
|
||||
/// ```
|
||||
pub fn extract_ip(addr: &SocketAddr) -> IpAddr {
|
||||
addr.ip()
|
||||
}
|
||||
|
||||
/// Extract IP address as string
|
||||
pub fn extract_ip_string(addr: &SocketAddr) -> String {
|
||||
addr.ip().to_string()
|
||||
}
|
||||
4
server/src/utils/mod.rs
Normal file
4
server/src/utils/mod.rs
Normal file
@@ -0,0 +1,4 @@
|
||||
//! Utility functions
|
||||
|
||||
pub mod ip_extract;
|
||||
pub mod validation;
|
||||
58
server/src/utils/validation.rs
Normal file
58
server/src/utils/validation.rs
Normal file
@@ -0,0 +1,58 @@
|
||||
//! Input validation and security checks
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
|
||||
/// Validate API key meets minimum security requirements
|
||||
///
|
||||
/// Requirements:
|
||||
/// - Minimum 32 characters
|
||||
/// - Not a common weak key
|
||||
/// - Sufficient character diversity
|
||||
pub fn validate_api_key_strength(api_key: &str) -> Result<()> {
|
||||
// Minimum length check
|
||||
if api_key.len() < 32 {
|
||||
return Err(anyhow!("API key must be at least 32 characters long for security"));
|
||||
}
|
||||
|
||||
// Check for common weak keys
|
||||
let weak_keys = [
|
||||
"password", "12345", "admin", "test", "api_key",
|
||||
"secret", "changeme", "default", "guruconnect"
|
||||
];
|
||||
let lowercase_key = api_key.to_lowercase();
|
||||
for weak in &weak_keys {
|
||||
if lowercase_key.contains(weak) {
|
||||
return Err(anyhow!("API key contains weak/common patterns and is not secure"));
|
||||
}
|
||||
}
|
||||
|
||||
// Check for sufficient entropy (basic diversity check)
|
||||
let unique_chars: std::collections::HashSet<char> = api_key.chars().collect();
|
||||
if unique_chars.len() < 10 {
|
||||
return Err(anyhow!(
|
||||
"API key has insufficient character diversity (need at least 10 unique characters)"
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_validate_api_key_strength() {
|
||||
// Too short
|
||||
assert!(validate_api_key_strength("short").is_err());
|
||||
|
||||
// Weak pattern
|
||||
assert!(validate_api_key_strength("password_but_long_enough_now_123456789").is_err());
|
||||
|
||||
// Low entropy
|
||||
assert!(validate_api_key_strength("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa").is_err());
|
||||
|
||||
// Good key
|
||||
assert!(validate_api_key_strength("KfPrjjC3J6YMx9q1yjPxZAYkHLM2JdFy1XRxHJ9oPnw0NU3xH074ufHk7fj").is_ok());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user