🏡/repos/ml-explore/

mlx

🧠

MLX

Array framework for machine learning on Apple Silicon

Created by Apple's ML research team. Unified memory model, lazy evaluation, NumPy-like API, PyTorch-style neural networks. Optimized for M1/M2/M3/M4.

Python + C++Metal GPUUnified Memory
💡

Why MLX?

What makes it different from PyTorch/JAX

Unified Memory

MLX: Arrays live in shared memory. GPU and CPU access same data without copying.
Others: PyTorch/JAX require explicit .to('cuda') transfers over PCIe bus.

Lazy Evaluation

MLX: Operations build a computation graph. Execute only when you call mx.eval().
Others: PyTorch is eager by default. JAX requires explicit jit compilation.

Dynamic Shapes

MLX: Graph structure changes with input shapes. No recompilation needed.
Others: JAX requires consistent shapes for compiled functions.

Hardware Target

MLX: Designed specifically for Apple Silicon's Metal GPU and unified architecture.
Others: PyTorch/JAX target NVIDIA CUDA primarily.
🏗️

Core Architecture

How MLX works under the hood

┌─────────────────────────────────────────────────────┐
│                    Python API                        │
│  mx.array, mx.nn.Module, mx.optimizers              │
├─────────────────────────────────────────────────────┤
│                  Computation Graph                   │
│  array → Primitive → array → Primitive → ...        │
├─────────────────────────────────────────────────────┤
│                    Transforms                        │
│  grad (VJP), jvp, vmap, compile (JIT)              │
├─────────────────────────────────────────────────────┤
│                     Scheduler                        │
│  Streams (CPU threads / Metal command queues)       │
├─────────────────────────────────────────────────────┤
│                     Backends                         │
│  Metal (GPU)  │  CPU (Accelerate)  │  CUDA (Linux) │
├─────────────────────────────────────────────────────┤
│              Unified Memory (Apple Silicon)          │
│  No data transfer between CPU ↔ GPU                 │
└─────────────────────────────────────────────────────┘
📊

Array System

The foundation of MLX

Array as Graph Node

Every array holds a reference to the operation (Primitive) that created it and its input arrays. This forms a directed acyclic graph (DAG) representing the computation.

// C++ Array Structure (mlx/array.h)
class array {
  struct ArrayDesc {
    Shape shape;              // Dimensions
    Strides strides;          // Memory layout
    Dtype dtype;              // float32, int32, etc.

    std::shared_ptr<Primitive> primitive;  // Op that produced this
    std::vector<array> inputs;             // Input arrays

    Status status;  // unscheduled → evaluated → available
    std::shared_ptr<Data> data;  // The actual buffer
  };
};

Array States & Lazy Evaluation

import mlx.core as mx

# 1. Create arrays - no computation yet
x = mx.random.normal((1000, 1000))  # status: unscheduled
y = mx.random.normal((1000, 1000))  # status: unscheduled

# 2. Build computation graph - still no computation
z = x @ y           # Matrix multiply - status: unscheduled
w = mx.relu(z)      # ReLU activation - status: unscheduled
loss = mx.mean(w)   # Mean reduction - status: unscheduled

# 3. Trigger evaluation - NOW computation happens
mx.eval(loss)  # All dependencies computed, loss.status → available

# Async evaluation (non-blocking)
mx.async_eval(loss)  # Returns immediately
# ... do other work ...
# Access loss.item() blocks until ready
unscheduled

Array created but computation not yet scheduled

evaluated

Scheduled for execution, may be in progress

available

Computation complete, data ready to read

⚙️

Primitives

Operations as graph nodes

Primitive Interface

Every operation (add, matmul, conv, etc.) is a Primitive class implementing forward, backward, and vectorization.

// mlx/primitives.h
class Primitive {
  // Forward pass implementations
  virtual void eval_cpu(const std::vector<array>& inputs,
                        std::vector<array>& outputs) = 0;
  virtual void eval_gpu(const std::vector<array>& inputs,
                        std::vector<array>& outputs) = 0;

  // Automatic differentiation
  virtual std::vector<array> vjp(  // Reverse-mode (backprop)
      const std::vector<array>& primals,
      const std::vector<array>& cotangents,
      const std::vector<int>& argnums,
      const std::vector<array>& outputs);

  virtual std::vector<array> jvp(  // Forward-mode
      const std::vector<array>& primals,
      const std::vector<array>& tangents,
      const std::vector<int>& argnums);

  // Vectorization (batching)
  virtual std::pair<std::vector<array>, std::vector<int>> vmap(
      const std::vector<array>& inputs,
      const std::vector<int>& axes);
};

Key Primitives

Arithmetic

  • Add, Subtract, Multiply, Divide
  • Power, Exp, Log
  • MatMul, Addmm

Activation

  • ReLU, Sigmoid, Tanh
  • GELU, SiLU, Softmax
  • LayerNorm, RMSNorm

Reduction

  • Sum, Mean, Prod
  • Max, Min, ArgMax
  • All, Any

Automatic Differentiation

Computing gradients

Two Modes

  • VJP (Reverse-mode): Backpropagation. Efficient when outputs < inputs. Used for training.
  • JVP (Forward-mode): Computes Jacobian-vector products. Efficient when inputs < outputs.

Basic Gradient Computation

import mlx.core as mx

# Define a function
def loss_fn(w, x, y):
    pred = x @ w
    return mx.mean((pred - y) ** 2)

# Create gradient function (VJP)
grad_fn = mx.grad(loss_fn)  # Differentiates w.r.t. first argument

# Or get both value and gradient
value_and_grad_fn = mx.value_and_grad(loss_fn)

# Compute
w = mx.random.normal((10, 1))
x = mx.random.normal((100, 10))
y = mx.random.normal((100, 1))

loss, grad_w = value_and_grad_fn(w, x, y)
mx.eval(loss, grad_w)

print(f"Loss: {loss.item()}")
print(f"Gradient shape: {grad_w.shape}")

Gradient w.r.t. Multiple Arguments

# Differentiate w.r.t. multiple arguments
grad_fn = mx.grad(loss_fn, argnums=[0, 1])  # Grad w.r.t. w and x
grad_w, grad_x = grad_fn(w, x, y)

# Higher-order derivatives
def f(x):
    return mx.sum(x ** 3)

df = mx.grad(f)      # First derivative: 3x^2
ddf = mx.grad(df)    # Second derivative: 6x

x = mx.array([1.0, 2.0, 3.0])
print(f"f(x)  = {f(x).item()}")      # 1 + 8 + 27 = 36
print(f"f'(x) = {df(x)}")            # [3, 12, 27]
print(f"f''(x) = {ddf(x)}")          # [6, 12, 18]

How Each Primitive Implements VJP

// Example: Multiply primitive VJP (mlx/primitives.cpp)
std::vector<array> Multiply::vjp(
    const std::vector<array>& primals,      // [a, b]
    const std::vector<array>& cotangents,   // [∂L/∂(a*b)]
    const std::vector<int>& argnums,
    const std::vector<array>& outputs) {

  // Chain rule: ∂L/∂a = ∂L/∂(a*b) * b
  //             ∂L/∂b = ∂L/∂(a*b) * a
  std::vector<array> grads;
  for (int argnum : argnums) {
    if (argnum == 0) {
      grads.push_back(cotangents[0] * primals[1]);  // ∂L/∂a
    } else {
      grads.push_back(cotangents[0] * primals[0]);  // ∂L/∂b
    }
  }
  return grads;
}
🔄

Transforms

Function transformations

mx.grad

Transform function to compute gradients via reverse-mode autodiff.

def f(x):
    return mx.sum(x ** 2)

grad_f = mx.grad(f)
x = mx.array([1., 2., 3.])
print(grad_f(x))  # [2., 4., 6.]

mx.vmap

Vectorize a function over a batch dimension automatically.

def single_example(x, w):
    return x @ w

# Vectorize over batch dimension
batched_fn = mx.vmap(single_example, in_axes=(0, None))

x_batch = mx.random.normal((32, 10))  # 32 examples
w = mx.random.normal((10, 5))
out = batched_fn(x_batch, w)  # (32, 5)

mx.compile

JIT compile for kernel fusion and optimization.

@mx.compile
def fused_ops(x):
    # These ops get fused into one kernel
    x = mx.relu(x)
    x = x * 2
    x = x + 1
    return x

# First call compiles, subsequent calls reuse
out = fused_ops(mx.random.normal((1000,)))

mx.vjp / mx.jvp

Low-level access to vector-Jacobian and Jacobian-vector products.

def f(x):
    return mx.stack([mx.sum(x**2), mx.sum(x**3)])

x = mx.array([1., 2., 3.])
v = mx.array([1., 0.])  # Select first output

# VJP: compute ∂f/∂x weighted by v
outputs, vjp_fn = mx.vjp(f, [x])
grads = vjp_fn([v])

# JVP: directional derivative
primals, tangents = mx.jvp(f, [x], [mx.ones_like(x)])
💾

Memory Management

Unified memory on Apple Silicon

Why Unified Memory Matters

Traditional (CUDA)

# PyTorch on NVIDIA
x_cpu = torch.randn(1000, 1000)
x_gpu = x_cpu.to('cuda')  # PCIe transfer!
y_gpu = model(x_gpu)
y_cpu = y_gpu.to('cpu')   # PCIe transfer!

# Each transfer: ~12 GB/s over PCIe
# Latency + bandwidth bottleneck

MLX (Apple Silicon)

# MLX on M1/M2/M3/M4
x = mx.random.normal((1000, 1000))
y = model(x)

# No transfer needed!
# CPU and GPU access same memory
# ~200+ GB/s memory bandwidth

Memory Tracking & Limits

import mlx.core as mx

# Check memory usage
print(f"Active: {mx.metal.get_active_memory() / 1e9:.2f} GB")
print(f"Peak: {mx.metal.get_peak_memory() / 1e9:.2f} GB")
print(f"Cache: {mx.metal.get_cache_memory() / 1e9:.2f} GB")

# Set limits
mx.metal.set_memory_limit(8 * 1024**3)  # 8 GB max
mx.metal.set_cache_limit(2 * 1024**3)   # 2 GB cache

# Clear cache
mx.metal.clear_cache()

# Reset peak tracking
mx.metal.reset_peak_memory()
🧩

Neural Networks (mlx.nn)

PyTorch-style module system

Module Base Class

Modules are dictionaries containing parameters (arrays) and submodules. This enables easy parameter extraction and updates.

import mlx.nn as nn
import mlx.core as mx

class Module(dict):
    """Base class - a dict of params and submodules"""

    def parameters(self):
        """Get all trainable parameters recursively"""

    def trainable_parameters(self):
        """Get non-frozen parameters"""

    def freeze(self, keys=None, recurse=True):
        """Freeze parameters (no gradient computation)"""

    def unfreeze(self, keys=None, recurse=True):
        """Unfreeze parameters"""

    def update(self, parameters):
        """Replace parameters (used by optimizer)"""

    def train(self, mode=True):
        """Set training mode (affects Dropout, BatchNorm)"""

    def eval(self):
        """Set evaluation mode"""

Defining a Model

import mlx.nn as nn
import mlx.core as mx

class MLP(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super().__init__()
        self.fc1 = nn.Linear(in_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, out_dim)
        self.dropout = nn.Dropout(p=0.1)

    def __call__(self, x):
        x = nn.relu(self.fc1(x))
        x = self.dropout(x)
        x = nn.relu(self.fc2(x))
        x = self.dropout(x)
        return self.fc3(x)

model = MLP(784, 256, 10)

# Inspect parameters
for name, param in model.parameters().items():
    print(f"{name}: {param.shape}")

Available Layers

Linear

  • nn.Linear
  • nn.Bilinear
  • nn.Embedding

Convolution

  • nn.Conv1d, Conv2d, Conv3d
  • nn.ConvTranspose1d/2d/3d

Normalization

  • nn.LayerNorm
  • nn.BatchNorm
  • nn.GroupNorm
  • nn.RMSNorm

Recurrent

  • nn.RNN
  • nn.LSTM
  • nn.GRU

Attention

  • nn.MultiHeadAttention
  • nn.Transformer
  • nn.TransformerEncoder

Regularization

  • nn.Dropout
  • nn.Dropout2d
  • nn.Dropout3d
🔍

Attention & Transformers

Efficient implementations

MultiHeadAttention

import mlx.nn as nn
import mlx.core as mx

# Create attention layer
attn = nn.MultiHeadAttention(
    dims=512,           # Model dimension
    num_heads=8,        # Number of attention heads
    query_input_dims=512,
    key_input_dims=512,
    value_input_dims=512,
    value_dims=512,
    value_output_dims=512,
    bias=False
)

# Forward pass
queries = mx.random.normal((2, 10, 512))  # (batch, seq_len, dim)
keys = mx.random.normal((2, 20, 512))
values = mx.random.normal((2, 20, 512))

# Optional: causal mask for autoregressive
mask = nn.MultiHeadAttention.create_additive_causal_mask(10)

output = attn(queries, keys, values, mask=mask)
print(output.shape)  # (2, 10, 512)

Fused Attention (mx.fast)

MLX provides optimized fused kernels for common patterns:

import mlx.core as mx

# Scaled dot-product attention - single fused kernel
output = mx.fast.scaled_dot_product_attention(
    queries,    # (batch, heads, seq_q, head_dim)
    keys,       # (batch, heads, seq_k, head_dim)
    values,     # (batch, heads, seq_k, head_dim)
    scale=1.0 / math.sqrt(head_dim),
    mask=None   # Optional attention mask
)

# RoPE (Rotary Position Embedding)
output = mx.fast.rope(
    x,              # Input tensor
    dims=64,        # Dimensions to rotate
    traditional=False,
    base=10000.0,
    scale=1.0,
    offset=0
)

# RMS Normalization
output = mx.fast.rms_norm(x, weight, eps=1e-5)

# Layer Normalization
output = mx.fast.layer_norm(x, weight, bias, eps=1e-5)

Full Transformer Block

class TransformerBlock(nn.Module):
    def __init__(self, dims, num_heads, mlp_dims, dropout=0.0):
        super().__init__()
        self.attention = nn.MultiHeadAttention(dims, num_heads)
        self.norm1 = nn.LayerNorm(dims)
        self.norm2 = nn.LayerNorm(dims)
        self.mlp = nn.Sequential(
            nn.Linear(dims, mlp_dims),
            nn.GELU(),
            nn.Linear(mlp_dims, dims),
        )
        self.dropout = nn.Dropout(dropout)

    def __call__(self, x, mask=None):
        # Self-attention with residual
        h = self.norm1(x)
        h = self.attention(h, h, h, mask=mask)
        x = x + self.dropout(h)

        # FFN with residual
        h = self.norm2(x)
        h = self.mlp(h)
        x = x + self.dropout(h)

        return x
🏋️

Training

Optimizers and training loops

Available Optimizers

import mlx.optimizers as optim

# SGD with momentum
optimizer = optim.SGD(learning_rate=0.01, momentum=0.9, nesterov=True)

# Adam
optimizer = optim.Adam(learning_rate=0.001, betas=(0.9, 0.999), eps=1e-8)

# AdamW (Adam with decoupled weight decay)
optimizer = optim.AdamW(learning_rate=0.001, weight_decay=0.01)

# Others: RMSprop, Adagrad, AdaDelta, Lion, Adafactor

Learning Rate Schedulers

import mlx.optimizers as optim

# Exponential decay
lr_schedule = optim.exponential_decay(init=1e-3, decay_rate=0.99)

# Step decay
lr_schedule = optim.step_decay(init=1e-3, decay_rate=0.5, step_size=1000)

# Cosine annealing
lr_schedule = optim.cosine_decay(init=1e-3, decay_steps=10000, end=1e-5)

# Linear warmup + cosine decay
lr_schedule = optim.join_schedules(
    schedules=[
        optim.linear_schedule(init=0, end=1e-3, steps=100),
        optim.cosine_decay(init=1e-3, decay_steps=9900),
    ],
    boundaries=[100]
)

# Use with optimizer
optimizer = optim.Adam(learning_rate=lr_schedule)

Complete Training Loop

import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim

# Model
model = MLP(784, 256, 10)

# Optimizer
optimizer = optim.AdamW(learning_rate=1e-3, weight_decay=0.01)

# Loss function
def loss_fn(model, x, y):
    logits = model(x)
    return nn.losses.cross_entropy(logits, y).mean()

# Training step - use nn.value_and_grad for models
@mx.compile
def train_step(model, x, y):
    loss, grads = nn.value_and_grad(model, loss_fn)(model, x, y)
    return loss, grads

# Training loop
for epoch in range(num_epochs):
    model.train()  # Enable dropout

    for batch_x, batch_y in data_loader:
        # Forward + backward
        loss, grads = train_step(model, batch_x, batch_y)

        # Update parameters
        optimizer.update(model, grads)

        # Evaluate to release memory
        mx.eval(model.parameters())

    # Validation
    model.eval()  # Disable dropout
    val_loss = evaluate(model, val_data)

    print(f"Epoch {epoch}: train_loss={loss.item():.4f}, val_loss={val_loss:.4f}")

Gradient Checkpointing (Memory Optimization)

# Trade compute for memory by recomputing activations during backward
checkpointed_model = nn.checkpoint(model)

def loss_fn(model, x, y):
    # Intermediate activations not stored, recomputed during backward
    logits = checkpointed_model(x)
    return nn.losses.cross_entropy(logits, y).mean()

# Useful for large models that don't fit in memory

Inference

Loading and running models

Save & Load Weights

# Save weights
model.save_weights("model.safetensors")  # Preferred format
model.save_weights("model.npz")          # NumPy format

# Load weights
model = MLP(784, 256, 10)
model.load_weights("model.safetensors")

# Or load from Hugging Face format
weights = mx.load("pytorch_model.bin")  # Auto-converts
model.load_weights(list(weights.items()))

Efficient Inference

# Set to eval mode (disables dropout)
model.eval()

# Compile for faster inference
@mx.compile
def predict(model, x):
    return model(x)

# Run inference
x = mx.random.normal((1, 784))
logits = predict(model, x)
mx.eval(logits)

# Get predictions
probs = mx.softmax(logits, axis=-1)
pred_class = mx.argmax(probs, axis=-1).item()
📉

Quantization

Reduce model size and speed up inference

Quantized Layers

import mlx.nn as nn

# Quantized linear layer
quantized_linear = nn.QuantizedLinear(
    input_dims=1024,
    output_dims=1024,
    bias=True,
    group_size=64,  # Quantization group size
    bits=4          # 4-bit quantization
)

# Convert existing model
def quantize_model(model, group_size=64, bits=4):
    """Replace Linear layers with QuantizedLinear"""
    for name, module in model.items():
        if isinstance(module, nn.Linear):
            model[name] = nn.QuantizedLinear.from_linear(
                module, group_size=group_size, bits=bits
            )
        elif isinstance(module, nn.Module):
            quantize_model(module, group_size, bits)
    return model

# Quantize
quantized_model = quantize_model(model)

Quantization Primitives

import mlx.core as mx

# Quantize tensor
quantized, scales, biases = mx.quantize(
    w,                    # Weights to quantize
    group_size=64,        # Elements per quantization group
    bits=4                # Bits per element (2, 4, or 8)
)

# Dequantize
dequantized = mx.dequantize(quantized, scales, biases, group_size, bits)

# Quantized matrix multiply (fused kernel)
output = mx.quantized_matmul(
    x,          # Input (not quantized)
    w_quant,    # Quantized weights
    scales,
    biases,
    transpose=True,
    group_size=64,
    bits=4
)
🔧

Fine-tuning

LoRA and parameter-efficient training

LoRA (Low-Rank Adaptation)

import mlx.nn as nn

class LoRALinear(nn.Module):
    """Linear layer with LoRA adaptation"""

    def __init__(self, linear, rank=8, alpha=16):
        super().__init__()
        in_dim = linear.weight.shape[1]
        out_dim = linear.weight.shape[0]

        # Freeze original weights
        self.linear = linear
        self.linear.freeze()

        # Low-rank matrices (trainable)
        scale = alpha / rank
        self.lora_a = mx.random.normal((in_dim, rank)) * 0.01
        self.lora_b = mx.zeros((rank, out_dim))
        self.scale = scale

    def __call__(self, x):
        # Original forward + LoRA
        y = self.linear(x)
        lora_out = (x @ self.lora_a @ self.lora_b) * self.scale
        return y + lora_out

# Apply LoRA to model
def apply_lora(model, rank=8, target_modules=["query_proj", "value_proj"]):
    for name, module in model.items():
        if any(t in name for t in target_modules) and isinstance(module, nn.Linear):
            model[name] = LoRALinear(module, rank=rank)
        elif isinstance(module, nn.Module):
            apply_lora(module, rank, target_modules)

# Freeze base model, only train LoRA
model.freeze()
apply_lora(model)

# Now only LoRA parameters are trainable
trainable = model.trainable_parameters()
print(f"Trainable params: {sum(p.size for p in trainable.values())}")

Selective Freezing

# Freeze entire model
model.freeze()

# Unfreeze specific layers
model.classifier.unfreeze()

# Freeze by key pattern
model.freeze(keys="bias")  # Freeze all biases
model.unfreeze(keys="weight")  # Unfreeze all weights

# Check what's trainable
for name, param in model.trainable_parameters().items():
    print(f"Training: {name} {param.shape}")
📈

Example: Linear Regression from Scratch

import mlx.core as mx

# Generate data
N = 100
D = 10
X = mx.random.normal((N, D))
true_w = mx.random.normal((D, 1))
y = X @ true_w + mx.random.normal((N, 1)) * 0.1

# Initialize weights
w = mx.random.normal((D, 1))

# Loss function
def loss_fn(w):
    pred = X @ w
    return mx.mean((pred - y) ** 2)

# Get gradient function
grad_fn = mx.grad(loss_fn)

# Training
lr = 0.1
for i in range(100):
    # Compute gradient
    grad = grad_fn(w)

    # Update weights
    w = w - lr * grad

    # Evaluate (trigger computation)
    mx.eval(w)

    if i % 10 == 0:
        loss = loss_fn(w)
        mx.eval(loss)
        print(f"Step {i}: loss = {loss.item():.6f}")

print(f"\nLearned vs True correlation: {mx.corrcoef(w.flatten(), true_w.flatten())[0,1].item():.4f}")
🧠

Example: MNIST Classifier

import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np

# Load MNIST (using numpy, then convert)
# In practice, use mlx-data or similar
train_images = mx.array(np.load("mnist_train_images.npy") / 255.0)
train_labels = mx.array(np.load("mnist_train_labels.npy"))

# Model
class MNISTClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)
        self.dropout = nn.Dropout(0.5)

    def __call__(self, x):
        # x: (batch, 28, 28) -> (batch, 1, 28, 28)
        x = x[:, None, :, :]

        x = self.pool(nn.relu(self.conv1(x)))  # -> (batch, 32, 14, 14)
        x = self.pool(nn.relu(self.conv2(x)))  # -> (batch, 64, 7, 7)
        x = x.reshape(x.shape[0], -1)          # -> (batch, 64*7*7)
        x = self.dropout(nn.relu(self.fc1(x)))
        return self.fc2(x)

model = MNISTClassifier()
optimizer = optim.Adam(learning_rate=1e-3)

def loss_fn(model, images, labels):
    logits = model(images)
    return nn.losses.cross_entropy(logits, labels).mean()

# Compiled training step
@mx.compile
def train_step(model, images, labels):
    loss, grads = nn.value_and_grad(model, loss_fn)(model, images, labels)
    return loss, grads

# Training
batch_size = 64
num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    num_batches = 0

    # Shuffle
    indices = mx.random.permutation(len(train_images))

    for i in range(0, len(train_images), batch_size):
        batch_idx = indices[i:i+batch_size]
        batch_images = train_images[batch_idx]
        batch_labels = train_labels[batch_idx]

        loss, grads = train_step(model, batch_images, batch_labels)
        optimizer.update(model, grads)
        mx.eval(model.parameters())

        total_loss += loss.item()
        num_batches += 1

    avg_loss = total_loss / num_batches
    print(f"Epoch {epoch+1}: loss = {avg_loss:.4f}")

# Evaluation
model.eval()
test_logits = model(test_images)
test_preds = mx.argmax(test_logits, axis=-1)
accuracy = mx.mean(test_preds == test_labels).item()
print(f"Test accuracy: {accuracy:.2%}")
🔬

Debugging & Profiling

Memory Debugging

import mlx.core as mx

# Track memory through training
for step in range(100):
    loss, grads = train_step(model, batch_x, batch_y)
    optimizer.update(model, grads)
    mx.eval(model.parameters())

    # Print memory stats
    active = mx.metal.get_active_memory() / 1e9
    peak = mx.metal.get_peak_memory() / 1e9
    print(f"Step {step}: active={active:.2f}GB, peak={peak:.2f}GB")

    # Reset peak to track per-step
    mx.metal.reset_peak_memory()

Timing

import time

# Synchronous timing
mx.eval(x)  # Ensure previous ops complete
start = time.perf_counter()

y = model(x)
mx.eval(y)  # Wait for completion

end = time.perf_counter()
print(f"Forward pass: {(end - start) * 1000:.2f}ms")

# Or use mx.synchronize()
mx.synchronize()
start = time.perf_counter()
# ... operations ...
mx.synchronize()
end = time.perf_counter()
📁

Key Files Reference

C++ Core

  • mlx/array.h - Array class
  • mlx/primitives.h - Operations
  • mlx/transforms.h - grad, vjp, jvp
  • mlx/compile.h - JIT compilation
  • mlx/scheduler.h - Execution
  • mlx/backend/metal/ - GPU kernels

Python API

  • python/mlx/nn/layers/ - NN layers
  • python/mlx/nn/losses.py - Loss funcs
  • python/mlx/optimizers/ - Optimizers
  • python/src/transforms.cpp - Bindings
  • examples/python/ - Examples
💎

Key Takeaways

Unified Memory Advantage

No CPU↔GPU transfers on Apple Silicon. Arrays accessible by both without copying. ~200+ GB/s memory bandwidth.

Lazy by Default

Build computation graphs declaratively. Execute with mx.eval(). Enables automatic optimization and fusion.

Composable Transforms

grad, vmap, compile transform functions. Chain them: compile(vmap(grad(f))). JAX-style functional transforms.

PyTorch-like nn.Module

Familiar Module API for neural networks. Parameters are dict entries. freeze/unfreeze for fine-tuning.

Every Primitive has VJP

Each operation implements its own backward pass. Enables automatic differentiation through any computation.

Quantization Built-in

4-bit, 8-bit quantization with fused kernels. QuantizedLinear for memory-efficient inference.

🔗

Resources