Building an arXiv Agent with Rig & Rust

Cover image

Hi there! In this tutorial we're going to be using the Rig AI framework to create an AI agent for helping you learn by suggesting research papers based on a given subject. We're also going to use Shuttle to deploy it on the web so that other people can try it out!

Interested in just looking at the code? Check out the GitHub repository.

Why Rig?

Rig (by Playgrounds, who also additionally maintains arc.fun) is a new up-and-coming Rust framework designed to make AI agent creation as easy as possible. It can currently create agentic pipelines (ie. pipelines that can execute a number of prompts LLM-assisted steps), integrate RAG with AI agents as well as exposing an API that you can use to create your own tools.

The framework itself is growing quite rapidly with the maintainer team using it in production, so it will be receiving updates for quite some time! They're also holding an event called the ARC Handshake, which will be a showcase of AI agents built using Rig.

Pre-requisites

Before we start, you'll want to ensure you have the Rust programming language installed and additionally an OpenAI API token. If you don't have one already, you will need to sign in and obtain an API key from the dashboard as this will be required to make it work.

You'll also want the cargo-shuttle crate installed (our CLI) - check out our installation instructions for more info.

Let's get building!

If you haven't already, don't forget to create a new Shuttle project via cargo-shuttle:

shuttle init --template axum

Once you've followed the prompt and it's set up, we'll need to add our required dependencies to build the project:

cargo add reqwest serde serde-json anyhow thiserror quick-xml rig-core \
urlencoding tracing -F serde/derive,quick-xml/serialize

Creating our Rig AI agent

The first step will be to create our AI agent. We can do this by creating a tool that will be able to send a prompt to a model (or in this case OpenAI), as well as do some functionality when called.

Setup

The arXiv export endpoint (http://export.arxiv.org/api/query) takes a few different query parameters that are relevant to us:

  • search_query (The actual search query we want to search for)
  • max (maximum number of papers we want in the result)

Before we write the implementation block, let's declare the relevant structs we want. We want one for holding paper results, one for search arguments and one for the tool itself.

// Struct to hold paper metadata
#[derive(Debug, serde::Deserialize, serde::Serialize)]
pub struct Paper {
    pub title: String,
    pub authors: Vec<String>,
    pub abstract_text: String,
    pub url: String,
    pub categories: Vec<String>,
}

impl Paper {
    fn new() -> Self {
        Self {
            title: String::new(),
            authors: Vec::new(),
            abstract_text: String::new(),
            url: String::new(),
            categories: Vec::new(),
        }
    }
}

#[derive(serde::Deserialize)]
pub struct SearchArgs {
    query: String,
    max_results: Option<i32>,
}

// Tool to search for papers
#[derive(serde::Deserialize, serde::Serialize)]
pub struct ArxivSearchTool;

Error Handling

While the AI agent is carrying out work, there are many different types of errors it can get. We should aim to represent these by using an enum. We additionally enhance our error type by implementing thiserror::Error (requiring Debug), which allows us to easily derive From<T: Error> for our new error type.

#[derive(Debug, thiserror::Error)]
pub enum ArxivError {
    #[error("Network error: {0}")]
    Network(#[from] reqwest::Error),
    #[error("XML parsing error: {0}")]
    XmlParsing(#[from] quick_xml::Error),
    #[error("No results found")]
    NoResults,
    #[error("UTF-8 decoding error: {0}")]
    Utf8Error(#[from] std::str::Utf8Error),
}

By doing this, a couple of cool things happen:

  • It enables use of the ? operator which increases readability and allows error propagation up the call stack
  • Makes it obvious why and/or how something has failed (we can check the enum variant)

Rig Tool Definitions

Next, we want to write the implementation block. It requires us to provide functionality for two methods - definition() which provides the tool definition and prompting for the model, as well as call() which is the actual functionality for the tool.

const ARXIV_URL: &str = "http://export.arxiv.org/api/query";

impl Tool for ArxivSearchTool {
    const NAME: &'static str = "search_arxiv";
    type Error = ArxivError;
    type Args = SearchArgs;
    type Output = Vec<Paper>;

    async fn definition(&self, _prompt: String) -> ToolDefinition {
        ToolDefinition {
            name: "search_arxiv".to_string(),
            description: "Search for academic papers on arXiv".to_string(),
            parameters: json!({
                "type": "object",
                "properties": {
                    "query": {
                        "type": "string",
                        "description": "Search query for papers"
                    },
                    "max_results": {
                        "type": "integer",
                        "description": "Maximum number of results to return (default: 5)"
                    }
                },
                "required": ["query"]
            }),
        }
    }

    async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
        let max_results = args.max_results.unwrap_or(5);
        let client = reqwest::Client::new();

        let response = client
            .get(ARXIV_URL)
            .query(&[
                ("search_query", format!("all:{}", args.query)),
                ("start", 0.to_string()),
                ("max_results", max_results.to_string()),
            ])
            .send()
            .await?
            .text()
            .await?;

        parse_arxiv_response(&response) // not implemented yet!
    }
}

Note that the tool is simply a tool that can be called by the model - it can either be added to a given Rig prompt by itself, or it can be added to a toolset with other tools to provide a comprehensive user experience.

Parsing the arXiv response

Next, we need parse the response from arXiv. The actual response format is in XML - this is what a typical entry looks like:

<entry>
<id>http://arxiv.org/abs/2407.11861v1</id>
<updated>2024-07-16T15:48:36Z</updated>
<published>2024-07-16T15:48:36Z</published>
<title>What Makes a Meme a Meme? Identifying Memes for Memetics-Aware Dataset Creation</title>
<summary> <!-- Snipped for brevity --> </summary>
<author>
<name>Muzhaffar Hazman</name>
</author>
<author>
<name>Susan McKeever</name>
</author>
<author>
<name>Josephine Griffith</name>
</author>
<arxiv:comment xmlns:arxiv="http://arxiv.org/schemas/atom">Accepted for Publication at AAAI-ICWSM 2025</arxiv:comment>
<link href="http://arxiv.org/abs/2407.11861v1" rel="alternate" type="text/html"/>
<link title="pdf" href="http://arxiv.org/pdf/2407.11861v1" rel="related" type="application/pdf"/>
<arxiv:primary_category xmlns:arxiv="http://arxiv.org/schemas/atom" term="cs.LG" scheme="http://arxiv.org/schemas/atom"/>
<category term="cs.LG" scheme="http://arxiv.org/schemas/atom"/>
</entry>

In order to parse it, we will use the quickxml crate which offers good performance while still being relatively easy to use.

To ensure codebase readability, we will create a struct that will act as our parser and hold the parser state in it:

#[derive(Default)]
struct ArxivParser<'a> {
    papers: Vec<Paper>,
    current_paper: Option<Paper>,
    current_authors: Vec<String>,
    current_categories: Vec<String>,
    in_entry: bool,
    current_field: Option<&'a str>,
}

We will implement several methods on this struct:

  • Methods for handling different types of XML events
  • A public method for parsing the whole text

To start with, we'll create methods for handling the start of an XML tag as well as any text that is contained within the tag:

impl<'a> ArxivParser<'a> {
    fn parse_start_event(&mut self, event: &BytesStart) {
        match event.name().as_ref() {
            // if the tag is "entry", this means we're at the start of a new xml block
            // so we can clear related variables and start anew
            b"entry" => {
                self.in_entry = true;
                self.current_paper = Some(Paper::new());
                self.current_authors.clear();
                self.current_categories.clear();
            }
            // otherwise, change the parsing state
            b"title" if self.in_entry => self.current_field = Some("title"),
            b"author" if self.in_entry => self.current_field = Some("author"),
            b"summary" if self.in_entry => self.current_field = Some("abstract"),
            b"link" if self.in_entry => self.current_field = Some("link"),
            b"category" if self.in_entry => self.current_field = Some("category"),
            _ => (),
        };
    }

    fn parse_text_event(&mut self, event: &BytesText) -> Result<(), ArxivError>
        // if there's no current paper, just don't return anything
        let Some(paper) = self.current_paper.as_mut() else {
            return Ok(());
        };
        // otherwise, attempt to get the text and fill in the relevant field
        let text = str::from_utf8(event.as_ref())?.to_owned();
        match self.current_field {
            Some("title") => paper.title = text,
            Some("author") => self.current_authors.push(text),
            Some("abstract") => paper.abstract_text = text,
            _ => (),
        }
        Ok(())
    }
}

Before we continue, we need to create a small function to be able to convert the links we get from parsing an arXiv XML entry to be able to return the PDF response. See below where we replace arxiv.org/abs with arxiv.org/pdf:

fn convert_pdf_url(url: &str) -> String {
    if url.contains("arxiv.org/abs/") {
        // Convert abstract URL to PDF URL
        url.replace("arxiv.org/abs/", "arxiv.org/pdf/")
            .replace("http://", "https://")
            + ".pdf"
    } else if url.contains("arxiv.org/pdf/") {
        // Ensure PDF URL uses HTTPS
        url.replace("http://", "https://")
    } else {
        // Fallback for other URLs
        url.replace("http://", "https://")
    }
}

Next, we want to parse empty XML elements. If the XML element is a Link or Category, we attempt to parse them like below and add the relevant parts to the parser state:

impl<'a> ArxivParser<'a> {
        // .. other methods here
        fn parse_empty_event(&mut self, event: &BytesStart) -> Result<(), ArxivError> {
        // if we're not in an entry, just don't do anything
        if !self.in_entry {
            return Ok(());
        }
        // if the element is a link, convert the URL to the relevant format
        // and add the URL to the paper
        if event.name().as_ref() == b"link" {
            if let Some(paper) = self.current_paper.as_mut() {
                for attr in event.attributes().flatten() {
                    if attr.key.as_ref() == b"href" {
                        let url = str::from_utf8(&attr.value)?;
                        // Convert to HTTPS and ensure PDF URL
                        let secure_url = convert_pdf_url(url);
                        secure_url.clone_into(&mut paper.url);
                    }
                }
            }
        }
        // if the element is a Category, push the category terms
        // into the parser's list of current categories
        if event.name().as_ref() == b"category" {
            for attr in event.attributes().flatten() {
                if attr.key.as_ref() == b"term" {
                    self.current_categories
                        .push(str::from_utf8(&attr.value)?.to_owned());
                }
            }
        }

        Ok(())
    }

We also need to ensure that we are correctly resetting the parser state when we reach the end of an element:

  • If we reach the end of an entry, we need to push our current Paper to the list of papers being generated by the parser
  • If we reach the end of another element, we just reset the current_field field as there is nothing left to parse within the given element.
impl<'a> ArxivParser<'a> {
    // .. other methods
    fn parse_end_event(&mut self, event: &BytesEnd) -> Result<(), ArxivError> {
        // this is an end event - if the end tag is for an entry
        // add the current paper to the list of papers
        match event.name().as_ref() {
            b"entry" => {
                if let Some(mut paper) = self.current_paper.take() {
                    paper.authors.clone_from(&self.current_authors);
                    paper.categories.clone_from(&self.current_categories);
                    self.papers.push(paper);
                }
                self.in_entry = false;
            }
            // else, just change the currently parsed field to None
            // as there is now nothing to parse
            b"title" | b"author" | b"summary" | b"link" | b"category" => {
                self.current_field = None;
            }
            _ => (),
        }
        Ok(())
    }
}

Finally, we need to write our method for parsing the whole response. This function simply loops until the end of the XML file has been reached (an Event::Eof has been reached). If there's no papers that have been parsed (no results were returned), we will return an error.

impl<'a> ArxivParser<'a> {
    // .. other methods
    fn parse_response(&mut self, input: &str) -> Result<Vec<Paper>, ArxivError> {
        let mut reader = Reader::from_str(input);
        reader.trim_text(true);

        let mut buf = Vec::new();
        loop {
            match reader.read_event_into(&mut buf) {
                Ok(Event::Start(ref e)) => self.parse_start_event(e),
                Ok(Event::Text(ref e)) => self.parse_text_event(e)?,
                Ok(Event::Empty(ref e)) => self.parse_empty_event(e)?,
                Ok(Event::End(ref e)) => self.parse_end_event(e)?,
                // EoF means end of file - we can stop trying to parse here
                Ok(Event::Eof) => break,
                Err(e) => return Err(ArxivError::XmlParsing(e)),
                _ => (),
            }
        }

        if self.papers.is_empty() {
            return Err(ArxivError::NoResults);
        }

        Ok(self.papers.clone())
    }
}

Now we can move onto writing our web service!

Writing our Rust web service

In terms of application code for the web service itself, there are a few things:

  • Setting up application state
  • Setting up endpoints
  • Serving a frontend (of our choice)

Setting up application state

We can define the application state as a struct that holds our OpenAI client. Note that it's required to derive Clone as application state is shared over many requests, hence the Clone trait requirement (as it guarantees this behaviour).

#[derive(Clone)]
struct AppState {
    openai_client: openai::Client,
}

Error handling

Before we implement our endpoint, let's again take a minute to consider error handling. We want error handling to be as idiomatic as possible - and we also use anyhow::Error for a large majority of our errors as it's easy to use. Therefore, we should implement From<T: Into<anyhow::Error>> for our application error. We also additionally want to implement axum::response::IntoResponse for our error type as this allows the type to be used as part of a return type signature in an Axum function handler.

use axum::response::IntoResponse;

struct AppError(anyhow::Error);

impl IntoResponse for AppError {
    fn into_response(self) -> Response {
        (
            axum::http::StatusCode::INTERNAL_SERVER_ERROR,
            format!("Something went wrong: {}", self.0),
        )
            .into_response()
    }
}

impl<E> From<E> for AppError
where
    E: Into<anyhow::Error>,
{
    fn from(err: E) -> Self {
        Self(err.into())
    }
}

Endpoints

There is only one single endpoint we need to write - the handler for receiving requests about papers to search and then carrying out the required work. The code snippet below is a relatively simple - we grab the query, we create an AI agent that uses GPT-4 and add some additional pre-amble (extra context) as well as the original prompt as part of the query. We then deserialize the resulting string and return a HTML response.

use serde::Deserialize;

// here we create a struct that we can use as the required request body shape

#[derive(Deserialize)]
struct SearchRequest {
    query: String,
}

async fn search_papers(
    State(state): State<AppState>,
    Json(request): Json<SearchRequest>,
) -> Result<impl IntoResponse, AppError> {
    let paper_agent = state.openai_client
        .agent(GPT_4)
        .preamble(
            "You are a helpful research assistant that can search and analyze academic papers from arXiv. \
             When asked about a research topic, use the search_arxiv tool to find relevant papers and \
             return only the raw JSON response from the tool."
        )
        .tool(ArxivSearchTool)
        .build();

    let response = paper_agent
        .prompt(&request.query)
        .await?;

    // return the response as HTML
    // note that if you want to return just a JSON response
    // you can return `Ok(axum::Json(papers))`
    let papers: Vec<Paper> = serde_json::from_str(&response)?;

    let html = tools::format_papers_as_html(&papers)?; // see below!
    Ok(Html(html))
}

Frontend

While we won't dive into writing the frontend in this article, you can have a look at the HTML files yourself and copy them in (we use a folder called static, which is in the project root).

You'll need to create a handler to serve it:

use axum::response::Html;

// Handler for serving the static index.html
async fn serve_index() -> impl IntoResponse {
    Html(include_str!("../static/index.html"))
}

We also additionally return the table of papers (from our API endpoint) as HTML. The following function is used to combine HTML templating with the papers we've sent through our AI agent to generate a result:

// HTML formatting function for papers
pub fn format_papers_as_html(papers: &[Paper]) -> Result<String, Box<dyn std::error::Error>> {
    let tpl = std::fs::read_to_string("static/table.html")?;
    let mut context = tera::Context::new();
    context.insert("papers", papers);

    let result = tera::Tera::one_off(&tpl, &context, true)?;

    Ok(result)
}

Putting it all together

Next, we'll set up our main function. We'll need to add the Secrets annotation (see below) to our function parameters, as well as adding a CorsLayer and our additional routes.

#[shuttle_runtime::main]
async fn axum(
    // this annotation provides our secrets from Secrets.toml
    #[shuttle_runtime::Secrets] secrets: SecretStore,
) -> shuttle_axum::ShuttleAxum {
    // Initialize OpenAI client from secrets
    let openai_key = secrets
        .get("OPENAI_API_KEY")
        .context("OPENAI_API_KEY secret not found")?;

    let openai_client = openai::Client::new(&openai_key);

    // Create shared state
    let state = AppState { openai_client };

    // Set up CORS
    let cors = CorsLayer::new()
        .allow_origin(Any)
        .allow_methods([axum::http::Method::GET, axum::http::Method::POST])
        .allow_headers(Any);

    // Create router
    let router = Router::new()
        .route("/", get(serve_index))
        .route("/api/search", post(search_papers))
        .layer(cors)
        .with_state(state);

    Ok(router.into())
}

Deploying

If you've added any frontend assets, make sure you add them to a Shuttle.toml file in your root folder (this will allow the frontend templates folder to be included in deployment):

[build]
assets = ["<folder-name>/*"]

If you've followed this tutorial from start to finish, the folder name should be static. We add /* at the end to signify that we want to include the whole directory and all files inside it.

To deploy, all you need to do is write shuttle deploy into the terminal and watch the magic happen! If you'd like to run this locally, simply run shuttle run then visit localhost:8000.

Finishing up

Thanks for reading! Hopefully this article has increased your understanding of how to implement your own AI agent using the Rig AI framework, as well as writing a web service with a frontend.

Further reading:

This blog post is powered by shuttle - The Rust-native, open source, cloud development platform. If you have any questions, or want to provide feedback, join our Discord server!
Share article
rocket

Build the Future of Backend Development with us

Join the movement and help revolutionize the world of backend development. Together, we can create the future!