January 9, 2024

Taking Candle for a spin by 'Building GPT From Scratch'

Taking Candle for a spin by 'Building GPT From Scratch'

Learning a new framework can be done by building something new, or by porting an existing piece of code. The advantage of the latter is that it's possible to assess how much of the new framework you will likely explore. This blog post will detail the experience of porting Andrej Karpathy's Let's Build GPT tutorial to Rust using the Candle tensor library. The code of the port itself is on GitHub.

This post will not explore the basics of transformer models and (self-)attention. For those, please refer to this excellent explainer by Josh Starmer, or Karphathy's video linked above.

Given how Rust is more strict than Python/PyTorch and how Candle is still a young framework there were a few things to overcome, so let's dive in!

Weights and Training

Back propagation means correcting the weights of the network according to the error gradient. A difference between PyTorch and Candle is that the latter requires the developer to keep track of all the weights in a data structure called a VarMap. (Unless using pretrained weights, then the VarBuilder uses a different backend)

let var_map = VarMap::new();
let var_builder = VarBuilder::from_varmap(&var_map, DType::F32, device);
let token_embedding_table = embedding(

The snippet above sees the VarMap being wrapped inside a VarBuilder and this is what is actually used to wire all the nodes together. It does this by setting up a layer hierarchy where each layer's weight tensor gets a prefix, very much like a file system path, where the parts are separated by a dot. The method documentation thus describes it as "This can be think [sic] of as cd into a directory."

If there is a naming collision then the builder will throw an error. This also means that if there are parallel layers the prefixes need to be different as shown here with the block_index in the name:

for block_index in 0..num_blocks {
    blocks = blocks.add(Block::new(
        var_builder.push_prefix(format!("block_{}", block_index)),

Once all the layers are setup, it is quite straightforward to train the model with an optimizer of our choice. Below AdamW is used and the train() method itself is an adaptation of this blog post:

pub fn train(&self, dataset: &mut Dataset, num_epochs: usize, batch_size: usize) -> Result<()> {
    let mut optimizer = AdamW::new(self.var_map.all_vars(), ParamsAdamW::default())?;

    for epoch in 0..num_epochs {
        let (training_inputs, training_targets) =
            dataset.random_training_batch(self.block_size, batch_size)?;
        let logits = self.forward(&training_inputs)?;
        let (batch_size, time_size, channel_size) = logits.shape().dims3()?;
        let loss = loss::cross_entropy(
            &logits.reshape(Shape::from((batch_size * time_size, channel_size)))?,
            &training_targets.reshape(Shape::from((batch_size * time_size,)))?,

            "Epoch: {epoch:3} Train loss: {:8.5}",


One of the things to notice here is that the Module trait defines a forward() method that takes a Tensor and returns a Tensor, so the loss is calculated here in the training loop. This is much cleaner and removes the need for the if-statement like Karpathy does in his forward() method.

Matrix Indexing and Manipulation

In this version of Candle the brackets are not overloaded like in PyTorch, so people searching for a DataFrame'esque usage might be thrown off at first. Indexing is done via the i() method and it (currently) only accepts positive integers (so no negative indexing to start from the back), ranges or tuples (for multiple dimensions). This line from BigramLanguageModel::generate() displays all three simultaneously:

// focus only on the last time step
let most_recent_logits = logits.i((0, generated_ids_cond_length - 1, ..))?; // becomes (B, C)

It is a also important to realize that broadcasting needs to be done explicitly via the broadcast_*() method family. Below is an example from the self_attention_examples module that displays a broadcast operation, the triangular lower matrix creation (tril2()) and an approach towards masking/logical indexing:

let neg_inf = Tensor::try_from(f32::NEG_INFINITY)?

let masked_fill = Tensor::tril2(t, DType::U32, device)?
    .where_cond(&weights, &neg_inf)?;

Notice the to_device() call when creating a tensor from a single scalar.


To enable CUDA or cuDNN (even faster) just add it as a feature flag in the Cargo.toml:

candle-core = { git = "https://github.com/huggingface/candle.git", version = "0.3.2" , features = ["cuda"]}

When working with Candle tensors and shifting the program from the CPU to the GPU one might suddenly encounter the following runtime error:

"matmul is only supported for contiguous tensors lstride: [1, 16, 512] rstride: [0, 1, 64] mnk: (32, 64, 64)"

Luckily, the solution is hidden in the error message. Tensors have a contiguous() method that transforms the tensor into a contiguous version (i.e. all tensor elements packed without break in memory). If needed there is also a is_contiguous() method to check first. See it in action in the MultiHeadAttention struct in bigram_language_model.rs, including a Family Guy reference:

let concatenated = Tensor::cat(
        .map(|h| {
                .map_err(|error| eprintln!("Error creating the model: {}", error))
                .expect("Could not apply head. Diggity")
)?.contiguous()?; // JV: This was necessary for Cuda to fix the stride/contiguous error. It doesn't occur on the CPU

This iterates over all the heads and concatenates their results.

On-device Sampling

One thing that was still notably missing was an on-device multinomial sampler. The sampler.rs module was taken from one of the Candle examples and first moves the tensor with probabilities to CPU memory to sample from it. This is a bottleneck in the BigramLanguageModel::generate() method. It would be faster to generate the tokens on the GPU and move to the CPU only to construct the final string.

In Conclusion

Porting the tutorial to Rust and Candle was great fun and it made the code better to read. The Candle framework is still very young, which is especially noticeable with the API documentation and features (e.g. negative indexing, DataFrame-esque indexing) both lacking compared to a mature framework like PyTorch, but these are things that are expected to be added in future releases. Candle is one to watch!

Feel free to reach out or leave a comment if you want to know more after reading the code!