JAX Generative Models
Minimal, unified JAX implementation of DDPM and Flow Matching
Learn generative models from scratch. 2D point cloud datasets, Equinox neural networks, JIT-compiled training, Rerun visualization. Unified interface for both diffusion and flow.
Architecture
Unified interface for generative models
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β CLI (Tyro) β β train | generate | animate β βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€ β Strategy Interface β β forward() | reverse() | loss_fn() | sample() β ββββββββββββββββββββββββ¬βββββββββββββββββββββββββββββββββββ€ β DDPM β Flow Matching β β Noise prediction β Velocity field β β Stochastic reverse β Deterministic ODE β ββββββββββββββββββββββββ΄βββββββββββββββββββββββββββββββββββ€ β Neural Networks β β MLP | ResNet + Sinusoidal Time Embedding β βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€ β Datasets β β Cat | Moon | Swiss-Roll | Gaussian Mixture β βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€ β JAX Primitives β β jit | vmap | grad | lax.scan | random.split β βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
Unified Time Convention
Both strategies use: t=0 (source/noise) β t=1 (target/data). This allows swapping strategies without changing training code.
Strategy Interface
Common protocol for all generative models
base.py - Strategy Protocol
from typing import Protocol
import jax
import jax.numpy as jnp
import equinox as eqx
class Strategy(Protocol):
"""Unified interface for generative models."""
def loss_fn(
self,
model: eqx.Module,
x: jax.Array, # Single data point
key: jax.Array, # PRNG key
) -> jax.Array:
"""Compute loss for training. Returns scalar."""
...
def forward(
self,
t: float, # Time in [0, 1]
x: jax.Array, # Target data point
key: jax.Array,
) -> tuple[jax.Array, jax.Array]:
"""Perturb data for training. Returns (x_t, target)."""
...
def reverse(
self,
model: eqx.Module,
t: float, # Current time
x_t: jax.Array, # Current state
key: jax.Array,
) -> jax.Array:
"""One reverse step. Returns x_{t+dt} (closer to data)."""
...
def sample_from_source_distribution(
self,
key: jax.Array,
num_samples: int,
data_dim: int,
) -> jax.Array:
"""Sample from source (noise). Returns shape (num_samples, data_dim)."""
...
def sample_from_target_distribution(
self,
model: eqx.Module,
key: jax.Array,
num_samples: int,
data_dim: int,
) -> tuple[jax.Array, jax.Array]:
"""Generate samples. Returns (x_final, trajectory)."""
...DDPM (Denoising Diffusion)
Learn to predict and remove noise
Mathematical Foundation
Forward Process: Progressively add Gaussian noise over T steps
where αΎ±_t = β(s=1 to t) (1 - Ξ²_s) is cumulative product of (1 - variance schedule)
Reverse Process: Learn to denoise, predicting added noise
ΞΌ_ΞΈ(x_t, t) = (1/βΞ±_t) Β· (x_t - (Ξ²_t/β(1-αΎ±_t)) Β· Ξ΅_ΞΈ(x_t, t))
Training Objective: Predict the noise that was added
DDPM Implementation
# ddpm.py - Key implementation details
@dataclass
class DDPMStrategyConfig:
name: Literal["ddpm"] = "ddpm"
num_transport_steps: int = 50 # Number of diffusion steps
beta_min: float = 1e-4 # Min noise variance
beta_max: float = 0.02 # Max noise variance
class DDPMStrategy:
def __init__(self, config: DDPMStrategyConfig):
self.num_steps = config.num_transport_steps
# Linear beta schedule
self.betas = jnp.linspace(config.beta_min, config.beta_max, self.num_steps)
self.alphas = 1.0 - self.betas
self.alphas_cumprod = jnp.cumprod(self.alphas) # αΎ±_t
def forward(self, t: float, x: jax.Array, key: jax.Array):
"""Add noise to data point at time t."""
# Map t β [0,1] to step index
# t=0 β step T-1 (most noise), t=1 β step 0 (data)
idx = jnp.clip(
jnp.floor((1.0 - t) * self.num_steps).astype(jnp.int32),
0, self.num_steps - 1
)
alpha_bar = self.alphas_cumprod[idx]
noise = jax.random.normal(key, shape=x.shape)
# x_t = βαΎ±_t Β· x + β(1-αΎ±_t) Β· Ξ΅
x_t = jnp.sqrt(alpha_bar) * x + jnp.sqrt(1 - alpha_bar) * noise
return x_t, noise # Return perturbed state and target (noise)
def loss_fn(self, model: eqx.Module, x: jax.Array, key: jax.Array):
"""MSE between predicted and true noise."""
t_key, noise_key = jax.random.split(key)
t = jax.random.uniform(t_key) # Random time
x_t, true_noise = self.forward(t, x, noise_key)
pred_noise = model(t, x_t)
return jnp.mean((pred_noise - true_noise) ** 2)
def reverse(self, model, t, x_t, key):
"""One denoising step."""
idx = jnp.clip(
jnp.floor((1.0 - t) * self.num_steps).astype(jnp.int32),
0, self.num_steps - 1
)
alpha = self.alphas[idx]
alpha_bar = self.alphas_cumprod[idx]
beta = self.betas[idx]
# Predict noise
eps_pred = model(t, x_t)
# Compute mean: ΞΌ = (1/βΞ±) Β· (x_t - Ξ²/β(1-αΎ±) Β· Ξ΅_pred)
mean = (1.0 / jnp.sqrt(alpha)) * (
x_t - (beta / jnp.sqrt(1.0 - alpha_bar)) * eps_pred
)
# Compute variance (simplified)
alpha_bar_prev = jnp.where(idx > 0, self.alphas_cumprod[idx - 1], 1.0)
variance = ((1.0 - alpha_bar_prev) / (1.0 - alpha_bar)) * beta
sigma = jnp.sqrt(variance)
# Sample: x_{t-1} = ΞΌ + Ο Β· z (no noise at final step)
def add_noise(k):
return mean + sigma * jax.random.normal(k, x_t.shape)
def no_noise(_):
return mean
return jax.lax.cond(idx == 0, no_noise, add_noise, key)DDPM Sampling (Generation)
def sample_from_target_distribution(self, model, key, num_samples, data_dim):
"""Generate samples by iterative denoising."""
src_key, loop_key = jax.random.split(key)
# Start from pure noise at t=0
x_t = self.sample_from_source_distribution(src_key, num_samples, data_dim)
# Time steps: 0 β 1 (noise β data)
ts = jnp.linspace(0.0, 1.0 - 1.0/self.num_steps, self.num_steps)
keys = jax.random.split(loop_key, self.num_steps)
def scan_body(x_t, inputs):
t, step_key = inputs
batch_keys = jax.random.split(step_key, num_samples)
# Apply reverse step to each sample
x_next = jax.vmap(
self.reverse, in_axes=(None, None, 0, 0)
)(model, t, x_t, batch_keys)
return x_next, x_next # (carry, output)
# Efficient loop via lax.scan
x_final, trajectory = jax.lax.scan(scan_body, x_t, (ts, keys))
return x_final, trajectoryFlow Matching
Learn velocity field for direct transport
Mathematical Foundation
Key Insight: Instead of learning noise, learn velocity field that transports noise to data
where x_0 ~ N(0, ΟΒ²I) (source), x_1 = data (target)
Target velocity: v = x_1 - x_0 (straight line direction)
Training Objective: Predict velocity at each point
Generation: Integrate ODE from t=0 to t=1
x(0) = noise, x(1) = generated sample
Flow Matching Implementation
# flow_matching.py
@dataclass
class FlowMatchingStrategyConfig:
name: Literal["flow_matching"] = "flow_matching"
num_transport_steps: int = 50 # ODE integration steps
base_std: float = 1.0 # Source distribution std
class FlowMatchingStrategy:
def __init__(self, config: FlowMatchingStrategyConfig):
self.num_steps = config.num_transport_steps
self.base_std = config.base_std
def forward(self, t: float, x_target: jax.Array, key: jax.Array):
"""Linear interpolation between source and target."""
# Sample source noise
x_source = jax.random.normal(key, x_target.shape) * self.base_std
# Interpolate: x_t = (1-t) Β· x_source + t Β· x_target
x_t = (1.0 - t) * x_source + t * x_target
# Target velocity: direction from source to target
target_velocity = x_target - x_source
return x_t, target_velocity
def loss_fn(self, model: eqx.Module, x: jax.Array, key: jax.Array):
"""MSE between predicted and target velocity."""
t_key, noise_key = jax.random.split(key)
t = jax.random.uniform(t_key) # Random time
x_t, target_v = self.forward(t, x, noise_key)
pred_v = model(t, x_t)
return jnp.mean((pred_v - target_v) ** 2)
def reverse(self, model, t, x_t, key):
"""One ODE integration step (Forward Euler)."""
# Deterministic! No noise needed
dt = 1.0 / self.num_steps
v_pred = model(t, x_t)
# x_{t+dt} = x_t + dt Β· v_pred
return x_t + dt * v_pred
def sample_from_target_distribution(self, model, key, num_samples, data_dim):
"""Generate by integrating ODE."""
x_t = self.sample_from_source_distribution(key, num_samples, data_dim)
# Integrate from t=0 to t=1
ts = jnp.linspace(0.0, 1.0 - 1.0/self.num_steps, self.num_steps)
keys = jax.random.split(key, self.num_steps) # Not used (deterministic)
def scan_body(x_t, inputs):
t, _ = inputs
# No vmap over keys needed - deterministic
x_next = jax.vmap(
self.reverse, in_axes=(None, None, 0, None)
)(model, t, x_t, None)
return x_next, x_next
x_final, trajectory = jax.lax.scan(scan_body, x_t, (ts, keys))
return x_final, trajectoryDDPM vs Flow Matching
Key differences
| Aspect | DDPM | Flow Matching |
|---|---|---|
| Learning Target | Noise Ξ΅ added at step t | Velocity v = xβ - xβ |
| Forward Process | Iterative noise addition | Linear interpolation |
| Reverse Process | Stochastic (adds noise) | Deterministic ODE |
| Typical Steps | 50-1000 steps | 20-50 steps |
| Loss Function | MSE(Ξ΅_pred, Ξ΅_true) | MSE(v_pred, v_true) |
| Math Complexity | Variance schedule, cumprod | Simple interpolation |
| Generation Speed | Slower (many stochastic steps) | Faster (fewer deterministic) |
Neural Network Architectures
Time-conditioned networks
Sinusoidal Time Embedding
Encodes continuous time t into high-dimensional representation. Allows network to distinguish nearby time steps.
class SinusoidalTimeEmbed(eqx.Module):
"""Transformer-style positional encoding for time."""
freqs: jax.Array # Precomputed frequencies
def __init__(self, dim: int):
half_dim = dim // 2
# Frequencies: 1, 1/10000^(2/d), 1/10000^(4/d), ...
self.freqs = jnp.exp(
-jnp.log(10000.0) * jnp.arange(half_dim) / half_dim
)
def __call__(self, t: float) -> jax.Array:
# t * frequencies
args = t * self.freqs
# Concatenate sin and cos
# Shape: (dim,) = (half_dim,) + (half_dim,)
return jnp.concatenate([jnp.cos(args), jnp.sin(args)])
# Example: t=0.5, dim=8
# freqs = [1.0, 0.1, 0.01, 0.001]
# args = [0.5, 0.05, 0.005, 0.0005]
# output = [cos(0.5), cos(0.05), ..., sin(0.5), sin(0.05), ...]MLP Architecture
@dataclass
class MLPConfig:
type: Literal["mlp"] = "mlp"
hidden_dim: int = 512 # Width of hidden layers
depth: int = 10 # Number of layers
activation: str = "gelu" # gelu / relu / swish
class MLP(eqx.Module):
"""Simple MLP with time concatenation."""
time_embed: SinusoidalTimeEmbed
layers: list
activation: Callable
def __init__(self, config: MLPConfig, key: jax.Array, data_dim: int):
time_dim = config.hidden_dim // 4
self.time_embed = SinusoidalTimeEmbed(time_dim)
# Activation function
self.activation = {"gelu": jax.nn.gelu, "relu": jax.nn.relu}[config.activation]
keys = jax.random.split(key, config.depth + 1)
self.layers = []
# First layer: data + time_embed β hidden
in_dim = data_dim + time_dim
self.layers.append(eqx.nn.Linear(in_dim, config.hidden_dim, key=keys[0]))
# Hidden layers
for i in range(1, config.depth):
self.layers.append(
eqx.nn.Linear(config.hidden_dim, config.hidden_dim, key=keys[i])
)
# Output layer: hidden β data_dim
self.layers.append(
eqx.nn.Linear(config.hidden_dim, data_dim, key=keys[-1])
)
def __call__(self, t: float, x: jax.Array) -> jax.Array:
# Embed time and concatenate
t_emb = self.time_embed(t)
h = jnp.concatenate([x, t_emb])
# Forward through layers
for layer in self.layers[:-1]:
h = self.activation(layer(h))
return self.layers[-1](h) # No activation on outputResNet Architecture
@dataclass
class ResNetConfig:
type: Literal["resnet"] = "resnet"
hidden_dim: int = 512
num_blocks: int = 4
use_layer_norm: bool = True
class ResNetBlock(eqx.Module):
"""Residual block with optional LayerNorm."""
linear1: eqx.nn.Linear
linear2: eqx.nn.Linear
norm: eqx.nn.LayerNorm | None
def __init__(self, dim: int, use_norm: bool, key: jax.Array):
k1, k2 = jax.random.split(key)
self.linear1 = eqx.nn.Linear(dim, dim, key=k1)
self.linear2 = eqx.nn.Linear(dim, dim, key=k2)
self.norm = eqx.nn.LayerNorm(dim) if use_norm else None
def __call__(self, x: jax.Array) -> jax.Array:
h = x
if self.norm:
h = self.norm(h)
h = jax.nn.gelu(self.linear1(h))
h = self.linear2(h)
return x + h # Residual connection
class ResNet(eqx.Module):
"""ResNet with time embedding added globally."""
time_embed: SinusoidalTimeEmbed
time_proj: eqx.nn.Linear # Project time to hidden_dim
input_proj: eqx.nn.Linear # Project data to hidden_dim
blocks: list[ResNetBlock]
output_proj: eqx.nn.Linear
def __call__(self, t: float, x: jax.Array) -> jax.Array:
# Time embedding (added to all blocks)
t_emb = self.time_proj(self.time_embed(t))
# Project input
h = self.input_proj(x)
# Residual blocks with time
for block in self.blocks:
h = block(h + t_emb) # Add time at each block
return self.output_proj(h)Training Pipeline
JIT-compiled training with JAX
Training Loop
# train.py
def train(cfg: TrainConfig, key: jax.Array) -> None:
# 1. Initialize model
init_key, data_key, train_key = jax.random.split(key, 3)
model = create_model(cfg.model, init_key, data_dim=2)
# 2. Create strategy (DDPM or Flow Matching)
strategy = create_strategy(cfg.strategy)
# 3. Partition model for JIT
params, static = eqx.partition(model, eqx.is_inexact_array)
# 4. Setup optimizer (Optax)
optimizer = create_optimizer(cfg.optimizer)
opt_state = optimizer.init(params)
# 5. JIT-compile training step
@eqx.filter_jit
def train_step(params, static, opt_state, batch, key):
def batch_loss(params, static, batch, keys):
model = eqx.combine(params, static)
# Vectorize loss over batch
losses = jax.vmap(
strategy.loss_fn,
in_axes=(None, 0, 0) # Model, data, keys
)(model, batch, keys)
return jnp.mean(losses)
# Compute gradients
loss, grads = eqx.filter_value_and_grad(batch_loss)(
params, static, batch, jax.random.split(key, cfg.batch_size)
)
# Update parameters
updates, opt_state = optimizer.update(grads, opt_state, params)
params = eqx.apply_updates(params, updates)
return params, opt_state, loss
# 6. Main loop
for step in range(cfg.train_steps):
# Get batch
data_key, batch_key = jax.random.split(data_key)
batch = get_batch(batch_key, cfg.dataset, cfg.batch_size)
# Train step
train_key, step_key = jax.random.split(train_key)
params, opt_state, loss = train_step(
params, static, opt_state, batch, step_key
)
if (step + 1) % cfg.log_interval == 0:
print(f"Step {step+1}: loss = {loss:.6f}")
# 7. Save model
model = eqx.combine(params, static)
eqx.tree_serialise_leaves(cfg.model_path, model)JAX Patterns
Key techniques used throughout
jax.vmap
Vectorize single-sample functions over batches automatically.
# Apply loss to batch
batch_losses = jax.vmap(
strategy.loss_fn,
in_axes=(None, 0, 0) # Model, data[batch], keys[batch]
)(model, batch_data, batch_keys)
# Apply reverse step to batch
x_next = jax.vmap(
strategy.reverse,
in_axes=(None, None, 0, 0) # Model, time, states, keys
)(model, t, x_batch, keys_batch)jax.lax.scan
Efficient loop that compiles to single XLA operation.
def scan_body(carry, inputs):
x_t = carry
t, key = inputs
x_next = reverse_step(x_t, t, key)
return x_next, x_next # (new_carry, output)
# Loop: num_steps iterations
x_final, trajectory = jax.lax.scan(
scan_body,
x_initial, # Initial carry
(times, keys) # Inputs per step
)
# trajectory shape: (num_steps, batch, dim)jax.random.split
Explicit PRNG key management for reproducibility.
# Split key into multiple independent keys
key, init_key, train_key = jax.random.split(key, 3)
# One key per sample in batch
batch_keys = jax.random.split(key, batch_size)
# Each split produces completely independent streams
# Deterministic: same key β same sequenceeqx.partition / combine
Separate parameters from model structure for JIT.
# Partition: separate differentiable arrays
params, static = eqx.partition(
model,
eqx.is_inexact_array # Filter for float arrays
)
# JIT only operates on params (static is constant)
@eqx.filter_jit
def step(params, static, ...):
model = eqx.combine(params, static)
...
# Reconstruct model after training
final_model = eqx.combine(params, static)jax.lax.cond
JIT-compatible conditional branching. Both branches are traced.
# DDPM: Don't add noise at final step
def add_noise(key):
noise = jax.random.normal(key, shape=mean.shape)
return mean + sigma * noise
def no_noise(_):
return mean
# Select branch based on condition
x_next = jax.lax.cond(
step_idx == 0, # Predicate
no_noise, # True branch
add_noise, # False branch
key # Argument passed to selected branch
)
# Note: Both branches are compiled, selection happens at runtimeDatasets
2D point cloud distributions
Cat
Points sampled from cat silhouette image
@dataclass
class CatConfig:
name: Literal["cat"] = "cat"
data_dim: int = 2
# Loads assets/cat.png
# Extracts dark pixels (< 128)
# Normalizes to [-1, 1]
# Adds small jitter noiseGaussian Mixture
4 clusters at corners
@dataclass
class GaussianMixtureConfig:
name: Literal["gaussian_mixture"]
scale: float = 3.0
noise_std: float = 0.3
# Centers: [Β±scale, Β±scale]
# Each point: center + N(0, noise_stdΒ²)Two Moons
Non-convex crescent shapes
@dataclass
class MoonConfig:
name: Literal["moon"] = "moon"
noise: float = 0.05
scale: float = 3.0
# Uses sklearn.datasets.make_moons
# Centered and scaledSwiss Roll
Spiral manifold (2D projection)
@dataclass
class SwissRollConfig:
name: Literal["swiss_roll"]
noise: float = 0.3
scale: float = 3.0
# sklearn.datasets.make_swiss_roll
# First 2 dimensions onlyCLI Usage
Training, generation, and animation
Train a Model
# Train DDPM on cat dataset
uv run scripts/main.py train \
strategy:ddpm \
model:mlp \
dataset:cat \
--batch-size 1024 \
--train-steps 3000 \
--model-path outputs/cat_ddpm.eqx
# Train Flow Matching on moon dataset with ResNet
uv run scripts/main.py train \
strategy:flow-matching \
model:resnet \
dataset:moon \
--batch-size 512 \
--train-steps 5000 \
--model-path outputs/moon_flow.eqxGenerate Samples
# Generate from trained model
uv run scripts/main.py generate \
strategy:ddpm \
model:mlp \
dataset:cat \
--model-path outputs/cat_ddpm.eqx \
--num-samples 2000 \
--output-image-path outputs/cat_samples.pngCreate Animation
# Visualize generation trajectory
uv run scripts/main.py animate \
strategy:flow-matching \
model:resnet \
dataset:moon \
--model-path outputs/moon_flow.eqx \
--num-samples 1500 \
--output-video-path outputs/moon.gif \
--fps 30Complete Example
Train and generate step by step
import jax
import jax.numpy as jnp
import equinox as eqx
import optax
from jax_gen.strategies import create_strategy
from jax_gen.models import create_model
from jax_gen.data import get_batch
from jax_gen.config import DDPMStrategyConfig, MLPConfig, CatConfig
# 1. Configuration
strategy_cfg = DDPMStrategyConfig(num_transport_steps=50)
model_cfg = MLPConfig(hidden_dim=256, depth=6)
data_cfg = CatConfig()
# 2. Initialize
key = jax.random.PRNGKey(42)
key, init_key, train_key = jax.random.split(key, 3)
model = create_model(model_cfg, init_key, data_dim=2)
strategy = create_strategy(strategy_cfg)
optimizer = optax.adam(1e-3)
params, static = eqx.partition(model, eqx.is_inexact_array)
opt_state = optimizer.init(params)
# 3. Training step
@eqx.filter_jit
def train_step(params, static, opt_state, batch, key):
def loss_fn(params):
model = eqx.combine(params, static)
keys = jax.random.split(key, len(batch))
losses = jax.vmap(strategy.loss_fn, in_axes=(None, 0, 0))(model, batch, keys)
return jnp.mean(losses)
loss, grads = eqx.filter_value_and_grad(loss_fn)(params)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = eqx.apply_updates(params, updates)
return params, opt_state, loss
# 4. Training loop
for step in range(1000):
train_key, batch_key, step_key = jax.random.split(train_key, 3)
batch = get_batch(batch_key, data_cfg, batch_size=512)
params, opt_state, loss = train_step(params, static, opt_state, batch, step_key)
if (step + 1) % 100 == 0:
print(f"Step {step+1}: loss = {loss:.6f}")
# 5. Generation
model = eqx.combine(params, static)
gen_key = jax.random.PRNGKey(0)
samples, trajectory = strategy.sample_from_target_distribution(
model, gen_key, num_samples=1000, data_dim=2
)
print(f"Generated {samples.shape[0]} samples!")
# samples.shape: (1000, 2)Key Takeaways
Unified Interface
Strategy protocol allows swapping DDPM β Flow Matching without changing training code. Same forward/reverse/loss_fn signature.
DDPM: Noise Prediction
Add noise progressively, learn to predict it. Stochastic reverse process needs many steps (50-1000).
Flow Matching: Velocity Field
Linear interpolation, learn velocity direction. Deterministic ODE needs fewer steps (20-50).
JAX Patterns
vmap for batching, lax.scan for loops, random.split for PRNG, partition/combine for JIT. All explicit, no hidden state.
Time Embedding
Sinusoidal encoding lets network distinguish nearby time steps. Critical for learning continuous dynamics.
2D Visualization
Low-dimensional datasets (Cat, Moon, Swiss-Roll) make it easy to visualize training progress and generation quality.