Implementing API Rate Limiting in Rust

Cover image

Hello world! We’re going to talk about implementing rate limiting for your API in Rust. When it comes to services in production, you want to ensure that bad actors aren’t abusing your APIs - this is where API rate limiting comes in.

For this tutorial we will be implementing the “sliding window” algorithm by having a dynamic period to check request histories over as well, using a basic in-memory hashmap to store users and their request times. We will also look at using tower-governor to configure rate limiting for you.

Implementing a naive sliding window rate limiter

To figure out how we can do this from the ground up, let’s write a naive sliding window IP based rate limiter from scratch.

To get started, we’re going to initialise a regular project using cargo init and follow the prompt, picking Axum as our framework of choice.

We’re going to need some extra dependencies, so let’s install them with this shell snippet:

cargo add serde@1.0.196 -F derive
cargo add chrono@0.4.34 -F serde,clock

We’ll declare a new struct that holds a HashMap of IpAddr keys with the values being Vec<DateTime<Utc>> (a Vector of UTC-timezone timestamps).

use std::sync::{Arc, Mutex};
use std::collections::HashMap;
use std::net::IpAddr;
use chrono::{DateTime, Utc};

// This will be the request limit (per minute) for a user to access an endpoint
// If the user attempts to go beyond this limit, we should return an error
const REQUEST_LIMIT: usize = 120;

#[derive(Clone, Default)]
pub struct RateLimiter {
    requests: Arc<Mutex<HashMap<IpAddr, Vec<DateTime<Utc>>>>>,
}

To get started, we’ll want to lock our hashmap by using .lock() which gives us write access. Then we’ll want to check whether or not the hashmap contains a key containing the IP address we want to check for with the .entry() function, then modify it by retaining valid timestamps and pushing a new entry depending on whether or not the length is under the request limit. We then check if the entry length is a higher length than the request limit - if so, return an error; if not, return Ok(()).

impl RateLimiter {
    fn check_if_rate_limited(&self, ip_addr: IpAddr) -> Result<(), String> {
        // we only want to keep timestamps from up to 60 seconds ago
        let throttle_time_limit = Utc::now() - std::time::Duration::from_secs(60);

        let mut requests_hashmap = self.requests.lock().unwrap();

        let mut requests_for_ip = requests_hashmap
            // grab the entry here and allow us to modify it in place
            .entry(ip_addr)
            // if the entry is empty, insert a vec with the current timestamp
            .or_insert(Vec::new());

        requests_for_ip.retain(|x| x.to_utc() > throttle_time_limit);
        requests_for_ip.push(Utc::now());

        if requests_for_ip.len() > REQUEST_LIMIT {
            return Err("IP is rate limited :(".to_string());
        }

        Ok(())
    }
}

Here is a basic example of how you might use this:

use std::net::Ipv4Addr;

fn main() {
let rate_limiter = RateLimiter::default();

    let localhost_v4 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));

    // here we request 120 times - our request limit
    for _ in 1..80 {
    assert!(rate_limiter.check_if_rate_limited(localhost_v4).is_ok())
    }

    // wait 30 seconds
    std::thread::sleep(std::time::Duration::from_secs(30));

    // make another 40 requests here to satisfy request quota
    for _ in 1..40 {
    assert!(rate_limiter.check_if_rate_limited(localhost_v4).is_ok())
    }

    // wait another 30 seconds
    std::thread::sleep(std::time::Duration::from_secs(30));

    // now we can make another 80 requests
    for _ in 1..80 {
    assert!(rate_limiter.check_if_rate_limited(localhost_v4).is_ok())
    }
}

From here if we wanted to extend this to work with Axum, we could. However, production-ready rate-limiting systems are typically much more advanced than this. We’ll be discussing how you can utilize crates for rate limiting below, including usage of user-based rate limiting.

Implementing user-based rate limiting

For external-facing websites without a login, IP addresses are the only thing you can use (besides browser information) to track users. However, it can be much more useful to rate limit based on authenticated users rather than IP addresses. While working with IP addresses, you may run into the following issues:

  • Multiple users may have the same IP address
  • Users can simply change the IP address they use if you block them (via proxy or other methods)

Working with user-based rate limiting allows us to solve these issues. While users can have more than one IP address, we can assign it all to the same user.

Getting started

To initialise our web service we'll use shuttle init (requires cargo-shuttle installed) to create our project, making sure to pick Axum as the framework.

Before adding the rate limiter itself, we’re going to create a custom header key! This will be used in routes where we require user authentication. We can also use the header when implementing our custom key extractor for the rate limiter later on. We’ll want to start by adding axum-extra with the typed-header feature:

cargo add axum-extra -F typed-header

Next we’ll want to create a struct that will hold a String and implement the axum_extra re-export of headers::Header. You can see the Header implementation below, where it decodes the value by iterating over HeaderValue and creates the CustomHeader struct.

We can start by defining a HeaderName:

static X: HeaderName = HeaderName::from_static("x-custom-key");
static CUSTOM_HEADER: &HeaderName = &X;

pub struct CustomHeader(String);

impl CustomHeader {
    pub fn key(self) -> String {
        self.0
    }
}

Now that we’ve defined our custom header name (which will be used as the header key), we can implement axum_extra::headers::Header for CustomHeader:


impl Header for CustomHeader {
    fn name() -> &'static HeaderName {
        CUSTOM_HEADER
    }

    fn decode<'i, I>(values: &mut I) -> Result<Self, axum_extra::headers::Error>
    where
        I: Iterator<Item = &'i HeaderValue>,
    {
        let value = values
            .next()
            .ok_or_else(axum_extra::headers::Error::invalid)?;

        Ok(CustomHeader(value.to_str().unwrap().to_owned()))
    }

    fn encode<E>(&self, values: &mut E)
    where
        E: Extend<HeaderValue>,
    {
        let s = &self.0;

        let value = HeaderValue::from_str(s).unwrap();

        values.extend(std::iter::once(value));
    }
}

To use CustomHeader as an Axum extractor, we need to wrap it in TypedHeader like so:

async fn register(
    TypedHeader(header): TypedHeader<CustomHeader>,
) -> impl IntoResponse {
    // .. your code goes here
}

This is all well and good, but how does this relate to rate limiting?

While we can use this in middleware, a better alternative solution would be to use tower_governor. This crate is a Tower service backed by the governor crate (a crate for regulating data with rate limiting) and makes it much easier to implement rate limiting, The crate uses the Generic Cell Rate Algorithm (GCRA) which is a much more sophisticated version of a leaky bucket. You can read much more about GCRA here.

To get started, we’ll add the crate to our Rust program:

cargo add tower-governor

When we want to add it to our main function, we can do it by using GovernorConfigBuilder and then adding it into GovernorLayer. Note that while GovernorConfigBuilder doesn’t implement Clone, adding a Tower service layer requires it to implement Clone. This means that we need to box the config builder and then later on, we can use Box::leak to leak the box to get a &'static lifetime GovernorConfig for usage with our axum::Router:

use auxm::{Router, routing::get};
use tower_governor::{governor::GovernorConfigBuilder, GovernorLayer};

#[shuttle_runtime::main]
async fn main() -> shuttle_axum::ShuttleAxum {
    let governor_conf = Box::new(
        GovernorConfigBuilder::default()
            .per_second(2)
            .burst_size(5)
            .finish()
            .unwrap(),
    );

    let router = Router::new()
        .route("/", get(hello_world))
        .layer(GovernorLayer {
            // We can leak this because it is created once and then never needs to be destructed
            config: Box::leak(governor_conf),
        });

    Ok(router.into())
}

By default, GovernorConfigBuilder uses a type called PeerIpKeyExtractor which attempts to grab the IP key of a connecting client. However, to use our header as the extracted key we can implement tower_governor::key_extractor::KeyExtractor. To do this, we’ll use a unit struct as when we add it to GovernorConfigBuilder later on, there aren’t currently any extra variables we need:

use tower_governor::GovernorError;
use axum::http::Request;

#[derive(Clone)]
pub struct CustomHeaderExtractor;

impl KeyExtractor for CustomHeaderExtractor {
    type Key = String;

    fn extract<T>(&self, req: &Request<T>) -> Result<Self::Key, GovernorError> {
        let headers = req.headers();

        match headers.get(CUSTOM_HEADER) {
            Some(res) => {
                let res = res.to_str()
                    .map_err(|_| GovernorError::UnableToExtractKey)?;

                Ok(res.to_owned())
            },
            None => Err(GovernorError::UnableToExtractKey)
        }
    }
}

This allows us to add CustomHeaderExtractor to our GovernorConfigBuilder in our main function.

let governor_conf = Box::new(
    GovernorConfigBuilder::default()
        .per_second(2)
        .burst_size(5)
        .key_extractor(CustomHeaderExtractor)
        .finish()
        .unwrap(),
);

When a user attempts to access any route that is layered with the GovernorLayer, now it’ll attempt to get a header with the header name x-custom-key - if it’s not present, the route will return an error. Here we have set the limit to allow users to send 5 requests every 2 seconds.

Note that in the builder, the per_second() function tells us exactly how many seconds the interval will be between replenishing the quota and burst_size tells us what the quota is before tower-governor will start blocking requests from a given IP address (or API key, in our case). We can also additionally set per_millisecond() and per_nanosecond() parameters so that if you want to replenish the quota every half a second for example, you can use per_millisecond(500) in the builder.

Deploying

Now that we’re done, you can deploy using shuttle deploy (add --ad if on a dirty Git branch) and watch the magic happen. Once finished, Shuttle will output the details of your deployment in the terminal.

Finishing Up

Thanks for reading! With this guide, implementing rate limiting in a Rust web service should be much easier to do. Productionizing Rust web services has never been easier!

Read more:

  • Try out implementing JWT authentication here.
  • Learn more about the Tracing ecosystem for logging here.
This blog post is powered by shuttle - The Rust-native, open source, cloud development platform. If you have any questions, or want to provide feedback, join our Discord server!
Share article
rocket

Build the Future of Backend Development with us

Join the movement and help revolutionize the world of backend development. Together, we can create the future!