Deep Learning from Scratch in Rust, Part 4 — Pluggable Backends
Throughout this series, we’ve been writing B::add, B::matmul, B::exp without explaining what B actually is. Time to pay that debt.
B is a backend — an implementation of tensor operations. Different backends can target different hardware:
- CPU with SIMD intrinsics
- Metal shaders for macOS GPUs
- CUDA kernels for NVIDIA GPUs
Today we’ll see how Rust’s type system lets us write autodiff code once and run it anywhere — with the backend choice resolved entirely at compile time.
The Backend Trait
Here’s the core abstraction:
pub trait Backend: Clone + Send + Sync + 'static {
type Tensor: TensorData;
// Creation
fn zeros(shape: &Shape) -> Self::Tensor;
fn ones(shape: &Shape) -> Self::Tensor;
fn from_vec(data: Vec<f32>, shape: Shape) -> Self::Tensor;
// Element-wise ops
fn add(a: &Self::Tensor, b: &Self::Tensor) -> Self::Tensor;
fn mul(a: &Self::Tensor, b: &Self::Tensor) -> Self::Tensor;
fn exp(x: &Self::Tensor) -> Self::Tensor;
// ... more ops
// Linear algebra
fn matmul(a: &Self::Tensor, b: &Self::Tensor) -> Self::Tensor;
// Reductions
fn sum(x: &Self::Tensor, axes: Option<&[usize]>, keepdims: bool) -> Self::Tensor;
}
A backend provides two things:
- An associated type
Tensorfor storage - Static methods implementing every operation
Notice: these are static methods, not instance methods. The backend is a type, not a value. This is crucial.
Type-Level Backend Selection
Our tensor type is parameterized by backend:
pub struct Tensor<B: Backend>(Arc<TensorNode<B>>);
struct TensorNode<B: Backend> {
id: NodeId,
op: TensorOp,
data: B::Tensor, // Backend-specific storage!
children: Vec<Tensor<B>>,
}
When you write:
type T = Tensor<CpuBackend>;
let x = T::var("x", CpuBackend::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3])));
let y = x.exp();
The compiler:
- Sees
Tensor<CpuBackend> - Monomorphizes all methods with
B = CpuBackend - Inlines the backend calls directly
There’s no runtime dispatch. The backend is “erased” into concrete machine code. This is zero-cost abstraction.
The CPU Backend
Let’s implement a simple CPU backend:
#[derive(Clone)]
pub struct CpuBackend;
pub struct CpuTensor {
data: Vec<f32>,
shape: Shape,
strides: Strides,
}
impl TensorData for CpuTensor {
fn shape(&self) -> &Shape { &self.shape }
fn strides(&self) -> &Strides { &self.strides }
fn as_slice(&self) -> &[f32] { &self.data }
fn as_slice_mut(&mut self) -> &mut [f32] { &mut self.data }
}
impl Backend for CpuBackend {
type Tensor = CpuTensor;
fn zeros(shape: &Shape) -> CpuTensor {
CpuTensor {
data: vec![0.0; shape.numel()],
shape: shape.clone(),
strides: shape.contiguous_strides(),
}
}
fn add(a: &CpuTensor, b: &CpuTensor) -> CpuTensor {
// Handle broadcasting...
let out_shape = broadcast_shape(a.shape(), b.shape()).unwrap();
let mut out = vec![0.0; out_shape.numel()];
for i in 0..out.len() {
let a_idx = broadcast_index(i, &out_shape, a.shape());
let b_idx = broadcast_index(i, &out_shape, b.shape());
out[i] = a.data[a_idx] + b.data[b_idx];
}
CpuTensor {
data: out,
shape: out_shape.clone(),
strides: out_shape.contiguous_strides(),
}
}
fn exp(x: &CpuTensor) -> CpuTensor {
let data: Vec<f32> = x.data.iter().map(|&v| v.exp()).collect();
CpuTensor {
data,
shape: x.shape.clone(),
strides: x.strides.clone(),
}
}
// ... implement all other ops
}
This works, but it’s slow. Let’s add SIMD.
SIMD Acceleration
Modern CPUs have vector instructions that process 4-8 floats simultaneously:
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
fn add_simd(a: &[f32], b: &[f32], out: &mut [f32]) {
let n = a.len();
let simd_width = 8; // AVX processes 8 floats
let simd_end = n - (n % simd_width);
// SIMD loop
unsafe {
for i in (0..simd_end).step_by(simd_width) {
let va = _mm256_loadu_ps(a.as_ptr().add(i));
let vb = _mm256_loadu_ps(b.as_ptr().add(i));
let vout = _mm256_add_ps(va, vb);
_mm256_storeu_ps(out.as_mut_ptr().add(i), vout);
}
}
// Scalar tail
for i in simd_end..n {
out[i] = a[i] + b[i];
}
}
For transcendental functions (exp, log, sin), we can use approximations:
// Fast exp approximation using polynomial (Padé or Taylor)
fn exp_simd(x: &[f32], out: &mut [f32]) {
unsafe {
for i in (0..x.len()).step_by(8) {
let v = _mm256_loadu_ps(x.as_ptr().add(i));
// Clamp to avoid overflow
let v = _mm256_min_ps(v, _mm256_set1_ps(88.0));
let v = _mm256_max_ps(v, _mm256_set1_ps(-88.0));
// exp(x) ≈ (1 + x/n)^n, using range reduction and polynomial
// ... polynomial approximation
let result = fast_exp_avx(v);
_mm256_storeu_ps(out.as_mut_ptr().add(i), result);
}
}
}
Speedup: 4-8x for element-wise ops compared to scalar loops.
Matrix Multiplication: The Critical Op
MatMul dominates neural network compute. A naive implementation:
fn matmul_naive(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) {
for i in 0..m {
for j in 0..n {
let mut sum = 0.0;
for l in 0..k {
sum += a[i * k + l] * b[l * n + j];
}
c[i * n + j] = sum;
}
}
}
This is cache-hostile. Better: tile the matrices to fit in L1/L2 cache:
fn matmul_tiled(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) {
const TILE: usize = 64; // Fits in L1 cache
for i0 in (0..m).step_by(TILE) {
for j0 in (0..n).step_by(TILE) {
for l0 in (0..k).step_by(TILE) {
// Process TILE×TILE block
for i in i0..min(i0 + TILE, m) {
for l in l0..min(l0 + TILE, k) {
let a_val = a[i * k + l];
for j in j0..min(j0 + TILE, n) {
c[i * n + j] += a_val * b[l * n + j];
}
}
}
}
}
}
}
For production, call BLAS (OpenBLAS, Intel MKL, Accelerate):
fn matmul_blas(a: &[f32], b: &[f32], c: &mut [f32], m: usize, k: usize, n: usize) {
unsafe {
cblas_sgemm(
CblasRowMajor, CblasNoTrans, CblasNoTrans,
m as i32, n as i32, k as i32,
1.0, // alpha
a.as_ptr(), k as i32, // A
b.as_ptr(), n as i32, // B
0.0, // beta
c.as_mut_ptr(), n as i32, // C
);
}
}
BLAS implementations are heavily optimized — hand-tuned assembly, multi-threading, architecture-specific kernels.
The Metal Backend
On macOS, we can use Metal for GPU acceleration:
#[derive(Clone)]
pub struct MetalBackend {
device: metal::Device,
command_queue: metal::CommandQueue,
}
pub struct MetalTensor {
buffer: metal::Buffer,
shape: Shape,
strides: Strides,
}
Metal operations work differently:
- Data lives in GPU memory (
MTLBuffer) - Operations are encoded into command buffers
- Work is submitted asynchronously
impl Backend for MetalBackend {
type Tensor = MetalTensor;
fn add(a: &MetalTensor, b: &MetalTensor) -> MetalTensor {
let device = metal::Device::system_default().unwrap();
let out_shape = broadcast_shape(a.shape(), b.shape()).unwrap();
let out_buffer = device.new_buffer(
(out_shape.numel() * 4) as u64,
MTLResourceOptions::StorageModeShared,
);
// Load compute shader
let library = device.new_library_with_source(ADD_SHADER, &Default::default()).unwrap();
let function = library.get_function("add_tensors", None).unwrap();
let pipeline = device.new_compute_pipeline_state_with_function(&function).unwrap();
// Encode and dispatch
let command_buffer = command_queue.new_command_buffer();
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
encoder.set_buffer(0, Some(&a.buffer), 0);
encoder.set_buffer(1, Some(&b.buffer), 0);
encoder.set_buffer(2, Some(&out_buffer), 0);
// ... set shape uniforms
let threads_per_group = MTLSize::new(256, 1, 1);
let num_groups = MTLSize::new(
(out_shape.numel() as u64 + 255) / 256, 1, 1
);
encoder.dispatch_thread_groups(num_groups, threads_per_group);
encoder.end_encoding();
command_buffer.commit();
command_buffer.wait_until_completed();
MetalTensor {
buffer: out_buffer,
shape: out_shape,
strides: out_shape.contiguous_strides(),
}
}
}
The Metal shader (in MSL):
#include <metal_stdlib>
using namespace metal;
kernel void add_tensors(
device const float* a [[buffer(0)]],
device const float* b [[buffer(1)]],
device float* out [[buffer(2)]],
uint idx [[thread_position_in_grid]]
) {
// Handle broadcasting in shader...
out[idx] = a[idx] + b[idx];
}
Compile-Time Backend Selection
The magic of Rust’s type system: backend choice happens at compile time.
// Choose backend once at the top
#[cfg(feature = "metal")]
type B = MetalBackend;
#[cfg(not(feature = "metal"))]
type B = CpuBackend;
type T = Tensor<B>;
// All downstream code is generic
fn train_model(inputs: &T, targets: &T) -> T {
// This code works for ANY backend
let hidden = inputs.matmul(&weights) + &biases;
let output = hidden.relu();
let loss = mse_loss(&output, targets);
loss.backward()
// ...
}
You can also make it runtime-selectable, but you pay the cost of dynamic dispatch:
enum AnyBackend {
Cpu(CpuBackend),
Metal(MetalBackend),
}
// Now operations require match statements or trait objects
// This adds overhead, but sometimes worth it for flexibility
Backend Interop
Sometimes you want to move data between backends:
trait IntoBackend<Target: Backend>: Backend {
fn convert(tensor: &Self::Tensor) -> Target::Tensor;
}
impl IntoBackend<MetalBackend> for CpuBackend {
fn convert(tensor: &CpuTensor) -> MetalTensor {
// Copy CPU data to GPU buffer
let device = metal::Device::system_default().unwrap();
let buffer = device.new_buffer_with_data(
tensor.data.as_ptr() as *const _,
(tensor.data.len() * 4) as u64,
MTLResourceOptions::StorageModeShared,
);
MetalTensor {
buffer,
shape: tensor.shape.clone(),
strides: tensor.strides.clone(),
}
}
}
When to Use Which Backend
| Backend | Best for | Latency | Throughput |
|---|---|---|---|
| CPU (scalar) | Debugging, small tensors | Low | Low |
| CPU (SIMD) | Medium tensors, dev machines | Low | Medium |
| Metal | macOS, large batches | Medium | High |
| CUDA | NVIDIA GPUs, production | Medium | Highest |
Rules of thumb:
- Tensors < 1000 elements: CPU is fine
- Tensors < 100K elements: SIMD CPU competitive with GPU
- Tensors > 100K elements: GPU wins for throughput
- Latency-sensitive: CPU avoids kernel launch overhead
The Power of Zero-Cost Abstraction
Here’s what’s remarkable: our autodiff code doesn’t know or care about backends.
// This is backend-agnostic
pub fn backward<B: Backend>(output: &Tensor<B>) -> Gradients<B> {
// Same algorithm, regardless of B
}
// So is the optimizer
impl<B: Backend> Optimizer<B> for Adam<B> {
fn step(&mut self, ...) -> B::Tensor {
// Uses B::mul, B::add, etc.
}
}
When you compile with --features metal, the entire autodiff engine becomes Metal-accelerated. When you compile for CPU, it uses SIMD. Same source code, optimal machine code for each target.
This is Rust’s promise: write generic, get specialized.
Summary
We’ve covered:
- Backend trait: Abstract interface for tensor operations
- Type-level selection: Backend resolved at compile time, zero runtime cost
- CPU backend: Scalar loops → SIMD → BLAS
- Metal backend: GPU shaders for parallel compute
- Interop: Moving data between backends
The autodiff engine runs on any backend that implements the trait. Write once, run anywhere — at native speed.
Part 4 of the “Deep Learning from Scratch in Rust” series. See Part 1 for tensor gradients, Part 2 for layers and loss functions, Part 3 for optimizers, and Part 5 for neural network architectures.