Fast deep sets with FLAX

Published

January 11, 2024

In Zaheer et al., NIPS 2017 Deep Sets are introduced, a new network architecture with interesting applications to particle physics. One of the first applications in particle physics is discussed here: PT. Komiske, EM. Metodiev, J. Thaler, MIT-CTP 5064.

In particle physics, we often want to use mappings from events to numbers. For example, we could assign the label 1 if the event contained an interesting signal, and 0 otherwise. Each event contains a variable number of tracks. This is an obstacle for applying normal feed-forward networks, which cannot handle inputs of varying size.

Deep Sets can learn the mapping from a set of inputs to an output. The size of the set can vary from sample to sample. The architecture assumes that the mapping is invariant to permutations of the elements in the set. This is the case in the aforementioned example application. Although we could introduce an ordering over the set (for example, sort tracks along the momentum), such an ordering is arbitrary.

Mathematically, a Deep Set architecture consists of two feed forward networks \(\rho\) and \(\phi\) and a pooling operation invariant to permutations. One typically uses the sum, but other operations are possible, like the maximum, \[ y(\{x\}) = \rho(\sum_i \phi_i(x_i)). \]

Deep Sets are a very promising architecture, but implementing them in JAX / FLAX is a challenge. Here I show a fast implementation, the first that I am aware of, which is able to handle sets of variable length. Other implementations I found on GitHub only work for sets of constant size.

The main challenge to solve was to write the code in such a way that JAX can process input and output as tensors of fixed size. This is a requirement for the JIT compiler and to obtain a fast training loop. The compilation generally requires that all shapes are of fixed size, which contradicts the idea of passing sets of variable size to the network.

My solution consists of three parts.

Efficient padding

JAX wants inputs of fixed size. We could decide to process one set at a time and pad the size of the set to the largest possible set. This would be wasteful in our application, however, because the track length distribution has a long tail. Most events have only few tracks but rarely an event has many tracks. Padding all events up to the largest size would be very inefficient.

My solution is to let the model process several events at once. To do this, I concatenate all the sets into a large contiguous block. This block can be padded to a fixed size. If the block is large enough, the sizes of different events average out, and we only need to pad a small fraction of the block size.

Handling the sum

Next, we need to replace the sum, because it operates on a dynamically sized chunks, and we can’t have those inside a JIT’ed function. My solution is to replace it with a matrix-vector product. All the inputs are in already one large block. We pass those in parallel to the first \(\phi\) network. The output is another large block of latent vectors, one for each element in each set. To sum these over the original sets, I use matrices constructed like this one \[ \begin{pmatrix} 1 & 1 & 1 & 0 & 0 & 0 \\ 0 & 0 & 0 & 1 & 1 & 0 \\ 0 & 0 & 0 & 0 & 0 & 1 \\ \end{pmatrix}, \] which, when applied to a block with six vectors from three sets, computes the sum over the first set with three vectors, the second with two, and the third with one vector. The output are three latent vectors which are further processed in parallel by the second network. The summation matrix is also padded with zeros to match the input block. When finally the loss is computed, we discard the padded values by multiplying them with zeros.

Using sparse matrices

The summation matrix is sparse. We use the experimental support for sparse matrices to speed up the computation. In this toy example, I observed a speedup by a factor of two.

The code below demonstrates these ideas on a toy example.

import jax
import jax.numpy as jnp
from jax.experimental import sparse
from flax import linen as nn
import numpy as np
import optax
from typing import Collection, Callable, Optional

# We generate some toy data, 1000 sets each with a size drawn from
# an exponential distribution. The maximum size is fixed to 400.
# Each set has elements which are vectors of size two.
rng = np.random.default_rng(1)

def generate(rng, size):
    n = rng.exponential(size=size)
    n *= 400 / np.max(n)
    n += 1  # ensure there are no empty sets
    n = n.astype(int)

    # The mapping to learn is a non-linear function of the inputs. One can also
    # replace np.mean with np.max or np.sum.
    def true_mapping(x):
        return np.log1p(np.abs(np.sum(x[:, 0] ** 2 + 3 * x[:, 1])))

    X = [np.array(rng.normal(size=(ni, 2)), dtype=np.float32) for ni in n]
    y = [true_mapping(x) for x in X]

    return X, y

def preprocess(X, y, padded_X_size, padded_y_size, dtype=np.float32):
    """
    Concatenates the input sets and pads inputs and outputs to fixed sizes.

    Returns padded arrays, the summation matrix, and a mask to undo the padding.
    """
    assert padded_y_size >= len(y)
    n = [len(x) for x in X]
    mask = np.zeros(padded_y_size, dtype=bool)
    mask[:len(y)] = 1
    assert padded_X_size >= np.sum(n)
    y = np.concatenate([y, np.zeros(padded_y_size - len(y), dtype=dtype)])
    X = np.concatenate(X, dtype=dtype)
    X = np.concatenate([X, np.zeros((padded_X_size - len(X),) + X.shape[1:], dtype=dtype)])
    indices = np.empty((np.sum(n), 2), dtype=int)
    a = 0
    for j, b in enumerate(np.cumsum(n)):
        indices[a:b, 0] = j
        indices[a:b, 1] = np.arange(a, b)
        a = b
    sum_matrix = sparse.BCOO(
        (np.ones(len(indices), dtype=np.int8), indices),
        shape=(padded_y_size, padded_X_size),
        indices_sorted=True, unique_indices=True)
    print(f"X fractional overhead {np.mean(np.sum(sum_matrix.todense(), axis=0) == 0):.2f}", )
    print(f"y fractional overhead {np.mean(~mask):.2f}")
    X = jnp.array(X)
    y = jnp.array(y)
    return X, y, sum_matrix, mask

X_train, y_train = generate(rng, 100)
X_train, y_train, sum_train, mask_train = preprocess(X_train, y_train, 5_000, 100)

X_test, y_test = generate(rng, 100)
X_test, y_test, sum_test, mask_test = preprocess(X_test, y_test, sum(len(x) for x in X_test), len(y_test))
X fractional overhead 0.00
y fractional overhead 0.00
X fractional overhead 0.00
y fractional overhead 0.00

We train the model until the loss on the test sample is lower than 0.1. Achieving a better loss is difficult with such a small training sample.

class MLP(nn.Module):
    nodes: Collection[int]
    nonlin: Callable
    output: Optional[int] = None

    @nn.compact
    def __call__(self, x):
        for size in self.nodes:
            x = nn.Dense(size)(x)
            x = self.nonlin(x)
        if self.output is not None:
            return nn.Dense(self.output)(x)
        return x

class Model(nn.Module):
    phi_nodes: Collection[int]
    rho_nodes: Collection[int]

    def setup(self):
        self.phi = MLP(self.phi_nodes, nn.relu)
        self.rho = MLP(self.rho_nodes, nn.relu, 1)

    @nn.compact
    def __call__(self, x, sum_matrix):
        x = self.phi(x)
        y = sum_matrix @ x
        y = self.rho(y)
        return y.flatten()

rng_key = jax.random.PRNGKey(0)

model = Model((256,) * 6, (256,) * 6)
theta = model.init(rng_key, X_train, sum_train)
# model.tabulate does not work with BCOO
# print(model.tabulate(rng_key, X, sum_matrix))

opt = optax.adam(learning_rate=1e-3)
opt_state = opt.init(theta)

@jax.jit
def loss_fn(theta, X, y, sum_matrix, mask):
    yp = model.apply(theta, X, sum_matrix)
    return jnp.mean(mask * (y - yp) ** 2)

@jax.jit
def step(theta, opt_state, X, y, sum_matrix, mask):
    loss, grad = jax.value_and_grad(loss_fn)(theta, X, y, sum_matrix, mask)
    updates, opt_state = opt.update(grad, opt_state)
    theta = optax.apply_updates(theta, updates)
    return loss, theta, opt_state

best_loss = np.inf
best_theta = theta
train_loss = []
test_loss = []
for epoch in range(1, 5000):
    loss, theta, opt_state = step(theta, opt_state,
                                  X_train, y_train, 
                                  sum_train, mask_train)
    train_loss.append(loss)

    loss = loss_fn(theta, X_test, y_test, sum_test, mask_test)
    test_loss.append(loss)

    if test_loss[-1] < best_loss:
        best_loss = test_loss[-1]
        best_theta = theta

    stop = best_loss < 0.1
    if epoch == 1 or stop or epoch % 50 == 0:
        print(f"epoch = {epoch:5} "
              f"loss(train) = {train_loss[-1]:6.3f} "
              f"loss(test) = {test_loss[-1]:6.3f}")
    if stop:
        break
epoch =     1 loss(train) = 23.354 loss(test) =  7.235
epoch =    50 loss(train) =  2.143 loss(test) =  2.690
epoch =   100 loss(train) =  0.434 loss(test) =  0.438
epoch =   149 loss(train) =  0.019 loss(test) =  0.095

I plot the training and testing loss as a function of the epoch. Adam does a good job of minimizing the loss without further rate scheduling. I tried several other optimizers, Adam performed best out of those which worked (some optimizers from optax fail on this problem with an error message).

from matplotlib import pyplot as plt

plt.figure()
plt.plot(train_loss, label="train")
plt.plot(test_loss, label="test")
plt.legend()
plt.xlabel("epoch")
plt.ylabel("loss")
plt.semilogy();

To show that the network managed to learn the mapping, I compare the predictions on the test set with the corresponding true values. While there are still sizable residuals, the model has learned to approximate the true mapping. Improving the model further would require a larger training sample.

yp = model.apply(best_theta, X_test, sum_test)

fig, ax = plt.subplots(2, 1, sharex=True, sharey=False)
ax[0].plot(y_test, ".", label="input")
ax[0].plot(yp, "o", mfc="none", label="model")
ax[0].legend()
res = y_test - yp
ax[1].plot(res, ".")
ax[1].set_ylabel("residual")
ax[1].set_xlabel("set index")

plt.figure()
plt.hist(res, bins=20, density=True)
plt.title(f"standard deviation = {np.std(res):.2f}");