JAX
Composable transformations of NumPy programs
Autodiff, JIT compilation, vectorization, and parallelization. Write NumPy code, run on GPU/TPU with transformations that compose freely.
Core Transformations
Composable function transformations
jax.grad
Automatic differentiation (reverse-mode). Computes gradients for backprop.
def loss(params, x, y):
pred = model(params, x)
return jnp.mean((pred - y) ** 2)
# Get gradient function
grad_fn = jax.grad(loss)
grads = grad_fn(params, x, y)jax.jit
JIT compilation via XLA. Fuses ops, runs on GPU/TPU.
@jax.jit
def train_step(params, batch):
grads = jax.grad(loss)(params, batch)
return update_params(params, grads)
# First call compiles, subsequent calls are fastjax.vmap
Auto-vectorization. Map over batch dimension without loops.
def single_loss(param, x, y):
return (model(param, x) - y) ** 2
# Vectorize over batch
batch_loss = jax.vmap(single_loss, in_axes=(None, 0, 0))
losses = batch_loss(params, xs, ys)jax.pmap
Parallel map across devices. SPMD parallelism for multi-GPU.
@jax.pmap
def parallel_step(params, batch):
return train_step(params, batch)
# Runs on all available GPUs/TPUsComposition is Key
Transformations compose freely: jax.jit(jax.grad(jax.vmap(f))) - vectorize, differentiate, compile.
Neural Network Primitives
jax.nn module
Activation Functions
relugelusiluswishsigmoidtanhsoftmaxlog_softmaxsoftpluseluseluleaky_reluAttention (jax.nn.dot_product_attention)
import jax.nn
# Scaled dot-product attention
# Q: (batch, seq, heads, head_dim)
# K, V: (batch, seq, kv_heads, head_dim)
output = jax.nn.dot_product_attention(
query, # Q
key, # K
value, # V
scale=1.0/jnp.sqrt(d_k), # Scaling factor
is_causal=True, # Causal mask for autoregressive
# implementation='cudnn' # FlashAttention on GPU
)
# Supports:
# - Multi-Head Attention (MHA): heads == kv_heads
# - Grouped Query Attention (GQA): heads % kv_heads == 0
# - Multi-Query Attention (MQA): kv_heads == 1Building a Transformer
From scratch in JAX
Layer Implementations
import jax
import jax.numpy as jnp
def layer_norm(x, gamma, beta, eps=1e-6):
"""Layer normalization."""
mean = jnp.mean(x, axis=-1, keepdims=True)
var = jnp.var(x, axis=-1, keepdims=True)
return gamma * (x - mean) / jnp.sqrt(var + eps) + beta
def linear(x, W, b=None):
"""Linear projection."""
out = jnp.einsum("...d,dh->...h", x, W)
return out + b if b is not None else out
def mlp(x, W1, b1, W2, b2):
"""MLP with GELU activation."""
h = jax.nn.gelu(linear(x, W1, b1))
return linear(h, W2, b2)
def attention(q, k, v, mask=None):
"""Multi-head attention."""
d_k = q.shape[-1]
scores = jnp.einsum("bthd,bshd->bhts", q, k) / jnp.sqrt(d_k)
if mask is not None:
scores = jnp.where(mask, scores, -1e9)
attn = jax.nn.softmax(scores, axis=-1)
return jnp.einsum("bhts,bshd->bthd", attn, v)Transformer Block
def transformer_block(x, params, mask=None):
"""Single transformer block with pre-norm."""
# Self-attention
residual = x
x = layer_norm(x, params['ln1_g'], params['ln1_b'])
# Project to Q, K, V
q = linear(x, params['Wq']).reshape(*x.shape[:-1], n_heads, head_dim)
k = linear(x, params['Wk']).reshape(*x.shape[:-1], n_heads, head_dim)
v = linear(x, params['Wv']).reshape(*x.shape[:-1], n_heads, head_dim)
# Attention + project out
attn_out = attention(q, k, v, mask)
attn_out = attn_out.reshape(*x.shape[:-1], d_model)
attn_out = linear(attn_out, params['Wo'])
x = residual + attn_out
# MLP
residual = x
x = layer_norm(x, params['ln2_g'], params['ln2_b'])
x = mlp(x, params['W1'], params['b1'], params['W2'], params['b2'])
x = residual + x
return xFull Transformer
def transformer(tokens, params, is_training=False):
"""Decoder-only transformer."""
batch_size, seq_len = tokens.shape
# Token + position embeddings
x = params['token_emb'][tokens] # (batch, seq, d_model)
pos = jnp.arange(seq_len)
x = x + params['pos_emb'][pos]
# Causal mask
mask = jnp.tril(jnp.ones((seq_len, seq_len)))
# Transformer blocks
for i in range(n_layers):
x = transformer_block(x, params[f'block_{i}'], mask)
# Final layer norm + output projection
x = layer_norm(x, params['ln_f_g'], params['ln_f_b'])
logits = linear(x, params['out_proj']) # (batch, seq, vocab_size)
return logitsAction Prediction Transformer
For computer use tasks
Next-Action vs Next-Token
Instead of predicting the next token, we predict the next action (click, type, scroll, etc.) given screen state and history. The model outputs action type + parameters, not text.
Action Space Definition
from dataclasses import dataclass
from enum import IntEnum
import jax.numpy as jnp
class ActionType(IntEnum):
CLICK = 0
DOUBLE_CLICK = 1
RIGHT_CLICK = 2
TYPE = 3
KEY = 4 # keyboard shortcut
SCROLL = 5
DRAG = 6
WAIT = 7
DONE = 8 # task complete
@dataclass
class Action:
action_type: int # ActionType enum
x: float # normalized [0, 1]
y: float # normalized [0, 1]
text: str | None # for TYPE action
key: str | None # for KEY action (e.g., "ctrl+c")
scroll_delta: float # for SCROLL action
# Model outputs:
# - action_type: (batch, num_action_types) logits
# - coordinates: (batch, 2) for x, y
# - text_tokens: (batch, max_text_len, vocab_size) for typing
# - key_tokens: (batch, max_key_len, vocab_size) for shortcutsScreen Encoder (Vision)
def patch_embed(image, params):
"""Convert image to patch embeddings (ViT-style)."""
# image: (batch, height, width, channels)
batch, h, w, c = image.shape
patch_size = 16
# Extract patches: (batch, n_patches, patch_dim)
patches = image.reshape(
batch, h // patch_size, patch_size, w // patch_size, patch_size, c
)
patches = patches.transpose(0, 1, 3, 2, 4, 5)
patches = patches.reshape(batch, -1, patch_size * patch_size * c)
# Linear projection to d_model
x = linear(patches, params['patch_proj'])
# Add position embeddings
n_patches = x.shape[1]
x = x + params['patch_pos_emb'][:n_patches]
return x # (batch, n_patches, d_model)
def encode_screen(screenshot, params):
"""Encode screenshot to embeddings."""
# Patch embedding
x = patch_embed(screenshot, params['vision'])
# Vision transformer blocks
for i in range(n_vision_layers):
x = transformer_block(x, params['vision'][f'block_{i}'])
return x # (batch, n_patches, d_model)Action History Encoder
def encode_action(action, params):
"""Encode a single action to embedding."""
# Action type embedding
type_emb = params['action_type_emb'][action.action_type]
# Coordinate embedding (if applicable)
coord_emb = jnp.zeros(d_model)
if action.action_type in [ActionType.CLICK, ActionType.DOUBLE_CLICK,
ActionType.RIGHT_CLICK, ActionType.DRAG]:
# Fourier features for coordinates
coord = jnp.array([action.x, action.y])
coord_emb = linear(fourier_features(coord), params['coord_proj'])
# Text embedding (if TYPE action)
text_emb = jnp.zeros(d_model)
if action.action_type == ActionType.TYPE and action.text:
text_tokens = tokenize(action.text)
text_emb = jnp.mean(params['text_emb'][text_tokens], axis=0)
return type_emb + coord_emb + text_emb
def encode_history(actions, params):
"""Encode action history sequence."""
# (batch, history_len, d_model)
action_embs = jax.vmap(encode_action, in_axes=(0, None))(actions, params)
# Add position embeddings
action_embs = action_embs + params['history_pos_emb'][:len(actions)]
return action_embsFull Action Prediction Model
def action_transformer(screenshot, action_history, task_prompt, params):
"""
Predict next action given current screen, history, and task.
Args:
screenshot: (batch, H, W, C) current screen
action_history: list of past actions
task_prompt: tokenized task description
Returns:
action_type_logits: (batch, num_action_types)
coord_pred: (batch, 2) normalized x, y
text_logits: (batch, max_len, vocab_size) for TYPE
"""
# Encode inputs
screen_emb = encode_screen(screenshot, params) # (batch, n_patches, d)
history_emb = encode_history(action_history, params) # (batch, hist_len, d)
task_emb = encode_text(task_prompt, params) # (batch, task_len, d)
# Concatenate: [CLS, screen, history, task, SEP]
cls_token = params['cls_token'].reshape(1, 1, -1).repeat(batch, axis=0)
sep_token = params['sep_token'].reshape(1, 1, -1).repeat(batch, axis=0)
x = jnp.concatenate([
cls_token,
screen_emb,
sep_token,
history_emb,
sep_token,
task_emb,
], axis=1)
# Transformer (cross-attention between modalities)
for i in range(n_layers):
x = transformer_block(x, params[f'block_{i}'])
# Output heads from CLS token
cls_out = x[:, 0] # (batch, d_model)
# Action type prediction
action_logits = linear(cls_out, params['action_head'])
# Coordinate prediction (for click/drag)
coord_pred = jax.nn.sigmoid(linear(cls_out, params['coord_head']))
# Text generation (for TYPE action)
text_logits = linear(x, params['text_head']) # (batch, seq, vocab)
return {
'action_type': action_logits,
'coords': coord_pred,
'text': text_logits,
}Loss Function
def action_loss(params, batch):
"""
Compute loss for action prediction.
batch contains:
screenshot, action_history, task_prompt,
target_action_type, target_coords, target_text
"""
preds = action_transformer(
batch['screenshot'],
batch['action_history'],
batch['task_prompt'],
params
)
# Action type loss (cross-entropy)
action_type_loss = -jnp.sum(
jax.nn.log_softmax(preds['action_type']) *
jax.nn.one_hot(batch['target_action_type'], num_action_types),
axis=-1
).mean()
# Coordinate loss (L2 for click/drag actions)
coord_mask = jnp.isin(
batch['target_action_type'],
jnp.array([ActionType.CLICK, ActionType.DOUBLE_CLICK,
ActionType.RIGHT_CLICK, ActionType.DRAG])
)
coord_loss = jnp.where(
coord_mask,
jnp.sum((preds['coords'] - batch['target_coords']) ** 2, axis=-1),
0.0
).mean()
# Text loss (cross-entropy for TYPE actions)
text_mask = batch['target_action_type'] == ActionType.TYPE
text_loss = jnp.where(
text_mask[:, None],
-jnp.sum(
jax.nn.log_softmax(preds['text']) *
jax.nn.one_hot(batch['target_text'], vocab_size),
axis=-1
),
0.0
).mean()
return action_type_loss + coord_loss + text_lossTraining Loop
Optimizing the model
Parameter Initialization
def init_params(key, config):
"""Initialize transformer parameters."""
keys = jax.random.split(key, 20)
d_model = config.d_model
n_heads = config.n_heads
head_dim = d_model // n_heads
ff_dim = config.ff_dim
he_init = jax.nn.initializers.he_normal()
params = {
# Embeddings
'token_emb': he_init(keys[0], (config.vocab_size, d_model)),
'pos_emb': he_init(keys[1], (config.max_seq_len, d_model)),
'action_type_emb': he_init(keys[2], (len(ActionType), d_model)),
'patch_pos_emb': he_init(keys[3], (config.n_patches, d_model)),
# Special tokens
'cls_token': he_init(keys[4], (d_model,)),
'sep_token': he_init(keys[5], (d_model,)),
}
# Transformer blocks
for i in range(config.n_layers):
params[f'block_{i}'] = init_transformer_block(keys[6 + i], config)
# Output heads
params['action_head'] = he_init(keys[-3], (d_model, len(ActionType)))
params['coord_head'] = he_init(keys[-2], (d_model, 2))
params['text_head'] = he_init(keys[-1], (d_model, config.vocab_size))
return paramsAdam Optimizer
def adam_init(params):
"""Initialize Adam optimizer state."""
return {
'm': jax.tree.map(jnp.zeros_like, params), # First moment
'v': jax.tree.map(jnp.zeros_like, params), # Second moment
't': 0
}
def adam_update(params, grads, opt_state, lr=1e-4, beta1=0.9, beta2=0.999, eps=1e-8):
"""Adam optimizer step."""
t = opt_state['t'] + 1
# Update moments
m = jax.tree.map(
lambda m, g: beta1 * m + (1 - beta1) * g,
opt_state['m'], grads
)
v = jax.tree.map(
lambda v, g: beta2 * v + (1 - beta2) * g ** 2,
opt_state['v'], grads
)
# Bias correction
m_hat = jax.tree.map(lambda m: m / (1 - beta1 ** t), m)
v_hat = jax.tree.map(lambda v: v / (1 - beta2 ** t), v)
# Update params
params = jax.tree.map(
lambda p, m, v: p - lr * m / (jnp.sqrt(v) + eps),
params, m_hat, v_hat
)
return params, {'m': m, 'v': v, 't': t}Training Step
@jax.jit
def train_step(params, opt_state, batch):
"""Single training step."""
loss, grads = jax.value_and_grad(action_loss)(params, batch)
params, opt_state = adam_update(params, grads, opt_state)
return params, opt_state, {'loss': loss}
def train(config, dataset):
"""Full training loop."""
key = jax.random.key(config.seed)
key, init_key = jax.random.split(key)
# Initialize
params = init_params(init_key, config)
opt_state = adam_init(params)
for epoch in range(config.epochs):
for batch in dataset:
params, opt_state, metrics = train_step(params, opt_state, batch)
if step % config.log_every == 0:
print(f"Step {step}: loss={metrics['loss']:.4f}")
# Validation
val_loss = evaluate(params, val_dataset)
print(f"Epoch {epoch}: val_loss={val_loss:.4f}")
# Checkpoint
save_checkpoint(params, opt_state, epoch)
return paramsInference
Predicting actions
Action Prediction Loop
def predict_action(params, screenshot, history, task):
"""Predict next action."""
preds = action_transformer(
screenshot[None], # Add batch dim
history,
tokenize(task)[None],
params
)
# Sample or argmax action type
action_type = jnp.argmax(preds['action_type'][0])
# Get coordinates
x, y = preds['coords'][0]
# Generate text if TYPE action
text = None
if action_type == ActionType.TYPE:
text = decode_text(preds['text'][0])
return Action(
action_type=int(action_type),
x=float(x),
y=float(y),
text=text,
key=None,
scroll_delta=0.0
)
def execute_task(params, task, env):
"""Execute task in environment."""
history = []
for step in range(max_steps):
screenshot = env.screenshot()
action = predict_action(params, screenshot, history, task)
if action.action_type == ActionType.DONE:
return True, history
env.execute(action)
history.append(action)
return False, historyData Collection
Building training data
Trajectory Format
@dataclass
class Trajectory:
task: str # "Open Chrome and search for JAX"
screenshots: list[np.array] # Screen at each step
actions: list[Action] # Action taken at each step
success: bool # Task completed successfully
def trajectory_to_examples(traj: Trajectory):
"""Convert trajectory to training examples."""
examples = []
for t in range(len(traj.actions)):
examples.append({
'screenshot': traj.screenshots[t],
'action_history': traj.actions[:t],
'task_prompt': traj.task,
'target_action_type': traj.actions[t].action_type,
'target_coords': jnp.array([traj.actions[t].x, traj.actions[t].y]),
'target_text': tokenize(traj.actions[t].text or ""),
})
return examples
# Collect demonstrations
trajectories = collect_human_demos() # Or from existing datasets
dataset = [ex for traj in trajectories for ex in trajectory_to_examples(traj)]Random Numbers
Functional PRNG design
import jax.random as jr
# Keys are explicit states (not stateful like NumPy)
key = jr.key(42)
# Split key for different uses
key, subkey1, subkey2 = jr.split(key, 3)
# Generate random arrays
x = jr.normal(subkey1, (10, 10))
dropout_mask = jr.bernoulli(subkey2, 0.1, x.shape)
# In training loop
for step in range(num_steps):
key, subkey = jr.split(key)
# Use subkey for dropout, data augmentation, etc.PyTrees
Nested structure handling
# PyTrees are nested structures (dicts, lists, tuples, namedtuples)
params = {
'encoder': {
'layer_0': {'W': jnp.ones((10, 10)), 'b': jnp.zeros(10)},
'layer_1': {'W': jnp.ones((10, 10)), 'b': jnp.zeros(10)},
},
'decoder': {...}
}
# Apply function to all leaves
zeros = jax.tree.map(jnp.zeros_like, params)
scaled = jax.tree.map(lambda x: x * 0.1, params)
# Combine trees
updated = jax.tree.map(lambda p, g: p - 0.01 * g, params, grads)
# Flatten/unflatten
leaves, treedef = jax.tree.flatten(params)
params = jax.tree.unflatten(treedef, leaves)Key Takeaways
Composable Transforms
grad, jit, vmap, pmap compose freely. Write simple code, transform it for performance.
Functional Paradigm
Pure functions, explicit state (params, PRNG keys). No hidden mutation. Predictable behavior.
NumPy + Accelerators
Write NumPy-like code, run on GPU/TPU via XLA. Same code works everywhere.
Action Prediction
Replace next-token with next-action. Same transformer architecture, different output heads.