2025-09-30
I want to get started with JAX fast and am starting with a simple MLP with a continuous output. Rules:
I am slightly familiar with PyTorch, so obviously I started with a
class MLP
. As it turns out, this is not the most idiomatic
way to write JAX code, more on that later. But here’s the first version,
including all its bugs.
import jax
import jax.numpy as jnp
# Synthetic data gen
= 10000
N = jax.random.PRNGKey(42)
key = jax.random.split(key)
key, subkey = jax.random.normal(subkey, (N, 5))
X
def real_func(xs):
return jnp.sum(xs, axis=1, keepdims=True)
= real_func(X)
ys
print(f'Generated synthetic data: X: {X.shape}, ys: {ys.shape}')
# Model meta defn.
class MLP:
def __init__(self, key, input_size, hidden_sizes, output_size):
self.params = []
= [input_size] + hidden_sizes
sizes self.nparams = []
for i, (in_size, out_size) in enumerate(zip(sizes[:-1], sizes[1:])):
= jax.random.split(key)
key, subkey = jax.random.normal(subkey, shape=(out_size, in_size)) * jax.numpy.sqrt(2. / in_size)
W = jax.random.normal(subkey, shape=(out_size,)) * jax.numpy.sqrt(2. / in_size)
b = jax.random.normal(subkey, shape=(out_size,)) * jax.numpy.sqrt(2. / in_size) # Adaptive gain for LayerNorm
g self.nparams.append(3)
self.params.extend((W, b, g))
= jax.random.normal(subkey, shape=(output_size, hidden_sizes[-1])) * jax.numpy.sqrt(2. / hidden_sizes[-1])
W_out = jax.random.normal(subkey, shape=(output_size,))
b_out self.params.extend((W_out, b_out))
self.nparams.append(2)
self._loss_and_grad = jax.jit(jax.value_and_grad(self._loss))
def _forward(self, params, xs):
for i in range(0, len(params)-2, 3):
= params[i:i+3]
W, b, g # LayerNorm
= xs @ W.T
a = a.mean(axis=1, keepdims=True)
mu = a.std(axis=1, keepdims=True)
sigma # print(f'mu:{mu.item()}:{mu.shape}, sigma:{sigma.item()}:{sigma.shape}')
= g * (a - mu) / sigma
a # print(f'a:{a.shape}')
# ReLU
= jax.nn.relu(a + b)
xs = params[-2:]
W, b = xs @ W.T + b
xs return xs
def forward(self, xs):
return self._forward(self.params, xs)
def _loss(self, params, xs, ys):
= self._forward(params, xs)
preds assert preds.shape == ys.shape, f'prediction shape {preds.shape} != truth shape {ys.shape}'
return jax.numpy.mean((ys - preds) ** 2)
def loss(self, xs, ys):
return self._loss(self.params, xs, ys)
def loss_and_grad(self, xs, ys):
return self._loss_and_grad(self.params, xs, ys)
# Train/val split
= int(.8 * N)
ntrain = X[:ntrain], ys[:ntrain]
xs_train, ys_train = X[ntrain:], ys[ntrain:]
xs_val, ys_val print(f'Training set: {xs_train.shape}, {ys_train.shape}')
print(f'Validation set: {xs_val.shape}, {ys_val.shape}')
# Model instantiation
= jax.random.split(key)
key, subkey = MLP(subkey, 5, [8, 6, 6], 1)
mlp # print(mlp.loss_and_grad(xs_train[:1], ys_train[:1]))
# print('Weights\n', mlp.params)
# print('Weight shapes\n', [p.shape for p in mlp.params])
# Training
= 0.01
lr = 1000
niter = 256
nbatch = jax.random.permutation(subkey, ntrain)
idx = 0
start for i in range(niter):
= jax.random.split(key)
key, subkey = jax.random.permutation(subkey, ntrain)[:nbatch]
idx = xs_train[idx[start:start+nbatch]], ys_train[idx[start:start+nbatch]]
xb, yb += nbatch
start if start >= ntrain:
= 0
start = jax.random.split(key)
key, subkey = jax.random.permutation(subkey, ntrain)
idx
= mlp.loss_and_grad(xb, yb)
l, dl if i % 100 == 0:
= mlp.loss(xs_train, ys_train)
train_loss print(f'Iteration {i}, batch loss = {l.item()}, train loss = {train_loss.item()}')
for k in range(len(mlp.params)):
-= lr * dl[k]
mlp.params[k]
print(f'Final train loss = {mlp.loss(xs_train, ys_train).item()}')
print(f'Final val loss = {mlp.loss(xs_val, ys_val).item()}')
Generated synthetic data: X: (10000, 5), ys: (10000, 1)
Training set: (8000, 5), (8000, 1)
Validation set: (2000, 5), (2000, 1)
Iteration 0, batch loss = 13.008309364318848, train loss = 13.735170364379883
Iteration 100, batch loss = nan, train loss = 8.126752853393555
Iteration 200, batch loss = nan, train loss = 7.5402913093566895
Iteration 300, batch loss = nan, train loss = 7.13360595703125
Iteration 400, batch loss = nan, train loss = 6.8245697021484375
Iteration 500, batch loss = nan, train loss = 6.5315937995910645
Iteration 600, batch loss = nan, train loss = 6.269874572753906
Iteration 700, batch loss = nan, train loss = 6.0824151039123535
Iteration 800, batch loss = 6.443068504333496, train loss = 5.874807834625244
Iteration 900, batch loss = nan, train loss = 5.619460105895996
Final train loss = 5.504083156585693
Final val loss = 5.637689113616943
Both the final training and validation losses are high for a simple function and a frankly large network. I stared the code down, but I’ll be honest, I couldn’t figure what’s going on. So I moved on to step two, ask for a critique.
A summary of fixes:
1. Use separate RNG keys for weight initialization: The following reuses the random state and can cause the parameters to be correlated, which is no good. When I made this change alone.
key, subkey = jax.random.split(key)
W = jax.random.normal(subkey, shape=(out_size, in_size)) * jax.numpy.sqrt(2. / in_size)
b = jax.random.normal(subkey, shape=(out_size,)) * jax.numpy.sqrt(2. / in_size)
g = jax.random.normal(subkey, shape=(out_size,)) * jax.numpy.sqrt(2. / in_size) # Adaptive gain for LayerNorm
When I made these changes alone, there wasn’t an appreciable decrease in training or val losses. But this was a good fix.
2. LayerNorm implementation: One was an obvious fix:
Use a small epsilon (say \(10^{-5}\))
in the denominator to guard against division by zero, in case the std is
zero. The second fix was more interesting: Apparently for LayerNorm, the
beta and gamma parameters should not be He initialized, but are
to be initialized to zeros and ones respectively. These fixes
moved my val loss from 5.64 to 2.86!! In hindsight, I the gain
parameters g
starting out as all 1s makes sense – why
decimate the activations a priori by using g
values drawn
from a normal?
3. Batching bug: Uh, embarrasing. See how I
re-compute a 256-long idx
array in every iteration of the
training loop, but still honest-to-god advance the start
index? Well, since len(idx)=256
, the second and subsequent
iterations will attempt the indexing idx[256:512]
,
idx[512:768]
… which will all be empty slices… until
start
is reset, and then we get one nonempty batch again.
Yuck. The fix was simple of course, and got us from 5.64
in
val loss to 0.82
!!
4. JIT’ing the core: I am JIT’ing only the loss/grad
computation function, but the parameter updates still happen in Python
in the training loop. The idiom here is to JIT the toplevel unit that
gets invoked over and over. In our case, we’d define a
train_step
function and move the param updates in there.
Making this change made the code much, much faster. I
did not %timeit
, maybe I should, but the difference is
apparent.
5. Use optax: This library provides implementations
of optimizers like AdamW, which do more sophisticated things like
momentum, gradient clipping for numerical stablity. It also becomes
easier to support learning rate schedules later. Using optax moved the
val loss from 0.82
to 0.57
. Nice, free win.
But beware, naively updating parameters in the step function won’t work.
Since JAX arrays are immutable, we have to return the updated params and
continue in our training loop with these updated params. This makes the
code awkward, but I persist in using an MLP
class for now,
and punt the translation to a more functional style to the next
iteration.
import jax, optax
import jax.numpy as jnp
# Synthetic data gen
= 10000
N = jax.random.PRNGKey(42)
key = jax.random.split(key)
key, subkey = jax.random.normal(subkey, (N, 5))
X
def real_func(xs):
return jnp.sum(xs, axis=1, keepdims=True)
= real_func(X)
ys
print(f'Generated synthetic data: X: {X.shape}, ys: {ys.shape}')
# Model meta defn.
class MLP:
def __init__(self, key, input_size, hidden_sizes, output_size):
self.params = []
= [input_size] + hidden_sizes
sizes self.nparams = []
for i, (in_size, out_size) in enumerate(zip(sizes[:-1], sizes[1:])):
= jax.random.split(key, 2)
key, wkey = jax.random.normal(wkey, shape=(out_size, in_size)) * jax.numpy.sqrt(2. / in_size)
W = jnp.zeros(shape=(out_size,))
b = jnp.ones(shape=(out_size,)) # Adaptive gain for LayerNorm
g self.nparams.append(3)
self.params.extend((W, b, g))
= jax.random.split(key, 3)
key, wkey, bkey = jax.random.normal(wkey, shape=(output_size, hidden_sizes[-1])) * jax.numpy.sqrt(2. / hidden_sizes[-1])
W_out = jax.random.normal(bkey, shape=(output_size,))
b_out self.params.extend((W_out, b_out))
self.nparams.append(2)
self._loss_and_grad = jax.value_and_grad(self._loss)
def _forward(self, params, xs):
= 1e-5
eps for i in range(0, len(params)-2, 3):
= params[i:i+3]
W, b, g # LayerNorm
= xs @ W.T
a = a.mean(axis=1, keepdims=True)
mu = a.std(axis=1, keepdims=True)
sigma # print(f'mu:{mu.item()}:{mu.shape}, sigma:{sigma.item()}:{sigma.shape}')
= g * (a - mu) / (eps + sigma)
a # print(f'a:{a.shape}')
# ReLU
= jax.nn.relu(a + b)
xs = params[-2:]
W, b = xs @ W.T + b
xs return xs
def forward(self, xs):
return self._forward(self.params, xs)
def _loss(self, params, xs, ys):
= self._forward(params, xs)
preds assert preds.shape == ys.shape, f'prediction shape {preds.shape} != truth shape {ys.shape}'
return jax.numpy.mean((ys - preds) ** 2)
def loss(self, xs, ys):
return self._loss(self.params, xs, ys)
def loss_and_grad(self, xs, ys):
return self._loss_and_grad(self.params, xs, ys)
# Train/val split
= int(.8 * N)
ntrain = X[:ntrain], ys[:ntrain]
xs_train, ys_train = X[ntrain:], ys[ntrain:]
xs_val, ys_val print(f'Training set: {xs_train.shape}, {ys_train.shape}')
print(f'Validation set: {xs_val.shape}, {ys_val.shape}')
# Model instantiation
= jax.random.split(key)
key, subkey = MLP(subkey, 5, [8, 6, 6], 1)
mlp # print(mlp.loss_and_grad(xs_train[:1], ys_train[:1]))
# print('Weights\n', mlp.params)
# print('Weight shapes\n', [p.shape for p in mlp.params])
# Training
= optax.adamw(learning_rate=0.01, weight_decay=0.001)
opt
@jax.jit
def train_step(params, opt_state, xs, ys):
= mlp._loss_and_grad(params, xs, ys)
l, dl = opt.update(dl, opt_state, params)
updates, opt_state = optax.apply_updates(params, updates)
params return params, opt_state
= jax.random.split(key)
key, subkey = jax.random.permutation(subkey, ntrain)
idx = 1000
niter = 256
nbatch = 0
start = opt.init(mlp.params)
opt_state
for i in range(niter):
= jax.random.split(key)
key, subkey = xs_train[idx[start:start+nbatch]], ys_train[idx[start:start+nbatch]]
xb, yb += nbatch
start if start >= ntrain:
= 0
start = jax.random.split(key)
key, subkey = jax.random.permutation(subkey, ntrain)
idx = train_step(mlp.params, opt_state, xb, yb)
mlp.params, opt_state if i % 100 == 0:
= mlp.loss(xs_train, ys_train)
train_loss print(f'Iteration {i}, batch loss = {l.item()}, train loss = {train_loss.item()}')
print(f'Final train loss = {mlp.loss(xs_train, ys_train).item()}')
print(f'Final val loss = {mlp.loss(xs_val, ys_val).item()}')
Generated synthetic data: X: (10000, 5), ys: (10000, 1)
Training set: (8000, 5), (8000, 1)
Validation set: (2000, 5), (2000, 1)
Iteration 0, batch loss = nan, train loss = 4.737359523773193
Iteration 100, batch loss = nan, train loss = 0.7498972415924072
Iteration 200, batch loss = nan, train loss = 0.6198509335517883
Iteration 300, batch loss = nan, train loss = 0.5501429438591003
Iteration 400, batch loss = nan, train loss = 0.49914073944091797
Iteration 500, batch loss = nan, train loss = 0.5024610757827759
Iteration 600, batch loss = nan, train loss = 0.48184746503829956
Iteration 700, batch loss = nan, train loss = 0.47913187742233276
Iteration 800, batch loss = nan, train loss = 0.4928268492221832
Iteration 900, batch loss = nan, train loss = 0.4822283685207367
Final train loss = 0.5036829710006714
Final val loss = 0.5718706250190735
The above code is modeled after how we’d code a model in PyTorch. And it smells, as we have to do the weird param shuffling in and out of the model instance. So here is the above code minus redundant parts, with pure functions and a toplevel training loop that plumbs all required state explicitly.
import jax, optax
import jax.numpy as jnp
# Synthetic data gen
= 10000
N = jax.random.PRNGKey(42)
key = jax.random.split(key)
key, subkey = jax.random.normal(subkey, (N, 5))
X
def real_func(xs):
return jnp.sum(xs, axis=1, keepdims=True)
= real_func(X)
ys
print(f'Generated synthetic data: X: {X.shape}, ys: {ys.shape}')
# Model meta defn.
def init_mlp(key, input_size, hidden_sizes, output_size):
= []
params = [input_size] + hidden_sizes
sizes for i, (in_size, out_size) in enumerate(zip(sizes[:-1], sizes[1:])):
= jax.random.split(key, 2)
key, wkey dict(
params.append(=jax.random.normal(wkey, shape=(out_size, in_size)) * jax.numpy.sqrt(2. / in_size),
W=jnp.zeros(shape=(out_size,)),
b=jnp.ones(shape=(out_size,))))
g= jax.random.split(key, 3)
key, wkey, bkey = jax.random.normal(wkey, shape=(output_size, hidden_sizes[-1])) * jax.numpy.sqrt(2. / hidden_sizes[-1])
W_out = jax.random.normal(bkey, shape=(output_size,))
b_out return params, {'W': W_out, 'b': b_out}
def predict(params, xs):
= params
layers, head = 1e-5
eps for l in layers:
= l['W']
W = l['b']
b = l['g']
g # LayerNorm
= xs @ W.T
a = a.mean(axis=1, keepdims=True)
mu = a.std(axis=1, keepdims=True)
sigma = g * (a - mu) / (eps + sigma)
a # ReLU
= jax.nn.relu(a + b)
xs # Output: No norm/nonlinearity.
= head['W'], head['b']
W, b = xs @ W.T + b
xs return xs
@jax.jit
def mse_loss(params, xs, ys):
= predict(params, xs)
preds return jnp.mean((ys - preds) ** 2)
def step(opt, params, opt_state, xb, yb):
= jax.value_and_grad(mse_loss)(params, xb, yb)
l, dl #print(dl)
= opt.update(dl, opt_state, params)
updates, opt_state = optax.apply_updates(params, updates)
params return tuple(params), opt_state
= jax.jit(step, static_argnames='opt')
step
# Train/val split
= int(.8 * N)
ntrain = X[:ntrain], ys[:ntrain]
xs_train, ys_train = X[ntrain:], ys[ntrain:]
xs_val, ys_val print(f'Training set: {xs_train.shape}, {ys_train.shape}')
print(f'Validation set: {xs_val.shape}, {ys_val.shape}')
# Model instantiation
= jax.random.split(key)
key, subkey = init_mlp(subkey, 5, [8, 6, 6], 1)
params
# Training
= optax.adamw(learning_rate=0.01, weight_decay=0.001)
opt
= jax.random.split(key)
key, subkey = jax.random.permutation(subkey, ntrain)
idx = 1000
niter = 256
nbatch = 0
start = opt.init(params)
opt_state
for i in range(niter):
= jax.random.split(key)
key, subkey = xs_train[idx[start:start+nbatch]], ys_train[idx[start:start+nbatch]]
xb, yb += nbatch
start if start >= ntrain:
= 0
start = jax.random.split(key)
key, subkey = jax.random.permutation(subkey, ntrain)
idx = step(opt, params, opt_state, xb, yb)
params, opt_state if i % 100 == 0:
= mse_loss(params, xb, yb)
l = mse_loss(params, xs_train, ys_train)
train_loss print(f'Iteration {i}, batch loss = {l.item()}, train loss = {train_loss.item()}')
print(f'Final train loss = {mse_loss(params, xs_train, ys_train).item()}')
print(f'Final val loss = {mse_loss(params, xs_val, ys_val).item()}')
Generated synthetic data: X: (10000, 5), ys: (10000, 1)
Training set: (8000, 5), (8000, 1)
Validation set: (2000, 5), (2000, 1)
Iteration 0, batch loss = 5.191179275512695, train loss = 4.737359523773193
Iteration 100, batch loss = 0.6641169786453247, train loss = 0.7498972415924072
Iteration 200, batch loss = 0.5052511096000671, train loss = 0.6198509335517883
Iteration 300, batch loss = 0.4988126754760742, train loss = 0.5501430034637451
Iteration 400, batch loss = 0.5596280694007874, train loss = 0.49914073944091797
Iteration 500, batch loss = 0.4847404360771179, train loss = 0.5024611353874207
Iteration 600, batch loss = 0.6246224045753479, train loss = 0.48184746503829956
Iteration 700, batch loss = 0.45626115798950195, train loss = 0.47913187742233276
Iteration 800, batch loss = 0.5729255676269531, train loss = 0.4928268492221832
Iteration 900, batch loss = 0.5909319519996643, train loss = 0.4822283983230591
Final train loss = 0.5036829710006714
Final val loss = 0.5718706250190735