Source code for sbi.neural_nets.embedding_nets.permutation_invariant
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
from typing import Optional
import torch
from torch import Tensor, nn
from sbi.neural_nets.embedding_nets.fully_connected import FCEmbedding
[docs]
class PermutationInvariantEmbedding(nn.Module):
"""Permutation invariant embedding network.
Takes as input a tensor with (batch, permutation_dim, input_dim)
and outputs (batch, output_dim).
References:
Chan et al. (2018): "A likelihood-free inference framework for population genetic
data using exchangeable neural networks"
Radev et al. (2020): "BayesFlow: Learning complex stochastic models with invertible
neural networks"
"""
def __init__(
self,
trial_net: nn.Module,
trial_net_output_dim: int,
aggregation_fn: Optional[str] = "sum",
num_hiddens: int = 100,
num_layers: int = 2,
output_dim: int = 20,
aggregation_dim: int = 1,
):
"""Permutation invariant multi-layer NN.
Applies the trial_net to every trial to obtain trial embeddings.
It then aggregates the trial embeddings across the aggregation dimension to
construct a permutation invariant embedding across iid trials.
The resulting embedding is processed further using an additional fully
connected net. The input to the final embedding net is the trial_net output
plus the number of trials N: (batch, trial_net_output_dim + 1)
If the data x has varying number of trials per batch element, missing trials
should be encoded as NaNs. In the forward pass, the NaNs are masked.
Args:
trial_net: Network to process one trial. The combining_operation is
applied to its output. Takes as input (batch, input_dim), where
input_dim is the dimensionality of a single trial. Produces output
(batch, latent_dim).
Remark: This network should be large enough as it acts on all (iid)
inputs seperatley and needs enough capacity to process the information
of all inputs.
trial_net_output_dim: Dimensionality of the output of the trial_net.
aggregation_fn: Function to aggregate the trial embeddings. Defaults to
taking the sum over the non-nan values.
num_layers: Number of fully connected layer, minimum of 2.
num_hiddens: Number of hidden dimensions in fully-connected layers.
output_dim: Dimensionality of the output.
aggregation_dim: Dimension along which to aggregate the trial embeddings.
"""
super().__init__()
self.trial_net = trial_net
self.aggregation_dim = aggregation_dim
assert aggregation_fn in [
"mean",
"sum",
], "aggregation_fn must be 'mean' or 'sum'."
self.aggregation_fn = aggregation_fn
# construct fully connected layers
self.fc_subnet = FCEmbedding(
input_dim=trial_net_output_dim + 1, # +1 to encode number of trials
output_dim=output_dim,
num_layers=num_layers,
num_hiddens=num_hiddens,
)
[docs]
def forward(self, x: Tensor) -> Tensor:
"""Network forward pass.
Args:
x: Input tensor (batch_size, permutation_dim, input_dim)
Returns:
Network output (batch_size, output_dim).
"""
# Get number of trials from non-nan entries
num_batch, max_num_trials = x.shape[0], x.shape[self.aggregation_dim]
nan_counts = (
torch.isnan(x)
.sum(dim=self.aggregation_dim) # count nans over trial dimension
.reshape(-1)[:num_batch] # counts are the same across data dims
.unsqueeze(-1) # make it (batch, 1) to match embeddings below
)
# number of non-nan trials
trial_counts = max_num_trials - nan_counts
# get nan entries
is_nan = torch.isnan(x)
# apply trial net with nan entries replaced with 0
masked_x = torch.nan_to_num(x, nan=0.0)
trial_embeddings = self.trial_net(masked_x)
# replace previous nan entries with zeros
trial_embeddings = trial_embeddings * (~is_nan.all(-1, keepdim=True)).float()
# Take mean over permutation dimension divide by number of trials
# (instead of just taking torch.mean) to account for masking.
if self.aggregation_fn == "mean":
combined_embedding = (
trial_embeddings.sum(dim=self.aggregation_dim) / trial_counts
)
else:
combined_embedding = trial_embeddings.sum(dim=self.aggregation_dim)
assert not torch.isnan(combined_embedding).any(), "NaNs in embedding."
# add number of trials as additional input
return self.fc_subnet(torch.cat([combined_embedding, trial_counts], dim=1))