Throughout this series, we’ve built tensors with autodiff, layers and loss functions, optimizers that learn, and backends that run efficiently on different hardware. We have all the ingredients. Now the question becomes: what do we actually build with them?

This post explores neural network architectures — from simple feedforward networks to the attention-based transformers that power modern AI. We’ll focus on building intuition first, then see how these architectures map to the components we’ve already built.

The Architecture Landscape

Neural network “architecture” is just a fancy word for how we wire together our building blocks. The same Linear layers, activations, and loss functions can be arranged into dramatically different structures:

graph LR
    subgraph "Feedforward"
        ff_in["Input"] --> ff_h["Hidden"] --> ff_out["Output"]
    end

    subgraph "Recurrent"
        rnn_in["Input<sub>t</sub>"] --> rnn_h["Hidden"] --> rnn_out["Output<sub>t</sub>"]
        rnn_h --> |"state"| rnn_h
    end

    subgraph "Transformer"
        tf_in["Input"] --> tf_attn["Attention"] --> tf_ff["FFN"] --> tf_out["Output"]
        tf_in -.-> |"attends to"| tf_attn
    end

    classDef input fill:none,stroke:#60a5fa,stroke-width:2px
    classDef hidden fill:none,stroke:#a78bfa,stroke-width:2px
    classDef output fill:none,stroke:#34d399,stroke-width:2px
    class ff_in,rnn_in,tf_in input
    class ff_h,rnn_h,tf_attn,tf_ff hidden
    class ff_out,rnn_out,tf_out output

Each architecture makes different tradeoffs. Feedforward networks are simple but can’t handle sequences. Recurrent networks process sequences but struggle with long-range dependencies. Transformers handle both — at the cost of more computation.

New Building Blocks

Before diving into architectures, we need a few primitives beyond Part 2’s Linear layers and activations.

Layer Normalization

Batch normalization normalizes across the batch dimension — problematic for variable-length sequences. Layer normalization normalizes across the feature dimension instead:

\[\text{LayerNorm}(x) = \gamma \cdot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta\]

Where $\mu$ and $\sigma^2$ are computed per-sample across features, and $\gamma$, $\beta$ are learnable.

pub struct LayerNorm<B: Backend> {
    gamma: Tensor<B>,  // [d_model]
    beta: Tensor<B>,   // [d_model]
    eps: f32,
}

impl<B: Backend> LayerNorm<B> {
    pub fn new(d_model: usize) -> Self {
        LayerNorm {
            gamma: Tensor::var("gamma", B::ones(&Shape::new(vec![d_model]))),
            beta: Tensor::var("beta", B::zeros(&Shape::new(vec![d_model]))),
            eps: 1e-5,
        }
    }

    pub fn forward(&self, x: &Tensor<B>) -> Tensor<B> {
        // x: [batch, seq_len, d_model]
        // Normalize over last dimension (d_model)
        let mean = x.mean(Some(&[x.ndim() - 1]), true);      // [batch, seq, 1]
        let var = x.var(Some(&[x.ndim() - 1]), true);        // [batch, seq, 1]
        let normalized = (x - &mean) / ((&var + self.eps).sqrt());
        &normalized * &self.gamma + &self.beta
    }
}

Softmax (Numerically Stable)

Softmax converts logits to probabilities. The naive implementation overflows; we subtract the max first:

pub fn softmax<B: Backend>(x: &Tensor<B>, axis: isize) -> Tensor<B> {
    // x: [batch, seq_len, d_model] or any shape
    // Subtract max for numerical stability
    let max_x = x.max(Some(&[axis as usize]), true);  // keepdims=true
    let shifted = x - &max_x;
    let exp_x = shifted.exp();
    let sum_exp = exp_x.sum(Some(&[axis as usize]), true);
    &exp_x / &sum_exp
}

Dropout

Dropout randomly zeros elements during training, scaling survivors to maintain expected values:

pub fn dropout<B: Backend>(x: &Tensor<B>, p: f32, training: bool) -> Tensor<B> {
    if !training || p == 0.0 {
        return x.clone();
    }
    // Generate random mask: 1 with probability (1-p), 0 with probability p
    let mask = B::random_uniform(x.shape()).gt(&B::scalar(p));
    let scale = 1.0 / (1.0 - p);
    x * &mask * scale
}

Feedforward Networks (Quick Review)

The simplest architecture: stack Linear layers with activations.

// From Part 2 - unchanged
pub struct MLP<B: Backend> {
    layers: Vec<Linear<B>>,
}

impl<B: Backend> MLP<B> {
    pub fn forward(&self, x: &Tensor<B>) -> Tensor<B> {
        let mut h = x.clone();
        for (i, layer) in self.layers.iter().enumerate() {
            h = layer.forward(&h);
            if i < self.layers.len() - 1 {
                h = h.relu();  // Activation between layers, not after last
            }
        }
        h
    }
}

Feedforward networks are universal function approximators — given enough hidden units, they can approximate any continuous function. But they have fixed input/output sizes. A 784-input MLP for MNIST can’t process a 1024-pixel image.

For variable-length sequences like text, we need something else.

Recurrent Neural Networks (Historical Context)

RNNs introduced the idea of hidden state that persists across time steps:

graph LR
    subgraph "Unrolled RNN"
        x0["x₀"] --> h0["h₀"]
        x1["x₁"] --> h1["h₁"]
        x2["x₂"] --> h2["h₂"]
        x3["x₃"] --> h3["h₃"]

        h0 --> |"W<sub>hh</sub>"| h1
        h1 --> |"W<sub>hh</sub>"| h2
        h2 --> |"W<sub>hh</sub>"| h3

        h0 --> y0["y₀"]
        h1 --> y1["y₁"]
        h2 --> y2["y₂"]
        h3 --> y3["y₃"]
    end

    classDef input fill:none,stroke:#60a5fa,stroke-width:2px
    classDef hidden fill:none,stroke:#a78bfa,stroke-width:2px
    classDef output fill:none,stroke:#34d399,stroke-width:2px
    class x0,x1,x2,x3 input
    class h0,h1,h2,h3 hidden
    class y0,y1,y2,y3 output

At each time step:

\[h_t = \tanh(W_{xh} x_t + W_{hh} h_{t-1} + b_h)\]

The same weights $W_{xh}$ and $W_{hh}$ are reused at every step — this is called weight sharing or parameter tying.

pub struct RNNCell<B: Backend> {
    w_xh: Tensor<B>,  // [d_input, d_hidden]
    w_hh: Tensor<B>,  // [d_hidden, d_hidden]
    b_h: Tensor<B>,   // [d_hidden]
}

impl<B: Backend> RNNCell<B> {
    pub fn forward(&self, x: &Tensor<B>, h_prev: &Tensor<B>) -> Tensor<B> {
        // x: [batch, d_input], h_prev: [batch, d_hidden]
        let xh = x.matmul(&self.w_xh);           // [batch, d_hidden]
        let hh = h_prev.matmul(&self.w_hh);      // [batch, d_hidden]
        (&xh + &hh + &self.b_h).tanh()           // [batch, d_hidden]
    }
}

The Vanishing Gradient Problem

RNNs have a fatal flaw. During backpropagation through time, gradients flow through the same $W_{hh}$ repeatedly:

\[\frac{\partial h_t}{\partial h_0} = \prod_{i=1}^{t} \frac{\partial h_i}{\partial h_{i-1}} = \prod_{i=1}^{t} W_{hh}^T \cdot \text{diag}(\tanh'(z_i))\]

If the largest eigenvalue of $W_{hh}$ is less than 1, gradients shrink exponentially. If greater than 1, they explode. Either way, learning long-range dependencies becomes nearly impossible.

LSTM and GRU cells mitigate this with gating mechanisms — learned gates that control information flow. But they’re complex, and a simpler solution emerged.

The Attention Mechanism

The key insight: instead of compressing all past information into a fixed-size hidden state, why not let the model look back at all previous positions directly?

Intuition: A Lookup Table

Think of attention as a soft database query:

  1. You have a query (what you’re looking for)
  2. You have keys (labels for stored items)
  3. You have values (the actual stored data)

Regular lookup: find the key that exactly matches, return its value.

Attention: compute similarity between query and all keys, return a weighted average of values.

Scaled Dot-Product Attention

The most common attention variant computes similarity using dot products:

\[\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V\]

Let’s break this down step by step.

Given:

  • Queries $Q$: [batch, seq_q, d_k] — what we’re looking for
  • Keys $K$: [batch, seq_k, d_k] — what we’re matching against
  • Values $V$: [batch, seq_k, d_v] — what we retrieve

The computation flows:

graph LR
    Q["Q<br/>[seq_q, d_k]"] --> qk["Q @ Kᵀ"]
    K["K<br/>[seq_k, d_k]"] --> qk
    qk --> scores["scores<br/>[seq_q, seq_k]"]
    scores --> scale["÷ √d_k"]
    scale --> sm["softmax"]
    sm --> weights["weights<br/>[seq_q, seq_k]"]
    weights --> wv["weights @ V"]
    V["V<br/>[seq_k, d_v]"] --> wv
    wv --> out["output<br/>[seq_q, d_v]"]

    classDef input fill:none,stroke:#60a5fa,stroke-width:2px
    classDef op fill:none,stroke:#a78bfa,stroke-width:2px
    classDef output fill:none,stroke:#34d399,stroke-width:2px
    class Q,K,V input
    class qk,scale,sm,wv op
    class scores,weights,out output

Why scale by $\sqrt{d_k}$? Dot products grow with dimension. If $q$ and $k$ are vectors with independent components of variance 1, their dot product has variance $d_k$. Large values push softmax into saturation where gradients vanish. Scaling by $\sqrt{d_k}$ keeps the variance at 1.

pub fn scaled_dot_product_attention<B: Backend>(
    q: &Tensor<B>,  // [batch, seq_q, d_k]
    k: &Tensor<B>,  // [batch, seq_k, d_k]
    v: &Tensor<B>,  // [batch, seq_k, d_v]
    mask: Option<&Tensor<B>>,  // [batch, seq_q, seq_k] or broadcastable
) -> Tensor<B> {
    let d_k = q.shape().dim(q.ndim() - 1) as f32;

    // scores = Q @ K^T / sqrt(d_k)
    // [batch, seq_q, d_k] @ [batch, d_k, seq_k] -> [batch, seq_q, seq_k]
    let scores = q.matmul(&k.transpose(-2, -1)) / d_k.sqrt();

    // Apply mask (for causal attention or padding)
    let scores = match mask {
        Some(m) => {
            // Where mask is 0, set scores to -inf (will become 0 after softmax)
            let neg_inf = B::scalar(f32::NEG_INFINITY);
            scores.where_cond(m, &neg_inf)
        }
        None => scores,
    };

    // Attention weights: softmax over keys dimension
    let weights = softmax(&scores, -1);  // [batch, seq_q, seq_k]

    // Output: weighted sum of values
    // [batch, seq_q, seq_k] @ [batch, seq_k, d_v] -> [batch, seq_q, d_v]
    weights.matmul(v)
}

Visualizing Attention Weights

The attention weight matrix shows which positions attend to which. Here’s an interactive visualization:

Hover over cells to see attention weights. Toggle between self-attention and causal (masked) attention.

Multi-Head Attention

One attention head learns one type of relationship. Multi-head attention runs several attention heads in parallel, each with its own learned projections:

graph TB
    subgraph "Multi-Head Attention"
        input["Input<br/>[seq, d_model]"]

        subgraph "Head 1"
            q1["Q₁"] --> attn1["Attention"]
            k1["K₁"] --> attn1
            v1["V₁"] --> attn1
        end

        subgraph "Head 2"
            q2["Q₂"] --> attn2["Attention"]
            k2["K₂"] --> attn2
            v2["V₂"] --> attn2
        end

        subgraph "Head h"
            qh["Qₕ"] --> attnh["Attention"]
            kh["Kₕ"] --> attnh
            vh["Vₕ"] --> attnh
        end

        input --> q1 & k1 & v1
        input --> q2 & k2 & v2
        input --> qh & kh & vh

        attn1 --> concat["Concat"]
        attn2 --> concat
        attnh --> concat

        concat --> proj["W<sup>O</sup>"]
        proj --> output["Output<br/>[seq, d_model]"]
    end

    classDef input fill:none,stroke:#60a5fa,stroke-width:2px
    classDef proj fill:none,stroke:#a78bfa,stroke-width:1px
    classDef attn fill:none,stroke:#f472b6,stroke-width:2px
    classDef output fill:none,stroke:#34d399,stroke-width:2px
    class input input
    class q1,k1,v1,q2,k2,v2,qh,kh,vh proj
    class attn1,attn2,attnh attn
    class concat,proj,output output

Each head operates on a slice of dimension $d_k = d_{model} / h$. This doesn’t increase parameter count versus a single head of the same total dimension.

pub struct MultiHeadAttention<B: Backend> {
    num_heads: usize,
    d_k: usize,       // d_model / num_heads
    w_q: Tensor<B>,   // [d_model, d_model]
    w_k: Tensor<B>,   // [d_model, d_model]
    w_v: Tensor<B>,   // [d_model, d_model]
    w_o: Tensor<B>,   // [d_model, d_model]
}

impl<B: Backend> MultiHeadAttention<B> {
    pub fn forward(
        &self,
        q: &Tensor<B>,  // [batch, seq_q, d_model]
        k: &Tensor<B>,  // [batch, seq_k, d_model]
        v: &Tensor<B>,  // [batch, seq_k, d_model]
        mask: Option<&Tensor<B>>,
    ) -> Tensor<B> {
        let batch = q.shape().dim(0);
        let seq_q = q.shape().dim(1);
        let seq_k = k.shape().dim(1);

        // Project to Q, K, V
        let q = q.matmul(&self.w_q);  // [batch, seq_q, d_model]
        let k = k.matmul(&self.w_k);  // [batch, seq_k, d_model]
        let v = v.matmul(&self.w_v);  // [batch, seq_k, d_model]

        // Reshape to [batch, num_heads, seq, d_k]
        let q = q.reshape(&[batch, seq_q, self.num_heads, self.d_k])
                 .transpose(1, 2);  // [batch, num_heads, seq_q, d_k]
        let k = k.reshape(&[batch, seq_k, self.num_heads, self.d_k])
                 .transpose(1, 2);  // [batch, num_heads, seq_k, d_k]
        let v = v.reshape(&[batch, seq_k, self.num_heads, self.d_k])
                 .transpose(1, 2);  // [batch, num_heads, seq_k, d_k]

        // Attention per head (batched)
        let attn_out = scaled_dot_product_attention(&q, &k, &v, mask);
        // [batch, num_heads, seq_q, d_k]

        // Concat heads: reshape back to [batch, seq_q, d_model]
        let concat = attn_out.transpose(1, 2)
                            .reshape(&[batch, seq_q, self.num_heads * self.d_k]);

        // Final projection
        concat.matmul(&self.w_o)  // [batch, seq_q, d_model]
    }
}

Why multiple heads? Different heads can learn different types of relationships — one might focus on syntax, another on semantics, another on positional patterns. The model decides which heads to use for each task.

The Transformer Architecture

Now we can assemble the full transformer. The original architecture has both encoder and decoder stacks:

graph TB
    subgraph "Encoder Stack (N×)"
        enc_in["Input Embeddings<br/>+ Positional Encoding"]
        enc_attn["Multi-Head<br/>Self-Attention"]
        enc_add1["Add & Norm"]
        enc_ff["Feed Forward"]
        enc_add2["Add & Norm"]
        enc_out["Encoder Output"]

        enc_in --> enc_attn --> enc_add1 --> enc_ff --> enc_add2 --> enc_out
        enc_in -.-> enc_add1
        enc_add1 -.-> enc_add2
    end

    subgraph "Decoder Stack (N×)"
        dec_in["Output Embeddings<br/>+ Positional Encoding"]
        dec_self["Masked Multi-Head<br/>Self-Attention"]
        dec_add1["Add & Norm"]
        dec_cross["Multi-Head<br/>Cross-Attention"]
        dec_add2["Add & Norm"]
        dec_ff["Feed Forward"]
        dec_add3["Add & Norm"]
        dec_out["Decoder Output"]

        dec_in --> dec_self --> dec_add1 --> dec_cross --> dec_add2 --> dec_ff --> dec_add3 --> dec_out
        dec_in -.-> dec_add1
        dec_add1 -.-> dec_add2
        dec_add2 -.-> dec_add3
    end

    enc_out --> |"K, V"| dec_cross

    classDef input fill:none,stroke:#60a5fa,stroke-width:2px
    classDef attn fill:none,stroke:#f472b6,stroke-width:2px
    classDef norm fill:none,stroke:#fbbf24,stroke-width:1px
    classDef ff fill:none,stroke:#a78bfa,stroke-width:2px
    classDef output fill:none,stroke:#34d399,stroke-width:2px

    class enc_in,dec_in input
    class enc_attn,dec_self,dec_cross attn
    class enc_add1,enc_add2,dec_add1,dec_add2,dec_add3 norm
    class enc_ff,dec_ff ff
    class enc_out,dec_out output

The Transformer Block

Each layer (encoder or decoder) follows the same pattern:

  1. Self-attention: each position attends to all positions
  2. Residual connection + LayerNorm: add the input back, normalize
  3. Feed-forward network: two linear layers with activation
  4. Another residual + LayerNorm
pub struct TransformerBlock<B: Backend> {
    self_attn: MultiHeadAttention<B>,
    norm1: LayerNorm<B>,
    ff: FeedForward<B>,
    norm2: LayerNorm<B>,
    dropout_p: f32,
}

impl<B: Backend> TransformerBlock<B> {
    pub fn forward(&self, x: &Tensor<B>, mask: Option<&Tensor<B>>, training: bool) -> Tensor<B> {
        // Self-attention with residual
        let attn_out = self.self_attn.forward(x, x, x, mask);
        let attn_out = dropout(&attn_out, self.dropout_p, training);
        let x = self.norm1.forward(&(x + &attn_out));

        // FFN with residual
        let ff_out = self.ff.forward(&x);
        let ff_out = dropout(&ff_out, self.dropout_p, training);
        self.norm2.forward(&(&x + &ff_out))
    }
}

pub struct FeedForward<B: Backend> {
    linear1: Linear<B>,  // [d_model, d_ff]
    linear2: Linear<B>,  // [d_ff, d_model]
}

impl<B: Backend> FeedForward<B> {
    pub fn forward(&self, x: &Tensor<B>) -> Tensor<B> {
        // x: [batch, seq, d_model]
        // Typical d_ff = 4 * d_model
        self.linear2.forward(&self.linear1.forward(x).gelu())
    }
}

Positional Encoding

Attention is permutation-invariant — “cat sat” and “sat cat” produce the same output without position information. We need to inject position awareness.

The original transformer uses sinusoidal encoding:

\(PE_{(pos, 2i)} = \sin(pos / 10000^{2i/d_{model}})\) \(PE_{(pos, 2i+1)} = \cos(pos / 10000^{2i/d_{model}})\)

Each position gets a unique pattern of sines and cosines at different frequencies.

Each position has a unique encoding pattern. Hover to see values.

pub fn sinusoidal_positional_encoding<B: Backend>(max_len: usize, d_model: usize) -> Tensor<B> {
    let mut pe = vec![0.0; max_len * d_model];

    for pos in 0..max_len {
        for i in 0..d_model {
            let freq = 1.0 / (10000.0_f32).powf((2 * (i / 2)) as f32 / d_model as f32);
            pe[pos * d_model + i] = if i % 2 == 0 {
                (pos as f32 * freq).sin()
            } else {
                (pos as f32 * freq).cos()
            };
        }
    }

    B::from_vec(pe, Shape::new(vec![max_len, d_model]))
}

Modern models often use learned positional embeddings instead — just another embedding table indexed by position. Both work; learned embeddings are simpler but can’t extrapolate to longer sequences than seen in training.

Encoder-Only Models (BERT-style)

BERT keeps only the encoder stack. All positions can attend to all other positions (bidirectional), which is ideal for understanding tasks:

graph TB
    subgraph "BERT Architecture"
        input["[CLS] The cat sat [SEP]"]
        emb["Token + Position + Segment<br/>Embeddings"]
        enc1["Encoder Block 1"]
        enc2["Encoder Block 2"]
        encN["Encoder Block N"]
        out["Hidden States"]

        input --> emb --> enc1 --> enc2 --> |"..."| encN --> out
    end

    subgraph "Task Heads"
        cls["[CLS] → Classification"]
        tokens["Tokens → Token Classification"]
    end

    out --> cls
    out --> tokens

    classDef input fill:none,stroke:#60a5fa,stroke-width:2px
    classDef encoder fill:none,stroke:#a78bfa,stroke-width:2px
    classDef output fill:none,stroke:#34d399,stroke-width:2px
    class input input
    class emb,enc1,enc2,encN encoder
    class out,cls,tokens output

Pre-training: BERT is trained on two tasks:

  1. Masked Language Modeling (MLM): randomly mask 15% of tokens, predict them
  2. Next Sentence Prediction (NSP): predict if sentence B follows sentence A

Fine-tuning: Add a task-specific head on top of the pre-trained encoder:

  • Classification: use [CLS] token’s representation
  • NER/tagging: use each token’s representation
  • Question answering: predict start/end spans
pub struct BertForClassification<B: Backend> {
    encoder: TransformerEncoder<B>,  // Stack of encoder blocks
    classifier: Linear<B>,            // [d_model, num_classes]
}

impl<B: Backend> BertForClassification<B> {
    pub fn forward(&self, input_ids: &Tensor<B>, attention_mask: Option<&Tensor<B>>) -> Tensor<B> {
        // input_ids: [batch, seq_len]
        let hidden = self.encoder.forward(input_ids, attention_mask);
        // hidden: [batch, seq_len, d_model]

        // Take [CLS] token (position 0)
        let cls_hidden = hidden.slice(1, 0, 1).squeeze(1);  // [batch, d_model]

        // Classify
        self.classifier.forward(&cls_hidden)  // [batch, num_classes]
    }
}

Decoder-Only Models (GPT-style)

GPT keeps only the decoder stack, but with a key difference: causal masking. Each position can only attend to itself and earlier positions. This enables autoregressive generation.

graph TB
    subgraph "GPT Architecture"
        input["The cat sat"]
        emb["Token + Position<br/>Embeddings"]
        dec1["Decoder Block 1<br/>(causal mask)"]
        dec2["Decoder Block 2<br/>(causal mask)"]
        decN["Decoder Block N<br/>(causal mask)"]
        out["Hidden States"]
        lm_head["LM Head<br/>(predict next token)"]
        pred["on"]

        input --> emb --> dec1 --> dec2 --> |"..."| decN --> out --> lm_head --> pred
    end

    classDef input fill:none,stroke:#60a5fa,stroke-width:2px
    classDef decoder fill:none,stroke:#f472b6,stroke-width:2px
    classDef output fill:none,stroke:#34d399,stroke-width:2px
    class input input
    class emb,dec1,dec2,decN decoder
    class out,lm_head,pred output

Why causal masking? During training, we can compute loss for all positions in parallel — each position predicts its next token using only past context. Without the mask, the model could “cheat” by looking at the answer.

Why this architecture dominates now: Scaling laws. GPT-style models show consistent improvement with more parameters, data, and compute. They can also be prompted for diverse tasks without fine-tuning (few-shot learning).

pub struct GPT<B: Backend> {
    embedding: Embedding<B>,     // Token embeddings
    pos_embedding: Tensor<B>,    // Learned positional embeddings
    blocks: Vec<TransformerBlock<B>>,
    ln_final: LayerNorm<B>,
    lm_head: Linear<B>,          // [d_model, vocab_size]
}

impl<B: Backend> GPT<B> {
    pub fn forward(&self, input_ids: &Tensor<B>) -> Tensor<B> {
        // input_ids: [batch, seq_len]
        let seq_len = input_ids.shape().dim(1);

        // Embeddings
        let tok_emb = self.embedding.forward(input_ids);  // [batch, seq, d_model]
        let pos_emb = self.pos_embedding.slice(0, 0, seq_len);  // [seq, d_model]
        let mut h = &tok_emb + &pos_emb;

        // Create causal mask (lower triangular)
        let mask = create_causal_mask::<B>(seq_len);  // [seq, seq]

        // Transformer blocks
        for block in &self.blocks {
            h = block.forward(&h, Some(&mask), self.training);
        }

        // Final layer norm + LM head
        let h = self.ln_final.forward(&h);
        self.lm_head.forward(&h)  // [batch, seq, vocab_size]
    }
}

pub fn create_causal_mask<B: Backend>(seq_len: usize) -> Tensor<B> {
    // Lower triangular matrix: 1s where we CAN attend, 0s where we CAN'T
    let mut mask = vec![0.0; seq_len * seq_len];
    for i in 0..seq_len {
        for j in 0..=i {
            mask[i * seq_len + j] = 1.0;
        }
    }
    B::from_vec(mask, Shape::new(vec![seq_len, seq_len]))
}

Autoregressive Generation

To generate text, we feed the model’s output back as input:

pub fn generate<B: Backend>(
    model: &GPT<B>,
    prompt: &Tensor<B>,  // [1, prompt_len]
    max_new_tokens: usize,
    temperature: f32,
) -> Tensor<B> {
    let mut tokens = prompt.clone();

    for _ in 0..max_new_tokens {
        // Forward pass
        let logits = model.forward(&tokens);  // [1, seq, vocab]

        // Take logits for last position
        let seq_len = logits.shape().dim(1);
        let last_logits = logits.slice(1, seq_len - 1, seq_len).squeeze(1);  // [1, vocab]

        // Apply temperature
        let scaled = &last_logits / temperature;

        // Sample from distribution
        let probs = softmax(&scaled, -1);
        let next_token = categorical_sample(&probs);  // [1, 1]

        // Append to sequence
        tokens = B::concat(&[&tokens, &next_token], 1);
    }

    tokens
}

Encoder-Decoder Models (T5-style)

For sequence-to-sequence tasks (translation, summarization), both encoder and decoder are useful:

  • Encoder: processes the full input with bidirectional attention
  • Decoder: generates output autoregressively, but also attends to encoder output

The key addition is cross-attention: the decoder’s queries attend to the encoder’s keys and values.

pub struct TransformerDecoderBlock<B: Backend> {
    self_attn: MultiHeadAttention<B>,
    norm1: LayerNorm<B>,
    cross_attn: MultiHeadAttention<B>,  // New: cross-attention
    norm2: LayerNorm<B>,
    ff: FeedForward<B>,
    norm3: LayerNorm<B>,
}

impl<B: Backend> TransformerDecoderBlock<B> {
    pub fn forward(
        &self,
        x: &Tensor<B>,              // Decoder hidden states
        encoder_out: &Tensor<B>,    // Encoder output
        self_attn_mask: Option<&Tensor<B>>,
        cross_attn_mask: Option<&Tensor<B>>,
        training: bool,
    ) -> Tensor<B> {
        // Causal self-attention
        let attn_out = self.self_attn.forward(x, x, x, self_attn_mask);
        let x = self.norm1.forward(&(x + &dropout(&attn_out, 0.1, training)));

        // Cross-attention to encoder
        // Q from decoder, K and V from encoder
        let cross_out = self.cross_attn.forward(&x, encoder_out, encoder_out, cross_attn_mask);
        let x = self.norm2.forward(&(&x + &dropout(&cross_out, 0.1, training)));

        // FFN
        let ff_out = self.ff.forward(&x);
        self.norm3.forward(&(&x + &dropout(&ff_out, 0.1, training)))
    }
}

Choosing an Architecture

Task Architecture Why
Text classification Encoder (BERT) Needs full context understanding
Named entity recognition Encoder (BERT) Per-token classification with bidirectional context
Text generation Decoder (GPT) Autoregressive by nature
Translation Encoder-Decoder (T5) Needs to understand source, generate target
Summarization Encoder-Decoder (T5) or Decoder (GPT) Both work; E-D traditional, GPT simpler
Question answering Encoder (BERT) or Decoder (GPT) BERT for extractive, GPT for generative
Code completion Decoder (GPT) Autoregressive, like text generation

The trend: decoder-only models are increasingly used for everything. With enough scale and prompting, GPT-style models can handle tasks previously requiring specialized architectures.

Practical Considerations

Computational Complexity

Self-attention has $O(n^2)$ complexity in sequence length — every position attends to every other position. For long sequences:

Sequence length Attention operations
512 262K
2048 4.2M
8192 67M
32768 1B

This is why context windows were limited. Solutions:

  • Sparse attention: only attend to nearby tokens + some distant ones
  • Flash attention: memory-efficient implementation (still $O(n^2)$ but much faster)
  • Linear attention: approximate attention with $O(n)$ complexity

Memory Requirements

A forward pass stores activations for backprop. For a transformer:

  • Each attention layer stores attention weights: $O(batch \times heads \times seq^2)$
  • Each layer stores intermediate activations: $O(batch \times seq \times d_{model})$

With 32 layers, batch size 8, sequence length 2048, and d_model 4096 — that’s ~25GB just for activations.

Techniques to reduce memory:

  • Gradient checkpointing: recompute activations during backward instead of storing
  • Mixed precision: use FP16/BF16 instead of FP32
  • Activation memory optimization: fused kernels that don’t store intermediates

Connecting Back to Our Framework

Everything we’ve built in this series supports these architectures:

  • Part 1 (Tensors): All the shape manipulations, broadcasting, matmuls
  • Part 2 (Layers): Linear, activations — the building blocks
  • Part 3 (Optimizers): Adam trains transformers well with proper learning rate scheduling
  • Part 4 (Backends): Transformers are compute-intensive; GPU backends essential

The backward pass through attention? It just works. Our autodiff engine computes gradients for softmax, matmul, and all the shape operations automatically.

// This computes gradients through the entire transformer
let output = model.forward(&input_ids);
let loss = cross_entropy_loss(&output, &labels);
let grads = loss.backward();

// grads contains gradients for every parameter in every layer

Summary

We’ve covered the progression from simple feedforward networks to modern transformers:

  1. Feedforward: Fixed input/output, no sequence handling
  2. RNNs: Handle sequences via hidden state, but vanishing gradients limit them
  3. Attention: Direct access to all positions, weighted by relevance
  4. Transformers: Attention + feed-forward, stacked with residuals and normalization
  5. Architecture variants: Encoder-only (BERT), decoder-only (GPT), encoder-decoder (T5)

The transformer’s power comes from:

  • Parallelization: All positions computed simultaneously (unlike RNNs)
  • Direct connections: Any position can attend to any other (no information bottleneck)
  • Flexibility: Same architecture works across modalities (text, images, audio, code)

Understanding these architectures completes our deep learning toolkit. We can now not only implement and train models, but understand why they’re structured the way they are.


Part 5 of the “Deep Learning from Scratch in Rust” series. See Part 1 for tensor gradients, Part 2 for layers and loss functions, Part 3 for optimizers, and Part 4 for pluggable backends.

The Evolution of Neural Architectures

The components we’ve covered didn’t emerge in isolation — each built on limitations discovered in prior work. Understanding this evolution helps explain why modern architectures look the way they do.

The RNN Era and Its Limits

Hochreiter & Schmidhuber (1997). Long Short-Term Memory. LSTMs introduced gating mechanisms to control information flow through time, allowing networks to learn when to remember and when to forget. This was the dominant architecture for sequence modeling for nearly two decades. However, LSTMs process tokens sequentially — you can’t compute position 100 until you’ve computed positions 1-99. This made training slow and limited parallelization on GPUs.

The Attention Revolution

Bahdanau et al. (2014). Neural Machine Translation by Jointly Learning to Align and Translate. Before transformers, this paper introduced attention as an addition to RNNs for machine translation. The decoder could “look back” at encoder states rather than relying solely on a fixed context vector. This dramatically improved translation quality for long sentences and planted the seed for what came next.

Vaswani et al. (2017). Attention Is All You Need. The breakthrough: attention doesn’t need RNNs at all. By using self-attention (each position attends to all positions) plus positional encodings, transformers achieved state-of-the-art translation quality while being far more parallelizable. The key insight was that the sequential inductive bias of RNNs wasn’t necessary — attention could learn positional relationships directly.

Stabilizing Deep Networks

Ba et al. (2016). Layer Normalization. Deep transformers (12+ layers) are hard to train. Layer normalization — normalizing across features rather than batch — stabilizes training by keeping activations in a reasonable range. Combined with residual connections, this allows gradients to flow through dozens of layers without vanishing or exploding. Every modern transformer uses this pattern: attention → add & norm → FFN → add & norm.

The Pretraining Paradigm

Radford et al. (2018). Improving Language Understanding by Generative Pre-Training. GPT showed that pretraining a decoder-only transformer on next-token prediction, then fine-tuning on downstream tasks, outperformed training from scratch. This established the “pretrain then fine-tune” paradigm that dominates today.

Devlin et al. (2018). BERT: Pre-training of Deep Bidirectional Transformers. BERT took a different approach: encoder-only with bidirectional attention, pretrained on masked language modeling (predict masked tokens) and next sentence prediction. For understanding tasks (classification, NER, QA), bidirectional context proved superior. BERT became the default for NLU benchmarks.

Raffel et al. (2019). Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer. T5 unified everything as text-to-text: classification becomes “classify: [input]” → “positive”, translation becomes “translate English to German: [input]” → “[output]”. This showed that a single encoder-decoder architecture with the right framing could handle any NLP task.

Scale Changes Everything

Radford et al. (2019). Language Models are Unsupervised Multitask Learners. GPT-2 (1.5B parameters) showed emergent capabilities: zero-shot task performance, coherent long-form generation, and basic reasoning — abilities not explicitly trained for. This suggested that scale itself might be a path to capability.

Kaplan et al. (2020). Scaling Laws for Neural Language Models. This paper quantified the relationship: loss decreases predictably as a power law with model size, dataset size, and compute. Crucially, the returns don’t diminish as fast as expected. This gave labs a roadmap: if you want better models, scale up — and you can predict how much better.

Brown et al. (2020). Language Models are Few-Shot Learners. GPT-3 (175B parameters) demonstrated in-context learning: describe a task in the prompt with a few examples, and the model performs it without any gradient updates. This shifted the paradigm from “fine-tune for each task” to “prompt engineering.”

Making Scale Practical

Dao et al. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention. The $O(n^2)$ attention bottleneck limited context lengths. FlashAttention restructured the computation to be IO-aware — minimizing memory reads/writes by fusing operations and using tiling. This enabled 4-16x longer contexts at the same memory budget, making 100K+ token contexts practical.

How Modern Systems Chain These Together

Today’s frontier models combine all these insights into a unified recipe:

  1. Architecture: Decoder-only transformer (GPT-style) dominates. Simpler than encoder-decoder, and bidirectional attention isn’t necessary when you have enough scale and the right training objective.

  2. Scale: 100B+ parameters, trained on trillions of tokens. The scaling laws paper showed this works; subsequent work refined the optimal ratio of parameters to data.

  3. Training stack: Mixed-precision training (BF16), gradient checkpointing, tensor/pipeline parallelism across thousands of GPUs. FlashAttention for memory efficiency. Layer norm and residual connections for stability.

  4. The forward pass at scale: Input tokens → embedding lookup → 80+ transformer blocks (each: multi-head attention with rotary embeddings → layer norm → FFN with SwiGLU activation → layer norm) → final layer norm → output projection → softmax over vocabulary.

  5. Post-training: After pretraining, models undergo supervised fine-tuning (SFT) on high-quality examples, then reinforcement learning from human feedback (RLHF) or direct preference optimization (DPO) to align outputs with human preferences.

The remarkable thing: the core operations are still just matrix multiplications, softmax, and element-wise nonlinearities — exactly what we built in Parts 1-4. The magic is in how they’re arranged (architecture), how many there are (scale), and what data flows through them (training).