J

jax-js

JAX-style ML framework in pure TypeScript • Zero dependencies • WebGPU + WASM

What is jax-js?

A machine learning framework that brings JAX-style array programming to the browser. It compiles numerical operations into WebAssembly and WebGPU kernels at runtime, enabling high-performance ML directly in JavaScript with zero external dependencies.

🎯

Zero Dependencies

Pure TypeScript, no Python backend, runs entirely in browser

GPU Acceleration

WebGPU compute shaders with f16 support for fast inference

🔄

Composable Transforms

grad(vmap(jit(f))) - nest transformations arbitrarily

🧮

Automatic Differentiation

Forward (jvp) and reverse (vjp) mode autodiff

High-Level Architecture

JavaScript
NumPy-like API
Tracing
Record operations
Jaxpr IR
Functional graph
ALU Expr
Scalar DAG
Backend
Wasm / WebGPU
Example: Gradient DescentTypeScript
import * as jax from "jax-js";

// Define a loss function
function loss(params: jax.Array, x: jax.Array, y: jax.Array) {
  const pred = x.dot(params);  // Linear prediction
  const diff = pred.sub(y);
  return diff.mul(diff).mean();  // MSE loss
}

// Get gradient function (reverse-mode autodiff)
const gradLoss = jax.grad(loss);

// JIT compile for speed
const step = jax.jit((params, x, y, lr) => {
  const g = gradLoss(params.ref, x, y);  // .ref for refcounting
  return params.sub(g.mul(lr));          // SGD update
});

// Training loop
let params = jax.randn([10, 1]);
for (let i = 0; i < 1000; i++) {
  params = step(params, X_train, y_train, 0.01);
}
~6,700
Source Lines
TypeScript
39
Source Files
Modular design
30+
Primitives
Core operations
3000+ GFLOP/s
Performance
M4 Pro matmul