JAX-style ML framework in pure TypeScript • Zero dependencies • WebGPU + WASM
A machine learning framework that brings JAX-style array programming to the browser. It compiles numerical operations into WebAssembly and WebGPU kernels at runtime, enabling high-performance ML directly in JavaScript with zero external dependencies.
Pure TypeScript, no Python backend, runs entirely in browser
WebGPU compute shaders with f16 support for fast inference
grad(vmap(jit(f))) - nest transformations arbitrarily
Forward (jvp) and reverse (vjp) mode autodiff
import * as jax from "jax-js";
// Define a loss function
function loss(params: jax.Array, x: jax.Array, y: jax.Array) {
const pred = x.dot(params); // Linear prediction
const diff = pred.sub(y);
return diff.mul(diff).mean(); // MSE loss
}
// Get gradient function (reverse-mode autodiff)
const gradLoss = jax.grad(loss);
// JIT compile for speed
const step = jax.jit((params, x, y, lr) => {
const g = gradLoss(params.ref, x, y); // .ref for refcounting
return params.sub(g.mul(lr)); // SGD update
});
// Training loop
let params = jax.randn([10, 1]);
for (let i = 0; i < 1000; i++) {
params = step(params, X_train, y_train, 0.01);
}