//! In-memory rate limiting middleware (Task 4). //! //! The v1 implementation leaned on `tower_governor` and never compiled (the //! `GovernorLayer` generic signature it tried to name does not exist in the //! crate's public API). Rather than fight those generics, this is a small, //! self-contained per-IP limiter behind a `Mutex>` — no new //! dependency, easy to reason about, and trivially unit-testable. //! //! Two protections are provided, both keyed by client IP: //! //! 1. **Fixed-window rate limit** ([`RateLimiter`]): at most `max_requests` //! per `window`. Over the cap → `429 Too Many Requests` (standard error //! envelope). Wired onto `POST /api/auth/login`, //! `POST /api/auth/change-password`, and the support-code validate route. //! //! 2. **Consecutive-failure lockout** ([`FailureLockout`]): after //! `max_failures` consecutive *failed* attempts from one IP, that IP is //! locked out for `cooldown`. A single success resets the counter. This is //! the brute-force defense for the support-code space (the code-validate //! route reports per-attempt success/failure into it). //! //! Client IP is taken from axum's [`ConnectInfo`] (the same source //! the relay uses for `client_ip`). `X-Forwarded-For` is intentionally NOT //! trusted here: the server terminates behind a known reverse proxy (NPM), and //! honoring a client-settable header would let an attacker trivially rotate the //! limiter key. If/when per-proxy XFF handling is needed it must be gated on a //! trusted-proxy allowlist — tracked as a follow-up, not done blindly here. //! //! Memory is bounded by pruning expired entries opportunistically on each call //! and capping the map size; an unbounded attacker rotating source IPs cannot //! grow the maps without bound. use std::collections::HashMap; use std::net::IpAddr; use std::sync::Mutex; use std::time::{Duration, Instant}; use axum::{ extract::{ConnectInfo, State}, http::StatusCode, response::{IntoResponse, Response}, Json, }; use serde::Serialize; use std::net::SocketAddr; // ============================================================================ // Tunables (named constants — no magic numbers at the call sites) // ============================================================================ /// Login: window length for the fixed-window counter. pub const LOGIN_WINDOW: Duration = Duration::from_secs(60); /// Login: max requests per window per IP. Comfortable for a human retyping a /// password, hostile for a credential-stuffing loop. pub const LOGIN_MAX_PER_WINDOW: u32 = 8; /// Change-password: window length. pub const CHANGE_PASSWORD_WINDOW: Duration = Duration::from_secs(60); /// Change-password: max requests per window per IP. Tighter than login — a user /// changes their password rarely, and this endpoint already requires a valid /// session, so a low cap costs nothing legitimate. pub const CHANGE_PASSWORD_MAX_PER_WINDOW: u32 = 5; /// Support-code validate: window length. pub const CODE_VALIDATE_WINDOW: Duration = Duration::from_secs(60); /// Support-code validate: max requests per window per IP. Tight, because the /// code space is small relative to a password and this route is the brute-force /// surface for it. pub const CODE_VALIDATE_MAX_PER_WINDOW: u32 = 15; /// Support-code validate: consecutive failed validations from one IP that trip /// the lockout. pub const CODE_VALIDATE_MAX_FAILURES: u32 = 10; /// Support-code validate: how long an IP stays locked out once tripped. pub const CODE_VALIDATE_LOCKOUT: Duration = Duration::from_secs(15 * 60); /// Hard cap on the number of distinct IPs tracked by any single limiter map. /// Prevents an IP-rotating attacker from growing memory without bound. When the /// cap is hit, the oldest-windowed entries are pruned. Generous for a real MSP /// fleet; an attacker hitting it is already being throttled per-IP. const MAX_TRACKED_IPS: usize = 100_000; // ============================================================================ // Fixed-window rate limiter // ============================================================================ /// One IP's fixed-window counter. #[derive(Debug, Clone, Copy)] struct Window { /// When the current window started. started: Instant, /// Requests counted in the current window. count: u32, } /// Per-IP fixed-window rate limiter. Cheap, lock-guarded, self-pruning. /// /// Cloneable: the inner state is shared via `Arc`-less `&'static`/`State` /// ownership in this app, but we keep an explicit `Arc` so it can live in /// `AppState` and be cloned with it. #[derive(Clone)] pub struct RateLimiter { inner: std::sync::Arc>>, max_requests: u32, window: Duration, } impl RateLimiter { /// Create a limiter allowing `max_requests` per `window` per IP. pub fn new(max_requests: u32, window: Duration) -> Self { Self { inner: std::sync::Arc::new(Mutex::new(HashMap::new())), max_requests, window, } } /// Record a request from `ip` and report whether it is allowed. /// /// Returns `true` if the request is within the cap (and counts it), `false` /// if the IP is over the cap for the current window. Uses `now` as the clock /// so the window logic is unit-testable without sleeping. fn check_at(&self, ip: IpAddr, now: Instant) -> bool { let mut map = self.inner.lock().unwrap_or_else(|e| e.into_inner()); // Opportunistic prune: if the map has grown large, drop entries whose // window has fully elapsed (they would reset on next touch anyway). if map.len() >= MAX_TRACKED_IPS { let window = self.window; map.retain(|_, w| now.duration_since(w.started) < window); } let entry = map.entry(ip).or_insert(Window { started: now, count: 0, }); // Roll the window forward if it has elapsed. if now.duration_since(entry.started) >= self.window { entry.started = now; entry.count = 0; } if entry.count >= self.max_requests { false } else { entry.count += 1; true } } /// Record a request from `ip` (using the real clock) and report whether it /// is allowed. pub fn check(&self, ip: IpAddr) -> bool { self.check_at(ip, Instant::now()) } } // ============================================================================ // Consecutive-failure lockout // ============================================================================ /// One IP's failure-streak state. #[derive(Debug, Clone, Copy)] struct FailState { /// Consecutive failures since the last success / lockout. failures: u32, /// If `Some`, the IP is locked out until this instant. locked_until: Option, /// Last time this entry was touched (for pruning). last_seen: Instant, } /// Per-IP consecutive-failure lockout. After `max_failures` consecutive /// failures the IP is locked out for `cooldown`; a success resets the streak. #[derive(Clone)] pub struct FailureLockout { inner: std::sync::Arc>>, max_failures: u32, cooldown: Duration, } impl FailureLockout { /// Create a lockout that trips after `max_failures` consecutive failures and /// holds for `cooldown`. pub fn new(max_failures: u32, cooldown: Duration) -> Self { Self { inner: std::sync::Arc::new(Mutex::new(HashMap::new())), max_failures, cooldown, } } /// Is `ip` currently locked out? (clock injected for tests) fn is_locked_at(&self, ip: IpAddr, now: Instant) -> bool { let mut map = self.inner.lock().unwrap_or_else(|e| e.into_inner()); match map.get(&ip) { Some(state) => match state.locked_until { Some(until) if now < until => true, Some(_) => { // Lockout elapsed — clear it so the IP gets a fresh start. if let Some(s) = map.get_mut(&ip) { s.locked_until = None; s.failures = 0; } false } None => false, }, None => false, } } /// Is `ip` currently locked out? (real clock) pub fn is_locked(&self, ip: IpAddr) -> bool { self.is_locked_at(ip, Instant::now()) } /// Record a failed attempt from `ip`. Trips the lockout once the streak /// reaches `max_failures`. (clock injected for tests) fn record_failure_at(&self, ip: IpAddr, now: Instant) { let mut map = self.inner.lock().unwrap_or_else(|e| e.into_inner()); if map.len() >= MAX_TRACKED_IPS { let cooldown = self.cooldown; map.retain(|_, s| { s.locked_until.map(|u| now < u).unwrap_or(false) || now.duration_since(s.last_seen) < cooldown }); } let state = map.entry(ip).or_insert(FailState { failures: 0, locked_until: None, last_seen: now, }); state.last_seen = now; state.failures = state.failures.saturating_add(1); if state.failures >= self.max_failures { state.locked_until = Some(now + self.cooldown); } } /// Record a failed attempt from `ip` (real clock). pub fn record_failure(&self, ip: IpAddr) { self.record_failure_at(ip, Instant::now()); } /// Record a successful attempt from `ip`, resetting its failure streak. pub fn record_success(&self, ip: IpAddr) { let mut map = self.inner.lock().unwrap_or_else(|e| e.into_inner()); if let Some(state) = map.get_mut(&ip) { state.failures = 0; state.locked_until = None; state.last_seen = Instant::now(); } } } // ============================================================================ // Shared rate-limit state (lives in AppState) // ============================================================================ /// Bundle of limiters carried in `AppState` and consumed by the middleware. #[derive(Clone)] pub struct RateLimitState { /// `POST /api/auth/login` pub login: RateLimiter, /// `POST /api/auth/change-password` pub change_password: RateLimiter, /// `GET /api/codes/:code/validate` (request-rate cap) pub code_validate: RateLimiter, /// Per-IP lockout on repeated failed code validations (brute-force defense). pub code_validate_lockout: FailureLockout, } impl RateLimitState { pub fn new() -> Self { Self { login: RateLimiter::new(LOGIN_MAX_PER_WINDOW, LOGIN_WINDOW), change_password: RateLimiter::new( CHANGE_PASSWORD_MAX_PER_WINDOW, CHANGE_PASSWORD_WINDOW, ), code_validate: RateLimiter::new(CODE_VALIDATE_MAX_PER_WINDOW, CODE_VALIDATE_WINDOW), code_validate_lockout: FailureLockout::new( CODE_VALIDATE_MAX_FAILURES, CODE_VALIDATE_LOCKOUT, ), } } } impl Default for RateLimitState { fn default() -> Self { Self::new() } } // ============================================================================ // 429 response (standard error envelope) // ============================================================================ #[derive(Debug, Serialize)] struct RateLimitError { detail: String, error_code: String, status_code: u16, } /// Build a `429 Too Many Requests` response using the standard error envelope. fn too_many_requests(detail: &str, error_code: &str) -> Response { ( StatusCode::TOO_MANY_REQUESTS, Json(RateLimitError { detail: detail.to_string(), error_code: error_code.to_string(), status_code: StatusCode::TOO_MANY_REQUESTS.as_u16(), }), ) .into_response() } // ============================================================================ // Axum middleware functions (one per protected route) // ============================================================================ /// Selects which limiter from [`RateLimitState`] a middleware uses. /// /// Each protected route gets its own `from_fn_with_state` middleware pointing at /// the matching limiter; keeping them as distinct functions avoids threading an /// extra "which limiter" parameter through the layer and keeps the wiring in /// `main.rs` self-documenting. /// Rate-limit middleware for `POST /api/auth/login`. pub async fn login_rate_limit( State(state): State, ConnectInfo(addr): ConnectInfo, request: axum::extract::Request, next: axum::middleware::Next, ) -> Response { let ip = addr.ip(); if !state.rate_limits.login.check(ip) { tracing::warn!("Rate limit exceeded on /api/auth/login from {}", ip); return too_many_requests( "Too many login attempts. Please wait a minute and try again.", "RATE_LIMITED", ); } next.run(request).await } /// Rate-limit middleware for `POST /api/auth/change-password`. pub async fn change_password_rate_limit( State(state): State, ConnectInfo(addr): ConnectInfo, request: axum::extract::Request, next: axum::middleware::Next, ) -> Response { let ip = addr.ip(); if !state.rate_limits.change_password.check(ip) { tracing::warn!( "Rate limit exceeded on /api/auth/change-password from {}", ip ); return too_many_requests( "Too many password-change attempts. Please wait a minute and try again.", "RATE_LIMITED", ); } next.run(request).await } /// Rate-limit + brute-force-lockout middleware for the support-code validate /// route (`GET /api/codes/:code/validate`). /// /// Two gates run here: /// 1. If the IP is currently locked out (too many consecutive failed /// validations), reject immediately with 429 — before the handler runs, so /// the code is never even looked up. /// 2. Otherwise apply the per-window request cap. /// /// The success/failure that drives the lockout is reported by the handler /// itself (it knows whether the code was valid), via /// [`RateLimitState::code_validate_lockout`]. pub async fn code_validate_rate_limit( State(state): State, ConnectInfo(addr): ConnectInfo, request: axum::extract::Request, next: axum::middleware::Next, ) -> Response { let ip = addr.ip(); // 1. Brute-force lockout takes precedence. if state.rate_limits.code_validate_lockout.is_locked(ip) { tracing::warn!( "Code-validate request from locked-out IP {} (too many failed attempts)", ip ); return too_many_requests( "Too many invalid codes from this address. Try again later.", "RATE_LIMITED_LOCKOUT", ); } // 2. Per-window request cap. if !state.rate_limits.code_validate.check(ip) { tracing::warn!("Rate limit exceeded on code-validate from {}", ip); return too_many_requests( "Too many code validation attempts. Please wait a minute and try again.", "RATE_LIMITED", ); } next.run(request).await } // ============================================================================ // Tests // ============================================================================ #[cfg(test)] mod tests { use super::*; fn ip(n: u8) -> IpAddr { IpAddr::from([10, 0, 0, n]) } #[test] fn fixed_window_allows_up_to_cap_then_blocks() { let limiter = RateLimiter::new(3, Duration::from_secs(60)); let t0 = Instant::now(); let a = ip(1); assert!(limiter.check_at(a, t0)); // 1 assert!(limiter.check_at(a, t0)); // 2 assert!(limiter.check_at(a, t0)); // 3 assert!(!limiter.check_at(a, t0)); // 4 -> blocked assert!(!limiter.check_at(a, t0)); // still blocked } #[test] fn fixed_window_resets_after_window_elapses() { let limiter = RateLimiter::new(2, Duration::from_secs(60)); let t0 = Instant::now(); let a = ip(2); assert!(limiter.check_at(a, t0)); assert!(limiter.check_at(a, t0)); assert!(!limiter.check_at(a, t0)); // over cap // Advance past the window — counter resets. let t1 = t0 + Duration::from_secs(61); assert!(limiter.check_at(a, t1)); assert!(limiter.check_at(a, t1)); assert!(!limiter.check_at(a, t1)); } #[test] fn fixed_window_is_per_ip() { let limiter = RateLimiter::new(1, Duration::from_secs(60)); let t0 = Instant::now(); assert!(limiter.check_at(ip(3), t0)); assert!(!limiter.check_at(ip(3), t0)); // ip3 over cap assert!(limiter.check_at(ip(4), t0)); // ip4 independent } #[test] fn lockout_trips_after_consecutive_failures() { let lockout = FailureLockout::new(3, Duration::from_secs(600)); let t0 = Instant::now(); let a = ip(5); assert!(!lockout.is_locked_at(a, t0)); lockout.record_failure_at(a, t0); // 1 assert!(!lockout.is_locked_at(a, t0)); lockout.record_failure_at(a, t0); // 2 assert!(!lockout.is_locked_at(a, t0)); lockout.record_failure_at(a, t0); // 3 -> trips assert!(lockout.is_locked_at(a, t0)); } #[test] fn lockout_success_resets_streak() { let lockout = FailureLockout::new(3, Duration::from_secs(600)); let t0 = Instant::now(); let a = ip(6); lockout.record_failure_at(a, t0); lockout.record_failure_at(a, t0); lockout.record_success(a); // streak reset lockout.record_failure_at(a, t0); lockout.record_failure_at(a, t0); // Only two failures since the reset — not yet locked. assert!(!lockout.is_locked_at(a, t0)); } #[test] fn lockout_expires_after_cooldown() { let lockout = FailureLockout::new(2, Duration::from_secs(600)); let t0 = Instant::now(); let a = ip(7); lockout.record_failure_at(a, t0); lockout.record_failure_at(a, t0); // trips assert!(lockout.is_locked_at(a, t0)); // After the cooldown the lock clears. let t1 = t0 + Duration::from_secs(601); assert!(!lockout.is_locked_at(a, t1)); } #[test] fn lockout_is_per_ip() { let lockout = FailureLockout::new(1, Duration::from_secs(600)); let t0 = Instant::now(); lockout.record_failure_at(ip(8), t0); // trips ip8 assert!(lockout.is_locked_at(ip(8), t0)); assert!(!lockout.is_locked_at(ip(9), t0)); // ip9 unaffected } }