Files
guru-connect/server/src/middleware/rate_limit.rs
Mike Swanson bfcdbb5379
Some checks failed
Build and Test / Build Server (Linux) (push) Failing after 6m12s
Build and Test / Build Agent (Windows) (push) Successful in 6m43s
Build and Test / Security Audit (push) Successful in 4m23s
Build and Test / Build Summary (push) Has been skipped
feat(server): v2 secure-session-core Task 4 - rate limit + single-use codes
SPEC-002 Phase 1 Task 4 (the final keystone task), code-reviewed APPROVED.
Closes the audit's reusable-code HIGH and rate-limiting-disabled HIGH.

- Rebuilt rate limiting as a self-contained in-memory per-IP limiter (replaces
  the non-compiling tower_governor; removed that dep). Fixed-window caps wired
  to login (8/min), change-password (5/min), code-validate (15/min) -> 429;
  per-IP lockout after 10 consecutive failed code validations (15-min cooldown).
- Single-use support codes: atomic consume on first agent bind (in-memory
  Pending->Connected under write lock + DB conditional UPDATE), rejecting a
  second presenter; validate/preview does not consume.
- Widened code format: XXX-XXX-XXX, 31-char unambiguous alphabet (no 0/O/1/I/L),
  CSPRNG + rejection sampling, ~44.6 bits (replaces 6-digit numeric); migration
  006 widens the code columns to TEXT.

Completes the keystone (Tasks 1-4): every audit CRITICAL + HIGH in the secure
auth/session core is now addressed. Known follow-up todos (not blocking): (1)
trusted-proxy client-IP extraction (NPM-on-loopback collapses clients to
127.0.0.1); (2) multi-instance fail-closed DB single-use gate. Not
cargo-check-verified locally - build-host/CI verification follows this commit.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-05-29 21:04:54 -07:00

525 lines
19 KiB
Rust

//! In-memory rate limiting middleware (Task 4).
//!
//! The v1 implementation leaned on `tower_governor` and never compiled (the
//! `GovernorLayer` generic signature it tried to name does not exist in the
//! crate's public API). Rather than fight those generics, this is a small,
//! self-contained per-IP limiter behind a `Mutex<HashMap<…>>` — no new
//! dependency, easy to reason about, and trivially unit-testable.
//!
//! Two protections are provided, both keyed by client IP:
//!
//! 1. **Fixed-window rate limit** ([`RateLimiter`]): at most `max_requests`
//! per `window`. Over the cap → `429 Too Many Requests` (standard error
//! envelope). Wired onto `POST /api/auth/login`,
//! `POST /api/auth/change-password`, and the support-code validate route.
//!
//! 2. **Consecutive-failure lockout** ([`FailureLockout`]): after
//! `max_failures` consecutive *failed* attempts from one IP, that IP is
//! locked out for `cooldown`. A single success resets the counter. This is
//! the brute-force defense for the support-code space (the code-validate
//! route reports per-attempt success/failure into it).
//!
//! Client IP is taken from axum's [`ConnectInfo<SocketAddr>`] (the same source
//! the relay uses for `client_ip`). `X-Forwarded-For` is intentionally NOT
//! trusted here: the server terminates behind a known reverse proxy (NPM), and
//! honoring a client-settable header would let an attacker trivially rotate the
//! limiter key. If/when per-proxy XFF handling is needed it must be gated on a
//! trusted-proxy allowlist — tracked as a follow-up, not done blindly here.
//!
//! Memory is bounded by pruning expired entries opportunistically on each call
//! and capping the map size; an unbounded attacker rotating source IPs cannot
//! grow the maps without bound.
use std::collections::HashMap;
use std::net::IpAddr;
use std::sync::Mutex;
use std::time::{Duration, Instant};
use axum::{
extract::{ConnectInfo, State},
http::StatusCode,
response::{IntoResponse, Response},
Json,
};
use serde::Serialize;
use std::net::SocketAddr;
// ============================================================================
// Tunables (named constants — no magic numbers at the call sites)
// ============================================================================
/// Login: window length for the fixed-window counter.
pub const LOGIN_WINDOW: Duration = Duration::from_secs(60);
/// Login: max requests per window per IP. Comfortable for a human retyping a
/// password, hostile for a credential-stuffing loop.
pub const LOGIN_MAX_PER_WINDOW: u32 = 8;
/// Change-password: window length.
pub const CHANGE_PASSWORD_WINDOW: Duration = Duration::from_secs(60);
/// Change-password: max requests per window per IP. Tighter than login — a user
/// changes their password rarely, and this endpoint already requires a valid
/// session, so a low cap costs nothing legitimate.
pub const CHANGE_PASSWORD_MAX_PER_WINDOW: u32 = 5;
/// Support-code validate: window length.
pub const CODE_VALIDATE_WINDOW: Duration = Duration::from_secs(60);
/// Support-code validate: max requests per window per IP. Tight, because the
/// code space is small relative to a password and this route is the brute-force
/// surface for it.
pub const CODE_VALIDATE_MAX_PER_WINDOW: u32 = 15;
/// Support-code validate: consecutive failed validations from one IP that trip
/// the lockout.
pub const CODE_VALIDATE_MAX_FAILURES: u32 = 10;
/// Support-code validate: how long an IP stays locked out once tripped.
pub const CODE_VALIDATE_LOCKOUT: Duration = Duration::from_secs(15 * 60);
/// Hard cap on the number of distinct IPs tracked by any single limiter map.
/// Prevents an IP-rotating attacker from growing memory without bound. When the
/// cap is hit, the oldest-windowed entries are pruned. Generous for a real MSP
/// fleet; an attacker hitting it is already being throttled per-IP.
const MAX_TRACKED_IPS: usize = 100_000;
// ============================================================================
// Fixed-window rate limiter
// ============================================================================
/// One IP's fixed-window counter.
#[derive(Debug, Clone, Copy)]
struct Window {
/// When the current window started.
started: Instant,
/// Requests counted in the current window.
count: u32,
}
/// Per-IP fixed-window rate limiter. Cheap, lock-guarded, self-pruning.
///
/// Cloneable: the inner state is shared via `Arc`-less `&'static`/`State`
/// ownership in this app, but we keep an explicit `Arc` so it can live in
/// `AppState` and be cloned with it.
#[derive(Clone)]
pub struct RateLimiter {
inner: std::sync::Arc<Mutex<HashMap<IpAddr, Window>>>,
max_requests: u32,
window: Duration,
}
impl RateLimiter {
/// Create a limiter allowing `max_requests` per `window` per IP.
pub fn new(max_requests: u32, window: Duration) -> Self {
Self {
inner: std::sync::Arc::new(Mutex::new(HashMap::new())),
max_requests,
window,
}
}
/// Record a request from `ip` and report whether it is allowed.
///
/// Returns `true` if the request is within the cap (and counts it), `false`
/// if the IP is over the cap for the current window. Uses `now` as the clock
/// so the window logic is unit-testable without sleeping.
fn check_at(&self, ip: IpAddr, now: Instant) -> bool {
let mut map = self.inner.lock().unwrap_or_else(|e| e.into_inner());
// Opportunistic prune: if the map has grown large, drop entries whose
// window has fully elapsed (they would reset on next touch anyway).
if map.len() >= MAX_TRACKED_IPS {
let window = self.window;
map.retain(|_, w| now.duration_since(w.started) < window);
}
let entry = map.entry(ip).or_insert(Window {
started: now,
count: 0,
});
// Roll the window forward if it has elapsed.
if now.duration_since(entry.started) >= self.window {
entry.started = now;
entry.count = 0;
}
if entry.count >= self.max_requests {
false
} else {
entry.count += 1;
true
}
}
/// Record a request from `ip` (using the real clock) and report whether it
/// is allowed.
pub fn check(&self, ip: IpAddr) -> bool {
self.check_at(ip, Instant::now())
}
}
// ============================================================================
// Consecutive-failure lockout
// ============================================================================
/// One IP's failure-streak state.
#[derive(Debug, Clone, Copy)]
struct FailState {
/// Consecutive failures since the last success / lockout.
failures: u32,
/// If `Some`, the IP is locked out until this instant.
locked_until: Option<Instant>,
/// Last time this entry was touched (for pruning).
last_seen: Instant,
}
/// Per-IP consecutive-failure lockout. After `max_failures` consecutive
/// failures the IP is locked out for `cooldown`; a success resets the streak.
#[derive(Clone)]
pub struct FailureLockout {
inner: std::sync::Arc<Mutex<HashMap<IpAddr, FailState>>>,
max_failures: u32,
cooldown: Duration,
}
impl FailureLockout {
/// Create a lockout that trips after `max_failures` consecutive failures and
/// holds for `cooldown`.
pub fn new(max_failures: u32, cooldown: Duration) -> Self {
Self {
inner: std::sync::Arc::new(Mutex::new(HashMap::new())),
max_failures,
cooldown,
}
}
/// Is `ip` currently locked out? (clock injected for tests)
fn is_locked_at(&self, ip: IpAddr, now: Instant) -> bool {
let mut map = self.inner.lock().unwrap_or_else(|e| e.into_inner());
match map.get(&ip) {
Some(state) => match state.locked_until {
Some(until) if now < until => true,
Some(_) => {
// Lockout elapsed — clear it so the IP gets a fresh start.
if let Some(s) = map.get_mut(&ip) {
s.locked_until = None;
s.failures = 0;
}
false
}
None => false,
},
None => false,
}
}
/// Is `ip` currently locked out? (real clock)
pub fn is_locked(&self, ip: IpAddr) -> bool {
self.is_locked_at(ip, Instant::now())
}
/// Record a failed attempt from `ip`. Trips the lockout once the streak
/// reaches `max_failures`. (clock injected for tests)
fn record_failure_at(&self, ip: IpAddr, now: Instant) {
let mut map = self.inner.lock().unwrap_or_else(|e| e.into_inner());
if map.len() >= MAX_TRACKED_IPS {
let cooldown = self.cooldown;
map.retain(|_, s| {
s.locked_until.map(|u| now < u).unwrap_or(false)
|| now.duration_since(s.last_seen) < cooldown
});
}
let state = map.entry(ip).or_insert(FailState {
failures: 0,
locked_until: None,
last_seen: now,
});
state.last_seen = now;
state.failures = state.failures.saturating_add(1);
if state.failures >= self.max_failures {
state.locked_until = Some(now + self.cooldown);
}
}
/// Record a failed attempt from `ip` (real clock).
pub fn record_failure(&self, ip: IpAddr) {
self.record_failure_at(ip, Instant::now());
}
/// Record a successful attempt from `ip`, resetting its failure streak.
pub fn record_success(&self, ip: IpAddr) {
let mut map = self.inner.lock().unwrap_or_else(|e| e.into_inner());
if let Some(state) = map.get_mut(&ip) {
state.failures = 0;
state.locked_until = None;
state.last_seen = Instant::now();
}
}
}
// ============================================================================
// Shared rate-limit state (lives in AppState)
// ============================================================================
/// Bundle of limiters carried in `AppState` and consumed by the middleware.
#[derive(Clone)]
pub struct RateLimitState {
/// `POST /api/auth/login`
pub login: RateLimiter,
/// `POST /api/auth/change-password`
pub change_password: RateLimiter,
/// `GET /api/codes/:code/validate` (request-rate cap)
pub code_validate: RateLimiter,
/// Per-IP lockout on repeated failed code validations (brute-force defense).
pub code_validate_lockout: FailureLockout,
}
impl RateLimitState {
pub fn new() -> Self {
Self {
login: RateLimiter::new(LOGIN_MAX_PER_WINDOW, LOGIN_WINDOW),
change_password: RateLimiter::new(
CHANGE_PASSWORD_MAX_PER_WINDOW,
CHANGE_PASSWORD_WINDOW,
),
code_validate: RateLimiter::new(CODE_VALIDATE_MAX_PER_WINDOW, CODE_VALIDATE_WINDOW),
code_validate_lockout: FailureLockout::new(
CODE_VALIDATE_MAX_FAILURES,
CODE_VALIDATE_LOCKOUT,
),
}
}
}
impl Default for RateLimitState {
fn default() -> Self {
Self::new()
}
}
// ============================================================================
// 429 response (standard error envelope)
// ============================================================================
#[derive(Debug, Serialize)]
struct RateLimitError {
detail: String,
error_code: String,
status_code: u16,
}
/// Build a `429 Too Many Requests` response using the standard error envelope.
fn too_many_requests(detail: &str, error_code: &str) -> Response {
(
StatusCode::TOO_MANY_REQUESTS,
Json(RateLimitError {
detail: detail.to_string(),
error_code: error_code.to_string(),
status_code: StatusCode::TOO_MANY_REQUESTS.as_u16(),
}),
)
.into_response()
}
// ============================================================================
// Axum middleware functions (one per protected route)
// ============================================================================
/// Selects which limiter from [`RateLimitState`] a middleware uses.
///
/// Each protected route gets its own `from_fn_with_state` middleware pointing at
/// the matching limiter; keeping them as distinct functions avoids threading an
/// extra "which limiter" parameter through the layer and keeps the wiring in
/// `main.rs` self-documenting.
/// Rate-limit middleware for `POST /api/auth/login`.
pub async fn login_rate_limit(
State(state): State<crate::AppState>,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
request: axum::extract::Request,
next: axum::middleware::Next,
) -> Response {
let ip = addr.ip();
if !state.rate_limits.login.check(ip) {
tracing::warn!("Rate limit exceeded on /api/auth/login from {}", ip);
return too_many_requests(
"Too many login attempts. Please wait a minute and try again.",
"RATE_LIMITED",
);
}
next.run(request).await
}
/// Rate-limit middleware for `POST /api/auth/change-password`.
pub async fn change_password_rate_limit(
State(state): State<crate::AppState>,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
request: axum::extract::Request,
next: axum::middleware::Next,
) -> Response {
let ip = addr.ip();
if !state.rate_limits.change_password.check(ip) {
tracing::warn!(
"Rate limit exceeded on /api/auth/change-password from {}",
ip
);
return too_many_requests(
"Too many password-change attempts. Please wait a minute and try again.",
"RATE_LIMITED",
);
}
next.run(request).await
}
/// Rate-limit + brute-force-lockout middleware for the support-code validate
/// route (`GET /api/codes/:code/validate`).
///
/// Two gates run here:
/// 1. If the IP is currently locked out (too many consecutive failed
/// validations), reject immediately with 429 — before the handler runs, so
/// the code is never even looked up.
/// 2. Otherwise apply the per-window request cap.
///
/// The success/failure that drives the lockout is reported by the handler
/// itself (it knows whether the code was valid), via
/// [`RateLimitState::code_validate_lockout`].
pub async fn code_validate_rate_limit(
State(state): State<crate::AppState>,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
request: axum::extract::Request,
next: axum::middleware::Next,
) -> Response {
let ip = addr.ip();
// 1. Brute-force lockout takes precedence.
if state.rate_limits.code_validate_lockout.is_locked(ip) {
tracing::warn!(
"Code-validate request from locked-out IP {} (too many failed attempts)",
ip
);
return too_many_requests(
"Too many invalid codes from this address. Try again later.",
"RATE_LIMITED_LOCKOUT",
);
}
// 2. Per-window request cap.
if !state.rate_limits.code_validate.check(ip) {
tracing::warn!("Rate limit exceeded on code-validate from {}", ip);
return too_many_requests(
"Too many code validation attempts. Please wait a minute and try again.",
"RATE_LIMITED",
);
}
next.run(request).await
}
// ============================================================================
// Tests
// ============================================================================
#[cfg(test)]
mod tests {
use super::*;
fn ip(n: u8) -> IpAddr {
IpAddr::from([10, 0, 0, n])
}
#[test]
fn fixed_window_allows_up_to_cap_then_blocks() {
let limiter = RateLimiter::new(3, Duration::from_secs(60));
let t0 = Instant::now();
let a = ip(1);
assert!(limiter.check_at(a, t0)); // 1
assert!(limiter.check_at(a, t0)); // 2
assert!(limiter.check_at(a, t0)); // 3
assert!(!limiter.check_at(a, t0)); // 4 -> blocked
assert!(!limiter.check_at(a, t0)); // still blocked
}
#[test]
fn fixed_window_resets_after_window_elapses() {
let limiter = RateLimiter::new(2, Duration::from_secs(60));
let t0 = Instant::now();
let a = ip(2);
assert!(limiter.check_at(a, t0));
assert!(limiter.check_at(a, t0));
assert!(!limiter.check_at(a, t0)); // over cap
// Advance past the window — counter resets.
let t1 = t0 + Duration::from_secs(61);
assert!(limiter.check_at(a, t1));
assert!(limiter.check_at(a, t1));
assert!(!limiter.check_at(a, t1));
}
#[test]
fn fixed_window_is_per_ip() {
let limiter = RateLimiter::new(1, Duration::from_secs(60));
let t0 = Instant::now();
assert!(limiter.check_at(ip(3), t0));
assert!(!limiter.check_at(ip(3), t0)); // ip3 over cap
assert!(limiter.check_at(ip(4), t0)); // ip4 independent
}
#[test]
fn lockout_trips_after_consecutive_failures() {
let lockout = FailureLockout::new(3, Duration::from_secs(600));
let t0 = Instant::now();
let a = ip(5);
assert!(!lockout.is_locked_at(a, t0));
lockout.record_failure_at(a, t0); // 1
assert!(!lockout.is_locked_at(a, t0));
lockout.record_failure_at(a, t0); // 2
assert!(!lockout.is_locked_at(a, t0));
lockout.record_failure_at(a, t0); // 3 -> trips
assert!(lockout.is_locked_at(a, t0));
}
#[test]
fn lockout_success_resets_streak() {
let lockout = FailureLockout::new(3, Duration::from_secs(600));
let t0 = Instant::now();
let a = ip(6);
lockout.record_failure_at(a, t0);
lockout.record_failure_at(a, t0);
lockout.record_success(a); // streak reset
lockout.record_failure_at(a, t0);
lockout.record_failure_at(a, t0);
// Only two failures since the reset — not yet locked.
assert!(!lockout.is_locked_at(a, t0));
}
#[test]
fn lockout_expires_after_cooldown() {
let lockout = FailureLockout::new(2, Duration::from_secs(600));
let t0 = Instant::now();
let a = ip(7);
lockout.record_failure_at(a, t0);
lockout.record_failure_at(a, t0); // trips
assert!(lockout.is_locked_at(a, t0));
// After the cooldown the lock clears.
let t1 = t0 + Duration::from_secs(601);
assert!(!lockout.is_locked_at(a, t1));
}
#[test]
fn lockout_is_per_ip() {
let lockout = FailureLockout::new(1, Duration::from_secs(600));
let t0 = Instant::now();
lockout.record_failure_at(ip(8), t0); // trips ip8
assert!(lockout.is_locked_at(ip(8), t0));
assert!(!lockout.is_locked_at(ip(9), t0)); // ip9 unaffected
}
}