How to use embedding nets for high-dimensional observations#

Many simulators return high-dimensional outputs such as time-series or images. To efficiently learn the posterior given such simulation outputs, sbi can employ embedding networks which reduce these high-dimensional outputs.

sbi provides pre-configured embedding networks (MLP, CNN, and permutation-invariant networks) or allows to pass custom-written embedding networks.

Only NPE and NRE methods can use an embedding_net to learn summary statistics from simulation outputs. NLE does not offer such functionality because the simulation outputs are also the output of the neural density estimator.

Using pre-configured embedding networks#

# import required modules
from sbi.neural_nets import posterior_nn

# import the different choices of pre-configured embedding networks
from sbi.neural_nets.embedding_nets import (
    FCEmbedding,
    CNNEmbedding,
    PermutationInvariantEmbedding
)

# Choose which type of pre-configured embedding net to use (e.g. CNN)
embedding_net = CNNEmbedding(input_shape=(32, 32))

# Instantiate the conditional neural density estimator
neural_posterior = posterior_nn(model="maf", embedding_net=embedding_net)

# Setup the inference procedure with NPE
trainer = NPE(density_estimator=neural_posterior)
# Continue as always...

Defining custom embedding networks#

Alternatively, it is also possible to define custom embedding networks and pass those to neural density estimator. For example, you can implement a custom CNN as follows:

class CustomCNN(nn.Module):
    def __init__(self):
        super().__init__()
        # 2D convolutional layer
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=2)
        # Maxpool layer that reduces 32x32 image to 4x4
        self.pool = nn.MaxPool2d(kernel_size=8, stride=8)
        # Fully connected layer taking as input the 6 flattened output arrays
        # from the maxpooling layer
        self.fc = nn.Linear(in_features=6 * 4 * 4, out_features=8)

    def forward(self, x):
        x = x.view(-1, 1, 32, 32)
        x = self.pool(F.relu(self.conv1(x)))
        x = x.view(-1, 6 * 4 * 4)
        x = F.relu(self.fc(x))
        return x

# instantiate the custom embedding_net
embedding_net_custom = CustomCNN()

# Instantiate the conditional neural density estimator
neural_posterior = posterior_nn(model="maf", embedding_net=embedding_net_custom)
trainer = NPE(density_estimator=neural_posterior)
# Continue as always...

Example: Inferring parameters from images#

For a full example on using embedding networks (on simulation outputs that are images), see this tutorial.