Learning JAX, good ol’ MLP

2025-09-30

I want to get started with JAX fast and am starting with a simple MLP with a continuous output. Rules:

  1. Hand-write all code: Autocomplete OK, AI-based mindreading to be used sparingly only for boilerplate I am familiar with.
  2. Ask ChatGPT/Claude to critique the code
  3. Incorporate feedback

I am slightly familiar with PyTorch, so obviously I started with a class MLP. As it turns out, this is not the most idiomatic way to write JAX code, more on that later. But here’s the first version, including all its bugs.

import jax
import jax.numpy as jnp

# Synthetic data gen

N = 10000
key = jax.random.PRNGKey(42)
key, subkey = jax.random.split(key)
X = jax.random.normal(subkey, (N, 5))

def real_func(xs):
  return jnp.sum(xs, axis=1, keepdims=True)

ys = real_func(X)

print(f'Generated synthetic data: X: {X.shape}, ys: {ys.shape}')

# Model meta defn.

class MLP:
  def __init__(self, key, input_size, hidden_sizes, output_size):
    self.params = []
    sizes = [input_size] + hidden_sizes
    self.nparams = []
    for i, (in_size, out_size) in enumerate(zip(sizes[:-1], sizes[1:])):
      key, subkey = jax.random.split(key)
      W = jax.random.normal(subkey, shape=(out_size, in_size)) * jax.numpy.sqrt(2. / in_size)
      b = jax.random.normal(subkey, shape=(out_size,)) * jax.numpy.sqrt(2. / in_size)
      g = jax.random.normal(subkey, shape=(out_size,)) * jax.numpy.sqrt(2. / in_size) # Adaptive gain for LayerNorm
      self.nparams.append(3)
      self.params.extend((W, b, g))

    W_out = jax.random.normal(subkey, shape=(output_size, hidden_sizes[-1])) * jax.numpy.sqrt(2. / hidden_sizes[-1])
    b_out = jax.random.normal(subkey, shape=(output_size,))
    self.params.extend((W_out, b_out))
    self.nparams.append(2)

    self._loss_and_grad = jax.jit(jax.value_and_grad(self._loss))


  def _forward(self, params, xs):
    for i in range(0, len(params)-2, 3):
      W, b, g = params[i:i+3]
      # LayerNorm
      a = xs @ W.T
      mu = a.mean(axis=1, keepdims=True)
      sigma = a.std(axis=1, keepdims=True)
      # print(f'mu:{mu.item()}:{mu.shape}, sigma:{sigma.item()}:{sigma.shape}')
      a = g * (a - mu) / sigma
      # print(f'a:{a.shape}')
      # ReLU
      xs = jax.nn.relu(a + b)
    W, b = params[-2:]
    xs = xs @ W.T + b
    return xs

  def forward(self, xs):
    return self._forward(self.params, xs)

  def _loss(self, params, xs, ys):
    preds = self._forward(params, xs)
    assert preds.shape == ys.shape, f'prediction shape {preds.shape} != truth shape {ys.shape}'
    return jax.numpy.mean((ys - preds) ** 2)

  def loss(self, xs, ys):
    return self._loss(self.params, xs, ys)

  def loss_and_grad(self, xs, ys):
    return self._loss_and_grad(self.params, xs, ys)

# Train/val split

ntrain = int(.8 * N)
xs_train, ys_train = X[:ntrain], ys[:ntrain]
xs_val, ys_val = X[ntrain:], ys[ntrain:]
print(f'Training set: {xs_train.shape}, {ys_train.shape}')
print(f'Validation set: {xs_val.shape}, {ys_val.shape}')

# Model instantiation

key, subkey = jax.random.split(key)
mlp = MLP(subkey, 5, [8, 6, 6], 1)
# print(mlp.loss_and_grad(xs_train[:1], ys_train[:1]))
# print('Weights\n', mlp.params)
# print('Weight shapes\n', [p.shape for p in mlp.params])

# Training

lr = 0.01
niter = 1000
nbatch = 256
idx = jax.random.permutation(subkey, ntrain)
start = 0
for i in range(niter):
  key, subkey = jax.random.split(key)
  idx = jax.random.permutation(subkey, ntrain)[:nbatch]
  xb, yb = xs_train[idx[start:start+nbatch]], ys_train[idx[start:start+nbatch]]
  start += nbatch
  if start >= ntrain:
    start = 0
    key, subkey = jax.random.split(key)
    idx = jax.random.permutation(subkey, ntrain)

  l, dl = mlp.loss_and_grad(xb, yb)
  if i % 100 == 0:
    train_loss = mlp.loss(xs_train, ys_train)
    print(f'Iteration {i}, batch loss = {l.item()}, train loss = {train_loss.item()}')
  for k in range(len(mlp.params)):
    mlp.params[k] -= lr * dl[k]

print(f'Final train loss = {mlp.loss(xs_train, ys_train).item()}')
print(f'Final val loss = {mlp.loss(xs_val, ys_val).item()}')
Generated synthetic data: X: (10000, 5), ys: (10000, 1)
Training set: (8000, 5), (8000, 1)
Validation set: (2000, 5), (2000, 1)
Iteration 0, batch loss = 13.008309364318848, train loss = 13.735170364379883
Iteration 100, batch loss = nan, train loss = 8.126752853393555
Iteration 200, batch loss = nan, train loss = 7.5402913093566895
Iteration 300, batch loss = nan, train loss = 7.13360595703125
Iteration 400, batch loss = nan, train loss = 6.8245697021484375
Iteration 500, batch loss = nan, train loss = 6.5315937995910645
Iteration 600, batch loss = nan, train loss = 6.269874572753906
Iteration 700, batch loss = nan, train loss = 6.0824151039123535
Iteration 800, batch loss = 6.443068504333496, train loss = 5.874807834625244
Iteration 900, batch loss = nan, train loss = 5.619460105895996
Final train loss = 5.504083156585693
Final val loss = 5.637689113616943

Both the final training and validation losses are high for a simple function and a frankly large network. I stared the code down, but I’ll be honest, I couldn’t figure what’s going on. So I moved on to step two, ask for a critique.

A summary of fixes:

1. Use separate RNG keys for weight initialization: The following reuses the random state and can cause the parameters to be correlated, which is no good. When I made this change alone.

      key, subkey = jax.random.split(key)
      W = jax.random.normal(subkey, shape=(out_size, in_size)) * jax.numpy.sqrt(2. / in_size)
      b = jax.random.normal(subkey, shape=(out_size,)) * jax.numpy.sqrt(2. / in_size)
      g = jax.random.normal(subkey, shape=(out_size,)) * jax.numpy.sqrt(2. / in_size) # Adaptive gain for LayerNorm

When I made these changes alone, there wasn’t an appreciable decrease in training or val losses. But this was a good fix.

2. LayerNorm implementation: One was an obvious fix: Use a small epsilon (say \(10^{-5}\)) in the denominator to guard against division by zero, in case the std is zero. The second fix was more interesting: Apparently for LayerNorm, the beta and gamma parameters should not be He initialized, but are to be initialized to zeros and ones respectively. These fixes moved my val loss from 5.64 to 2.86!! In hindsight, I the gain parameters g starting out as all 1s makes sense – why decimate the activations a priori by using g values drawn from a normal?

3. Batching bug: Uh, embarrasing. See how I re-compute a 256-long idx array in every iteration of the training loop, but still honest-to-god advance the start index? Well, since len(idx)=256, the second and subsequent iterations will attempt the indexing idx[256:512], idx[512:768]… which will all be empty slices… until start is reset, and then we get one nonempty batch again. Yuck. The fix was simple of course, and got us from 5.64 in val loss to 0.82!!

4. JIT’ing the core: I am JIT’ing only the loss/grad computation function, but the parameter updates still happen in Python in the training loop. The idiom here is to JIT the toplevel unit that gets invoked over and over. In our case, we’d define a train_step function and move the param updates in there. Making this change made the code much, much faster. I did not %timeit, maybe I should, but the difference is apparent.

5. Use optax: This library provides implementations of optimizers like AdamW, which do more sophisticated things like momentum, gradient clipping for numerical stablity. It also becomes easier to support learning rate schedules later. Using optax moved the val loss from 0.82 to 0.57. Nice, free win. But beware, naively updating parameters in the step function won’t work. Since JAX arrays are immutable, we have to return the updated params and continue in our training loop with these updated params. This makes the code awkward, but I persist in using an MLP class for now, and punt the translation to a more functional style to the next iteration.

import jax, optax
import jax.numpy as jnp

# Synthetic data gen

N = 10000
key = jax.random.PRNGKey(42)
key, subkey = jax.random.split(key)
X = jax.random.normal(subkey, (N, 5))

def real_func(xs):
  return jnp.sum(xs, axis=1, keepdims=True)

ys = real_func(X)

print(f'Generated synthetic data: X: {X.shape}, ys: {ys.shape}')

# Model meta defn.

class MLP:
  def __init__(self, key, input_size, hidden_sizes, output_size):
    self.params = []
    sizes = [input_size] + hidden_sizes
    self.nparams = []
    for i, (in_size, out_size) in enumerate(zip(sizes[:-1], sizes[1:])):
      key, wkey = jax.random.split(key, 2)
      W = jax.random.normal(wkey, shape=(out_size, in_size)) * jax.numpy.sqrt(2. / in_size)
      b = jnp.zeros(shape=(out_size,))
      g = jnp.ones(shape=(out_size,)) # Adaptive gain for LayerNorm
      self.nparams.append(3)
      self.params.extend((W, b, g))
    key, wkey, bkey = jax.random.split(key, 3)
    W_out = jax.random.normal(wkey, shape=(output_size, hidden_sizes[-1])) * jax.numpy.sqrt(2. / hidden_sizes[-1])
    b_out = jax.random.normal(bkey, shape=(output_size,))
    self.params.extend((W_out, b_out))
    self.nparams.append(2)

    self._loss_and_grad = jax.value_and_grad(self._loss)

  def _forward(self, params, xs):
    eps = 1e-5
    for i in range(0, len(params)-2, 3):
      W, b, g = params[i:i+3]
      # LayerNorm
      a = xs @ W.T
      mu = a.mean(axis=1, keepdims=True)
      sigma = a.std(axis=1, keepdims=True)
      # print(f'mu:{mu.item()}:{mu.shape}, sigma:{sigma.item()}:{sigma.shape}')
      a = g * (a - mu) / (eps + sigma)
      # print(f'a:{a.shape}')
      # ReLU
      xs = jax.nn.relu(a + b)
    W, b = params[-2:]
    xs = xs @ W.T + b
    return xs

  def forward(self, xs):
    return self._forward(self.params, xs)

  def _loss(self, params, xs, ys):
    preds = self._forward(params, xs)
    assert preds.shape == ys.shape, f'prediction shape {preds.shape} != truth shape {ys.shape}'
    return jax.numpy.mean((ys - preds) ** 2)

  def loss(self, xs, ys):
    return self._loss(self.params, xs, ys)

  def loss_and_grad(self, xs, ys):
    return self._loss_and_grad(self.params, xs, ys)

# Train/val split

ntrain = int(.8 * N)
xs_train, ys_train = X[:ntrain], ys[:ntrain]
xs_val, ys_val = X[ntrain:], ys[ntrain:]
print(f'Training set: {xs_train.shape}, {ys_train.shape}')
print(f'Validation set: {xs_val.shape}, {ys_val.shape}')

# Model instantiation

key, subkey = jax.random.split(key)
mlp = MLP(subkey, 5, [8, 6, 6], 1)
# print(mlp.loss_and_grad(xs_train[:1], ys_train[:1]))
# print('Weights\n', mlp.params)
# print('Weight shapes\n', [p.shape for p in mlp.params])

# Training

opt = optax.adamw(learning_rate=0.01, weight_decay=0.001)

@jax.jit
def train_step(params, opt_state, xs, ys):
  l, dl = mlp._loss_and_grad(params, xs, ys)
  updates, opt_state = opt.update(dl, opt_state, params)
  params = optax.apply_updates(params, updates)
  return params, opt_state


key, subkey = jax.random.split(key)
idx = jax.random.permutation(subkey, ntrain)
niter = 1000
nbatch = 256
start = 0
opt_state = opt.init(mlp.params)

for i in range(niter):
  key, subkey = jax.random.split(key)
  xb, yb = xs_train[idx[start:start+nbatch]], ys_train[idx[start:start+nbatch]]
  start += nbatch
  if start >= ntrain:
    start = 0
    key, subkey = jax.random.split(key)
    idx = jax.random.permutation(subkey, ntrain)
  mlp.params, opt_state = train_step(mlp.params, opt_state, xb, yb)
  if i % 100 == 0:
    train_loss = mlp.loss(xs_train, ys_train)
    print(f'Iteration {i}, batch loss = {l.item()}, train loss = {train_loss.item()}')
print(f'Final train loss = {mlp.loss(xs_train, ys_train).item()}')
print(f'Final val loss = {mlp.loss(xs_val, ys_val).item()}')
Generated synthetic data: X: (10000, 5), ys: (10000, 1)
Training set: (8000, 5), (8000, 1)
Validation set: (2000, 5), (2000, 1)
Iteration 0, batch loss = nan, train loss = 4.737359523773193
Iteration 100, batch loss = nan, train loss = 0.7498972415924072
Iteration 200, batch loss = nan, train loss = 0.6198509335517883
Iteration 300, batch loss = nan, train loss = 0.5501429438591003
Iteration 400, batch loss = nan, train loss = 0.49914073944091797
Iteration 500, batch loss = nan, train loss = 0.5024610757827759
Iteration 600, batch loss = nan, train loss = 0.48184746503829956
Iteration 700, batch loss = nan, train loss = 0.47913187742233276
Iteration 800, batch loss = nan, train loss = 0.4928268492221832
Iteration 900, batch loss = nan, train loss = 0.4822283685207367
Final train loss = 0.5036829710006714
Final val loss = 0.5718706250190735

More idiomatic, functional code

The above code is modeled after how we’d code a model in PyTorch. And it smells, as we have to do the weird param shuffling in and out of the model instance. So here is the above code minus redundant parts, with pure functions and a toplevel training loop that plumbs all required state explicitly.

import jax, optax
import jax.numpy as jnp

# Synthetic data gen

N = 10000
key = jax.random.PRNGKey(42)
key, subkey = jax.random.split(key)
X = jax.random.normal(subkey, (N, 5))

def real_func(xs):
  return jnp.sum(xs, axis=1, keepdims=True)

ys = real_func(X)

print(f'Generated synthetic data: X: {X.shape}, ys: {ys.shape}')

# Model meta defn.

def init_mlp(key, input_size, hidden_sizes, output_size):
  params = []
  sizes = [input_size] + hidden_sizes
  for i, (in_size, out_size) in enumerate(zip(sizes[:-1], sizes[1:])):
    key, wkey = jax.random.split(key, 2)
    params.append(dict(
      W=jax.random.normal(wkey, shape=(out_size, in_size)) * jax.numpy.sqrt(2. / in_size),
      b=jnp.zeros(shape=(out_size,)),
      g=jnp.ones(shape=(out_size,))))
  key, wkey, bkey = jax.random.split(key, 3)
  W_out = jax.random.normal(wkey, shape=(output_size, hidden_sizes[-1])) * jax.numpy.sqrt(2. / hidden_sizes[-1])
  b_out = jax.random.normal(bkey, shape=(output_size,))
  return params, {'W': W_out, 'b': b_out}

def predict(params, xs):
  layers, head = params
  eps = 1e-5
  for l in layers:
    W = l['W']
    b = l['b']
    g = l['g']
    # LayerNorm
    a = xs @ W.T
    mu = a.mean(axis=1, keepdims=True)
    sigma = a.std(axis=1, keepdims=True)
    a = g * (a - mu) / (eps + sigma)
    # ReLU
    xs = jax.nn.relu(a + b)
  # Output: No norm/nonlinearity.
  W, b = head['W'], head['b']
  xs = xs @ W.T + b
  return xs

@jax.jit
def mse_loss(params, xs, ys):
  preds = predict(params, xs)
  return jnp.mean((ys - preds) ** 2)

def step(opt, params, opt_state, xb, yb):
  l, dl = jax.value_and_grad(mse_loss)(params, xb, yb)
  #print(dl)
  updates, opt_state = opt.update(dl, opt_state, params)
  params = optax.apply_updates(params, updates)
  return tuple(params), opt_state

step = jax.jit(step, static_argnames='opt')

# Train/val split

ntrain = int(.8 * N)
xs_train, ys_train = X[:ntrain], ys[:ntrain]
xs_val, ys_val = X[ntrain:], ys[ntrain:]
print(f'Training set: {xs_train.shape}, {ys_train.shape}')
print(f'Validation set: {xs_val.shape}, {ys_val.shape}')

# Model instantiation

key, subkey = jax.random.split(key)
params = init_mlp(subkey, 5, [8, 6, 6], 1)

# Training

opt = optax.adamw(learning_rate=0.01, weight_decay=0.001)

key, subkey = jax.random.split(key)
idx = jax.random.permutation(subkey, ntrain)
niter = 1000
nbatch = 256
start = 0
opt_state = opt.init(params)

for i in range(niter):
  key, subkey = jax.random.split(key)
  xb, yb = xs_train[idx[start:start+nbatch]], ys_train[idx[start:start+nbatch]]
  start += nbatch
  if start >= ntrain:
    start = 0
    key, subkey = jax.random.split(key)
    idx = jax.random.permutation(subkey, ntrain)
  params, opt_state = step(opt, params, opt_state, xb, yb)
  if i % 100 == 0:
    l = mse_loss(params, xb, yb)
    train_loss = mse_loss(params, xs_train, ys_train)
    print(f'Iteration {i}, batch loss = {l.item()}, train loss = {train_loss.item()}')
print(f'Final train loss = {mse_loss(params, xs_train, ys_train).item()}')
print(f'Final val loss = {mse_loss(params, xs_val, ys_val).item()}')

Generated synthetic data: X: (10000, 5), ys: (10000, 1)
Training set: (8000, 5), (8000, 1)
Validation set: (2000, 5), (2000, 1)
Iteration 0, batch loss = 5.191179275512695, train loss = 4.737359523773193
Iteration 100, batch loss = 0.6641169786453247, train loss = 0.7498972415924072
Iteration 200, batch loss = 0.5052511096000671, train loss = 0.6198509335517883
Iteration 300, batch loss = 0.4988126754760742, train loss = 0.5501430034637451
Iteration 400, batch loss = 0.5596280694007874, train loss = 0.49914073944091797
Iteration 500, batch loss = 0.4847404360771179, train loss = 0.5024611353874207
Iteration 600, batch loss = 0.6246224045753479, train loss = 0.48184746503829956
Iteration 700, batch loss = 0.45626115798950195, train loss = 0.47913187742233276
Iteration 800, batch loss = 0.5729255676269531, train loss = 0.4928268492221832
Iteration 900, batch loss = 0.5909319519996643, train loss = 0.4822283983230591
Final train loss = 0.5036829710006714
Final val loss = 0.5718706250190735