MLX
Array framework for machine learning on Apple Silicon
Created by Apple's ML research team. Unified memory model, lazy evaluation, NumPy-like API, PyTorch-style neural networks. Optimized for M1/M2/M3/M4.
Why MLX?
What makes it different from PyTorch/JAX
Unified Memory
Lazy Evaluation
Dynamic Shapes
Hardware Target
Core Architecture
How MLX works under the hood
┌─────────────────────────────────────────────────────┐ │ Python API │ │ mx.array, mx.nn.Module, mx.optimizers │ ├─────────────────────────────────────────────────────┤ │ Computation Graph │ │ array → Primitive → array → Primitive → ... │ ├─────────────────────────────────────────────────────┤ │ Transforms │ │ grad (VJP), jvp, vmap, compile (JIT) │ ├─────────────────────────────────────────────────────┤ │ Scheduler │ │ Streams (CPU threads / Metal command queues) │ ├─────────────────────────────────────────────────────┤ │ Backends │ │ Metal (GPU) │ CPU (Accelerate) │ CUDA (Linux) │ ├─────────────────────────────────────────────────────┤ │ Unified Memory (Apple Silicon) │ │ No data transfer between CPU ↔ GPU │ └─────────────────────────────────────────────────────┘
Array System
The foundation of MLX
Array as Graph Node
Every array holds a reference to the operation (Primitive) that created it and its input arrays. This forms a directed acyclic graph (DAG) representing the computation.
// C++ Array Structure (mlx/array.h)
class array {
struct ArrayDesc {
Shape shape; // Dimensions
Strides strides; // Memory layout
Dtype dtype; // float32, int32, etc.
std::shared_ptr<Primitive> primitive; // Op that produced this
std::vector<array> inputs; // Input arrays
Status status; // unscheduled → evaluated → available
std::shared_ptr<Data> data; // The actual buffer
};
};Array States & Lazy Evaluation
import mlx.core as mx
# 1. Create arrays - no computation yet
x = mx.random.normal((1000, 1000)) # status: unscheduled
y = mx.random.normal((1000, 1000)) # status: unscheduled
# 2. Build computation graph - still no computation
z = x @ y # Matrix multiply - status: unscheduled
w = mx.relu(z) # ReLU activation - status: unscheduled
loss = mx.mean(w) # Mean reduction - status: unscheduled
# 3. Trigger evaluation - NOW computation happens
mx.eval(loss) # All dependencies computed, loss.status → available
# Async evaluation (non-blocking)
mx.async_eval(loss) # Returns immediately
# ... do other work ...
# Access loss.item() blocks until readyunscheduledArray created but computation not yet scheduled
evaluatedScheduled for execution, may be in progress
availableComputation complete, data ready to read
Primitives
Operations as graph nodes
Primitive Interface
Every operation (add, matmul, conv, etc.) is a Primitive class implementing forward, backward, and vectorization.
// mlx/primitives.h
class Primitive {
// Forward pass implementations
virtual void eval_cpu(const std::vector<array>& inputs,
std::vector<array>& outputs) = 0;
virtual void eval_gpu(const std::vector<array>& inputs,
std::vector<array>& outputs) = 0;
// Automatic differentiation
virtual std::vector<array> vjp( // Reverse-mode (backprop)
const std::vector<array>& primals,
const std::vector<array>& cotangents,
const std::vector<int>& argnums,
const std::vector<array>& outputs);
virtual std::vector<array> jvp( // Forward-mode
const std::vector<array>& primals,
const std::vector<array>& tangents,
const std::vector<int>& argnums);
// Vectorization (batching)
virtual std::pair<std::vector<array>, std::vector<int>> vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes);
};Key Primitives
Arithmetic
- Add, Subtract, Multiply, Divide
- Power, Exp, Log
- MatMul, Addmm
Activation
- ReLU, Sigmoid, Tanh
- GELU, SiLU, Softmax
- LayerNorm, RMSNorm
Reduction
- Sum, Mean, Prod
- Max, Min, ArgMax
- All, Any
Automatic Differentiation
Computing gradients
Two Modes
- VJP (Reverse-mode): Backpropagation. Efficient when outputs < inputs. Used for training.
- JVP (Forward-mode): Computes Jacobian-vector products. Efficient when inputs < outputs.
Basic Gradient Computation
import mlx.core as mx
# Define a function
def loss_fn(w, x, y):
pred = x @ w
return mx.mean((pred - y) ** 2)
# Create gradient function (VJP)
grad_fn = mx.grad(loss_fn) # Differentiates w.r.t. first argument
# Or get both value and gradient
value_and_grad_fn = mx.value_and_grad(loss_fn)
# Compute
w = mx.random.normal((10, 1))
x = mx.random.normal((100, 10))
y = mx.random.normal((100, 1))
loss, grad_w = value_and_grad_fn(w, x, y)
mx.eval(loss, grad_w)
print(f"Loss: {loss.item()}")
print(f"Gradient shape: {grad_w.shape}")Gradient w.r.t. Multiple Arguments
# Differentiate w.r.t. multiple arguments
grad_fn = mx.grad(loss_fn, argnums=[0, 1]) # Grad w.r.t. w and x
grad_w, grad_x = grad_fn(w, x, y)
# Higher-order derivatives
def f(x):
return mx.sum(x ** 3)
df = mx.grad(f) # First derivative: 3x^2
ddf = mx.grad(df) # Second derivative: 6x
x = mx.array([1.0, 2.0, 3.0])
print(f"f(x) = {f(x).item()}") # 1 + 8 + 27 = 36
print(f"f'(x) = {df(x)}") # [3, 12, 27]
print(f"f''(x) = {ddf(x)}") # [6, 12, 18]How Each Primitive Implements VJP
// Example: Multiply primitive VJP (mlx/primitives.cpp)
std::vector<array> Multiply::vjp(
const std::vector<array>& primals, // [a, b]
const std::vector<array>& cotangents, // [∂L/∂(a*b)]
const std::vector<int>& argnums,
const std::vector<array>& outputs) {
// Chain rule: ∂L/∂a = ∂L/∂(a*b) * b
// ∂L/∂b = ∂L/∂(a*b) * a
std::vector<array> grads;
for (int argnum : argnums) {
if (argnum == 0) {
grads.push_back(cotangents[0] * primals[1]); // ∂L/∂a
} else {
grads.push_back(cotangents[0] * primals[0]); // ∂L/∂b
}
}
return grads;
}Transforms
Function transformations
mx.grad
Transform function to compute gradients via reverse-mode autodiff.
def f(x):
return mx.sum(x ** 2)
grad_f = mx.grad(f)
x = mx.array([1., 2., 3.])
print(grad_f(x)) # [2., 4., 6.]mx.vmap
Vectorize a function over a batch dimension automatically.
def single_example(x, w):
return x @ w
# Vectorize over batch dimension
batched_fn = mx.vmap(single_example, in_axes=(0, None))
x_batch = mx.random.normal((32, 10)) # 32 examples
w = mx.random.normal((10, 5))
out = batched_fn(x_batch, w) # (32, 5)mx.compile
JIT compile for kernel fusion and optimization.
@mx.compile
def fused_ops(x):
# These ops get fused into one kernel
x = mx.relu(x)
x = x * 2
x = x + 1
return x
# First call compiles, subsequent calls reuse
out = fused_ops(mx.random.normal((1000,)))mx.vjp / mx.jvp
Low-level access to vector-Jacobian and Jacobian-vector products.
def f(x):
return mx.stack([mx.sum(x**2), mx.sum(x**3)])
x = mx.array([1., 2., 3.])
v = mx.array([1., 0.]) # Select first output
# VJP: compute ∂f/∂x weighted by v
outputs, vjp_fn = mx.vjp(f, [x])
grads = vjp_fn([v])
# JVP: directional derivative
primals, tangents = mx.jvp(f, [x], [mx.ones_like(x)])Memory Management
Unified memory on Apple Silicon
Why Unified Memory Matters
Traditional (CUDA)
# PyTorch on NVIDIA
x_cpu = torch.randn(1000, 1000)
x_gpu = x_cpu.to('cuda') # PCIe transfer!
y_gpu = model(x_gpu)
y_cpu = y_gpu.to('cpu') # PCIe transfer!
# Each transfer: ~12 GB/s over PCIe
# Latency + bandwidth bottleneckMLX (Apple Silicon)
# MLX on M1/M2/M3/M4
x = mx.random.normal((1000, 1000))
y = model(x)
# No transfer needed!
# CPU and GPU access same memory
# ~200+ GB/s memory bandwidthMemory Tracking & Limits
import mlx.core as mx
# Check memory usage
print(f"Active: {mx.metal.get_active_memory() / 1e9:.2f} GB")
print(f"Peak: {mx.metal.get_peak_memory() / 1e9:.2f} GB")
print(f"Cache: {mx.metal.get_cache_memory() / 1e9:.2f} GB")
# Set limits
mx.metal.set_memory_limit(8 * 1024**3) # 8 GB max
mx.metal.set_cache_limit(2 * 1024**3) # 2 GB cache
# Clear cache
mx.metal.clear_cache()
# Reset peak tracking
mx.metal.reset_peak_memory()Neural Networks (mlx.nn)
PyTorch-style module system
Module Base Class
Modules are dictionaries containing parameters (arrays) and submodules. This enables easy parameter extraction and updates.
import mlx.nn as nn
import mlx.core as mx
class Module(dict):
"""Base class - a dict of params and submodules"""
def parameters(self):
"""Get all trainable parameters recursively"""
def trainable_parameters(self):
"""Get non-frozen parameters"""
def freeze(self, keys=None, recurse=True):
"""Freeze parameters (no gradient computation)"""
def unfreeze(self, keys=None, recurse=True):
"""Unfreeze parameters"""
def update(self, parameters):
"""Replace parameters (used by optimizer)"""
def train(self, mode=True):
"""Set training mode (affects Dropout, BatchNorm)"""
def eval(self):
"""Set evaluation mode"""Defining a Model
import mlx.nn as nn
import mlx.core as mx
class MLP(nn.Module):
def __init__(self, in_dim, hidden_dim, out_dim):
super().__init__()
self.fc1 = nn.Linear(in_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, out_dim)
self.dropout = nn.Dropout(p=0.1)
def __call__(self, x):
x = nn.relu(self.fc1(x))
x = self.dropout(x)
x = nn.relu(self.fc2(x))
x = self.dropout(x)
return self.fc3(x)
model = MLP(784, 256, 10)
# Inspect parameters
for name, param in model.parameters().items():
print(f"{name}: {param.shape}")Available Layers
Linear
- nn.Linear
- nn.Bilinear
- nn.Embedding
Convolution
- nn.Conv1d, Conv2d, Conv3d
- nn.ConvTranspose1d/2d/3d
Normalization
- nn.LayerNorm
- nn.BatchNorm
- nn.GroupNorm
- nn.RMSNorm
Recurrent
- nn.RNN
- nn.LSTM
- nn.GRU
Attention
- nn.MultiHeadAttention
- nn.Transformer
- nn.TransformerEncoder
Regularization
- nn.Dropout
- nn.Dropout2d
- nn.Dropout3d
Attention & Transformers
Efficient implementations
MultiHeadAttention
import mlx.nn as nn
import mlx.core as mx
# Create attention layer
attn = nn.MultiHeadAttention(
dims=512, # Model dimension
num_heads=8, # Number of attention heads
query_input_dims=512,
key_input_dims=512,
value_input_dims=512,
value_dims=512,
value_output_dims=512,
bias=False
)
# Forward pass
queries = mx.random.normal((2, 10, 512)) # (batch, seq_len, dim)
keys = mx.random.normal((2, 20, 512))
values = mx.random.normal((2, 20, 512))
# Optional: causal mask for autoregressive
mask = nn.MultiHeadAttention.create_additive_causal_mask(10)
output = attn(queries, keys, values, mask=mask)
print(output.shape) # (2, 10, 512)Fused Attention (mx.fast)
MLX provides optimized fused kernels for common patterns:
import mlx.core as mx
# Scaled dot-product attention - single fused kernel
output = mx.fast.scaled_dot_product_attention(
queries, # (batch, heads, seq_q, head_dim)
keys, # (batch, heads, seq_k, head_dim)
values, # (batch, heads, seq_k, head_dim)
scale=1.0 / math.sqrt(head_dim),
mask=None # Optional attention mask
)
# RoPE (Rotary Position Embedding)
output = mx.fast.rope(
x, # Input tensor
dims=64, # Dimensions to rotate
traditional=False,
base=10000.0,
scale=1.0,
offset=0
)
# RMS Normalization
output = mx.fast.rms_norm(x, weight, eps=1e-5)
# Layer Normalization
output = mx.fast.layer_norm(x, weight, bias, eps=1e-5)Full Transformer Block
class TransformerBlock(nn.Module):
def __init__(self, dims, num_heads, mlp_dims, dropout=0.0):
super().__init__()
self.attention = nn.MultiHeadAttention(dims, num_heads)
self.norm1 = nn.LayerNorm(dims)
self.norm2 = nn.LayerNorm(dims)
self.mlp = nn.Sequential(
nn.Linear(dims, mlp_dims),
nn.GELU(),
nn.Linear(mlp_dims, dims),
)
self.dropout = nn.Dropout(dropout)
def __call__(self, x, mask=None):
# Self-attention with residual
h = self.norm1(x)
h = self.attention(h, h, h, mask=mask)
x = x + self.dropout(h)
# FFN with residual
h = self.norm2(x)
h = self.mlp(h)
x = x + self.dropout(h)
return xTraining
Optimizers and training loops
Available Optimizers
import mlx.optimizers as optim
# SGD with momentum
optimizer = optim.SGD(learning_rate=0.01, momentum=0.9, nesterov=True)
# Adam
optimizer = optim.Adam(learning_rate=0.001, betas=(0.9, 0.999), eps=1e-8)
# AdamW (Adam with decoupled weight decay)
optimizer = optim.AdamW(learning_rate=0.001, weight_decay=0.01)
# Others: RMSprop, Adagrad, AdaDelta, Lion, AdafactorLearning Rate Schedulers
import mlx.optimizers as optim
# Exponential decay
lr_schedule = optim.exponential_decay(init=1e-3, decay_rate=0.99)
# Step decay
lr_schedule = optim.step_decay(init=1e-3, decay_rate=0.5, step_size=1000)
# Cosine annealing
lr_schedule = optim.cosine_decay(init=1e-3, decay_steps=10000, end=1e-5)
# Linear warmup + cosine decay
lr_schedule = optim.join_schedules(
schedules=[
optim.linear_schedule(init=0, end=1e-3, steps=100),
optim.cosine_decay(init=1e-3, decay_steps=9900),
],
boundaries=[100]
)
# Use with optimizer
optimizer = optim.Adam(learning_rate=lr_schedule)Complete Training Loop
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
# Model
model = MLP(784, 256, 10)
# Optimizer
optimizer = optim.AdamW(learning_rate=1e-3, weight_decay=0.01)
# Loss function
def loss_fn(model, x, y):
logits = model(x)
return nn.losses.cross_entropy(logits, y).mean()
# Training step - use nn.value_and_grad for models
@mx.compile
def train_step(model, x, y):
loss, grads = nn.value_and_grad(model, loss_fn)(model, x, y)
return loss, grads
# Training loop
for epoch in range(num_epochs):
model.train() # Enable dropout
for batch_x, batch_y in data_loader:
# Forward + backward
loss, grads = train_step(model, batch_x, batch_y)
# Update parameters
optimizer.update(model, grads)
# Evaluate to release memory
mx.eval(model.parameters())
# Validation
model.eval() # Disable dropout
val_loss = evaluate(model, val_data)
print(f"Epoch {epoch}: train_loss={loss.item():.4f}, val_loss={val_loss:.4f}")Gradient Checkpointing (Memory Optimization)
# Trade compute for memory by recomputing activations during backward
checkpointed_model = nn.checkpoint(model)
def loss_fn(model, x, y):
# Intermediate activations not stored, recomputed during backward
logits = checkpointed_model(x)
return nn.losses.cross_entropy(logits, y).mean()
# Useful for large models that don't fit in memoryInference
Loading and running models
Save & Load Weights
# Save weights
model.save_weights("model.safetensors") # Preferred format
model.save_weights("model.npz") # NumPy format
# Load weights
model = MLP(784, 256, 10)
model.load_weights("model.safetensors")
# Or load from Hugging Face format
weights = mx.load("pytorch_model.bin") # Auto-converts
model.load_weights(list(weights.items()))Efficient Inference
# Set to eval mode (disables dropout)
model.eval()
# Compile for faster inference
@mx.compile
def predict(model, x):
return model(x)
# Run inference
x = mx.random.normal((1, 784))
logits = predict(model, x)
mx.eval(logits)
# Get predictions
probs = mx.softmax(logits, axis=-1)
pred_class = mx.argmax(probs, axis=-1).item()Quantization
Reduce model size and speed up inference
Quantized Layers
import mlx.nn as nn
# Quantized linear layer
quantized_linear = nn.QuantizedLinear(
input_dims=1024,
output_dims=1024,
bias=True,
group_size=64, # Quantization group size
bits=4 # 4-bit quantization
)
# Convert existing model
def quantize_model(model, group_size=64, bits=4):
"""Replace Linear layers with QuantizedLinear"""
for name, module in model.items():
if isinstance(module, nn.Linear):
model[name] = nn.QuantizedLinear.from_linear(
module, group_size=group_size, bits=bits
)
elif isinstance(module, nn.Module):
quantize_model(module, group_size, bits)
return model
# Quantize
quantized_model = quantize_model(model)Quantization Primitives
import mlx.core as mx
# Quantize tensor
quantized, scales, biases = mx.quantize(
w, # Weights to quantize
group_size=64, # Elements per quantization group
bits=4 # Bits per element (2, 4, or 8)
)
# Dequantize
dequantized = mx.dequantize(quantized, scales, biases, group_size, bits)
# Quantized matrix multiply (fused kernel)
output = mx.quantized_matmul(
x, # Input (not quantized)
w_quant, # Quantized weights
scales,
biases,
transpose=True,
group_size=64,
bits=4
)Fine-tuning
LoRA and parameter-efficient training
LoRA (Low-Rank Adaptation)
import mlx.nn as nn
class LoRALinear(nn.Module):
"""Linear layer with LoRA adaptation"""
def __init__(self, linear, rank=8, alpha=16):
super().__init__()
in_dim = linear.weight.shape[1]
out_dim = linear.weight.shape[0]
# Freeze original weights
self.linear = linear
self.linear.freeze()
# Low-rank matrices (trainable)
scale = alpha / rank
self.lora_a = mx.random.normal((in_dim, rank)) * 0.01
self.lora_b = mx.zeros((rank, out_dim))
self.scale = scale
def __call__(self, x):
# Original forward + LoRA
y = self.linear(x)
lora_out = (x @ self.lora_a @ self.lora_b) * self.scale
return y + lora_out
# Apply LoRA to model
def apply_lora(model, rank=8, target_modules=["query_proj", "value_proj"]):
for name, module in model.items():
if any(t in name for t in target_modules) and isinstance(module, nn.Linear):
model[name] = LoRALinear(module, rank=rank)
elif isinstance(module, nn.Module):
apply_lora(module, rank, target_modules)
# Freeze base model, only train LoRA
model.freeze()
apply_lora(model)
# Now only LoRA parameters are trainable
trainable = model.trainable_parameters()
print(f"Trainable params: {sum(p.size for p in trainable.values())}")Selective Freezing
# Freeze entire model
model.freeze()
# Unfreeze specific layers
model.classifier.unfreeze()
# Freeze by key pattern
model.freeze(keys="bias") # Freeze all biases
model.unfreeze(keys="weight") # Unfreeze all weights
# Check what's trainable
for name, param in model.trainable_parameters().items():
print(f"Training: {name} {param.shape}")Example: Linear Regression from Scratch
import mlx.core as mx
# Generate data
N = 100
D = 10
X = mx.random.normal((N, D))
true_w = mx.random.normal((D, 1))
y = X @ true_w + mx.random.normal((N, 1)) * 0.1
# Initialize weights
w = mx.random.normal((D, 1))
# Loss function
def loss_fn(w):
pred = X @ w
return mx.mean((pred - y) ** 2)
# Get gradient function
grad_fn = mx.grad(loss_fn)
# Training
lr = 0.1
for i in range(100):
# Compute gradient
grad = grad_fn(w)
# Update weights
w = w - lr * grad
# Evaluate (trigger computation)
mx.eval(w)
if i % 10 == 0:
loss = loss_fn(w)
mx.eval(loss)
print(f"Step {i}: loss = {loss.item():.6f}")
print(f"\nLearned vs True correlation: {mx.corrcoef(w.flatten(), true_w.flatten())[0,1].item():.4f}")Example: MNIST Classifier
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np
# Load MNIST (using numpy, then convert)
# In practice, use mlx-data or similar
train_images = mx.array(np.load("mnist_train_images.npy") / 255.0)
train_labels = mx.array(np.load("mnist_train_labels.npy"))
# Model
class MNISTClassifier(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.fc1 = nn.Linear(64 * 7 * 7, 128)
self.fc2 = nn.Linear(128, 10)
self.dropout = nn.Dropout(0.5)
def __call__(self, x):
# x: (batch, 28, 28) -> (batch, 1, 28, 28)
x = x[:, None, :, :]
x = self.pool(nn.relu(self.conv1(x))) # -> (batch, 32, 14, 14)
x = self.pool(nn.relu(self.conv2(x))) # -> (batch, 64, 7, 7)
x = x.reshape(x.shape[0], -1) # -> (batch, 64*7*7)
x = self.dropout(nn.relu(self.fc1(x)))
return self.fc2(x)
model = MNISTClassifier()
optimizer = optim.Adam(learning_rate=1e-3)
def loss_fn(model, images, labels):
logits = model(images)
return nn.losses.cross_entropy(logits, labels).mean()
# Compiled training step
@mx.compile
def train_step(model, images, labels):
loss, grads = nn.value_and_grad(model, loss_fn)(model, images, labels)
return loss, grads
# Training
batch_size = 64
num_epochs = 10
for epoch in range(num_epochs):
model.train()
total_loss = 0
num_batches = 0
# Shuffle
indices = mx.random.permutation(len(train_images))
for i in range(0, len(train_images), batch_size):
batch_idx = indices[i:i+batch_size]
batch_images = train_images[batch_idx]
batch_labels = train_labels[batch_idx]
loss, grads = train_step(model, batch_images, batch_labels)
optimizer.update(model, grads)
mx.eval(model.parameters())
total_loss += loss.item()
num_batches += 1
avg_loss = total_loss / num_batches
print(f"Epoch {epoch+1}: loss = {avg_loss:.4f}")
# Evaluation
model.eval()
test_logits = model(test_images)
test_preds = mx.argmax(test_logits, axis=-1)
accuracy = mx.mean(test_preds == test_labels).item()
print(f"Test accuracy: {accuracy:.2%}")Debugging & Profiling
Memory Debugging
import mlx.core as mx
# Track memory through training
for step in range(100):
loss, grads = train_step(model, batch_x, batch_y)
optimizer.update(model, grads)
mx.eval(model.parameters())
# Print memory stats
active = mx.metal.get_active_memory() / 1e9
peak = mx.metal.get_peak_memory() / 1e9
print(f"Step {step}: active={active:.2f}GB, peak={peak:.2f}GB")
# Reset peak to track per-step
mx.metal.reset_peak_memory()Timing
import time
# Synchronous timing
mx.eval(x) # Ensure previous ops complete
start = time.perf_counter()
y = model(x)
mx.eval(y) # Wait for completion
end = time.perf_counter()
print(f"Forward pass: {(end - start) * 1000:.2f}ms")
# Or use mx.synchronize()
mx.synchronize()
start = time.perf_counter()
# ... operations ...
mx.synchronize()
end = time.perf_counter()Key Files Reference
C++ Core
- mlx/array.h - Array class
- mlx/primitives.h - Operations
- mlx/transforms.h - grad, vjp, jvp
- mlx/compile.h - JIT compilation
- mlx/scheduler.h - Execution
- mlx/backend/metal/ - GPU kernels
Python API
- python/mlx/nn/layers/ - NN layers
- python/mlx/nn/losses.py - Loss funcs
- python/mlx/optimizers/ - Optimizers
- python/src/transforms.cpp - Bindings
- examples/python/ - Examples
Key Takeaways
Unified Memory Advantage
No CPU↔GPU transfers on Apple Silicon. Arrays accessible by both without copying. ~200+ GB/s memory bandwidth.
Lazy by Default
Build computation graphs declaratively. Execute with mx.eval(). Enables automatic optimization and fusion.
Composable Transforms
grad, vmap, compile transform functions. Chain them: compile(vmap(grad(f))). JAX-style functional transforms.
PyTorch-like nn.Module
Familiar Module API for neural networks. Parameters are dict entries. freeze/unfreeze for fine-tuning.
Every Primitive has VJP
Each operation implements its own backward pass. Enables automatic differentiation through any computation.
Quantization Built-in
4-bit, 8-bit quantization with fused kernels. QuantizedLinear for memory-efficient inference.