How to use embedding nets for high-dimensional observations#
Many simulators return high-dimensional outputs such as time-series or images. To efficiently learn the posterior given such simulation outputs, sbi can employ embedding networks which reduce these high-dimensional outputs.
sbi provides pre-configured embedding networks (MLP, CNN, and permutation-invariant networks) or allows to pass custom-written embedding networks.
Only
NPEandNREmethods can use anembedding_netto learn summary statistics from simulation outputs.NLEdoes not offer such functionality because the simulation outputs are also the output of the neural density estimator.
Using pre-configured embedding networks#
# import required modules
from sbi.neural_nets import posterior_nn
# import the different choices of pre-configured embedding networks
from sbi.neural_nets.embedding_nets import (
FCEmbedding,
CNNEmbedding,
PermutationInvariantEmbedding
)
# Choose which type of pre-configured embedding net to use (e.g. CNN)
embedding_net = CNNEmbedding(input_shape=(32, 32))
# Instantiate the conditional neural density estimator
neural_posterior = posterior_nn(model="maf", embedding_net=embedding_net)
# Setup the inference procedure with NPE
trainer = NPE(density_estimator=neural_posterior)
# Continue as always...
Defining custom embedding networks#
Alternatively, it is also possible to define custom embedding networks and pass those to neural density estimator. For example, you can implement a custom CNN as follows:
class CustomCNN(nn.Module):
def __init__(self):
super().__init__()
# 2D convolutional layer
self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=2)
# Maxpool layer that reduces 32x32 image to 4x4
self.pool = nn.MaxPool2d(kernel_size=8, stride=8)
# Fully connected layer taking as input the 6 flattened output arrays
# from the maxpooling layer
self.fc = nn.Linear(in_features=6 * 4 * 4, out_features=8)
def forward(self, x):
x = x.view(-1, 1, 32, 32)
x = self.pool(F.relu(self.conv1(x)))
x = x.view(-1, 6 * 4 * 4)
x = F.relu(self.fc(x))
return x
# instantiate the custom embedding_net
embedding_net_custom = CustomCNN()
# Instantiate the conditional neural density estimator
neural_posterior = posterior_nn(model="maf", embedding_net=embedding_net_custom)
trainer = NPE(density_estimator=neural_posterior)
# Continue as always...
Example: Inferring parameters from images#
For a full example on using embedding networks (on simulation outputs that are images), see this tutorial.