🏡/repos/jax-ml/

jax

🔢

JAX

Composable transformations of NumPy programs

Autodiff, JIT compilation, vectorization, and parallelization. Write NumPy code, run on GPU/TPU with transformations that compose freely.

gradjitvmappmap
🔄

Core Transformations

Composable function transformations

jax.grad

Automatic differentiation (reverse-mode). Computes gradients for backprop.

def loss(params, x, y):
  pred = model(params, x)
  return jnp.mean((pred - y) ** 2)

# Get gradient function
grad_fn = jax.grad(loss)
grads = grad_fn(params, x, y)

jax.jit

JIT compilation via XLA. Fuses ops, runs on GPU/TPU.

@jax.jit
def train_step(params, batch):
  grads = jax.grad(loss)(params, batch)
  return update_params(params, grads)

# First call compiles, subsequent calls are fast

jax.vmap

Auto-vectorization. Map over batch dimension without loops.

def single_loss(param, x, y):
  return (model(param, x) - y) ** 2

# Vectorize over batch
batch_loss = jax.vmap(single_loss, in_axes=(None, 0, 0))
losses = batch_loss(params, xs, ys)

jax.pmap

Parallel map across devices. SPMD parallelism for multi-GPU.

@jax.pmap
def parallel_step(params, batch):
  return train_step(params, batch)

# Runs on all available GPUs/TPUs

Composition is Key

Transformations compose freely: jax.jit(jax.grad(jax.vmap(f))) - vectorize, differentiate, compile.

🧠

Neural Network Primitives

jax.nn module

Activation Functions

relugelusiluswishsigmoidtanhsoftmaxlog_softmaxsoftpluseluseluleaky_relu

Attention (jax.nn.dot_product_attention)

import jax.nn

# Scaled dot-product attention
# Q: (batch, seq, heads, head_dim)
# K, V: (batch, seq, kv_heads, head_dim)
output = jax.nn.dot_product_attention(
    query,                    # Q
    key,                      # K
    value,                    # V
    scale=1.0/jnp.sqrt(d_k), # Scaling factor
    is_causal=True,          # Causal mask for autoregressive
    # implementation='cudnn'  # FlashAttention on GPU
)

# Supports:
# - Multi-Head Attention (MHA): heads == kv_heads
# - Grouped Query Attention (GQA): heads % kv_heads == 0
# - Multi-Query Attention (MQA): kv_heads == 1
🏗️

Building a Transformer

From scratch in JAX

Layer Implementations

import jax
import jax.numpy as jnp

def layer_norm(x, gamma, beta, eps=1e-6):
    """Layer normalization."""
    mean = jnp.mean(x, axis=-1, keepdims=True)
    var = jnp.var(x, axis=-1, keepdims=True)
    return gamma * (x - mean) / jnp.sqrt(var + eps) + beta

def linear(x, W, b=None):
    """Linear projection."""
    out = jnp.einsum("...d,dh->...h", x, W)
    return out + b if b is not None else out

def mlp(x, W1, b1, W2, b2):
    """MLP with GELU activation."""
    h = jax.nn.gelu(linear(x, W1, b1))
    return linear(h, W2, b2)

def attention(q, k, v, mask=None):
    """Multi-head attention."""
    d_k = q.shape[-1]
    scores = jnp.einsum("bthd,bshd->bhts", q, k) / jnp.sqrt(d_k)
    if mask is not None:
        scores = jnp.where(mask, scores, -1e9)
    attn = jax.nn.softmax(scores, axis=-1)
    return jnp.einsum("bhts,bshd->bthd", attn, v)

Transformer Block

def transformer_block(x, params, mask=None):
    """Single transformer block with pre-norm."""
    # Self-attention
    residual = x
    x = layer_norm(x, params['ln1_g'], params['ln1_b'])

    # Project to Q, K, V
    q = linear(x, params['Wq']).reshape(*x.shape[:-1], n_heads, head_dim)
    k = linear(x, params['Wk']).reshape(*x.shape[:-1], n_heads, head_dim)
    v = linear(x, params['Wv']).reshape(*x.shape[:-1], n_heads, head_dim)

    # Attention + project out
    attn_out = attention(q, k, v, mask)
    attn_out = attn_out.reshape(*x.shape[:-1], d_model)
    attn_out = linear(attn_out, params['Wo'])
    x = residual + attn_out

    # MLP
    residual = x
    x = layer_norm(x, params['ln2_g'], params['ln2_b'])
    x = mlp(x, params['W1'], params['b1'], params['W2'], params['b2'])
    x = residual + x

    return x

Full Transformer

def transformer(tokens, params, is_training=False):
    """Decoder-only transformer."""
    batch_size, seq_len = tokens.shape

    # Token + position embeddings
    x = params['token_emb'][tokens]  # (batch, seq, d_model)
    pos = jnp.arange(seq_len)
    x = x + params['pos_emb'][pos]

    # Causal mask
    mask = jnp.tril(jnp.ones((seq_len, seq_len)))

    # Transformer blocks
    for i in range(n_layers):
        x = transformer_block(x, params[f'block_{i}'], mask)

    # Final layer norm + output projection
    x = layer_norm(x, params['ln_f_g'], params['ln_f_b'])
    logits = linear(x, params['out_proj'])  # (batch, seq, vocab_size)

    return logits
🖱️

Action Prediction Transformer

For computer use tasks

Next-Action vs Next-Token

Instead of predicting the next token, we predict the next action (click, type, scroll, etc.) given screen state and history. The model outputs action type + parameters, not text.

Action Space Definition

from dataclasses import dataclass
from enum import IntEnum
import jax.numpy as jnp

class ActionType(IntEnum):
    CLICK = 0
    DOUBLE_CLICK = 1
    RIGHT_CLICK = 2
    TYPE = 3
    KEY = 4           # keyboard shortcut
    SCROLL = 5
    DRAG = 6
    WAIT = 7
    DONE = 8          # task complete

@dataclass
class Action:
    action_type: int       # ActionType enum
    x: float              # normalized [0, 1]
    y: float              # normalized [0, 1]
    text: str | None      # for TYPE action
    key: str | None       # for KEY action (e.g., "ctrl+c")
    scroll_delta: float   # for SCROLL action

# Model outputs:
# - action_type: (batch, num_action_types) logits
# - coordinates: (batch, 2) for x, y
# - text_tokens: (batch, max_text_len, vocab_size) for typing
# - key_tokens: (batch, max_key_len, vocab_size) for shortcuts

Screen Encoder (Vision)

def patch_embed(image, params):
    """Convert image to patch embeddings (ViT-style)."""
    # image: (batch, height, width, channels)
    batch, h, w, c = image.shape
    patch_size = 16

    # Extract patches: (batch, n_patches, patch_dim)
    patches = image.reshape(
        batch, h // patch_size, patch_size, w // patch_size, patch_size, c
    )
    patches = patches.transpose(0, 1, 3, 2, 4, 5)
    patches = patches.reshape(batch, -1, patch_size * patch_size * c)

    # Linear projection to d_model
    x = linear(patches, params['patch_proj'])

    # Add position embeddings
    n_patches = x.shape[1]
    x = x + params['patch_pos_emb'][:n_patches]

    return x  # (batch, n_patches, d_model)

def encode_screen(screenshot, params):
    """Encode screenshot to embeddings."""
    # Patch embedding
    x = patch_embed(screenshot, params['vision'])

    # Vision transformer blocks
    for i in range(n_vision_layers):
        x = transformer_block(x, params['vision'][f'block_{i}'])

    return x  # (batch, n_patches, d_model)

Action History Encoder

def encode_action(action, params):
    """Encode a single action to embedding."""
    # Action type embedding
    type_emb = params['action_type_emb'][action.action_type]

    # Coordinate embedding (if applicable)
    coord_emb = jnp.zeros(d_model)
    if action.action_type in [ActionType.CLICK, ActionType.DOUBLE_CLICK,
                               ActionType.RIGHT_CLICK, ActionType.DRAG]:
        # Fourier features for coordinates
        coord = jnp.array([action.x, action.y])
        coord_emb = linear(fourier_features(coord), params['coord_proj'])

    # Text embedding (if TYPE action)
    text_emb = jnp.zeros(d_model)
    if action.action_type == ActionType.TYPE and action.text:
        text_tokens = tokenize(action.text)
        text_emb = jnp.mean(params['text_emb'][text_tokens], axis=0)

    return type_emb + coord_emb + text_emb

def encode_history(actions, params):
    """Encode action history sequence."""
    # (batch, history_len, d_model)
    action_embs = jax.vmap(encode_action, in_axes=(0, None))(actions, params)

    # Add position embeddings
    action_embs = action_embs + params['history_pos_emb'][:len(actions)]

    return action_embs

Full Action Prediction Model

def action_transformer(screenshot, action_history, task_prompt, params):
    """
    Predict next action given current screen, history, and task.

    Args:
        screenshot: (batch, H, W, C) current screen
        action_history: list of past actions
        task_prompt: tokenized task description

    Returns:
        action_type_logits: (batch, num_action_types)
        coord_pred: (batch, 2) normalized x, y
        text_logits: (batch, max_len, vocab_size) for TYPE
    """
    # Encode inputs
    screen_emb = encode_screen(screenshot, params)      # (batch, n_patches, d)
    history_emb = encode_history(action_history, params) # (batch, hist_len, d)
    task_emb = encode_text(task_prompt, params)         # (batch, task_len, d)

    # Concatenate: [CLS, screen, history, task, SEP]
    cls_token = params['cls_token'].reshape(1, 1, -1).repeat(batch, axis=0)
    sep_token = params['sep_token'].reshape(1, 1, -1).repeat(batch, axis=0)

    x = jnp.concatenate([
        cls_token,
        screen_emb,
        sep_token,
        history_emb,
        sep_token,
        task_emb,
    ], axis=1)

    # Transformer (cross-attention between modalities)
    for i in range(n_layers):
        x = transformer_block(x, params[f'block_{i}'])

    # Output heads from CLS token
    cls_out = x[:, 0]  # (batch, d_model)

    # Action type prediction
    action_logits = linear(cls_out, params['action_head'])

    # Coordinate prediction (for click/drag)
    coord_pred = jax.nn.sigmoid(linear(cls_out, params['coord_head']))

    # Text generation (for TYPE action)
    text_logits = linear(x, params['text_head'])  # (batch, seq, vocab)

    return {
        'action_type': action_logits,
        'coords': coord_pred,
        'text': text_logits,
    }

Loss Function

def action_loss(params, batch):
    """
    Compute loss for action prediction.

    batch contains:
        screenshot, action_history, task_prompt,
        target_action_type, target_coords, target_text
    """
    preds = action_transformer(
        batch['screenshot'],
        batch['action_history'],
        batch['task_prompt'],
        params
    )

    # Action type loss (cross-entropy)
    action_type_loss = -jnp.sum(
        jax.nn.log_softmax(preds['action_type']) *
        jax.nn.one_hot(batch['target_action_type'], num_action_types),
        axis=-1
    ).mean()

    # Coordinate loss (L2 for click/drag actions)
    coord_mask = jnp.isin(
        batch['target_action_type'],
        jnp.array([ActionType.CLICK, ActionType.DOUBLE_CLICK,
                   ActionType.RIGHT_CLICK, ActionType.DRAG])
    )
    coord_loss = jnp.where(
        coord_mask,
        jnp.sum((preds['coords'] - batch['target_coords']) ** 2, axis=-1),
        0.0
    ).mean()

    # Text loss (cross-entropy for TYPE actions)
    text_mask = batch['target_action_type'] == ActionType.TYPE
    text_loss = jnp.where(
        text_mask[:, None],
        -jnp.sum(
            jax.nn.log_softmax(preds['text']) *
            jax.nn.one_hot(batch['target_text'], vocab_size),
            axis=-1
        ),
        0.0
    ).mean()

    return action_type_loss + coord_loss + text_loss
🏋️

Training Loop

Optimizing the model

Parameter Initialization

def init_params(key, config):
    """Initialize transformer parameters."""
    keys = jax.random.split(key, 20)

    d_model = config.d_model
    n_heads = config.n_heads
    head_dim = d_model // n_heads
    ff_dim = config.ff_dim

    he_init = jax.nn.initializers.he_normal()

    params = {
        # Embeddings
        'token_emb': he_init(keys[0], (config.vocab_size, d_model)),
        'pos_emb': he_init(keys[1], (config.max_seq_len, d_model)),
        'action_type_emb': he_init(keys[2], (len(ActionType), d_model)),
        'patch_pos_emb': he_init(keys[3], (config.n_patches, d_model)),

        # Special tokens
        'cls_token': he_init(keys[4], (d_model,)),
        'sep_token': he_init(keys[5], (d_model,)),
    }

    # Transformer blocks
    for i in range(config.n_layers):
        params[f'block_{i}'] = init_transformer_block(keys[6 + i], config)

    # Output heads
    params['action_head'] = he_init(keys[-3], (d_model, len(ActionType)))
    params['coord_head'] = he_init(keys[-2], (d_model, 2))
    params['text_head'] = he_init(keys[-1], (d_model, config.vocab_size))

    return params

Adam Optimizer

def adam_init(params):
    """Initialize Adam optimizer state."""
    return {
        'm': jax.tree.map(jnp.zeros_like, params),  # First moment
        'v': jax.tree.map(jnp.zeros_like, params),  # Second moment
        't': 0
    }

def adam_update(params, grads, opt_state, lr=1e-4, beta1=0.9, beta2=0.999, eps=1e-8):
    """Adam optimizer step."""
    t = opt_state['t'] + 1

    # Update moments
    m = jax.tree.map(
        lambda m, g: beta1 * m + (1 - beta1) * g,
        opt_state['m'], grads
    )
    v = jax.tree.map(
        lambda v, g: beta2 * v + (1 - beta2) * g ** 2,
        opt_state['v'], grads
    )

    # Bias correction
    m_hat = jax.tree.map(lambda m: m / (1 - beta1 ** t), m)
    v_hat = jax.tree.map(lambda v: v / (1 - beta2 ** t), v)

    # Update params
    params = jax.tree.map(
        lambda p, m, v: p - lr * m / (jnp.sqrt(v) + eps),
        params, m_hat, v_hat
    )

    return params, {'m': m, 'v': v, 't': t}

Training Step

@jax.jit
def train_step(params, opt_state, batch):
    """Single training step."""
    loss, grads = jax.value_and_grad(action_loss)(params, batch)
    params, opt_state = adam_update(params, grads, opt_state)
    return params, opt_state, {'loss': loss}

def train(config, dataset):
    """Full training loop."""
    key = jax.random.key(config.seed)
    key, init_key = jax.random.split(key)

    # Initialize
    params = init_params(init_key, config)
    opt_state = adam_init(params)

    for epoch in range(config.epochs):
        for batch in dataset:
            params, opt_state, metrics = train_step(params, opt_state, batch)

            if step % config.log_every == 0:
                print(f"Step {step}: loss={metrics['loss']:.4f}")

        # Validation
        val_loss = evaluate(params, val_dataset)
        print(f"Epoch {epoch}: val_loss={val_loss:.4f}")

        # Checkpoint
        save_checkpoint(params, opt_state, epoch)

    return params
🎯

Inference

Predicting actions

Action Prediction Loop

def predict_action(params, screenshot, history, task):
    """Predict next action."""
    preds = action_transformer(
        screenshot[None],  # Add batch dim
        history,
        tokenize(task)[None],
        params
    )

    # Sample or argmax action type
    action_type = jnp.argmax(preds['action_type'][0])

    # Get coordinates
    x, y = preds['coords'][0]

    # Generate text if TYPE action
    text = None
    if action_type == ActionType.TYPE:
        text = decode_text(preds['text'][0])

    return Action(
        action_type=int(action_type),
        x=float(x),
        y=float(y),
        text=text,
        key=None,
        scroll_delta=0.0
    )

def execute_task(params, task, env):
    """Execute task in environment."""
    history = []

    for step in range(max_steps):
        screenshot = env.screenshot()
        action = predict_action(params, screenshot, history, task)

        if action.action_type == ActionType.DONE:
            return True, history

        env.execute(action)
        history.append(action)

    return False, history
📊

Data Collection

Building training data

Trajectory Format

@dataclass
class Trajectory:
    task: str                    # "Open Chrome and search for JAX"
    screenshots: list[np.array] # Screen at each step
    actions: list[Action]       # Action taken at each step
    success: bool               # Task completed successfully

def trajectory_to_examples(traj: Trajectory):
    """Convert trajectory to training examples."""
    examples = []

    for t in range(len(traj.actions)):
        examples.append({
            'screenshot': traj.screenshots[t],
            'action_history': traj.actions[:t],
            'task_prompt': traj.task,
            'target_action_type': traj.actions[t].action_type,
            'target_coords': jnp.array([traj.actions[t].x, traj.actions[t].y]),
            'target_text': tokenize(traj.actions[t].text or ""),
        })

    return examples

# Collect demonstrations
trajectories = collect_human_demos()  # Or from existing datasets
dataset = [ex for traj in trajectories for ex in trajectory_to_examples(traj)]
🎲

Random Numbers

Functional PRNG design

import jax.random as jr

# Keys are explicit states (not stateful like NumPy)
key = jr.key(42)

# Split key for different uses
key, subkey1, subkey2 = jr.split(key, 3)

# Generate random arrays
x = jr.normal(subkey1, (10, 10))
dropout_mask = jr.bernoulli(subkey2, 0.1, x.shape)

# In training loop
for step in range(num_steps):
    key, subkey = jr.split(key)
    # Use subkey for dropout, data augmentation, etc.
🌳

PyTrees

Nested structure handling

# PyTrees are nested structures (dicts, lists, tuples, namedtuples)
params = {
    'encoder': {
        'layer_0': {'W': jnp.ones((10, 10)), 'b': jnp.zeros(10)},
        'layer_1': {'W': jnp.ones((10, 10)), 'b': jnp.zeros(10)},
    },
    'decoder': {...}
}

# Apply function to all leaves
zeros = jax.tree.map(jnp.zeros_like, params)
scaled = jax.tree.map(lambda x: x * 0.1, params)

# Combine trees
updated = jax.tree.map(lambda p, g: p - 0.01 * g, params, grads)

# Flatten/unflatten
leaves, treedef = jax.tree.flatten(params)
params = jax.tree.unflatten(treedef, leaves)
💎

Key Takeaways

Composable Transforms

grad, jit, vmap, pmap compose freely. Write simple code, transform it for performance.

Functional Paradigm

Pure functions, explicit state (params, PRNG keys). No hidden mutation. Predictable behavior.

NumPy + Accelerators

Write NumPy-like code, run on GPU/TPU via XLA. Same code works everywhere.

Action Prediction

Replace next-token with next-action. Same transformer architecture, different output heads.