Implementing Middleware in Rust

Cover image

In this post we will take a general look into what middleware in Rust is, the benefits of using middleware and then how to use middleware in a Rust server application.

What is middleware?

A web server generally provides responses to requests. Very often, the protocol of choice is HTTP. A handler (sometimes called a response callback) is a function which takes a request's data and returns a response.

Most server frameworks have a system called a 'router' which routes requests based on various parameters - usually the URL path. In HTTP routing is typically a combination of the path and the request method (GET, POST, PUT etc.). The benefit of a router is that it allows splitting logic up by path, which makes building large systems with lots of endpoints easier to manage.

Individual path handlers are great, but sometimes you want logic which applies to a group of paths or indeed all paths. This is where middleware comes in. Unlike a handler, middleware is called on every request and path that it's registered on. Like handlers, middleware are functions.

Middleware is very much implementor dependent. We will have a look at some concrete examples, but different frameworks have opted for different tradeoffs in their middleware implementation. Some middleware implementations work on an immutable state and act as a transformer on request and responses. Other frameworks treat the inputs as mutable and can freely modify / mutate the request objects. Some frameworks implement Rust middleware that can fail or short circuit.

Middleware as a stack

Middleware tends to be well-ordered. That is, a request or response passes through middleware in a well-defined order, as each layer processes the request or response and passes it onto the next layer:

        requests

           |

           v

+----- layer_three -----+

| +---- layer_two ----+ |

| | +-- layer_one --+ | |

| | |               | | |

| | |    handler    | | |

| | |               | | |

| | +-- layer_one --+ | |

| +---- layer_two ----+ |

+----- layer_three -----+

           |

           v

        responses

Applications of middleware

Authentication

Many routes may want user information. The incoming request contain user information via cookies or http authentication. Rather than every path handler having to deal with extracting the information we can abstract this logic to a request middleware and pass it down to subsequent handlers.

Logging

Information about which paths users are going to and when can be very useful. With logging middleware we can log and store request information for later analysis.

Similar to logging is server response timings. This is a field / http header, which is standardized for holding timing information about requests. Here our middleware can log the start time of an incoming request and the end time on the response. Then the middleware can modify the outgoing response to include the timing. This header is often highlighted in developer tools, which can be useful while debugging. It can also be used in chunked / streamed responses where the header of a request might have already been sent by using Trailers.

Compression and other response optimizations

Middleware can also amend outgoing responses and compress the output via algorithms like gzip and brotli. This removes the responsibility out of handlers and provides a convenient default for all responses.

And it doesn't have to just be compression of responses, another use case is image resizing. Identifying mobile viewports using information on the request, outgoing responses can instead return smaller images rather than huge 4k images, in the end reducing bandwidth.

Structuring applications

As mentioned above the benefits of the middleware system is that while it is possible to do this stuff individually in each handler, abstracting it moves the responsibility away from the handlers. This can make management simpler and fewer lines of code!

fn index() {

    let index_page = "...";

    return compress(index_page);

}



fn about() {

    let about_page = "...";

    return compress(about_page);

}



fn search() {

    let search_page = "...";

    return compress(search_page);

}



Application::build()

    .routes([index, about, search])

vs

fn index() { return "..."; }

fn about() { return "..."; }

fn search() { return "..."; }



Application::build()

    .routes([index, about, search])

    .add_middleware(CompressionMiddleware::new())

Separating out code

The benefit of middleware just being functions is that they can be separated out to different modules or even crates. Many 3rd party services may choose to expose their service as a middleware rather than a system of complicated functions, and having to deal with users passing the correct state into them.

    .add_middleware(hot_new_server_logging_framework_start_up::Middleware::new())

Comparing middleware implementations in libraries

Rocket

Rocket is a server framework. Rocket's middleware implementation is known as fairings (yes there are many rocket related puns in the crate).

From Rocket's fairing documentation:

Rocket’s fairings are a lot like middleware from other frameworks, but they bear a few key distinctions:

Fairings cannot terminate or respond to an incoming request directly. Fairings cannot inject arbitrary, non-request data into a request. Fairings can prevent an application from launching. Fairings can inspect and modify the application's configuration.

To make a fairing in Rocket you have to implement the fairing trait:

struct MyCounterFairing {

    get_requests: AtomicUsize,

}



#[rocket::async_trait]

impl Fairing for MyCounterFairing {

    fn info(&self) -> Info {

        Info {

            name: "GET Counter",

            kind: Kind::Request

        }

    }



    async fn on_request(&self, request: &mut Request<'_>, _: &mut Data<'_>) {

        if let Method::Get = request.method() {

            self.get.fetch_add(1, Ordering::Relaxed);

        }

    }

}

Using the .attach method it's really simple to add a fairing to a application.

#[launch]

fn rocket() -> _ {

    rocket::build()

        .attach(MyCounterFairing {

            get_requests: AtomicUsize::new(0),

        })

        .attach(other_fairing)

}

Rocket middleware has several hooks. Each of them has a default implementation so can be left out (you don't have to explicitly write a method for each hook).

Requests using on_request

This fires when a request is received. This hook has a mutable reference to the request and so can modify the request. "It cannot abort or respond directly to the request; these issues are better handled via request guards or via response callbacks.".

As an aside, Rocket has a different non-middleware implementation that can be better suited for handlers that might short circuit an error rather than running a handler afterwards. We won't go into it here but if your middleware is fallible request guards might be a better option

Response using on_response

Similar to on_request this has mutable access to the response object (it also has immutable access to the request). Using this hook you can inject headers or amend partial responses (aka 404).

General server hooks

Rocket's fairings go beyond request and responses and can act as hooks into application startup and closing:

  • Ignite (on_ignite). Runs before starting the server. Can validate config values, set initial state or abort.
  • Liftoff (on_liftoff). After server has launched (started) "A liftoff callback can be a convenient hook for launching services related to the Rocket application being launched."
  • Shutdown (on_shutdown). This hook can be used to wind down services and save state before the application closes. Runs concurrently and no requests are returned before.

All Rocket fairings have a info field. The kind property decides which hooks the fairing can fire.

Ad hoc fairings

Simpler middleware using functions can be added using ad-hoc fairings. If the fairing doesn't have state / data with it, you can bypass needing to create a structure and writing a trait implementation for it and instead write a function.

Using AdHoc and any of the names of the above mentioned hooks we can instead creating a function using a function (+ a string info):

.attach(AdHoc::on_liftoff("Liftoff Printer", |_| Box::pin(async move {

    println!("...annnddd we have liftoff!");

})))

Axum

Similar to Rocket, Axum is a HTTP framework for web applications. Axum middleware is based of tower which is a separate crate which deals with lower level base for networking in Rust. Axum and tower middleware is referred to a 'layers'.

There are several ways to write middleware in Axum. Similar to standard fairings you can create a type that implements the Layer trait. The layer trait decorates / acts apon the Service trait.

This demo was taken from the Tower docs and before you get scared off we will see a much simpler way to implement middleware shortly.

pub struct LogLayer {

    target: &'static str,

}



impl<S> Layer<S> for LogLayer {

    type Service = LogService<S>;



    fn layer(&self, service: S) -> Self::Service {

        LogService {

            target: self.target,

            service

        }

    }

}



// This service implements the Log behavior

pub struct LogService<S> {

    target: &'static str,

    service: S,

}



impl<S, Request> Service<Request> for LogService<S>

where

    S: Service<Request>,

    Request: fmt::Debug,

{

    type Response = S::Response;

    type Error = S::Error;

    type Future = S::Future;



    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {

        self.service.poll_ready(cx)

    }



    fn call(&mut self, request: Request) -> Self::Future {

        // Insert log statement here or other functionality

        println!("request = {:?}, target = {:?}", request, self.target);

        self.service.call(request)

    }

}

We can register our mew layer (middleware) on a to a Axum application using .layer (similar to .attach in Rocket).

use axum::{routing::get, Router};



async fn handler() {}



let app = Router::new()

    .route("/", get(handler))

    .layer(LogLayer { target: "our site" })

    // `.route_layer` will only run the middleware if a route is matched

    .route_layer(TimeOutLayer)

There is also ServiceBuilder which is the recommended way to chain layers. They are executed in the reverse order to which they are attached (layer_one runs first).

Router::new()

    .route("/", get(handler))

    .layer(

        ServiceBuilder::new()

            .layer(layer_three)

            .layer(layer_two)

            .layer(layer_one)

    )

A simpler way

Similar to Rocket's trait fairings and ad hoc fairings there are two ways to write middleware for Axum using middleware::from_fn.

Using a demo from the Axum docs.

async fn auth<B>(req: Request<B>, next: Next<B>) -> Result<Response, StatusCode> {

    let auth_header = req.headers()

        .get(http::header::AUTHORIZATION)

        .and_then(|header| header.to_str().ok());



    match auth_header {

        Some(auth_header) if token_is_valid(auth_header) => {

            Ok(next.run(req).await)

        }

        _ => Err(StatusCode::UNAUTHORIZED),

    }

}
let app = Router::new()

    .route("/", get(|| async { /* ... */ }))

    .route_layer(middleware::from_fn(auth));

Existing ready to use layers:

As Axum is built on tower there are some great readily importable middleware that can be added as layers.

One of those is that TraceLayer that logs requests coming in and responses going out:

Mar 05 20:50:28.523 DEBUG request{method=GET path="/foo"}: tower_http::trace::on_request: started processing request

Mar 05 20:50:28.524 DEBUG request{method=GET path="/foo"}: tower_http::trace::on_response: finished processing request latency=1 ms status=200

There are a bunch of layers in the tower_http crate that can be used instead of writing your own.

Building authentication using our own middleware

Let's play around with a realistic example and build a middleware layer for our own application that manages authentication. In our route handlers we might want to know detailed information about the user that made the request. Rather than having to deal with passing around request information we can encapsulate this logic in middleware.

We'll be using Axum for this demo. The demo is not public at the moment, look out for a future post about authentication for when the full demo is made public!

Cookies as user state

Cookies can be used for maintaining user state. When a user cookie is set on the frontend it's sent with every request. We'll skip over how the cookie got there 😅 and leave it for a future tutorial.

Either way we want to add middleware which injects the following the struct into current request.

#[derive(Clone)]

struct AuthState(Option<(SessionId, Arc<OnceCell<User>>)>, Database);

We have got a bit fancy here. Rather than making a database request on every request we instead save the database pool in a mutable store (OnceCell) together with the session id. With all this information it means that getting user state can be lazy or not done at all.

We will build an auth function which builds up this lazy AuthState struct with the required information by parsing the headers of a request.

async fn auth<B>(

    mut req: Request<B>,

    next: Next<B>,

    database: Database,

) -> axum::response::Response {

    // Assuming we only have one cookie

    let key_pair_opt = req

        .headers()

        .get("Cookie")

        .and_then(|value| value.to_str().ok())

        .map(|value|

            value

                .split_once(';')

                .map(|(left, _)| left)

                .unwrap_or(value)

        )

        .and_then(|kv| kv.split_once('='));



    let auth_state = if let Some((key, value)) = key_pair_opt {

        if key != USER_COOKIE_NAME {

            None

        } else if let Ok(value) = value.parse::<u128>() {

            Some(value)

        } else {

            None

        }

    } else {

        None

    };



    req.extensions_mut().insert(AuthState(

        auth_state

            .map(|v| (

                v,

                Arc::new(OnceCell::new()),

                database

            )),

    ));

    next.run(req).await

}

this is a bit ad hoc parsing, proper parsing should account for multiple cookies etc and could be neater 😆.

At the end we do two important things. First we extend the request with this lazy auth state: req.extensions_mut().insert(...). Secondly we run the rest of the request stack: next.run(req).await.

Unlike Rocket fairings, in Axum we could return our own Response from the middleware and not run the handler by skipping next.run(req).await.

Attaching the middleware

We first attach it to our Axum application using:

let middleware_database = database_pool.clone();



Router::new()

    .layer(middleware::from_fn(move |req, next| {

        auth(req, next, middleware_database.clone())

    }))

Because our middleware also needs application state (in this case the database pool), we create a intermediate function which pulls that in.

Using the middleware

We can now use the state injected by the middleware using the Extension parameter.

async fn me(

    Extension(current_user): Extension<AuthState>,

) -> Result<impl IntoResponse, impl IntoResponse> {

    if let Some(user) = current_user.get_user().await {

        Ok(show_user(user))

    } else {

        Err(error_page("Not logged in"));

    }

}

I was actually surprised when this worked, Axum's handler parameter system is quite magic.

Conclusion

I hope you enjoyed reading this guide to using middleware in Rust! In summary, middleware helps you abstract common logic for paths into reusable stateful and stateless objects. Middleware might not be applicative for every scenario but when you need it, it is super useful!

Did this article help you? Feel free to give us a star on GitHub!

This blog post is powered by shuttle - The Rust-native, open source, cloud development platform. If you have any questions, or want to provide feedback, join our Discord server!
Share article
rocket

Build the Future of Backend Development with us

Join the movement and help revolutionize the world of backend development. Together, we can create the future!