//! In-memory rate limiting middleware (Task 4). //! //! The v1 implementation leaned on `tower_governor` and never compiled (the //! `GovernorLayer` generic signature it tried to name does not exist in the //! crate's public API). Rather than fight those generics, this is a small, //! self-contained per-IP limiter behind a `Mutex>` — no new //! dependency, easy to reason about, and trivially unit-testable. //! //! Two protections are provided, both keyed by client IP: //! //! 1. **Fixed-window rate limit** ([`RateLimiter`]): at most `max_requests` //! per `window`. Over the cap → `429 Too Many Requests` (standard error //! envelope). Wired onto `POST /api/auth/login`, //! `POST /api/auth/change-password`, and the support-code validate route. //! //! 2. **Consecutive-failure lockout** ([`FailureLockout`]): after //! `max_failures` consecutive *failed* attempts from one IP, that IP is //! locked out for `cooldown`. A single success resets the counter. This is //! the brute-force defense for the support-code space (the code-validate //! route reports per-attempt success/failure into it). //! //! Client IP is the REAL client IP from the shared trusted-proxy-aware extractor //! ([`crate::utils::ip_extract::client_ip`]) — the same source the relay and the //! audit/event log use, so all three never drift. The extractor honors //! `X-Forwarded-For` / `X-Real-IP` ONLY when the TCP peer is a configured trusted //! proxy (default: loopback, since NPM runs on the same host); a header from an //! untrusted peer is attacker-spoofable and is ignored. Keying on the real client //! IP is what makes the per-IP limiter and the failure lockout per-actual-client //! rather than per-proxy — without it, every external client buckets under the //! proxy's loopback address and one abuser could lock out the whole fleet. //! //! Memory is bounded by pruning expired entries opportunistically on each call //! and capping the map size; an unbounded attacker rotating source IPs cannot //! grow the maps without bound. use std::collections::HashMap; use std::net::IpAddr; use std::sync::Mutex; use std::time::{Duration, Instant}; use axum::{ extract::{ConnectInfo, State}, http::StatusCode, response::{IntoResponse, Response}, Json, }; use serde::Serialize; use std::net::SocketAddr; // ============================================================================ // Tunables (named constants — no magic numbers at the call sites) // ============================================================================ /// Login: window length for the fixed-window counter. pub const LOGIN_WINDOW: Duration = Duration::from_secs(60); /// Login: max requests per window per IP. Comfortable for a human retyping a /// password, hostile for a credential-stuffing loop. pub const LOGIN_MAX_PER_WINDOW: u32 = 8; /// Change-password: window length. pub const CHANGE_PASSWORD_WINDOW: Duration = Duration::from_secs(60); /// Change-password: max requests per window per IP. Tighter than login — a user /// changes their password rarely, and this endpoint already requires a valid /// session, so a low cap costs nothing legitimate. pub const CHANGE_PASSWORD_MAX_PER_WINDOW: u32 = 5; /// Support-code validate: window length. pub const CODE_VALIDATE_WINDOW: Duration = Duration::from_secs(60); /// Support-code validate: max requests per window per IP. Tight, because the /// code space is small relative to a password and this route is the brute-force /// surface for it. pub const CODE_VALIDATE_MAX_PER_WINDOW: u32 = 15; /// Support-code validate: consecutive failed validations from one IP that trip /// the lockout. pub const CODE_VALIDATE_MAX_FAILURES: u32 = 10; /// Support-code validate: how long an IP stays locked out once tripped. pub const CODE_VALIDATE_LOCKOUT: Duration = Duration::from_secs(15 * 60); /// Enroll (`POST /api/enroll`, SPEC-016): window length. pub const ENROLL_WINDOW: Duration = Duration::from_secs(60); /// Enroll: max requests per window per `(site_code, IP)`. A zero-touch site push /// drives N machines through enroll near-simultaneously, so this is generous /// (mass-deploy friendly) while still capping a runaway loop. Defense-in-depth: the /// 256-bit enrollment key is the load-bearing gate, not this cap. pub const ENROLL_MAX_PER_WINDOW: u32 = 60; /// Enroll: consecutive FAILED enroll attempts (bad/inactive key, unknown site) from /// one `(site_code, IP)` that trip the lockout. pub const ENROLL_MAX_FAILURES: u32 = 20; /// Enroll: how long a `(site_code, IP)` stays locked out once tripped. pub const ENROLL_LOCKOUT: Duration = Duration::from_secs(15 * 60); /// Hard cap on the number of distinct IPs tracked by any single limiter map. /// Prevents an IP-rotating attacker from growing memory without bound. When the /// cap is hit, the oldest-windowed entries are pruned. Generous for a real MSP /// fleet; an attacker hitting it is already being throttled per-IP. const MAX_TRACKED_IPS: usize = 100_000; // ============================================================================ // Fixed-window rate limiter // ============================================================================ /// One IP's fixed-window counter. #[derive(Debug, Clone, Copy)] struct Window { /// When the current window started. started: Instant, /// Requests counted in the current window. count: u32, } /// Per-IP fixed-window rate limiter. Cheap, lock-guarded, self-pruning. /// /// Cloneable: the inner state is shared via `Arc`-less `&'static`/`State` /// ownership in this app, but we keep an explicit `Arc` so it can live in /// `AppState` and be cloned with it. #[derive(Clone)] pub struct RateLimiter { inner: std::sync::Arc>>, max_requests: u32, window: Duration, } impl RateLimiter { /// Create a limiter allowing `max_requests` per `window` per IP. pub fn new(max_requests: u32, window: Duration) -> Self { Self { inner: std::sync::Arc::new(Mutex::new(HashMap::new())), max_requests, window, } } /// Record a request from `ip` and report whether it is allowed. /// /// Returns `true` if the request is within the cap (and counts it), `false` /// if the IP is over the cap for the current window. Uses `now` as the clock /// so the window logic is unit-testable without sleeping. fn check_at(&self, ip: IpAddr, now: Instant) -> bool { let mut map = self.inner.lock().unwrap_or_else(|e| e.into_inner()); // Opportunistic prune: if the map has grown large, drop entries whose // window has fully elapsed (they would reset on next touch anyway). if map.len() >= MAX_TRACKED_IPS { let window = self.window; map.retain(|_, w| now.duration_since(w.started) < window); } let entry = map.entry(ip).or_insert(Window { started: now, count: 0, }); // Roll the window forward if it has elapsed. if now.duration_since(entry.started) >= self.window { entry.started = now; entry.count = 0; } if entry.count >= self.max_requests { false } else { entry.count += 1; true } } /// Record a request from `ip` (using the real clock) and report whether it /// is allowed. pub fn check(&self, ip: IpAddr) -> bool { self.check_at(ip, Instant::now()) } } // ============================================================================ // Consecutive-failure lockout // ============================================================================ /// One IP's failure-streak state. #[derive(Debug, Clone, Copy)] struct FailState { /// Consecutive failures since the last success / lockout. failures: u32, /// If `Some`, the IP is locked out until this instant. locked_until: Option, /// Last time this entry was touched (for pruning). last_seen: Instant, } /// Per-IP consecutive-failure lockout. After `max_failures` consecutive /// failures the IP is locked out for `cooldown`; a success resets the streak. #[derive(Clone)] pub struct FailureLockout { inner: std::sync::Arc>>, max_failures: u32, cooldown: Duration, } impl FailureLockout { /// Create a lockout that trips after `max_failures` consecutive failures and /// holds for `cooldown`. pub fn new(max_failures: u32, cooldown: Duration) -> Self { Self { inner: std::sync::Arc::new(Mutex::new(HashMap::new())), max_failures, cooldown, } } /// Is `ip` currently locked out? (clock injected for tests) fn is_locked_at(&self, ip: IpAddr, now: Instant) -> bool { let mut map = self.inner.lock().unwrap_or_else(|e| e.into_inner()); match map.get(&ip) { Some(state) => match state.locked_until { Some(until) if now < until => true, Some(_) => { // Lockout elapsed — clear it so the IP gets a fresh start. if let Some(s) = map.get_mut(&ip) { s.locked_until = None; s.failures = 0; } false } None => false, }, None => false, } } /// Is `ip` currently locked out? (real clock) pub fn is_locked(&self, ip: IpAddr) -> bool { self.is_locked_at(ip, Instant::now()) } /// Record a failed attempt from `ip`. Trips the lockout once the streak /// reaches `max_failures`. (clock injected for tests) fn record_failure_at(&self, ip: IpAddr, now: Instant) { let mut map = self.inner.lock().unwrap_or_else(|e| e.into_inner()); if map.len() >= MAX_TRACKED_IPS { let cooldown = self.cooldown; map.retain(|_, s| { s.locked_until.map(|u| now < u).unwrap_or(false) || now.duration_since(s.last_seen) < cooldown }); } let state = map.entry(ip).or_insert(FailState { failures: 0, locked_until: None, last_seen: now, }); state.last_seen = now; state.failures = state.failures.saturating_add(1); if state.failures >= self.max_failures { state.locked_until = Some(now + self.cooldown); } } /// Record a failed attempt from `ip` (real clock). pub fn record_failure(&self, ip: IpAddr) { self.record_failure_at(ip, Instant::now()); } /// Record a successful attempt from `ip`, resetting its failure streak. pub fn record_success(&self, ip: IpAddr) { let mut map = self.inner.lock().unwrap_or_else(|e| e.into_inner()); if let Some(state) = map.get_mut(&ip) { state.failures = 0; state.locked_until = None; state.last_seen = Instant::now(); } } } // ============================================================================ // Composite-key limiter for enrollment (keyed by (site_code, IP)) — SPEC-016 // ============================================================================ // // The login / change-password / code-validate limiters above key purely on IP. // SPEC-016 §3 wants the enroll defense keyed on `(site_code, source-IP)` so a noisy // site push from one office IP cannot lock out a different site enrolling from the // same egress IP. Rather than overload the IP-only maps, this is a small dedicated // composite-key limiter + lockout. It is invoked from the enroll HANDLER (not a // `from_fn` layer) because the `site_code` lives in the JSON body, which a // pre-handler middleware cannot read without consuming it. Documented as // defense-in-depth: the 256-bit enrollment key is the real gate. /// Composite limiter key: the site_code and the real client IP. type EnrollKey = (String, IpAddr); /// Per-`(site_code, IP)` fixed-window limiter + consecutive-failure lockout. /// /// Combines both protections behind one lock-guarded map so the enroll handler /// makes a single allow/deny decision and reports success/failure into the same /// structure. Self-pruning and size-capped, like the IP-only limiters. #[derive(Clone)] pub struct EnrollLimiter { inner: std::sync::Arc>>, max_per_window: u32, window: Duration, max_failures: u32, cooldown: Duration, } #[derive(Debug, Clone, Copy)] struct EnrollEntry { window_started: Instant, count: u32, failures: u32, locked_until: Option, last_seen: Instant, } impl EnrollLimiter { pub fn new( max_per_window: u32, window: Duration, max_failures: u32, cooldown: Duration, ) -> Self { Self { inner: std::sync::Arc::new(Mutex::new(HashMap::new())), max_per_window, window, max_failures, cooldown, } } fn entry_now() -> EnrollEntry { let now = Instant::now(); EnrollEntry { window_started: now, count: 0, failures: 0, locked_until: None, last_seen: now, } } /// Admit one enroll attempt for `(site_code, ip)`. Returns `true` if allowed /// (and counts it). Returns `false` if the key is currently locked out OR over /// the per-window request cap. Clock injected for tests. fn check_at(&self, site_code: &str, ip: IpAddr, now: Instant) -> bool { let mut map = self.inner.lock().unwrap_or_else(|e| e.into_inner()); if map.len() >= MAX_TRACKED_IPS { let window = self.window; let cooldown = self.cooldown; map.retain(|_, e| { e.locked_until.map(|u| now < u).unwrap_or(false) || now.duration_since(e.window_started) < window || now.duration_since(e.last_seen) < cooldown }); } let key = (site_code.to_string(), ip); let e = map.entry(key).or_insert_with(Self::entry_now); e.last_seen = now; // Lockout takes precedence. if let Some(until) = e.locked_until { if now < until { return false; } // Cooldown elapsed — clear it for a fresh start. e.locked_until = None; e.failures = 0; } // Roll the fixed window forward if elapsed. if now.duration_since(e.window_started) >= self.window { e.window_started = now; e.count = 0; } if e.count >= self.max_per_window { false } else { e.count += 1; true } } /// Admit one enroll attempt (real clock). pub fn check(&self, site_code: &str, ip: IpAddr) -> bool { self.check_at(site_code, ip, Instant::now()) } fn record_failure_at(&self, site_code: &str, ip: IpAddr, now: Instant) { let mut map = self.inner.lock().unwrap_or_else(|e| e.into_inner()); let key = (site_code.to_string(), ip); let e = map.entry(key).or_insert_with(Self::entry_now); e.last_seen = now; e.failures = e.failures.saturating_add(1); if e.failures >= self.max_failures { e.locked_until = Some(now + self.cooldown); } } /// Record a FAILED enroll attempt (bad key / unknown site) for the key, /// tripping the lockout once the streak reaches `max_failures`. pub fn record_failure(&self, site_code: &str, ip: IpAddr) { self.record_failure_at(site_code, ip, Instant::now()); } /// Record a SUCCESSFUL enroll for the key, resetting its failure streak. pub fn record_success(&self, site_code: &str, ip: IpAddr) { let mut map = self.inner.lock().unwrap_or_else(|e| e.into_inner()); let key = (site_code.to_string(), ip); if let Some(e) = map.get_mut(&key) { e.failures = 0; e.locked_until = None; e.last_seen = Instant::now(); } } } // ============================================================================ // Shared rate-limit state (lives in AppState) // ============================================================================ /// Bundle of limiters carried in `AppState` and consumed by the middleware. #[derive(Clone)] pub struct RateLimitState { /// `POST /api/auth/login` pub login: RateLimiter, /// `POST /api/auth/change-password` pub change_password: RateLimiter, /// `GET /api/codes/:code/validate` (request-rate cap) pub code_validate: RateLimiter, /// Per-IP lockout on repeated failed code validations (brute-force defense). pub code_validate_lockout: FailureLockout, /// `POST /api/enroll` (SPEC-016): per-`(site_code, IP)` request cap + /// consecutive-failure lockout. Invoked from the enroll handler. pub enroll: EnrollLimiter, } impl RateLimitState { pub fn new() -> Self { Self { login: RateLimiter::new(LOGIN_MAX_PER_WINDOW, LOGIN_WINDOW), change_password: RateLimiter::new( CHANGE_PASSWORD_MAX_PER_WINDOW, CHANGE_PASSWORD_WINDOW, ), code_validate: RateLimiter::new(CODE_VALIDATE_MAX_PER_WINDOW, CODE_VALIDATE_WINDOW), code_validate_lockout: FailureLockout::new( CODE_VALIDATE_MAX_FAILURES, CODE_VALIDATE_LOCKOUT, ), enroll: EnrollLimiter::new( ENROLL_MAX_PER_WINDOW, ENROLL_WINDOW, ENROLL_MAX_FAILURES, ENROLL_LOCKOUT, ), } } } impl Default for RateLimitState { fn default() -> Self { Self::new() } } // ============================================================================ // 429 response (standard error envelope) // ============================================================================ #[derive(Debug, Serialize)] struct RateLimitError { detail: String, error_code: String, status_code: u16, } /// Build a `429 Too Many Requests` response using the standard error envelope. fn too_many_requests(detail: &str, error_code: &str) -> Response { ( StatusCode::TOO_MANY_REQUESTS, Json(RateLimitError { detail: detail.to_string(), error_code: error_code.to_string(), status_code: StatusCode::TOO_MANY_REQUESTS.as_u16(), }), ) .into_response() } // ============================================================================ // Axum middleware functions (one per protected route) // ============================================================================ // Selects which limiter from `RateLimitState` a middleware uses. // // Each protected route gets its own `from_fn_with_state` middleware pointing at // the matching limiter; keeping them as distinct functions avoids threading an // extra "which limiter" parameter through the layer and keeps the wiring in // `main.rs` self-documenting. /// Rate-limit middleware for `POST /api/auth/login`. pub async fn login_rate_limit( State(state): State, ConnectInfo(addr): ConnectInfo, request: axum::extract::Request, next: axum::middleware::Next, ) -> Response { let ip = crate::utils::ip_extract::client_ip(&addr, request.headers(), &state.trusted_proxies); if !state.rate_limits.login.check(ip) { tracing::warn!("Rate limit exceeded on /api/auth/login from {}", ip); return too_many_requests( "Too many login attempts. Please wait a minute and try again.", "RATE_LIMITED", ); } next.run(request).await } /// Rate-limit middleware for `POST /api/auth/change-password`. pub async fn change_password_rate_limit( State(state): State, ConnectInfo(addr): ConnectInfo, request: axum::extract::Request, next: axum::middleware::Next, ) -> Response { let ip = crate::utils::ip_extract::client_ip(&addr, request.headers(), &state.trusted_proxies); if !state.rate_limits.change_password.check(ip) { tracing::warn!( "Rate limit exceeded on /api/auth/change-password from {}", ip ); return too_many_requests( "Too many password-change attempts. Please wait a minute and try again.", "RATE_LIMITED", ); } next.run(request).await } /// Rate-limit + brute-force-lockout middleware for the support-code validate /// route (`GET /api/codes/:code/validate`). /// /// Two gates run here: /// 1. If the IP is currently locked out (too many consecutive failed /// validations), reject immediately with 429 — before the handler runs, so /// the code is never even looked up. /// 2. Otherwise apply the per-window request cap. /// /// The success/failure that drives the lockout is reported by the handler /// itself (it knows whether the code was valid), via /// [`RateLimitState::code_validate_lockout`]. pub async fn code_validate_rate_limit( State(state): State, ConnectInfo(addr): ConnectInfo, request: axum::extract::Request, next: axum::middleware::Next, ) -> Response { let ip = crate::utils::ip_extract::client_ip(&addr, request.headers(), &state.trusted_proxies); // 1. Brute-force lockout takes precedence. if state.rate_limits.code_validate_lockout.is_locked(ip) { tracing::warn!( "Code-validate request from locked-out IP {} (too many failed attempts)", ip ); return too_many_requests( "Too many invalid codes from this address. Try again later.", "RATE_LIMITED_LOCKOUT", ); } // 2. Per-window request cap. if !state.rate_limits.code_validate.check(ip) { tracing::warn!("Rate limit exceeded on code-validate from {}", ip); return too_many_requests( "Too many code validation attempts. Please wait a minute and try again.", "RATE_LIMITED", ); } next.run(request).await } // ============================================================================ // Tests // ============================================================================ #[cfg(test)] mod tests { use super::*; fn ip(n: u8) -> IpAddr { IpAddr::from([10, 0, 0, n]) } #[test] fn fixed_window_allows_up_to_cap_then_blocks() { let limiter = RateLimiter::new(3, Duration::from_secs(60)); let t0 = Instant::now(); let a = ip(1); assert!(limiter.check_at(a, t0)); // 1 assert!(limiter.check_at(a, t0)); // 2 assert!(limiter.check_at(a, t0)); // 3 assert!(!limiter.check_at(a, t0)); // 4 -> blocked assert!(!limiter.check_at(a, t0)); // still blocked } #[test] fn fixed_window_resets_after_window_elapses() { let limiter = RateLimiter::new(2, Duration::from_secs(60)); let t0 = Instant::now(); let a = ip(2); assert!(limiter.check_at(a, t0)); assert!(limiter.check_at(a, t0)); assert!(!limiter.check_at(a, t0)); // over cap // Advance past the window — counter resets. let t1 = t0 + Duration::from_secs(61); assert!(limiter.check_at(a, t1)); assert!(limiter.check_at(a, t1)); assert!(!limiter.check_at(a, t1)); } #[test] fn fixed_window_is_per_ip() { let limiter = RateLimiter::new(1, Duration::from_secs(60)); let t0 = Instant::now(); assert!(limiter.check_at(ip(3), t0)); assert!(!limiter.check_at(ip(3), t0)); // ip3 over cap assert!(limiter.check_at(ip(4), t0)); // ip4 independent } #[test] fn lockout_trips_after_consecutive_failures() { let lockout = FailureLockout::new(3, Duration::from_secs(600)); let t0 = Instant::now(); let a = ip(5); assert!(!lockout.is_locked_at(a, t0)); lockout.record_failure_at(a, t0); // 1 assert!(!lockout.is_locked_at(a, t0)); lockout.record_failure_at(a, t0); // 2 assert!(!lockout.is_locked_at(a, t0)); lockout.record_failure_at(a, t0); // 3 -> trips assert!(lockout.is_locked_at(a, t0)); } #[test] fn lockout_success_resets_streak() { let lockout = FailureLockout::new(3, Duration::from_secs(600)); let t0 = Instant::now(); let a = ip(6); lockout.record_failure_at(a, t0); lockout.record_failure_at(a, t0); lockout.record_success(a); // streak reset lockout.record_failure_at(a, t0); lockout.record_failure_at(a, t0); // Only two failures since the reset — not yet locked. assert!(!lockout.is_locked_at(a, t0)); } #[test] fn lockout_expires_after_cooldown() { let lockout = FailureLockout::new(2, Duration::from_secs(600)); let t0 = Instant::now(); let a = ip(7); lockout.record_failure_at(a, t0); lockout.record_failure_at(a, t0); // trips assert!(lockout.is_locked_at(a, t0)); // After the cooldown the lock clears. let t1 = t0 + Duration::from_secs(601); assert!(!lockout.is_locked_at(a, t1)); } #[test] fn lockout_is_per_ip() { let lockout = FailureLockout::new(1, Duration::from_secs(600)); let t0 = Instant::now(); lockout.record_failure_at(ip(8), t0); // trips ip8 assert!(lockout.is_locked_at(ip(8), t0)); assert!(!lockout.is_locked_at(ip(9), t0)); // ip9 unaffected } // -- EnrollLimiter (composite (site_code, IP) key) -------------------------- #[test] fn enroll_window_allows_up_to_cap_then_blocks() { let lim = EnrollLimiter::new(2, Duration::from_secs(60), 100, Duration::from_secs(600)); let t0 = Instant::now(); assert!(lim.check_at("SITE-A", ip(1), t0)); // 1 assert!(lim.check_at("SITE-A", ip(1), t0)); // 2 assert!(!lim.check_at("SITE-A", ip(1), t0)); // over cap } #[test] fn enroll_is_keyed_by_site_and_ip() { let lim = EnrollLimiter::new(1, Duration::from_secs(60), 100, Duration::from_secs(600)); let t0 = Instant::now(); assert!(lim.check_at("SITE-A", ip(1), t0)); assert!(!lim.check_at("SITE-A", ip(1), t0)); // same key over cap // Different site, same IP -> independent bucket. assert!(lim.check_at("SITE-B", ip(1), t0)); // Same site, different IP -> independent bucket. assert!(lim.check_at("SITE-A", ip(2), t0)); } #[test] fn enroll_lockout_trips_after_failures_and_blocks_check() { let lim = EnrollLimiter::new(100, Duration::from_secs(60), 3, Duration::from_secs(600)); let t0 = Instant::now(); lim.record_failure_at("SITE-A", ip(1), t0); lim.record_failure_at("SITE-A", ip(1), t0); // Not yet tripped: a check still admits. assert!(lim.check_at("SITE-A", ip(1), t0)); lim.record_failure_at("SITE-A", ip(1), t0); // 3rd -> trips // Now locked out: check denies even though under the request cap. assert!(!lim.check_at("SITE-A", ip(1), t0)); } #[test] fn enroll_success_resets_failure_streak() { let lim = EnrollLimiter::new(100, Duration::from_secs(60), 2, Duration::from_secs(600)); let t0 = Instant::now(); lim.record_failure_at("SITE-A", ip(1), t0); lim.record_success("SITE-A", ip(1)); // reset lim.record_failure_at("SITE-A", ip(1), t0); // Only one failure since reset -> not locked. assert!(lim.check_at("SITE-A", ip(1), t0)); } }