Regression example with a neural network

Published

July 19, 2023

Prepare data

import pickle
import gzip
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
with gzip.open("rec_hist_pt_25.pkl.gz") as f:
    rec = pickle.load(f)

with gzip.open("gen_hist_pt_25.pkl.gz") as f:
    gen = pickle.load(f)

# fix for divide by zero
gen.values()[gen.values() == 0] = 1

n_gen = gen.values()
n_rec = rec.values()
# fix for rec > n
n_rec = np.where(n_rec > n_gen, n_gen, n_rec)

X = []
data_k = []
data_n = []
for ieta, eta in enumerate(rec.axes[0].centers):
    for ipt, pt in enumerate(rec.axes[1].centers):
        for iphi, phi in enumerate(rec.axes[2].centers):
            for ich, _ in enumerate(rec.axes[3]):
                for isp, _ in enumerate(rec.axes[4]):
                    X.append((eta, np.log(pt), np.cos(phi), np.sin(phi), ich, isp))
                    data_k.append(n_rec[ieta, ipt, iphi, ich, isp])
                    data_n.append(n_gen[ieta, ipt, iphi, ich, isp])

X0 = np.array(X).astype(np.float32)
data_k = np.array(data_k).astype(np.float32)
data_n = np.array(data_n).astype(np.float32)
y = data_k / data_n

scaler = StandardScaler()
scaler.fit(X0)

X = scaler.transform(X0)

print("data points", X.shape[0])
data points 14400
def draw(model):
    labels = ("π", "K", "p", "other")
    for phii in rec.axes[2].centers[:]:
        if phii < 0:
            continue
        fig, axes = plt.subplots(3, 2, sharex=True, sharey=True, layout="compressed")
        plt.suptitle(rf"$\phi = {np.degrees(phii):.0f}$ deg")
        for etai, axi in zip(rec.axes[0].centers, axes.flat):
            plt.sca(axi)
            for ich, charge in enumerate((-1, 1)):
                for isp, label in enumerate(labels):
                    ma = X0[:, 0] == etai
                    ma &= X0[:, 2] == np.cos(phii)
                    ma &= X0[:, 3] == np.sin(phii)
                    ma &= X0[:, 4] == ich
                    ma &= X0[:, 5] == isp
                    plt.plot(np.exp(X0[ma, 1]), charge * y[ma], "o", ms=4, color=f"C{isp}", label=f"{label}" if ich else None)

                    mpt = np.geomspace(10, 1e4, 2000)
                    Xp = np.empty((len(mpt), 6), dtype=np.float32)
                    Xp[:, 0] = etai
                    Xp[:, 1] = np.log(mpt)
                    Xp[:, 2] = np.cos(phii)
                    Xp[:, 3] = np.sin(phii)
                    Xp[:, 4] = ich
                    Xp[:, 5] = isp
                    yp = model(scaler.transform(Xp))
                    plt.plot(np.exp(Xp[:, 1]), charge * yp, color=f"C{isp}")
        plt.sca(axes[0,0])
        plt.semilogx()
        plt.ylim(-1.1, 1.1)
        plt.legend(fontsize="xx-small", ncol=2, frameon=False);

Scikit-Learn

It turns out that the simple MLPRegressor in Scikit-Learn works very well on small datasets.

from sklearn.neural_network import MLPRegressor
from sklearn.metrics import mean_squared_error

clf = MLPRegressor(
    hidden_layer_sizes=(2**6,) * 4,
    batch_size=1000,
    alpha=0,
    tol=1e-6,
    max_iter=1000,
    verbose=1,
)

clf.fit(X, y)

# coefficients only available after training
n = 0
for c in clf.coefs_:
    n += np.prod(c.shape)
for c in clf.intercepts_:
    n += c.shape[0]
print("number of parameters", n)

print("MLPRegressor", mean_squared_error(clf.predict(X), y))
Iteration 1, loss = 0.03872877
Iteration 2, loss = 0.01838284
Iteration 3, loss = 0.01322917
Iteration 4, loss = 0.00992924
Iteration 5, loss = 0.00764789
Iteration 6, loss = 0.00596758
Iteration 7, loss = 0.00476454
Iteration 8, loss = 0.00403427
Iteration 9, loss = 0.00354749
Iteration 10, loss = 0.00312269
Iteration 11, loss = 0.00285058
Iteration 12, loss = 0.00266859
Iteration 13, loss = 0.00250857
Iteration 14, loss = 0.00243164
Iteration 15, loss = 0.00227876
Iteration 16, loss = 0.00218516
Iteration 17, loss = 0.00204908
Iteration 18, loss = 0.00195609
Iteration 19, loss = 0.00188667
Iteration 20, loss = 0.00181428
Iteration 21, loss = 0.00181260
Iteration 22, loss = 0.00172098
Iteration 23, loss = 0.00166783
Iteration 24, loss = 0.00162889
Iteration 25, loss = 0.00162617
Iteration 26, loss = 0.00151559
Iteration 27, loss = 0.00148309
Iteration 28, loss = 0.00146949
Iteration 29, loss = 0.00142209
Iteration 30, loss = 0.00137413
Iteration 31, loss = 0.00133488
Iteration 32, loss = 0.00132600
Iteration 33, loss = 0.00133857
Iteration 34, loss = 0.00133807
Iteration 35, loss = 0.00125656
Iteration 36, loss = 0.00125801
Iteration 37, loss = 0.00124734
Iteration 38, loss = 0.00116203
Iteration 39, loss = 0.00112787
Iteration 40, loss = 0.00113262
Iteration 41, loss = 0.00111350
Iteration 42, loss = 0.00109534
Iteration 43, loss = 0.00109662
Iteration 44, loss = 0.00103106
Iteration 45, loss = 0.00099346
Iteration 46, loss = 0.00098687
Iteration 47, loss = 0.00103844
Iteration 48, loss = 0.00099100
Iteration 49, loss = 0.00096235
Iteration 50, loss = 0.00097157
Iteration 51, loss = 0.00092747
Iteration 52, loss = 0.00091701
Iteration 53, loss = 0.00089235
Iteration 54, loss = 0.00087139
Iteration 55, loss = 0.00087560
Iteration 56, loss = 0.00087910
Iteration 57, loss = 0.00090950
Iteration 58, loss = 0.00085914
Iteration 59, loss = 0.00082885
Iteration 60, loss = 0.00081474
Iteration 61, loss = 0.00084764
Iteration 62, loss = 0.00081955
Iteration 63, loss = 0.00083712
Iteration 64, loss = 0.00080643
Iteration 65, loss = 0.00078607
Iteration 66, loss = 0.00077217
Iteration 67, loss = 0.00080849
Iteration 68, loss = 0.00081811
Iteration 69, loss = 0.00078761
Iteration 70, loss = 0.00077427
Iteration 71, loss = 0.00075709
Iteration 72, loss = 0.00075573
Iteration 73, loss = 0.00075636
Iteration 74, loss = 0.00074296
Iteration 75, loss = 0.00075769
Iteration 76, loss = 0.00078198
Iteration 77, loss = 0.00074714
Iteration 78, loss = 0.00076107
Iteration 79, loss = 0.00073871
Iteration 80, loss = 0.00071247
Iteration 81, loss = 0.00073748
Iteration 82, loss = 0.00074470
Iteration 83, loss = 0.00070695
Iteration 84, loss = 0.00069866
Iteration 85, loss = 0.00068365
Iteration 86, loss = 0.00067912
Iteration 87, loss = 0.00068739
Iteration 88, loss = 0.00067988
Iteration 89, loss = 0.00070646
Iteration 90, loss = 0.00067668
Iteration 91, loss = 0.00068584
Iteration 92, loss = 0.00066119
Iteration 93, loss = 0.00065713
Iteration 94, loss = 0.00079503
Iteration 95, loss = 0.00070443
Iteration 96, loss = 0.00069255
Iteration 97, loss = 0.00064721
Iteration 98, loss = 0.00066456
Iteration 99, loss = 0.00066503
Iteration 100, loss = 0.00064768
Iteration 101, loss = 0.00063161
Iteration 102, loss = 0.00065152
Iteration 103, loss = 0.00063627
Iteration 104, loss = 0.00062377
Iteration 105, loss = 0.00063346
Iteration 106, loss = 0.00065331
Iteration 107, loss = 0.00063847
Iteration 108, loss = 0.00061856
Iteration 109, loss = 0.00061806
Iteration 110, loss = 0.00060691
Iteration 111, loss = 0.00061332
Iteration 112, loss = 0.00061673
Iteration 113, loss = 0.00064613
Iteration 114, loss = 0.00060685
Iteration 115, loss = 0.00059055
Iteration 116, loss = 0.00060077
Iteration 117, loss = 0.00062042
Iteration 118, loss = 0.00065500
Iteration 119, loss = 0.00060160
Iteration 120, loss = 0.00060204
Iteration 121, loss = 0.00059747
Iteration 122, loss = 0.00061649
Iteration 123, loss = 0.00060509
Iteration 124, loss = 0.00063298
Iteration 125, loss = 0.00060878
Iteration 126, loss = 0.00061471
Training loss did not improve more than tol=0.000001 for 10 consecutive epochs. Stopping.
number of parameters 12993
MLPRegressor 0.0011639113
draw(clf.predict)

PyTorch

import torch
from torch import nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.optim import Adam

device = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_default_device(device)
torch.set_default_dtype(torch.float32)

nonlin = nn.ReLU()
num = 2**6
torch_model = nn.Sequential(
    nn.Linear(6, num),
    nonlin,
    nn.Linear(num, num),
    nonlin,
    nn.Linear(num, num),
    nonlin,
    nn.Linear(num, num),
    nonlin,
    nn.Linear(num, 1),
    # adding a ReLU here makes training unstable
    nn.Flatten(0, 1),
)

n = 0
for par in torch_model.parameters():
    n += np.prod(par.shape)
print("number of parameters", n)

# must make copies here
torch_X = torch.tensor(X.copy())
torch_y = torch.tensor(y.copy())

max_epoch = 5000
shrink_patience = 10
abort_patience = 100
learning_rate = 1e-3
learning_rate_shrink_factor = 0.8
tol = 1e-6

loss_fn = nn.MSELoss()
opt = Adam(torch_model.parameters(), lr=learning_rate)
scheduler = ReduceLROnPlateau(
    opt, patience=shrink_patience, factor=learning_rate_shrink_factor, verbose=True
)

losses = []
no_improvement_below_tolerance = 0
prev_loss = np.inf
for epoch in range(1, max_epoch + 1):
    y_pred = torch_model(torch_X)
    loss = loss_fn(y_pred, torch_y)
    opt.zero_grad()
    loss.backward()
    opt.step()
    scheduler.step(loss)

    loss = loss.item()
    losses.append(loss)
    if epoch == 1 or epoch % 100 == 0:
        print("epoch", epoch, "loss", loss)

    if loss < prev_loss - tol:
        no_improvement_below_tolerance = 0
        prev_loss = loss
    else:
        no_improvement_below_tolerance += 1
    if no_improvement_below_tolerance > abort_patience:
        break

plt.plot(losses)
plt.semilogy();
number of parameters 12993
epoch 1 loss 0.25884100794792175
epoch 100 loss 0.010155524127185345
epoch 200 loss 0.003915281500667334
epoch 300 loss 0.002415228867903352
epoch 400 loss 0.0017484889831393957
epoch 500 loss 0.0014784386148676276
Epoch 00508: reducing learning rate of group 0 to 8.0000e-04.
epoch 600 loss 0.0013304103631526232
Epoch 00695: reducing learning rate of group 0 to 6.4000e-04.
epoch 700 loss 0.001243863138370216
epoch 800 loss 0.0011801073560491204
epoch 900 loss 0.0011410759761929512
Epoch 00905: reducing learning rate of group 0 to 5.1200e-04.
epoch 1000 loss 0.00109128060285002
epoch 1100 loss 0.0010561983799561858
Epoch 01152: reducing learning rate of group 0 to 4.0960e-04.
epoch 1200 loss 0.0010263388976454735
epoch 1300 loss 0.001002235570922494
epoch 1400 loss 0.0009775016224011779
Epoch 01450: reducing learning rate of group 0 to 3.2768e-04.
epoch 1500 loss 0.0009563906933180988
epoch 1600 loss 0.0009381748968735337
epoch 1700 loss 0.0009208598639816046
epoch 1800 loss 0.0009038595599122345
Epoch 01887: reducing learning rate of group 0 to 2.6214e-04.
epoch 1900 loss 0.0008835234912112355
epoch 2000 loss 0.0008701192564330995
epoch 2100 loss 0.0008563698502257466
epoch 2200 loss 0.000845507369376719
epoch 2300 loss 0.0008326928946189582
Epoch 02305: reducing learning rate of group 0 to 2.0972e-04.
epoch 2400 loss 0.0008199398289434612
epoch 2500 loss 0.0008091804338619113
epoch 2600 loss 0.000800078792963177
epoch 2700 loss 0.0007896188762970269
epoch 2800 loss 0.000779114430770278
Epoch 02864: reducing learning rate of group 0 to 1.6777e-04.
epoch 2900 loss 0.0007669864571653306
epoch 3000 loss 0.0007585080456919968
epoch 3100 loss 0.0007498746854253113
epoch 3200 loss 0.0007412403356283903
epoch 3300 loss 0.0007329632644541562
epoch 3400 loss 0.0007246120949275792
Epoch 03408: reducing learning rate of group 0 to 1.3422e-04.
epoch 3500 loss 0.0007171895122155547
epoch 3600 loss 0.0007101951632648706
epoch 3700 loss 0.0007029927219264209
epoch 3800 loss 0.0006956694996915758
Epoch 03889: reducing learning rate of group 0 to 1.0737e-04.
epoch 3900 loss 0.0006882176967337728
epoch 4000 loss 0.0006816069362685084
epoch 4100 loss 0.0006751783657819033
epoch 4200 loss 0.0006688352441415191
epoch 4300 loss 0.0006628871196880937
epoch 4400 loss 0.0006569336983375251
epoch 4500 loss 0.0006518965819850564
epoch 4600 loss 0.0006464265170507133
epoch 4700 loss 0.00063846365083009
epoch 4800 loss 0.0006334380595944822
epoch 4900 loss 0.0006282057147473097
epoch 5000 loss 0.0006211302825249732

print("MLPRegressor", mean_squared_error(clf.predict(X), y))
print("PyTorch     ", mean_squared_error(torch_model(torch_X).detach().numpy(), y))
MLPRegressor 0.0011639113
PyTorch      0.00062100013
draw(lambda x: torch_model(torch.tensor(x)).detach().numpy())

Flax / JAX

Flax with JAX is a modern functional programming library which is great for trying out new things. It comes with fewer high-level functionality than PyTorch. Where PyTorch offers ReduceLROnPlateau, we need to implement something analog from scratch.

from typing import Callable, Iterable
from types import SimpleNamespace
import jax
from jax import numpy as jnp
from flax import linen as nn
import optax


class Model(nn.Module):
    sizes: Iterable[int]  # Number of hidden neurons
    nonlin: Callable

    @nn.compact  # Tells Flax to look for defined submodules
    def __call__(self, x):
        for size in self.sizes:
            x = nn.Dense(size)(x)
            x = self.nonlin(x)
        return nn.Dense(1)(x).flatten()


def train(config: SimpleNamespace, x, y):
    x = jax.device_put(x)
    y = jax.device_put(y)

    model = Model(config.layers, config.nonlin)

    @jax.jit
    def loss_fn(theta, x, y):
        def squared_error(x, y):
            y_pred = model.apply(theta, x)
            return jnp.square(y - y_pred)
        return jnp.mean(jax.vmap(squared_error, in_axes=0)(x, y))

    key1, key2 = jax.random.split(jax.random.PRNGKey(0))
    theta = model.init(key2, jax.random.normal(key1, x.shape[1:]))

    n = 0
    for layer in theta["params"].values():
        n += np.prod(layer["kernel"].shape) + layer["bias"].shape[0]
    print("number of parameters", n)

    tx = optax.inject_hyperparams(lambda learning_rate: optax.chain(
        optax.clip(1),  # equivalent to using huber loss
        optax.adabelief(learning_rate),
    ))(learning_rate=config.learning_rate)
    tx_state = tx.init(theta)
    loss_grad_fn = jax.value_and_grad(loss_fn)

    losses = []
    prev_loss = np.inf
    epochs_without_improvement = 0
    for epoch in range(1, config.max_epoch + 1):

        loss, grads = loss_grad_fn(theta, x, y)
        updates, tx_state = tx.update(grads, tx_state)
        theta = optax.apply_updates(theta, updates)

        losses.append(loss)

        if epoch == 1 or epoch % config.print_freq == 0:
            print(f"epoch {epoch} loss={loss}")

        if loss < prev_loss:
            epochs_without_improvement = 0
            prev_loss = loss
        else:
            epochs_without_improvement += 1
        if epochs_without_improvement >= config.shrink_patience:
            tx_state.hyperparams["learning_rate"] *= config.learning_rate_shrink_factor
            print(f"epoch {epoch } learning rate={tx_state.hyperparams['learning_rate']}")
            epochs_without_improvement = 0
            continue
        if tx_state.hyperparams["learning_rate"] < config.tolerance:
            break

    return model, theta, losses

print(f"JAX process: {jax.process_index()} / {jax.process_count()}")
print(f"JAX local devices: {jax.local_devices()}")

config = SimpleNamespace()
config.layers = (2 ** 6,) * 4
config.nonlin = nn.relu
config.max_epoch = 5000
config.learning_rate = 1e-2
config.shrink_patience = 10
config.learning_rate_shrink_factor = 0.8
config.print_freq = 100
config.tolerance = 1e-6

prev_losses = locals().get("prev_losses", None)

flax_model, flax_theta, losses = train(config, X, y)

best_losses = locals().get("best_losses", losses)
if best_losses[-1] < losses[-1]:
    best_losses = losses

plt.plot(losses, label="now")
if prev_losses is not None:
    plt.plot(prev_losses, label="previous")
plt.plot(best_losses, label="best")
plt.legend()
plt.semilogy();
JAX process: 0 / 1
JAX local devices: [CpuDevice(id=0)]
number of parameters 12993
epoch 1 loss=0.1615249365568161
epoch 100 loss=0.004056446719914675
epoch 200 loss=0.0022707092575728893
epoch 237 learning rate=0.00800000037997961
epoch 247 learning rate=0.006400000303983688
epoch 300 loss=0.0015574624994769692
epoch 400 loss=0.0013330463552847505
epoch 500 loss=0.0012052441015839577
epoch 600 loss=0.0011462707770988345
epoch 604 learning rate=0.005120000336319208
epoch 700 loss=0.0010661932174116373
epoch 773 learning rate=0.0040960004553198814
epoch 800 loss=0.0010216122027486563
epoch 900 loss=0.0009868466295301914
epoch 982 learning rate=0.0032768005039542913
epoch 1000 loss=0.0009563774801790714
epoch 1100 loss=0.0009271150338463485
epoch 1200 loss=0.0008946977904997766
epoch 1256 learning rate=0.002621440449729562
epoch 1300 loss=0.0008701515616849065
epoch 1400 loss=0.0008422626415267587
epoch 1500 loss=0.0008212537504732609
epoch 1600 loss=0.0008094692602753639
epoch 1602 learning rate=0.002097152406349778
epoch 1700 loss=0.0007778366561979055
epoch 1800 loss=0.0007623148267157376
epoch 1867 learning rate=0.0016777219716459513
epoch 1900 loss=0.0007488274131901562
epoch 2000 loss=0.0007354258559644222
epoch 2100 loss=0.0007237782701849937
epoch 2200 loss=0.0007126012351363897
epoch 2236 learning rate=0.0013421776238828897
epoch 2300 loss=0.0007015920709818602
epoch 2400 loss=0.0006939161103218794
epoch 2500 loss=0.000684640312101692
epoch 2544 learning rate=0.001073742168955505
epoch 2600 loss=0.0006755936774425209
epoch 2700 loss=0.0006690087029710412
epoch 2768 learning rate=0.0008589937351644039
epoch 2800 loss=0.0006616609753109515
epoch 2853 learning rate=0.0006871949881315231
epoch 2900 loss=0.0006564840441569686
epoch 3000 loss=0.0006521755130961537
epoch 3100 loss=0.0006475721020251513
epoch 3200 loss=0.0006428791675716639
epoch 3300 loss=0.0006390072521753609
epoch 3344 learning rate=0.0005497559905052185
epoch 3386 learning rate=0.000439804804045707
epoch 3400 loss=0.000634271593298763
epoch 3500 loss=0.0006310778553597629
epoch 3600 loss=0.0006278276559896767
epoch 3700 loss=0.0006245431723073125
epoch 3800 loss=0.000620958860963583
epoch 3900 loss=0.0006179117481224239
epoch 4000 loss=0.0006144311628304422
epoch 4091 learning rate=0.00035184386069886386
epoch 4100 loss=0.0006107029039412737
epoch 4200 loss=0.0006076746503822505
epoch 4300 loss=0.0006047412171028554
epoch 4400 loss=0.0006017453270033002
epoch 4500 loss=0.0005980761488899589
epoch 4510 learning rate=0.00028147510602138937
epoch 4600 loss=0.0005954525549896061
epoch 4700 loss=0.0005925073637627065
epoch 4800 loss=0.0005897046066820621
epoch 4900 loss=0.0005866923602297902
epoch 5000 loss=0.0005837904172949493

print("MLPRegressor", mean_squared_error(clf.predict(X), y))
print("PyTorch     ", mean_squared_error(torch_model(torch_X).detach().numpy(), y))
print("Flax        ", mean_squared_error(flax_model.apply(flax_theta, X), y))
MLPRegressor 0.0011639113
PyTorch      0.00062100013
Flax         0.0005836923
draw(lambda x: flax_model.apply(flax_theta, x))

Another version of learning with FLAX using the meta-learning technique, where the learning rate of the model optimization is learned by an outer learning step. Meta-learning outperforms our previous approach: the loss after 5000 epochs is smaller, and the run-time is a bit shorter.

from typing import Callable, Iterable
from types import SimpleNamespace
import jax
from jax import numpy as jnp
from flax import linen as nn
import optax


class Model(nn.Module):
    sizes: Iterable[int]  # Number of hidden neurons
    nonlin: Callable

    @nn.compact  # Tells Flax to look for defined submodules
    def __call__(self, x):
        for size in self.sizes:
            x = nn.Dense(size)(x)
            x = self.nonlin(x)
        return nn.Dense(1)(x).flatten()


def train(config, x, y):
    x = jax.device_put(x)
    y = jax.device_put(y)
    model = Model(config.layers, config.nonlin)

    root_key = jax.random.PRNGKey(0)
    param_key, perm_key = jax.random.split(root_key)
    theta = model.init(param_key, x)

    opt = optax.inject_hyperparams(optax.rmsprop)(learning_rate=config.learning_rate)

    meta_opt = optax.adabelief(learning_rate=config.meta_learning_rate)

    state = opt.init(theta)
    eta = -np.log(1.0 / config.learning_rate - 1)
    meta_state = meta_opt.init(eta)

    @jax.jit
    def loss_fn(theta, x, y):
        def squared_error(x, y):
            y_pred = model.apply(theta, x)
            return jnp.square(y - y_pred)
        return jnp.mean(jax.vmap(squared_error, in_axes=0)(x, y))

    @jax.jit
    def step(theta, state, x, y):
        loss, grad = jax.value_and_grad(loss_fn)(theta, x, y)
        updates, state = opt.update(grad, state)
        theta = optax.apply_updates(theta, updates)
        return loss, theta, state

    @jax.jit
    def outer_loss(eta, theta, state, x, y):
        state.hyperparams["learning_rate"] = jax.nn.sigmoid(eta)
        loss, theta, state = step(theta, state, x[:-1], y[:-1])
        return loss_fn(theta, x[-1:], y[-1:]), (loss, theta, state)

    @jax.jit
    def outer_step(eta, theta, meta_state, state, x, y):
        grad, (loss, theta, state) = jax.grad(outer_loss, has_aux=True)(eta, theta, state, x, y)
        meta_updates, meta_state = meta_opt.update(grad, meta_state)
        eta = optax.apply_updates(eta, meta_updates)
        return eta, theta, meta_state, state, loss

    learning_rates = []
    losses = []

    for epoch in range(1, config.max_epoch + 1):
        perm_train_key = jax.random.fold_in(perm_key, epoch)
        perm = jax.random.permutation(perm_train_key, len(x))
        eta, theta, meta_state, state, loss = outer_step(eta, theta, meta_state, state, x[perm], y[perm])
        learning_rate = jax.nn.sigmoid(eta)

        losses.append(loss)
        learning_rates.append(learning_rate)

        if epoch == 1 or epoch % config.print_freq == 0:
            print(f"epoch {epoch} loss={loss} lr={learning_rate}")

    return model, theta, losses, learning_rates

config = SimpleNamespace()
config.layers = (2 ** 6,) * 4
config.nonlin = nn.relu
config.max_epoch = 5000
config.learning_rate = 1e-2
config.meta_learning_rate = 0.03
config.print_freq = 100

flax_model, flax_ml_theta, losses, learning_rates = train(config, X, y)

fig, ax = plt.subplots(2, 1, sharex=True)
ax[0].plot(losses)
ax[0].set(ylabel="loss", yscale="log")
ax[1].plot(learning_rates)
ax[1].set(xlabel="epoch", ylabel="lr", yscale="log");
epoch 1 loss=0.31326305866241455 lr=0.00970732606947422
epoch 100 loss=0.022382477298378944 lr=0.008342958986759186
epoch 200 loss=0.009440391324460506 lr=0.008291025646030903
epoch 300 loss=0.005982980597764254 lr=0.008259016089141369
epoch 400 loss=0.004722251556813717 lr=0.008235295303165913
epoch 500 loss=0.0026633746456354856 lr=0.008185734041035175
epoch 600 loss=0.0032396698370575905 lr=0.008146436884999275
epoch 700 loss=0.002778511494398117 lr=0.008116255514323711
epoch 800 loss=0.003227198962122202 lr=0.00809383299201727
epoch 900 loss=0.0033369455486536026 lr=0.008060149848461151
epoch 1000 loss=0.0021190992556512356 lr=0.008024145849049091
epoch 1100 loss=0.002638544887304306 lr=0.007992972619831562
epoch 1200 loss=0.002345192711800337 lr=0.007957393303513527
epoch 1300 loss=0.001486414228565991 lr=0.007931879721581936
epoch 1400 loss=0.0017564770532771945 lr=0.007895255461335182
epoch 1500 loss=0.0017623923486098647 lr=0.007862220518290997
epoch 1600 loss=0.0017055724747478962 lr=0.007831801660358906
epoch 1700 loss=0.0017303635831922293 lr=0.007792978081852198
epoch 1800 loss=0.0011820418294519186 lr=0.007761101704090834
epoch 1900 loss=0.0019437490263953805 lr=0.007730470504611731
epoch 2000 loss=0.0014381277142092586 lr=0.007678112480789423
epoch 2100 loss=0.001700665452517569 lr=0.007647846359759569
epoch 2200 loss=0.0012748426524922252 lr=0.007603880017995834
epoch 2300 loss=0.0014644829789176583 lr=0.007567934691905975
epoch 2400 loss=0.0017952200723811984 lr=0.007531826850026846
epoch 2500 loss=0.0014581787399947643 lr=0.0074939788319170475
epoch 2600 loss=0.0011716835433617234 lr=0.0074319071136415005
epoch 2700 loss=0.0018534348346292973 lr=0.007394589018076658
epoch 2800 loss=0.0011852614115923643 lr=0.0073562441393733025
epoch 2900 loss=0.0014049195451661944 lr=0.007317622657865286
epoch 3000 loss=0.0016386691713705659 lr=0.007295084185898304
epoch 3100 loss=0.0010843385243788362 lr=0.007261158432811499
epoch 3200 loss=0.0009666284313425422 lr=0.007225275970995426
epoch 3300 loss=0.0010213557397946715 lr=0.0071740285493433475
epoch 3400 loss=0.001257327850908041 lr=0.0071229999884963036
epoch 3500 loss=0.0009237469639629126 lr=0.007060170639306307
epoch 3600 loss=0.0010209324536845088 lr=0.007001352030783892
epoch 3700 loss=0.0010023850481957197 lr=0.006939604412764311
epoch 3800 loss=0.0011092709610238671 lr=0.006892509292811155
epoch 3900 loss=0.0008320825872942805 lr=0.006856516934931278
epoch 4000 loss=0.0015411889180541039 lr=0.006792154163122177
epoch 4100 loss=0.000866543676238507 lr=0.00673663429915905
epoch 4200 loss=0.0009100940660573542 lr=0.00666508823633194
epoch 4300 loss=0.0008933907374739647 lr=0.006604321300983429
epoch 4400 loss=0.0006370239425450563 lr=0.00655223848298192
epoch 4500 loss=0.0007405806099995971 lr=0.006423715502023697
epoch 4600 loss=0.0006541237817145884 lr=0.006394413765519857
epoch 4700 loss=0.0005402541719377041 lr=0.006318734493106604
epoch 4800 loss=0.0009636411559768021 lr=0.00621804129332304
epoch 4900 loss=0.0007088639540597796 lr=0.006063143257051706
epoch 5000 loss=0.0007585420971736312 lr=0.006008874159306288

print("MLPRegressor       ", mean_squared_error(clf.predict(X), y))
print("PyTorch            ", mean_squared_error(torch_model(torch_X).detach().numpy(), y))
print("Flax               ", mean_squared_error(flax_model.apply(flax_theta, X), y))
print("Flax: meta-learning", mean_squared_error(flax_model.apply(flax_ml_theta, X), y))
MLPRegressor        0.0011639113
PyTorch             0.00062100013
Flax                0.0005836923
Flax: meta-learning 0.0007262042
draw(lambda x: flax_model.apply(flax_ml_theta, x))

from typing import Callable, Iterable, Any
from types import SimpleNamespace
import jax
from jax import numpy as jnp
from flax import struct
from flax.core import FrozenDict
from flax import linen as nn
import optax


class Model(nn.Module):
    sizes: Iterable[int]  # Number of hidden neurons
    nonlin: Callable
    dropout_rate: float = 0.0

    @nn.compact  # Tells Flax to look for defined submodules
    def __call__(self, x, train):
        for size in self.sizes:
            x = nn.Dense(size)(x)
            x = self.nonlin(x)
        x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x)
        return nn.Dense(1, use_bias=False)(x).flatten()


class TrainState(struct.PyTreeNode):
    dropout_key: jax.Array
    params: FrozenDict[str, Any] = struct.field(pytree_node=True)
    opt_state: optax.OptState = struct.field(pytree_node=True)
    meta_params: float
    meta_opt_state: optax.OptState = struct.field(pytree_node=True)


def predict(model, theta, x):
    y = model.apply({"params": theta}, x, train=False)
    with np.errstate(over="ignore"):
        return 1 / (1 + np.exp(-y))


def train(config, x, k, n):
    x = jax.device_put(x)
    k = jax.device_put(k)
    n = jax.device_put(n)

    model = Model(config.layers, config.nonlin, config.dropout_rate)

    root_key = jax.random.PRNGKey(0)
    params_key, dropout_key, perm_key = jax.random.split(root_key, num=3)
    variables = model.init(params_key, x, train=False)

    opt = optax.inject_hyperparams(optax.rmsprop)(learning_rate=config.learning_rate)
    meta_opt = optax.adabelief(learning_rate=config.meta_learning_rate)
    eta = -np.log(1.0 / config.learning_rate - 1)

    state = TrainState(
        dropout_key=dropout_key,
        params=variables["params"],
        opt_state=opt.init(variables["params"]),
        meta_params=eta,
        meta_opt_state=meta_opt.init(eta),
    )

    @jax.jit
    def loss_fn(params, state, x, k, n):
        def fn(x, k, n):
            y = model.apply(
                {"params": params},
                x,
                train=True,
                rngs={"dropout": state.dropout_key},
            )
            p_sat = jnp.clip(k / (n + 1e-6), 1e-6, 1 - 1e-6)
            y_sat = jnp.log(p_sat / (1 - p_sat))
            l = k * (y - y_sat) - n * (jnp.logaddexp(0, y) - jnp.logaddexp(0, y_sat))
            return l

        l = jax.vmap(fn, in_axes=0)(x, k, n)
        return -2 * jnp.mean(l)

    @jax.jit
    def inner_step(state, x, k, n):
        loss, grad = jax.value_and_grad(loss_fn)(state.params, state, x, k, n)
        updates, opt_state = opt.update(grad, state.opt_state)
        params = optax.apply_updates(state.params, updates)
        state = state.replace(params=params, opt_state=opt_state)
        return loss, state

    @jax.jit
    def outer_loss(eta, state, x, k, n):
        state.opt_state.hyperparams["learning_rate"] = jax.nn.sigmoid(eta)
        loss, state = inner_step(state, x[:-1], k[:-1], n[:-1])
        return loss_fn(state.params, state, x[-1:], k[-1:], n[-1:]), (loss, state)

    @jax.jit
    def outer_step(state, x, k, n):
        grad, (loss, state) = jax.grad(outer_loss, has_aux=True)(
            state.meta_params, state, x, k, n
        )
        meta_updates, meta_opt_state = meta_opt.update(grad, state.meta_opt_state)
        meta_params = optax.apply_updates(state.meta_params, meta_updates)
        state = state.replace(meta_params=meta_params, meta_opt_state=meta_opt_state)
        return loss, state

    learning_rates = []
    losses = []

    for epoch in range(1, config.max_epoch + 1):
        perm_train_key = jax.random.fold_in(key=perm_key, data=epoch)
        state = state.replace(
            dropout_key=jax.random.fold_in(key=dropout_key, data=epoch),
        )
        perm = jax.random.permutation(perm_train_key, len(x))
        loss, state = outer_step(state, x[perm], k[perm], n[perm])
        learning_rate = jax.nn.sigmoid(state.meta_params)
        losses.append(loss)
        learning_rates.append(learning_rate)

        if epoch == 1 or epoch % config.print_freq == 0:
            print(f"epoch {epoch} loss={loss} lr={learning_rate}")

    return model, state.params, losses, learning_rates


config = SimpleNamespace()
config.layers = (2**6,) * 4
config.nonlin = nn.relu
config.dropout_rate = 0.01
config.max_epoch = 5000
config.learning_rate = 1e-2
config.meta_learning_rate = 0.03
config.print_freq = 100

flax_model_2, flax_theta_2, losses, learning_rates = train(config, X, data_k, data_n)

fig, ax = plt.subplots(2, 1, sharex=True)
ax[0].plot(losses)
ax[0].set(ylabel="loss", yscale="log")
ax[1].plot(learning_rates)
ax[1].set(xlabel="epoch", ylabel="lr", yscale="log");
epoch 1 loss=8698.1416015625 lr=0.00970732606947422
epoch 100 loss=879.1326904296875 lr=0.006159865763038397
epoch 200 loss=270.08587646484375 lr=0.004554468207061291
epoch 300 loss=317.17718505859375 lr=0.0036662158090621233
epoch 400 loss=91.02045440673828 lr=0.0032094777561724186
epoch 500 loss=47.988033294677734 lr=0.0026605320163071156
epoch 600 loss=49.83475112915039 lr=0.0023816365282982588
epoch 700 loss=116.31311798095703 lr=0.0021545598283410072
epoch 800 loss=35.58114242553711 lr=0.0017672122921794653
epoch 900 loss=64.84205627441406 lr=0.0015509785152971745
epoch 1000 loss=10.535590171813965 lr=0.001485868007875979
epoch 1100 loss=26.29697036743164 lr=0.0013874215073883533
epoch 1200 loss=16.937255859375 lr=0.0013577654026448727
epoch 1300 loss=15.62310791015625 lr=0.0013214262435212731
epoch 1400 loss=13.242799758911133 lr=0.001257105148397386
epoch 1500 loss=11.61507797241211 lr=0.001235545496456325
epoch 1600 loss=14.939531326293945 lr=0.0012049111537635326
epoch 1700 loss=27.3996639251709 lr=0.001163122127763927
epoch 1800 loss=21.23556900024414 lr=0.0010698930127546191
epoch 1900 loss=32.21183776855469 lr=0.0010345007758587599
epoch 2000 loss=9.131669044494629 lr=0.001013571978546679
epoch 2100 loss=13.851628303527832 lr=0.0009823944419622421
epoch 2200 loss=10.54258918762207 lr=0.0009545115754008293
epoch 2300 loss=9.792181968688965 lr=0.000919334648642689
epoch 2400 loss=4.94473123550415 lr=0.000911602983251214
epoch 2500 loss=46.33889389038086 lr=0.0008922962588258088
epoch 2600 loss=7.025729179382324 lr=0.0008748381515033543
epoch 2700 loss=15.959792137145996 lr=0.0008621393935754895
epoch 2800 loss=48.963592529296875 lr=0.0008530952618457377
epoch 2900 loss=5.237392902374268 lr=0.0008423625258728862
epoch 3000 loss=15.789849281311035 lr=0.0008185050683096051
epoch 3100 loss=12.49224853515625 lr=0.0007375497953034937
epoch 3200 loss=5.135715961456299 lr=0.0007307584746740758
epoch 3300 loss=3.9630870819091797 lr=0.0006970498943701386
epoch 3400 loss=9.80554485321045 lr=0.0006649661809206009
epoch 3500 loss=4.468692302703857 lr=0.000663100800011307
epoch 3600 loss=48.7807502746582 lr=0.0006592243444174528
epoch 3700 loss=3.8573150634765625 lr=0.000651187205221504
epoch 3800 loss=5.408437728881836 lr=0.0006451716180890799
epoch 3900 loss=4.103397846221924 lr=0.0006489546503871679
epoch 4000 loss=7.220167636871338 lr=0.0006393020157702267
epoch 4100 loss=8.81872272491455 lr=0.0006356375524774194
epoch 4200 loss=3.7671053409576416 lr=0.0006279195658862591
epoch 4300 loss=3.080355167388916 lr=0.0006261372473090887
epoch 4400 loss=26.2066650390625 lr=0.0006219103815965354
epoch 4500 loss=2.989922046661377 lr=0.0006011227960698307
epoch 4600 loss=5.964871406555176 lr=0.00058519042795524
epoch 4700 loss=2.8790411949157715 lr=0.0005803073872812092
epoch 4800 loss=3.9400908946990967 lr=0.0005809014546684921
epoch 4900 loss=10.3778715133667 lr=0.0005587802152149379
epoch 5000 loss=5.031178951263428 lr=0.0005468810559250414

print("MLPRegressor              ", mean_squared_error(clf.predict(X), y))
print("PyTorch                   ", mean_squared_error(torch_model(torch_X).detach().numpy(), y))
print("Flax                      ", mean_squared_error(flax_model.apply(flax_theta, X), y))
print(" + meta-learning          ", mean_squared_error(flax_model.apply(flax_ml_theta, X), y))
print(" + binomial loss, dropout ", mean_squared_error(predict(flax_model_2, flax_theta_2, X), y))
MLPRegressor               0.0011639113
PyTorch                    0.00062100013
Flax                       0.0005836923
 + meta-learning           0.0007262042
 + binomial loss, dropout  0.001051374
draw(lambda x: predict(flax_model_2, flax_theta_2, x))

In a well-trained model, the parameters should not deviate too much from a normal distribution. We plot the parameter distributions for the last model.

for i, layer in enumerate(flax_theta_2.values()):
    fig, ax = plt.subplots(1, len(layer), figsize=(10, 4))
    ax = np.atleast_1d(ax)
    for axi, (key, val) in zip(ax, layer.items()):
        axi.hist(np.ravel(val), bins=20)
        axi.set_title(f"layer {i+1}: {key}")