🏑/repos/MizuhoAOKI/

jax_generative_models

🎨

JAX Generative Models

Minimal, unified JAX implementation of DDPM and Flow Matching

Learn generative models from scratch. 2D point cloud datasets, Equinox neural networks, JIT-compiled training, Rerun visualization. Unified interface for both diffusion and flow.

JAX + EquinoxDDPMFlow Matching
πŸ—οΈ

Architecture

Unified interface for generative models

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                      CLI (Tyro)                          β”‚
β”‚  train | generate | animate                              β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚                   Strategy Interface                     β”‚
β”‚  forward() | reverse() | loss_fn() | sample()           β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚        DDPM          β”‚         Flow Matching            β”‚
β”‚  Noise prediction    β”‚       Velocity field             β”‚
β”‚  Stochastic reverse  β”‚      Deterministic ODE           β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚                   Neural Networks                        β”‚
β”‚  MLP | ResNet  +  Sinusoidal Time Embedding             β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚                      Datasets                            β”‚
β”‚  Cat | Moon | Swiss-Roll | Gaussian Mixture             β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚                   JAX Primitives                         β”‚
β”‚  jit | vmap | grad | lax.scan | random.split            β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

Unified Time Convention

Both strategies use: t=0 (source/noise) β†’ t=1 (target/data). This allows swapping strategies without changing training code.

πŸ“

Strategy Interface

Common protocol for all generative models

base.py - Strategy Protocol

from typing import Protocol
import jax
import jax.numpy as jnp
import equinox as eqx

class Strategy(Protocol):
    """Unified interface for generative models."""

    def loss_fn(
        self,
        model: eqx.Module,
        x: jax.Array,      # Single data point
        key: jax.Array,    # PRNG key
    ) -> jax.Array:
        """Compute loss for training. Returns scalar."""
        ...

    def forward(
        self,
        t: float,          # Time in [0, 1]
        x: jax.Array,      # Target data point
        key: jax.Array,
    ) -> tuple[jax.Array, jax.Array]:
        """Perturb data for training. Returns (x_t, target)."""
        ...

    def reverse(
        self,
        model: eqx.Module,
        t: float,          # Current time
        x_t: jax.Array,    # Current state
        key: jax.Array,
    ) -> jax.Array:
        """One reverse step. Returns x_{t+dt} (closer to data)."""
        ...

    def sample_from_source_distribution(
        self,
        key: jax.Array,
        num_samples: int,
        data_dim: int,
    ) -> jax.Array:
        """Sample from source (noise). Returns shape (num_samples, data_dim)."""
        ...

    def sample_from_target_distribution(
        self,
        model: eqx.Module,
        key: jax.Array,
        num_samples: int,
        data_dim: int,
    ) -> tuple[jax.Array, jax.Array]:
        """Generate samples. Returns (x_final, trajectory)."""
        ...
🌫️

DDPM (Denoising Diffusion)

Learn to predict and remove noise

Mathematical Foundation

Forward Process: Progressively add Gaussian noise over T steps

q(x_t | x_0) = N(x_t; √ᾱ_t · x_0, (1 - ᾱ_t) · I)

where ᾱ_t = ∏(s=1 to t) (1 - β_s) is cumulative product of (1 - variance schedule)

Reverse Process: Learn to denoise, predicting added noise

p_ΞΈ(x_(t-1) | x_t) = N(x_(t-1); ΞΌ_ΞΈ(x_t, t), Οƒ_tΒ² Β· I)

μ_θ(x_t, t) = (1/√α_t) · (x_t - (β_t/√(1-ᾱ_t)) · Ρ_θ(x_t, t))

Training Objective: Predict the noise that was added

L = E_(x,t,Ρ)[ ||Ρ_θ(√ᾱ_t · x + √(1-ᾱ_t) · Ρ, t) - Ρ||² ]

DDPM Implementation

# ddpm.py - Key implementation details

@dataclass
class DDPMStrategyConfig:
    name: Literal["ddpm"] = "ddpm"
    num_transport_steps: int = 50    # Number of diffusion steps
    beta_min: float = 1e-4           # Min noise variance
    beta_max: float = 0.02           # Max noise variance


class DDPMStrategy:
    def __init__(self, config: DDPMStrategyConfig):
        self.num_steps = config.num_transport_steps

        # Linear beta schedule
        self.betas = jnp.linspace(config.beta_min, config.beta_max, self.num_steps)
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = jnp.cumprod(self.alphas)  # αΎ±_t

    def forward(self, t: float, x: jax.Array, key: jax.Array):
        """Add noise to data point at time t."""
        # Map t ∈ [0,1] to step index
        # t=0 β†’ step T-1 (most noise), t=1 β†’ step 0 (data)
        idx = jnp.clip(
            jnp.floor((1.0 - t) * self.num_steps).astype(jnp.int32),
            0, self.num_steps - 1
        )

        alpha_bar = self.alphas_cumprod[idx]
        noise = jax.random.normal(key, shape=x.shape)

        # x_t = √ᾱ_t · x + √(1-ᾱ_t) · Ρ
        x_t = jnp.sqrt(alpha_bar) * x + jnp.sqrt(1 - alpha_bar) * noise

        return x_t, noise  # Return perturbed state and target (noise)

    def loss_fn(self, model: eqx.Module, x: jax.Array, key: jax.Array):
        """MSE between predicted and true noise."""
        t_key, noise_key = jax.random.split(key)
        t = jax.random.uniform(t_key)  # Random time

        x_t, true_noise = self.forward(t, x, noise_key)
        pred_noise = model(t, x_t)

        return jnp.mean((pred_noise - true_noise) ** 2)

    def reverse(self, model, t, x_t, key):
        """One denoising step."""
        idx = jnp.clip(
            jnp.floor((1.0 - t) * self.num_steps).astype(jnp.int32),
            0, self.num_steps - 1
        )

        alpha = self.alphas[idx]
        alpha_bar = self.alphas_cumprod[idx]
        beta = self.betas[idx]

        # Predict noise
        eps_pred = model(t, x_t)

        # Compute mean: μ = (1/√α) · (x_t - β/√(1-ᾱ) · Ρ_pred)
        mean = (1.0 / jnp.sqrt(alpha)) * (
            x_t - (beta / jnp.sqrt(1.0 - alpha_bar)) * eps_pred
        )

        # Compute variance (simplified)
        alpha_bar_prev = jnp.where(idx > 0, self.alphas_cumprod[idx - 1], 1.0)
        variance = ((1.0 - alpha_bar_prev) / (1.0 - alpha_bar)) * beta
        sigma = jnp.sqrt(variance)

        # Sample: x_{t-1} = ΞΌ + Οƒ Β· z (no noise at final step)
        def add_noise(k):
            return mean + sigma * jax.random.normal(k, x_t.shape)

        def no_noise(_):
            return mean

        return jax.lax.cond(idx == 0, no_noise, add_noise, key)

DDPM Sampling (Generation)

def sample_from_target_distribution(self, model, key, num_samples, data_dim):
    """Generate samples by iterative denoising."""
    src_key, loop_key = jax.random.split(key)

    # Start from pure noise at t=0
    x_t = self.sample_from_source_distribution(src_key, num_samples, data_dim)

    # Time steps: 0 β†’ 1 (noise β†’ data)
    ts = jnp.linspace(0.0, 1.0 - 1.0/self.num_steps, self.num_steps)
    keys = jax.random.split(loop_key, self.num_steps)

    def scan_body(x_t, inputs):
        t, step_key = inputs
        batch_keys = jax.random.split(step_key, num_samples)

        # Apply reverse step to each sample
        x_next = jax.vmap(
            self.reverse, in_axes=(None, None, 0, 0)
        )(model, t, x_t, batch_keys)

        return x_next, x_next  # (carry, output)

    # Efficient loop via lax.scan
    x_final, trajectory = jax.lax.scan(scan_body, x_t, (ts, keys))

    return x_final, trajectory
🌊

Flow Matching

Learn velocity field for direct transport

Mathematical Foundation

Key Insight: Instead of learning noise, learn velocity field that transports noise to data

Interpolation: x_t = (1 - t) Β· x_0 + t Β· x_1
where x_0 ~ N(0, σ²I) (source), x_1 = data (target)

Target velocity: v = x_1 - x_0 (straight line direction)

Training Objective: Predict velocity at each point

L = E_(x_0,x_1,t)[ ||v_ΞΈ(x_t, t) - (x_1 - x_0)||Β² ]

Generation: Integrate ODE from t=0 to t=1

dx/dt = v_ΞΈ(x, t)
x(0) = noise, x(1) = generated sample

Flow Matching Implementation

# flow_matching.py

@dataclass
class FlowMatchingStrategyConfig:
    name: Literal["flow_matching"] = "flow_matching"
    num_transport_steps: int = 50   # ODE integration steps
    base_std: float = 1.0           # Source distribution std


class FlowMatchingStrategy:
    def __init__(self, config: FlowMatchingStrategyConfig):
        self.num_steps = config.num_transport_steps
        self.base_std = config.base_std

    def forward(self, t: float, x_target: jax.Array, key: jax.Array):
        """Linear interpolation between source and target."""
        # Sample source noise
        x_source = jax.random.normal(key, x_target.shape) * self.base_std

        # Interpolate: x_t = (1-t) Β· x_source + t Β· x_target
        x_t = (1.0 - t) * x_source + t * x_target

        # Target velocity: direction from source to target
        target_velocity = x_target - x_source

        return x_t, target_velocity

    def loss_fn(self, model: eqx.Module, x: jax.Array, key: jax.Array):
        """MSE between predicted and target velocity."""
        t_key, noise_key = jax.random.split(key)
        t = jax.random.uniform(t_key)  # Random time

        x_t, target_v = self.forward(t, x, noise_key)
        pred_v = model(t, x_t)

        return jnp.mean((pred_v - target_v) ** 2)

    def reverse(self, model, t, x_t, key):
        """One ODE integration step (Forward Euler)."""
        # Deterministic! No noise needed
        dt = 1.0 / self.num_steps
        v_pred = model(t, x_t)

        # x_{t+dt} = x_t + dt Β· v_pred
        return x_t + dt * v_pred

    def sample_from_target_distribution(self, model, key, num_samples, data_dim):
        """Generate by integrating ODE."""
        x_t = self.sample_from_source_distribution(key, num_samples, data_dim)

        # Integrate from t=0 to t=1
        ts = jnp.linspace(0.0, 1.0 - 1.0/self.num_steps, self.num_steps)
        keys = jax.random.split(key, self.num_steps)  # Not used (deterministic)

        def scan_body(x_t, inputs):
            t, _ = inputs
            # No vmap over keys needed - deterministic
            x_next = jax.vmap(
                self.reverse, in_axes=(None, None, 0, None)
            )(model, t, x_t, None)
            return x_next, x_next

        x_final, trajectory = jax.lax.scan(scan_body, x_t, (ts, keys))
        return x_final, trajectory
βš–οΈ

DDPM vs Flow Matching

Key differences

AspectDDPMFlow Matching
Learning TargetNoise Ξ΅ added at step tVelocity v = x₁ - xβ‚€
Forward ProcessIterative noise additionLinear interpolation
Reverse ProcessStochastic (adds noise)Deterministic ODE
Typical Steps50-1000 steps20-50 steps
Loss FunctionMSE(Ξ΅_pred, Ξ΅_true)MSE(v_pred, v_true)
Math ComplexityVariance schedule, cumprodSimple interpolation
Generation SpeedSlower (many stochastic steps)Faster (fewer deterministic)
🧠

Neural Network Architectures

Time-conditioned networks

Sinusoidal Time Embedding

Encodes continuous time t into high-dimensional representation. Allows network to distinguish nearby time steps.

class SinusoidalTimeEmbed(eqx.Module):
    """Transformer-style positional encoding for time."""
    freqs: jax.Array  # Precomputed frequencies

    def __init__(self, dim: int):
        half_dim = dim // 2
        # Frequencies: 1, 1/10000^(2/d), 1/10000^(4/d), ...
        self.freqs = jnp.exp(
            -jnp.log(10000.0) * jnp.arange(half_dim) / half_dim
        )

    def __call__(self, t: float) -> jax.Array:
        # t * frequencies
        args = t * self.freqs

        # Concatenate sin and cos
        # Shape: (dim,) = (half_dim,) + (half_dim,)
        return jnp.concatenate([jnp.cos(args), jnp.sin(args)])


# Example: t=0.5, dim=8
# freqs = [1.0, 0.1, 0.01, 0.001]
# args = [0.5, 0.05, 0.005, 0.0005]
# output = [cos(0.5), cos(0.05), ..., sin(0.5), sin(0.05), ...]

MLP Architecture

@dataclass
class MLPConfig:
    type: Literal["mlp"] = "mlp"
    hidden_dim: int = 512     # Width of hidden layers
    depth: int = 10           # Number of layers
    activation: str = "gelu"  # gelu / relu / swish


class MLP(eqx.Module):
    """Simple MLP with time concatenation."""
    time_embed: SinusoidalTimeEmbed
    layers: list
    activation: Callable

    def __init__(self, config: MLPConfig, key: jax.Array, data_dim: int):
        time_dim = config.hidden_dim // 4
        self.time_embed = SinusoidalTimeEmbed(time_dim)

        # Activation function
        self.activation = {"gelu": jax.nn.gelu, "relu": jax.nn.relu}[config.activation]

        keys = jax.random.split(key, config.depth + 1)
        self.layers = []

        # First layer: data + time_embed β†’ hidden
        in_dim = data_dim + time_dim
        self.layers.append(eqx.nn.Linear(in_dim, config.hidden_dim, key=keys[0]))

        # Hidden layers
        for i in range(1, config.depth):
            self.layers.append(
                eqx.nn.Linear(config.hidden_dim, config.hidden_dim, key=keys[i])
            )

        # Output layer: hidden β†’ data_dim
        self.layers.append(
            eqx.nn.Linear(config.hidden_dim, data_dim, key=keys[-1])
        )

    def __call__(self, t: float, x: jax.Array) -> jax.Array:
        # Embed time and concatenate
        t_emb = self.time_embed(t)
        h = jnp.concatenate([x, t_emb])

        # Forward through layers
        for layer in self.layers[:-1]:
            h = self.activation(layer(h))

        return self.layers[-1](h)  # No activation on output

ResNet Architecture

@dataclass
class ResNetConfig:
    type: Literal["resnet"] = "resnet"
    hidden_dim: int = 512
    num_blocks: int = 4
    use_layer_norm: bool = True


class ResNetBlock(eqx.Module):
    """Residual block with optional LayerNorm."""
    linear1: eqx.nn.Linear
    linear2: eqx.nn.Linear
    norm: eqx.nn.LayerNorm | None

    def __init__(self, dim: int, use_norm: bool, key: jax.Array):
        k1, k2 = jax.random.split(key)
        self.linear1 = eqx.nn.Linear(dim, dim, key=k1)
        self.linear2 = eqx.nn.Linear(dim, dim, key=k2)
        self.norm = eqx.nn.LayerNorm(dim) if use_norm else None

    def __call__(self, x: jax.Array) -> jax.Array:
        h = x
        if self.norm:
            h = self.norm(h)
        h = jax.nn.gelu(self.linear1(h))
        h = self.linear2(h)
        return x + h  # Residual connection


class ResNet(eqx.Module):
    """ResNet with time embedding added globally."""
    time_embed: SinusoidalTimeEmbed
    time_proj: eqx.nn.Linear   # Project time to hidden_dim
    input_proj: eqx.nn.Linear  # Project data to hidden_dim
    blocks: list[ResNetBlock]
    output_proj: eqx.nn.Linear

    def __call__(self, t: float, x: jax.Array) -> jax.Array:
        # Time embedding (added to all blocks)
        t_emb = self.time_proj(self.time_embed(t))

        # Project input
        h = self.input_proj(x)

        # Residual blocks with time
        for block in self.blocks:
            h = block(h + t_emb)  # Add time at each block

        return self.output_proj(h)
πŸ‹οΈ

Training Pipeline

JIT-compiled training with JAX

Training Loop

# train.py

def train(cfg: TrainConfig, key: jax.Array) -> None:
    # 1. Initialize model
    init_key, data_key, train_key = jax.random.split(key, 3)
    model = create_model(cfg.model, init_key, data_dim=2)

    # 2. Create strategy (DDPM or Flow Matching)
    strategy = create_strategy(cfg.strategy)

    # 3. Partition model for JIT
    params, static = eqx.partition(model, eqx.is_inexact_array)

    # 4. Setup optimizer (Optax)
    optimizer = create_optimizer(cfg.optimizer)
    opt_state = optimizer.init(params)

    # 5. JIT-compile training step
    @eqx.filter_jit
    def train_step(params, static, opt_state, batch, key):
        def batch_loss(params, static, batch, keys):
            model = eqx.combine(params, static)

            # Vectorize loss over batch
            losses = jax.vmap(
                strategy.loss_fn,
                in_axes=(None, 0, 0)  # Model, data, keys
            )(model, batch, keys)

            return jnp.mean(losses)

        # Compute gradients
        loss, grads = eqx.filter_value_and_grad(batch_loss)(
            params, static, batch, jax.random.split(key, cfg.batch_size)
        )

        # Update parameters
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = eqx.apply_updates(params, updates)

        return params, opt_state, loss

    # 6. Main loop
    for step in range(cfg.train_steps):
        # Get batch
        data_key, batch_key = jax.random.split(data_key)
        batch = get_batch(batch_key, cfg.dataset, cfg.batch_size)

        # Train step
        train_key, step_key = jax.random.split(train_key)
        params, opt_state, loss = train_step(
            params, static, opt_state, batch, step_key
        )

        if (step + 1) % cfg.log_interval == 0:
            print(f"Step {step+1}: loss = {loss:.6f}")

    # 7. Save model
    model = eqx.combine(params, static)
    eqx.tree_serialise_leaves(cfg.model_path, model)
⚑

JAX Patterns

Key techniques used throughout

jax.vmap

Vectorize single-sample functions over batches automatically.

# Apply loss to batch
batch_losses = jax.vmap(
    strategy.loss_fn,
    in_axes=(None, 0, 0)  # Model, data[batch], keys[batch]
)(model, batch_data, batch_keys)

# Apply reverse step to batch
x_next = jax.vmap(
    strategy.reverse,
    in_axes=(None, None, 0, 0)  # Model, time, states, keys
)(model, t, x_batch, keys_batch)

jax.lax.scan

Efficient loop that compiles to single XLA operation.

def scan_body(carry, inputs):
    x_t = carry
    t, key = inputs
    x_next = reverse_step(x_t, t, key)
    return x_next, x_next  # (new_carry, output)

# Loop: num_steps iterations
x_final, trajectory = jax.lax.scan(
    scan_body,
    x_initial,        # Initial carry
    (times, keys)     # Inputs per step
)
# trajectory shape: (num_steps, batch, dim)

jax.random.split

Explicit PRNG key management for reproducibility.

# Split key into multiple independent keys
key, init_key, train_key = jax.random.split(key, 3)

# One key per sample in batch
batch_keys = jax.random.split(key, batch_size)

# Each split produces completely independent streams
# Deterministic: same key β†’ same sequence

eqx.partition / combine

Separate parameters from model structure for JIT.

# Partition: separate differentiable arrays
params, static = eqx.partition(
    model,
    eqx.is_inexact_array  # Filter for float arrays
)

# JIT only operates on params (static is constant)
@eqx.filter_jit
def step(params, static, ...):
    model = eqx.combine(params, static)
    ...

# Reconstruct model after training
final_model = eqx.combine(params, static)

jax.lax.cond

JIT-compatible conditional branching. Both branches are traced.

# DDPM: Don't add noise at final step
def add_noise(key):
    noise = jax.random.normal(key, shape=mean.shape)
    return mean + sigma * noise

def no_noise(_):
    return mean

# Select branch based on condition
x_next = jax.lax.cond(
    step_idx == 0,  # Predicate
    no_noise,       # True branch
    add_noise,      # False branch
    key             # Argument passed to selected branch
)

# Note: Both branches are compiled, selection happens at runtime
πŸ“Š

Datasets

2D point cloud distributions

Cat

Points sampled from cat silhouette image

@dataclass
class CatConfig:
    name: Literal["cat"] = "cat"
    data_dim: int = 2

# Loads assets/cat.png
# Extracts dark pixels (< 128)
# Normalizes to [-1, 1]
# Adds small jitter noise

Gaussian Mixture

4 clusters at corners

@dataclass
class GaussianMixtureConfig:
    name: Literal["gaussian_mixture"]
    scale: float = 3.0
    noise_std: float = 0.3

# Centers: [Β±scale, Β±scale]
# Each point: center + N(0, noise_stdΒ²)

Two Moons

Non-convex crescent shapes

@dataclass
class MoonConfig:
    name: Literal["moon"] = "moon"
    noise: float = 0.05
    scale: float = 3.0

# Uses sklearn.datasets.make_moons
# Centered and scaled

Swiss Roll

Spiral manifold (2D projection)

@dataclass
class SwissRollConfig:
    name: Literal["swiss_roll"]
    noise: float = 0.3
    scale: float = 3.0

# sklearn.datasets.make_swiss_roll
# First 2 dimensions only
πŸ’»

CLI Usage

Training, generation, and animation

Train a Model

# Train DDPM on cat dataset
uv run scripts/main.py train \
  strategy:ddpm \
  model:mlp \
  dataset:cat \
  --batch-size 1024 \
  --train-steps 3000 \
  --model-path outputs/cat_ddpm.eqx

# Train Flow Matching on moon dataset with ResNet
uv run scripts/main.py train \
  strategy:flow-matching \
  model:resnet \
  dataset:moon \
  --batch-size 512 \
  --train-steps 5000 \
  --model-path outputs/moon_flow.eqx

Generate Samples

# Generate from trained model
uv run scripts/main.py generate \
  strategy:ddpm \
  model:mlp \
  dataset:cat \
  --model-path outputs/cat_ddpm.eqx \
  --num-samples 2000 \
  --output-image-path outputs/cat_samples.png

Create Animation

# Visualize generation trajectory
uv run scripts/main.py animate \
  strategy:flow-matching \
  model:resnet \
  dataset:moon \
  --model-path outputs/moon_flow.eqx \
  --num-samples 1500 \
  --output-video-path outputs/moon.gif \
  --fps 30
πŸ“

Complete Example

Train and generate step by step

import jax
import jax.numpy as jnp
import equinox as eqx
import optax

from jax_gen.strategies import create_strategy
from jax_gen.models import create_model
from jax_gen.data import get_batch
from jax_gen.config import DDPMStrategyConfig, MLPConfig, CatConfig

# 1. Configuration
strategy_cfg = DDPMStrategyConfig(num_transport_steps=50)
model_cfg = MLPConfig(hidden_dim=256, depth=6)
data_cfg = CatConfig()

# 2. Initialize
key = jax.random.PRNGKey(42)
key, init_key, train_key = jax.random.split(key, 3)

model = create_model(model_cfg, init_key, data_dim=2)
strategy = create_strategy(strategy_cfg)
optimizer = optax.adam(1e-3)

params, static = eqx.partition(model, eqx.is_inexact_array)
opt_state = optimizer.init(params)

# 3. Training step
@eqx.filter_jit
def train_step(params, static, opt_state, batch, key):
    def loss_fn(params):
        model = eqx.combine(params, static)
        keys = jax.random.split(key, len(batch))
        losses = jax.vmap(strategy.loss_fn, in_axes=(None, 0, 0))(model, batch, keys)
        return jnp.mean(losses)

    loss, grads = eqx.filter_value_and_grad(loss_fn)(params)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = eqx.apply_updates(params, updates)
    return params, opt_state, loss

# 4. Training loop
for step in range(1000):
    train_key, batch_key, step_key = jax.random.split(train_key, 3)
    batch = get_batch(batch_key, data_cfg, batch_size=512)
    params, opt_state, loss = train_step(params, static, opt_state, batch, step_key)

    if (step + 1) % 100 == 0:
        print(f"Step {step+1}: loss = {loss:.6f}")

# 5. Generation
model = eqx.combine(params, static)
gen_key = jax.random.PRNGKey(0)
samples, trajectory = strategy.sample_from_target_distribution(
    model, gen_key, num_samples=1000, data_dim=2
)

print(f"Generated {samples.shape[0]} samples!")
# samples.shape: (1000, 2)
πŸ’Ž

Key Takeaways

Unified Interface

Strategy protocol allows swapping DDPM ↔ Flow Matching without changing training code. Same forward/reverse/loss_fn signature.

DDPM: Noise Prediction

Add noise progressively, learn to predict it. Stochastic reverse process needs many steps (50-1000).

Flow Matching: Velocity Field

Linear interpolation, learn velocity direction. Deterministic ODE needs fewer steps (20-50).

JAX Patterns

vmap for batching, lax.scan for loops, random.split for PRNG, partition/combine for JIT. All explicit, no hidden state.

Time Embedding

Sinusoidal encoding lets network distinguish nearby time steps. Critical for learning continuous dynamics.

2D Visualization

Low-dimensional datasets (Cat, Moon, Swiss-Roll) make it easy to visualize training progress and generation quality.

πŸ”—

Resources