Prompting AWS Bedrock with Rust

Cover image

Hey there! In this article we're going to talk about using AWS Bedrock with Rust. At the end of this article, you’ll have a API that can take a JSON prompt from a HTTP request and return an answer from AWS Bedrock that can be streamed or returned as a full response.

Interested in the full code? You can find the repository here.

What is AWS Bedrock?

AWS Bedrock is one of the AI-based services that Amazon offers. It allows you to use models directly for inference and generative AI.

Compared to other AWS offerings like SageMaker, you only pay for each API call. This makes it much cheaper to use in a real application compared to SageMaker, which charges you based on instance uptime and can snowball costs very quickly. Bedrock also comes with tools like guard rails to allow you to customise topic/word filters (to mitigate model abuse) and add your own training data.

Getting Started

Setting up your foundational model

Before we get started, you’ll need to set up access to the foundational model you want to use and an IAM user. We'll use the Titan Text G1 Express model as an example.

To request access to a model, do the following:

  • Log into AWS Console and go to the AWS Bedrock section (it can also be found using the search bar)
  • Click "Model Access" on the left hand side - it's somewhere near the bottom.
  • Click "Manage Model access" (top right hand side of the table).
  • Find the model(s) you want access to, click the appropriate checkbox then click Request Access.

Note that some models require you to elaborate on your use case before AWS will approve access. The Titan Text models will generally grant you immediate access to use them. Once done, you'll need to find (and save) the name of the model ID that you're using!

You can find the endpoint URL you need here, as well as the model ID here.

Setting up an IAM user

You’ll also need an access key ID and secret access key from an IAM user:

  • AWS_ACCESS_KEY_ID (your Access Key)
  • AWS_SECRET_ACCESS_KEY (your Secret Access Key)

Both can be found in your IAM user credentials if you’ve already set a user up. If you don’t have an appropriate user with policies, you can get started quickly by doing the following:

  • Go to the Users menu
  • Start creating a user and go to the “Attach policies” section (then search “Bedrock”)
  • Here you can either use the “AmazonBedrockFullAccess” policy which gives you full access to Bedrock on that user, or you can create a custom policy. Select one and finish creating your user. Access to Bedrock is required, as otherwise you won’t be able to use it!
  • Go back to the Users menu and click on your newly created user
  • Go to “Access keys” and follow the prompt (clicking “Application outside of AWS”). Don’t forget to store your Access Key ID and Secret Access Key!

In production, you may want to go a step further and create a Group that you can then attach policies and users to.

Initialisation

To get started, we're going to use shuttle init (requires cargo-shuttle installed) to initialise our template, picking Axum as the framework.

Next, we're going to add our dependencies:

cargo add aws-config -F behavior-version-latest
cargo add aws-credential-types -F hardcoded-credentials
cargo add aws-sdk-bedrockruntime -F behavior-version-latest
cargo add serde -F derive
cargo add serde-json

You’ll also want to store your secrets in a Secrets.toml file (in the root of your project) like so:

AWS_ACCESS_KEY_ID = "your-access-key-id"
AWS_SECRET_ACCESS_KEY = "your-secret-access-key"
AWS_URL = "your-endpoint-url"

Setting up the AWS config

To get started, we'll create a function that will take secrets from our Secrets.toml file that we created earlier.

This can be done by adding the #[shuttle_runtime::Secrets] macro to our main function:

use shuttle_runtime::SecretStore;

#[shuttle_runtime::main]
async fn main(
    #[shuttle_runtime::Secrets] secrets: SecretStore
) -> shuttle_axum::ShuttleAxum {
    // .. your code here
}

On a local or deployment run, the secrets macro will allow the Shuttle runtime to read the Secrets.toml file!

Next we'll grab our secrets, then create our AWS Credentials struct then create an aws_config Config type. This will allow us to create the AWS Bedrock Runtime client, as well as any other client from the AWS Rust SDK that we need.

use cargo_shuttle::SecretStore;
use aws_credential_types::Credentials;
use aws_sdk_bedrockruntime::Client;

async fn create_client(secrets: SecretStore) -> Client {
    let access_key_id = secrets
        .get("AWS_ACCESS_KEY_ID")
        .expect("AWS_ACCESS_KEY_ID not set in Secrets.toml");
    let secret_access_key = secrets
        .get("AWS_SECRET_ACCESS_KEY")
        .expect("AWS_ACCESS_KEY_ID not set in Secrets.toml");
    let aws_url = secrets
        .get("AWS_URL")
        .expect("AWS_ACCESS_KEY_ID not set in Secrets.toml");

    // note here that the "None" is in place of a session token
    let creds = Credentials::from_keys(access_key_id, secret_access_key, None);

    let cfg = aws_config::from_env()
        .endpoint_url(aws_url)
        // Note: Don't forget to set this to the appropriate region!
        .region(Region::new("us-east-1"))
        .credentials_provider(creds)
        .load()
        .await;

    Client::new(&cfg)
}

Onto using the runtime itself!

Using the Bedrock runtime

Pre-requisites

Before we start using the Bedrock runtime, you'll need:

  • a model ID (for a model that you have access to)
  • the identifier of the guardrail you want to use (if you're using one)
  • the content type (JSON by default)

You can check out the model IDs here.

Getting a Prompt Response

For this example, we'll be using the Titan Text G1 Lite model. Although some models may differ in their request bodies and/or response body shape, the process is largely the same.

Before we can write our endpoint, we'll need to define a few structs:

  • A JSON input (that contains the prompt)
  • A struct that models the HTTP response from Bedrock for our model
  • A struct that models the HTTP request to Bedrock for our model
use serde::{Deserialize, Serialize};

#[derive(Deserialize, Serialize)]
struct Prompt {
    prompt: String,
}

#[derive(Deserialize, Debug)]
#[serde(rename_all = "camelCase")]
pub struct TitanResponse {
    input_text_token_count: i32,
    results: Vec<TitanTextResult>,
}

#[derive(Deserialize, Debug)]
#[serde(rename_all = "camelCase")]
struct TitanTextResult {
    token_count: i32,
    output_text: String,
    completion_reason: String,
}

#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct TextGenConfig {
    temperature: f32,
    top_p: f32,
    max_token_count: i32,
    stop_sequences: Vec<String>,
}

#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct TitanRequest {
    input_text: String,
    text_generation_config: TextGenConfig,
}

impl TitanRequest {
    fn new(prompt: String) -> Self {
        Self {
            input_text: prompt,
            text_generation_config: TextGenConfig {
                // higher temperature allows for more LLM creativity
                // the minimum value, 0.0, allows for a 100% predictable
                // response
                temperature: 0.2,
                // nucleus sampling probability - aka sampling the smallest
                // set of words that exceed the "top_p" threshold for a
                // response
                top_p: 0.0,
                // note here that 1 token is between 1 to 4 words
                // we have kept the max token count low here
                // to avoid high costs
                max_token_count: 100,
                stop_sequences: vec!["|".to_string()],
            },
        }
    }
}

Now onto making our prompt handler! We'll set up a function like so (note that we use destructuring to get access to the inner variable from our JSON prompt struct):

use axum::response::IntoResponse;

async fn prompt(
    State(state): State<AppState>,
    Json(Prompt { prompt }): Json<Prompt>,
) -> Result<impl IntoResponse, impl IntoResponse> {
    // .. code below
}

Next, we'll need to use our client from shared state to send a request to Bedrock.

use aws_sdk_bedrockruntime::primitives::Blob;
use axum::http::StatusCode;

let titan_req = TitanRequest::new(prompt);
let Ok(prompt) = serde_json::to_vec(&titan_req) else {
    return Err(StatusCode::BAD_REQUEST);
};

let blob = Blob::new(prompt);

let res = state.client
    .invoke_model()
    .body(blob)
    .model_id("amazon.titan-text-lite-v1:0:4k")
    .send().await
    .unwrap();

After this, we need to get the response, convert the response body to a &[u8] and deserialize it back into a response body struct.

Because the text results come back as a Vec, we'll then use .first() to then get the first results and immediately return the text as a HTTP string:

let res: &[u8] = &res.body.into_inner();
let Ok(response_body) = serde_json::from_slice::<TitanResponse>(res) else {
    return Err(StatusCode::INTERNAL_SERVER_ERROR);
};

let Some(TitanTextResult { output_text, .. }) = response_body.results.first() else {
    return Err(StatusCode::INTERNAL_SERVER_ERROR);
};

Ok(output_text.to_owned())

Streamed Prompt Responses

Often, it can be better to get a streamed response from a model. Models can often take a long time to formulate a full answer, so having a streamed response can greatly assist with user retention by not requiring them to wait for Bedrock to fully finish processing the tokens.

As you can see below, adding the method for a response stream does not require much change for requesting something from Bedrock:

let res = state.client
    .invoke_model_with_response_stream()
    .body(blob)
    .model_id("amazon.titan-text-lite-v1:0:4k")
    .send().await
    .unwrap();

However, for our response we do need to make some changes. To be able to create a stream, we need to declare a variable that is compatible with the futures::stream::Stream type. We can use the stream::unfold function that takes a variable, then puts the mutable state in a closure and then we can do as we like within our stream. The only requirement for this is that we return the item we want to output, as well as the state itself (so the stream can progress).

This would look something like this:

use futures::stream;
use aws_sdk_bedrockruntime::types::ResponseStream;

let stream = stream::unfold(res.body, |mut state| async move {
    let message = state.recv().await.unwrap();

    match message {
        Some(ResponseStream::Chunk(chunk)) => {
            let Ok(response_body) = serde_json::from_slice::<TitanResponse>(
                &chunk.bytes.unwrap().into_inner()
            ) else {
                println!("Unable to deserialize response body :(");
                return None;
            };

            let Some(TitanTextResult { output_text, .. }) = response_body.results.first() else {
                println!("No results :(");
                return None;
            };

            Some((output_text.clone(), state))
        }
        _ => None,
    }
});

Ok(axum_streams::StreamBodyAs::text(stream));

Interacting with your API

To make sure that the previous endpoint works, you can use curl on your service:

curl http://localhost:8000/prompt \
-H 'Content-Type: application/json' \
-d '{"prompt":"Hello world!"}'

Note that the above snippet is for a non-streamed response. If you want to receive a streamed response using curl, you need to add the --no-buffer flag:

curl --no-buffer http://localhost:8000/prompt/streamed \
-H 'Content-Type: application/json' \
-d '{"prompt":"Hello world!"}'

The reason why you need to do this is because curl by default stores the response in a buffer. By removing the buffer (via the flag), you can immediately receive the text as it comes.

If you want to serve your response from a HTTP webpage, you need to set up a new TextDecoder using JavaScript. This is the whole frontend HTML file that you need:

<div class="wrapper">
    <h1>Shuttle AWS Bedrock Prompt</h1>
    <input type="text" id="input" value="">
    <p id="output"></p>
</div>
<script>
    /* create a function to await reading from a streamed response */
    async function update(reader) {
        const decoder = new TextDecoder();
        while (true) {
            /* attempt to get a value from the stream */
            const {value, done} = await reader.read();
            if (done) break;
            /* decode the obtained value from the streamed response*/
            const chunk = decoder.decode(value);
            document.querySelector("#output").innerText += chunk;
        }
    }

    document.querySelector("#input")?.addEventListener("keypress", (e) => {
    /*
       this checks whether or not you pressed the Return key
       if so, send a HTTP request to the prompt endpoint and
       output the response using our previous update function
    */
        if (e.keyCode == 13) {
            const input = document.querySelector("#input").value;
            document.querySelector("#output").innerText = "Waiting...";
            const data = fetch("/prompt", {
                method: "POST",
                body: JSON.stringify({prompt: input}),
                headers: {
                    "Content-Type": "application/json",
                },
            }).then(res => res.body.getReader())
                .then(reader => {
                    document.querySelector("#output").innerText = "";
                    update(reader);
                });
        }
    })
</script>
</body>

If you want to add this file to your Axum service, you’ll want to install tower-http with the fs flag enabled:

cargo add tower-http -F fs

This will allow you to serve a whole directory (or file) on your web service!

We can do this by adding the HTML file to a subfolder of our project root and then declaring it in Shuttle.toml. Generally, we can use a wildcard - if we have a folder called static (aptly named to hold all our static assets), we can declare it like so:

assets = ["static/*"]

Then in our router, we would add tower_http::services::ServeDir as a tower layer for our Axum application:

use tower_http::services::ServeDir;

let router = Router::new()
    .route("/prompt", post(prompt))
    .route("/prompt/streamed", post(streamed_prompt))
    .nest_service("/", ServeDir::new("static"))
    .with_state(appstate);

Wrapping it all up

Now it’s time to hook it all up! We’ll change our axum::Router so that it should only have the prompt routes, as well as including the static assets we talked about earlier (you can remove this if you’re not using them).

Your Shuttle main function should look like this:

#[shuttle_runtime::main]
async fn main(
    #[shuttle_runtime::Secrets] secrets: SecretStore
) -> shuttle_axum::ShuttleAxum {
    // create AWS client
    let client = create_client(secrets).await;
    // create application shareable state from AWS client
    let appstate = AppState::new(client);
    // set up the router
    let router = Router::new()
        .route("/prompt", post(prompt))
        .route("/prompt/streamed", post(streamed_prompt))
        .with_state(appstate);

    Ok(router.into())
}

Deployment

To deploy, all you need to do is shuttle deploy (with the --allow-dirty flag if working from a Git branch with uncommitted changes) and watch the magic happen! Once finished, you’ll get a message containing information about your deployment as well as where you can reach the deployment URL.

Shuttle will also cache your dependencies, so if you need to re-deploy you won’t have to worry about waiting to re-compile!

Finishing up

Thanks for reading! By integrating Rust with AWS Bedrock, you can harness the power of both technologies to build scalable, reliable, and maintainable systems.

Read more:

This blog post is powered by shuttle - The Rust-native, open source, cloud development platform. If you have any questions, or want to provide feedback, join our Discord server!
Share article
rocket

Build the Future of Backend Development with us

Join the movement and help revolutionize the world of backend development. Together, we can create the future!