Source code for sbi.neural_nets.embedding_nets.cnn

# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from typing import List, Optional, Tuple, Union

import torch
from torch import Tensor, nn

from sbi.neural_nets.embedding_nets.fully_connected import FCEmbedding


def calculate_filter_output_size(input_size, padding, dilation, kernel, stride) -> int:
    """Returns output size of a filter given filter arguments.

    Uses formulas from https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html.
    """

    return int(
        (int(input_size) + 2 * int(padding) - int(dilation) * (int(kernel) - 1) - 1)
        / int(stride)
        + 1
    )


def get_new_cnn_output_size(
    input_shape: Tuple,
    conv_layer: Union[nn.Conv1d, nn.Conv2d],
    pool: Union[nn.MaxPool1d, nn.MaxPool2d],
) -> Union[Tuple[int], Tuple[int, int]]:
    """Returns new output size after applying a given convolution and pooling.

    Args:
        input_shape: tup.
        conv_layer: applied convolutional layers
        pool: applied pooling layer

    Returns:
        new output dimension of the cnn layer.

    """
    assert isinstance(input_shape, Tuple), "input shape must be Tuple."
    assert 0 < len(input_shape) < 3, "input shape must be 1 or 2d."
    assert isinstance(conv_layer.padding, Tuple), "conv layer attributes must be Tuple."
    assert isinstance(pool.padding, int), "pool layer attributes must be integers."

    out_after_conv = [
        calculate_filter_output_size(
            input_shape[i],
            conv_layer.padding[i],
            conv_layer.dilation[i],
            conv_layer.kernel_size[i],
            conv_layer.stride[i],
        )
        for i in range(len(input_shape))
    ]
    out_after_pool = [
        calculate_filter_output_size(
            out_after_conv[i],
            pool.padding,
            pool.dilation,
            pool.kernel_size,
            pool.stride,
        )
        for i in range(len(input_shape))
    ]
    return tuple(out_after_pool)  # pyright: ignore[reportReturnType]


[docs] class CNNEmbedding(nn.Module): """Convolutional embedding network (1D or 2D convolutions).""" def __init__( self, input_shape: Tuple, in_channels: int = 1, out_channels_per_layer: Optional[List] = None, num_conv_layers: int = 2, num_linear_layers: int = 2, num_linear_units: int = 50, output_dim: int = 20, kernel_size: int = 5, pool_kernel_size: int = 2, ): """Convolutional embedding network. First two layers are convolutional, followed by fully connected layers. Automatically infers whether to apply 1D or 2D convolution depending on input_shape. Allows usage of multiple (color) channels by passing in_channels > 1. Args: input_shape: Dimensionality of input, e.g., (28,) for 1D, (28, 28) for 2D. in_channels: Number of image channels, default 1. out_channels_per_layer: Number of out convolutional out_channels for each layer. Must match the number of layers passed below. num_cnn_layers: Number of convolutional layers. num_linear_layers: Number fully connected layer. num_linear_units: Number of hidden units in fully-connected layers. output_dim: Number of output units of the final layer. kernel_size: Kernel size for both convolutional layers. pool_size: pool size for MaxPool1d operation after the convolutional layers. """ super(CNNEmbedding, self).__init__() assert isinstance(input_shape, Tuple), ( "input_shape must be a Tuple of size 1 or 2, e.g., (width, [height])." ) assert ( 0 < len(input_shape) < 3 ), """input_shape must be a Tuple of size 1 or 2, e.g., (width, [height]). Number of input channels are passed separately""" use_2d_cnn = len(input_shape) == 2 conv_module = nn.Conv2d if use_2d_cnn else nn.Conv1d pool_module = nn.MaxPool2d if use_2d_cnn else nn.MaxPool1d if out_channels_per_layer is None: out_channels_per_layer = [6, 12] assert len(out_channels_per_layer) == num_conv_layers, ( "out_channels needs as many entries as num_cnn_layers." ) # define input shape with channel self.input_shape = (in_channels, *input_shape) # Construct CNN feature extractor. cnn_layers = [] cnn_output_size = input_shape stride = 1 padding = 1 for ii in range(num_conv_layers): # Defining another 2D convolution layer conv_layer = conv_module( in_channels=in_channels if ii == 0 else out_channels_per_layer[ii - 1], out_channels=out_channels_per_layer[ii], kernel_size=kernel_size, stride=stride, padding=padding, ) pool = pool_module(kernel_size=pool_kernel_size) cnn_layers += [conv_layer, nn.ReLU(inplace=True), pool] # Calculate change of output size of each CNN layer cnn_output_size = get_new_cnn_output_size(cnn_output_size, conv_layer, pool) assert all( cnn_output_size ), f"""CNN output size is zero at layer {ii + 1}. Either reduce num_cnn_layers to {ii} or adjust the kernel_size and pool_kernel_size accordingly.""" self.cnn_subnet = nn.Sequential(*cnn_layers) # Construct linear post processing net. self.linear_subnet = FCEmbedding( input_dim=out_channels_per_layer[-1] * torch.prod(torch.tensor(cnn_output_size)), output_dim=output_dim, num_layers=num_linear_layers, num_hiddens=num_linear_units, ) # Defining the forward pass
[docs] def forward(self, x: Tensor) -> Tensor: batch_size = x.size(0) # reshape to account for single channel data. x = self.cnn_subnet(x.view(batch_size, *self.input_shape)) # flatten for linear layers. x = x.view(batch_size, -1) x = self.linear_subnet(x) return x