Hello world! We’re going to talk about implementing rate limiting for your API in Rust. When it comes to services in production, you want to ensure that bad actors aren’t abusing your APIs - this is where API rate limiting comes in.
For this tutorial we will be implementing the “sliding window” algorithm by having a dynamic period to check request histories over as well, using a basic in-memory hashmap to store users and their request times. We will also look at using tower-governor
to configure rate limiting for you.
Implementing a naive sliding window rate limiter
To figure out how we can do this from the ground up, let’s write a naive sliding window IP based rate limiter from scratch.
To get started, we’re going to initialise a regular project using cargo init
and follow the prompt, picking Axum as our framework of choice.
We’re going to need some extra dependencies, so let’s install them with this shell snippet:
cargo add serde@1.0.196 -F derive
cargo add chrono@0.4.34 -F serde,clock
We’ll declare a new struct that holds a HashMap
of IpAddr
keys with the values being Vec<DateTime<Utc>>
(a Vector of UTC-timezone timestamps).
use std::sync::{Arc, Mutex};
use std::collections::HashMap;
use std::net::IpAddr;
use chrono::{DateTime, Utc};
// This will be the request limit (per minute) for a user to access an endpoint
// If the user attempts to go beyond this limit, we should return an error
const REQUEST_LIMIT: usize = 120;
#[derive(Clone, Default)]
pub struct RateLimiter {
requests: Arc<Mutex<HashMap<IpAddr, Vec<DateTime<Utc>>>>>,
}
To get started, we’ll want to lock our hashmap by using .lock()
which gives us write access. Then we’ll want to check whether or not the hashmap contains a key containing the IP address we want to check for with the .entry()
function, then modify it by retaining valid timestamps and pushing a new entry depending on whether or not the length is under the request limit. We then check if the entry length is a higher length than the request limit - if so, return an error; if not, return Ok(())
.
impl RateLimiter {
fn check_if_rate_limited(&self, ip_addr: IpAddr) -> Result<(), String> {
// we only want to keep timestamps from up to 60 seconds ago
let throttle_time_limit = Utc::now() - std::time::Duration::from_secs(60);
let mut requests_hashmap = self.requests.lock().unwrap();
let mut requests_for_ip = requests_hashmap
// grab the entry here and allow us to modify it in place
.entry(ip_addr)
// if the entry is empty, insert a vec with the current timestamp
.or_insert(Vec::new());
requests_for_ip.retain(|x| x.to_utc() > throttle_time_limit);
requests_for_ip.push(Utc::now());
if requests_for_ip.len() > REQUEST_LIMIT {
return Err("IP is rate limited :(".to_string());
}
Ok(())
}
}
Here is a basic example of how you might use this:
use std::net::Ipv4Addr;
fn main() {
let rate_limiter = RateLimiter::default();
let localhost_v4 = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
// here we request 120 times - our request limit
for _ in 1..80 {
assert!(rate_limiter.check_if_rate_limited(localhost_v4).is_ok())
}
// wait 30 seconds
std::thread::sleep(std::time::Duration::from_secs(30));
// make another 40 requests here to satisfy request quota
for _ in 1..40 {
assert!(rate_limiter.check_if_rate_limited(localhost_v4).is_ok())
}
// wait another 30 seconds
std::thread::sleep(std::time::Duration::from_secs(30));
// now we can make another 80 requests
for _ in 1..80 {
assert!(rate_limiter.check_if_rate_limited(localhost_v4).is_ok())
}
}
From here if we wanted to extend this to work with Axum, we could. However, production-ready rate-limiting systems are typically much more advanced than this. We’ll be discussing how you can utilize crates for rate limiting below, including usage of user-based rate limiting.
Implementing user-based rate limiting
For external-facing websites without a login, IP addresses are the only thing you can use (besides browser information) to track users. However, it can be much more useful to rate limit based on authenticated users rather than IP addresses. While working with IP addresses, you may run into the following issues:
- Multiple users may have the same IP address
- Users can simply change the IP address they use if you block them (via proxy or other methods)
Working with user-based rate limiting allows us to solve these issues. While users can have more than one IP address, we can assign it all to the same user.
Getting started
To initialise our web service we'll use shuttle init
(requires cargo-shuttle
installed) to create our project, making sure to pick Axum as the framework.
Before adding the rate limiter itself, we’re going to create a custom header key! This will be used in routes where we require user authentication. We can also use the header when implementing our custom key extractor for the rate limiter later on. We’ll want to start by adding axum-extra
with the typed-header
feature:
cargo add axum-extra -F typed-header
Next we’ll want to create a struct that will hold a String and implement the axum_extra
re-export of headers::Header
. You can see the Header
implementation below, where it decodes the value by iterating over HeaderValue
and creates the CustomHeader
struct.
We can start by defining a HeaderName
:
static X: HeaderName = HeaderName::from_static("x-custom-key");
static CUSTOM_HEADER: &HeaderName = &X;
pub struct CustomHeader(String);
impl CustomHeader {
pub fn key(self) -> String {
self.0
}
}
Now that we’ve defined our custom header name (which will be used as the header key), we can implement axum_extra::headers::Header
for CustomHeader
:
impl Header for CustomHeader {
fn name() -> &'static HeaderName {
CUSTOM_HEADER
}
fn decode<'i, I>(values: &mut I) -> Result<Self, axum_extra::headers::Error>
where
I: Iterator<Item = &'i HeaderValue>,
{
let value = values
.next()
.ok_or_else(axum_extra::headers::Error::invalid)?;
Ok(CustomHeader(value.to_str().unwrap().to_owned()))
}
fn encode<E>(&self, values: &mut E)
where
E: Extend<HeaderValue>,
{
let s = &self.0;
let value = HeaderValue::from_str(s).unwrap();
values.extend(std::iter::once(value));
}
}
To use CustomHeader
as an Axum extractor, we need to wrap it in TypedHeader
like so:
async fn register(
TypedHeader(header): TypedHeader<CustomHeader>,
) -> impl IntoResponse {
// .. your code goes here
}
This is all well and good, but how does this relate to rate limiting?
While we can use this in middleware, a better alternative solution would be to use tower_governor
. This crate is a Tower service backed by the governor
crate (a crate for regulating data with rate limiting) and makes it much easier to implement rate limiting, The crate uses the Generic Cell Rate Algorithm (GCRA) which is a much more sophisticated version of a leaky bucket. You can read much more about GCRA here.
To get started, we’ll add the crate to our Rust program:
cargo add tower-governor
When we want to add it to our main function, we can do it by using GovernorConfigBuilder
and then adding it into GovernorLayer
. Note that while GovernorConfigBuilder
doesn’t implement Clone
, adding a Tower service layer requires it to implement Clone
. This means that we need to box the config builder and then later on, we can use Box::leak
to leak the box to get a &'static
lifetime GovernorConfig
for usage with our axum::Router
:
use auxm::{Router, routing::get};
use tower_governor::{governor::GovernorConfigBuilder, GovernorLayer};
#[shuttle_runtime::main]
async fn main() -> shuttle_axum::ShuttleAxum {
let governor_conf = Box::new(
GovernorConfigBuilder::default()
.per_second(2)
.burst_size(5)
.finish()
.unwrap(),
);
let router = Router::new()
.route("/", get(hello_world))
.layer(GovernorLayer {
// We can leak this because it is created once and then never needs to be destructed
config: Box::leak(governor_conf),
});
Ok(router.into())
}
By default, GovernorConfigBuilder
uses a type called PeerIpKeyExtractor
which attempts to grab the IP key of a connecting client. However, to use our header as the extracted key we can implement tower_governor::key_extractor::KeyExtractor
. To do this, we’ll use a unit struct as when we add it to GovernorConfigBuilder
later on, there aren’t currently any extra variables we need:
use tower_governor::GovernorError;
use axum::http::Request;
#[derive(Clone)]
pub struct CustomHeaderExtractor;
impl KeyExtractor for CustomHeaderExtractor {
type Key = String;
fn extract<T>(&self, req: &Request<T>) -> Result<Self::Key, GovernorError> {
let headers = req.headers();
match headers.get(CUSTOM_HEADER) {
Some(res) => {
let res = res.to_str()
.map_err(|_| GovernorError::UnableToExtractKey)?;
Ok(res.to_owned())
},
None => Err(GovernorError::UnableToExtractKey)
}
}
}
This allows us to add CustomHeaderExtractor
to our GovernorConfigBuilder
in our main function.
let governor_conf = Box::new(
GovernorConfigBuilder::default()
.per_second(2)
.burst_size(5)
.key_extractor(CustomHeaderExtractor)
.finish()
.unwrap(),
);
When a user attempts to access any route that is layered with the GovernorLayer
, now it’ll attempt to get a header with the header name x-custom-key
- if it’s not present, the route will return an error. Here we have set the limit to allow users to send 5 requests every 2 seconds.
Note that in the builder, the per_second()
function tells us exactly how many seconds the interval will be between replenishing the quota and burst_size
tells us what the quota is before tower-governor
will start blocking requests from a given IP address (or API key, in our case). We can also additionally set per_millisecond()
and per_nanosecond()
parameters so that if you want to replenish the quota every half a second for example, you can use per_millisecond(500)
in the builder.
Deploying
Now that we’re done, you can deploy using shuttle deploy
(add --ad
if on a dirty Git branch) and watch the magic happen. Once finished, Shuttle will output the details of your deployment in the terminal.
Finishing Up
Thanks for reading! With this guide, implementing rate limiting in a Rust web service should be much easier to do. Productionizing Rust web services has never been easier!
Read more: