Source code for sbi.neural_nets.embedding_nets.resnet

# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from functools import partial
from typing import Callable, Dict, List, Optional, Type, Union

import torch
import torch.nn as nn
from torch import Tensor


def zero_pad(x: Tensor, c_out: int) -> Tensor:
    """
    Add additional channels to the input tensor by padding it with zeros.

    Args:
        x:      Input tensor
        c_out:  Number of output channels

    Returns:
        x_aug:      Input tensor with additional channels filled with zeros
    """

    # Check if it is an image like-input
    if not (len(x.shape) == 4):
        raise ValueError("Only 4D input tensors are supported.")

    # Check if the number of channels is increased
    if c_out <= x.shape[1]:
        raise ValueError("c_out must be larger than c_in to apply zero padding.")

    c_in = x.shape[1]
    a = torch.zeros(x.shape[0], c_out - c_in, x.shape[2], x.shape[3], device=x.device)

    x_aug = torch.cat([x, a], dim=1)

    return x_aug


def construct_simple_conv_net(
    c_in: int,
    c_out: int,
    c_hidden: int,
    activation: Type[nn.Module],
    downsample: bool,
) -> nn.Module:
    """
    Default implementation for a simple convolutional network with 3 layers.

    Args:
        c_in:       Number of input channels
        c_out:      Number of output channels
        c_hidden:   Number of hidden channels
        activation: Constructor for an activation function
        downsample: Apply a strided convolution to halve
                    the height and width of the input

    Returns:
        layers:     A sequential container with the convolutional layers
    """

    # Define the convolutional layers
    stride = 2 if downsample else 1
    layers = nn.Sequential(
        nn.Conv2d(c_in, c_hidden, kernel_size=3, stride=stride, padding=1),
        activation(),
        nn.Conv2d(c_hidden, c_hidden, kernel_size=3, stride=1, padding=1),
        activation(),
        nn.Conv2d(c_hidden, c_out, kernel_size=3, stride=1, padding=1),
    )

    return layers


def construct_simple_fc_net(
    c_in: int, c_out: int, c_hidden: int, activation: Type[nn.Module]
) -> nn.Module:
    """
    Default implementation for a fully connected neural network.

    Args:
        c_in:       Number of input channels
        c_out:      Number of output channels
        c_hidden:   Number of hidden channels
        activation: Constructor for an activation function

    Returns:
        layers:     A sequential container with the fully connected layers
    """

    layers = nn.Sequential(
        nn.Linear(c_in, c_hidden),
        activation(),
        nn.Linear(c_hidden, c_hidden),
        activation(),
        nn.Linear(c_hidden, c_out),
    )

    return layers


[docs] class ResNetEmbedding2D(nn.Module): """Residual neural network mapping image-like data to a fixed sized vector.""" def __init__( self, c_in: int, c_out: int = 20, c_hidden_fc: int = 1000, n_stages: int = 4, blocks_per_stage: Optional[List] = None, c_stages: Optional[List] = None, activation: Type[nn.Module] = nn.ReLU, change_c_mode: str = "conv", construct_mapping: Callable = construct_simple_conv_net, construct_mapping_kwargs: Optional[Dict] = None, residual_block_kwargs: Optional[Dict] = None, ) -> None: """Residual neural network mapping image-like data to a fixed sized vector. This network consists of a stacked set of residual blocks. The output of a block is given by the input to the block plus the transformed input, i.e. there is a skip connection. The network is structured in stages, where after each stage, the height and the width of the input are halved. Each stage can have a different number of residual blocks and a different number of channels. Image-like input is expected, i.e., a 4D tensor with dimensions (batch_size, channels, height, width) or a 3D tensor with dimensions (batch_size, height, width). In the case of 3D input the input is internally transformed to a 4D tensor with dimensions (batch_size, 1, height, width). The image like input data is transformed into a fixed sized vector of dimensions [batch_size, c_out]. By default, convolutional networks are used to model the transformation in each of the residual blocks. At the end of the network, a fully connected network is applied to map the flattened output of the last stage to the fixed sized output vector. References: He et al. (2015): "Deep Residual Learning for Image Recognition" Args: c_in: Number of input channels. c_out: Dimensionality of the embedding vector. c_hidden_fc: Number of hidden units in the fully connected layers. n_stages: Number of stages in the network. blocks_per_stage: Number of residual blocks per stage. c_stages: Number of channels per stage. activation: Constructor for the activation function change_c_mode: Mode to change the number of channels. Options are "conv" and "zeros". If "conv" is selected, 1x1 convolutions are applied to change the number of channels. If "zeros" is selected, the required additional channels are filled with zeros. construct_mapping: Constructor for the mapping function. The mapping function is applied to the input tensor before the residual connection is added. The default is a simple convolutional network. The function must have the signature `construct_mapping(c_in, c_out, activation, **kwargs)`. construct_mapping_kwargs: Additional keyword arguments for the mapping functions. residual_block_kwargs: Additional keyword arguments for the initialization of the residual blocks. """ super().__init__() # Additional keyword arguments for initializing the mapping function if construct_mapping_kwargs is None: construct_mapping_kwargs = {"c_hidden": 128} construct_mapping_kwargs["activation"] = activation # Additional keyword arguments for initializing the residual blocks if residual_block_kwargs is None: residual_block_kwargs = {} # Number of residual blocks in each stage if blocks_per_stage is None: blocks_per_stage = [2, 2, 2, 2] # Number of channels in each stage if c_stages is None: c_stages = [64, 128, 256, 512] # Check consistency of the specified network structure if not len(blocks_per_stage) == n_stages: raise ValueError( "The number of stages and the number of specified block must match." ) if not len(c_stages) == n_stages: raise ValueError( "The number of stages and the number of specified channels must match." ) self.c_hidden_fc = c_hidden_fc self.c_out = c_out self.activation = activation # Intial transformation applied before the residual blocks self.initial = nn.Sequential( nn.Conv2d(c_in, c_stages[0], kernel_size=7, stride=2, padding=3), nn.MaxPool2d(kernel_size=3, stride=2, padding=1), ) # Initialize the residual blocks construct_mapping_kwargs["activation"] = activation blocks = nn.ModuleList() for i in range(n_stages): for j in range(blocks_per_stage[i]): # Last block in the stage is a downsampling block which # increases the number of channels, except for the final stage if (j == blocks_per_stage[i] - 1) and (i != n_stages - 1): downsample = True c_out_ij = c_stages[i + 1] else: downsample = False c_out_ij = c_stages[i] construct_mapping_kwargs["downsample"] = downsample block_ij = ResidualBLock( is_conv_block=True, c_in=c_stages[i], c_out=c_out_ij, downsample=downsample, change_c_mode=change_c_mode, construct_mapping=construct_mapping, construct_mapping_kwargs=construct_mapping_kwargs, activation=activation, **residual_block_kwargs, ) blocks.append(block_ij) self.blocks = nn.Sequential(*blocks) # Final transformation is initialized on the first forward pass self.final = None
[docs] def forward(self, x: Tensor) -> Tensor: """ Forward pass of the residual network. Args: x: Input tensor Returns: z: Output tensor """ # Only three dimensions, interpret as one channel and add it if len(x.shape) == 3: x = x.unsqueeze(1) if not len(x.shape) == 4: raise ValueError("Only 4D input tensors are supported.") # Apply the initial transformation x = self.initial(x) # Apply the residual blocks y = self.blocks(x) # On first call intitalize the final transformation if self.final is None: # Get the shape of the input to the final layer c_in_final = int(torch.tensor(y.shape[1:]).prod().item()) self.final = nn.Sequential( nn.AvgPool2d(kernel_size=3, stride=1, padding=1), nn.Flatten(), nn.Linear(c_in_final, self.c_hidden_fc), self.activation(), nn.Linear(self.c_hidden_fc, self.c_hidden_fc), self.activation(), nn.Linear(self.c_hidden_fc, self.c_out), ) # Apply the final transformation z = self.final(y) return z
[docs] class ResNetEmbedding1D(nn.Module): """Residual neural network mapping vector-like data to a fixed-size vector.""" def __init__( self, c_in: int, c_out: int, construct_mapping: Callable = construct_simple_fc_net, construct_mapping_kwargs: Union[Dict, None] = None, residual_block_kwargs: Union[Dict, None] = None, n_blocks: int = 20, c_internal: int = 128, c_hidden_final: int = 1000, activation: Type[nn.Module] = nn.ReLU, ) -> None: """Residual neural network mapping vector-like data to a fixed-size vector. This network consists of a stacked set of residual blocks. The output of a block is given by the input to the block plus the transformed input, i.e., there is a skip connection. The input is expected to be vector-like, i.e., a 2D tensor with dimensions (batch_size, channels). The input data is transformed into a fixed-size vector of dimensions [batch_size, c_out]. By default, fully connected networks are used to model the transformation in each of the residual blocks. Initially, the input can be mapped to a space of different dimensionality than the input. On this space, the residual blocks act. At the end of the network, a fully connected network is applied to map the output of the last residual block to the fixed-size output vector. Args: c_in: Input dimensionality. c_out: Dimensionality of the embedding vector. construct_mapping: Constructor for the mapping function. The mapping function is applied to the input tensor before the residual connection is added. The default is a simple fully connected network. The function must have the signature `construct_mapping(c_in, c_out, activation, **kwargs)`. construct_mapping_kwargs: Additional keyword arguments for the mapping functions. residual_block_kwargs: Additional keyword arguments for the initialization of the residual blocks. n_blocks: Number of residual blocks. c_internal: Dimensionality of the internal space. c_hidden_final: Hidden dimensionality of the final aggregation network. activation: Constructor for the activation function. """ super().__init__() # Parameteters for the subnetwork initialization if construct_mapping_kwargs is None: construct_mapping_kwargs = {"c_hidden": 128} construct_mapping_kwargs["activation"] = activation # Parameters for the coupling block initialization if residual_block_kwargs is None: residual_block_kwargs = {} # Mapping of the input to the dimensionality used internally if c_in == c_internal: self.initial = nn.Identity() else: self.initial = nn.Linear(c_in, c_internal) # Final aggregation network self.final = nn.Sequential( nn.Linear(c_internal, c_hidden_final), activation(), nn.Linear(c_hidden_final, c_hidden_final), activation(), nn.Linear(c_hidden_final, c_out), ) # Residual blocks blocks = nn.ModuleList() for _i in range(n_blocks): block_i = ResidualBLock( c_in=c_internal, c_out=c_internal, is_conv_block=False, construct_mapping=construct_mapping, construct_mapping_kwargs=construct_mapping_kwargs, activation=activation, **residual_block_kwargs, ) blocks.append(block_i) self.blocks = nn.Sequential(*blocks)
[docs] def forward(self, x: Tensor) -> Tensor: """ Forward pass of the residual network. Args: x: Input tensor Returns: z: Output tensor """ # Check if vector like data is used if not len(x.shape) == 2: raise ValueError("Only 2D input tensors are supported.") # Apply the initial transformation x = self.initial(x) # Apply the residual blocks y = self.blocks(x) # Apply the final transformation z = self.final(y) return z
class ResidualBLock(nn.Module): def __init__( self, c_in: int, c_out: Optional[int], is_conv_block: bool, construct_mapping: Callable, construct_mapping_kwargs: Optional[Dict], activation: Type[nn.Module], downsample: bool = False, change_c_mode: str = "conv", ) -> None: """ Single residual block used in residual networks. Args: c_in: Number of input channels. c_out: Number of output channels. If None, use c_in. change_c_mode: Mode to change the number of channels. Options are "conv" and "zeros". If "conv" is selected, 1x1 convolutions are used to adjust the number of channels. If "zeros" is selected, any additional required channels are filled with zeros. Only relevant for image-like data. downsample: Apply a strided convolution to halve the height and width of the input. Only relevant for image-like data. construct_mapping: Constructor for the mapping function. The mapping function is applied to the input tensor before the residual connection is added. construct_mapping_kwargs: Additional keyword arguments for the mapping function. activation: Constructor for the activation function. """ super(ResidualBLock, self).__init__() # Additional keyword arguments for the mapping function if construct_mapping_kwargs is None: construct_mapping_kwargs = {} ############################################################################### # Change number of channels ############################################################################### # Preserve the input dimensionality if the output is not specified if c_out is None: c_out = c_in # Preserve the dimensionality of the input and the output of the block if c_in == c_out: self.residual = nn.Identity() # Initialize the transformation of the residual for the case of image-like data elif is_conv_block and (c_in != c_out): # Apply 1x1 convolutions to change the number of channels if change_c_mode == "conv": self.residual = nn.Conv2d( c_in, c_out, kernel_size=1, stride=1, bias=False ) # Fill the required additional channels with zeros elif change_c_mode == "zeros": # Check if there are more output channels than input channels if c_in > c_out: raise ValueError("c_in channels must be smaller than c_out.") self.residual = partial(zero_pad, c_out=c_out) else: raise ValueError(f"Invalid change_c_mode {change_c_mode}.") # Initialize the transformation of the residual for the case of vector-like data else: self.residual = nn.Linear(c_in, c_out, bias=False) ############################################################################### # Halve the height and width of the input for image-like data ############################################################################### # Apply a strided convolution if the input is downsampled if is_conv_block and downsample: self.downsampling = nn.Conv2d( c_out, c_out, kernel_size=3, stride=2, bias=False, padding=1 ) construct_mapping_kwargs["downsample"] = downsample # No downsampling else: self.downsampling = nn.Identity() ############################################################################### # Define the transformation of the residual block ############################################################################### self.f = construct_mapping(c_in, c_out, **construct_mapping_kwargs) self.final_activation = activation() def forward(self, x: Tensor) -> Tensor: """ Forward pass of the residual block. Args: x: Input tensor Returns: y: Output tensor """ # Check if image like data is used if not ((len(x.shape) == 4) or (len(x.shape) == 2)): raise ValueError("Only 2D images or 1D vectors are supported.") # Compute the residual r = self.downsampling(self.residual(x)) # Full transformation y = self.final_activation(self.f(x) + r) return y