Source code for sbi.neural_nets.embedding_nets.causal_cnn

# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from typing import List, Optional, Tuple, Union

from torch import Tensor, nn

from sbi.neural_nets.embedding_nets.cnn import calculate_filter_output_size


def causalConv1d(
    in_channels: int,
    out_channels: int,
    kernel_size: int,
    dilation: int = 1,
    stride: int = 1,
) -> nn.Module:
    """Returns a causal convolution by left padding the input

    Args:
        in_channels: number of input channels
        out_channels: number of output channels wanted
        kernel_size: wanted kernel size
        dilation: dilation to use in the convolution.
        stride: stride to use in the convolution.
            Stride and dilation cannot both be > 1.

    Returns:
        An nn.Sequential object that represents a 1D causal convolution.
    """
    assert not (dilation > 1 and stride > 1), (
        "we don't allow combining stride with dilation."
    )
    padding_size = dilation * (kernel_size - 1)
    padding = nn.ZeroPad1d(padding=(padding_size, 0))
    conv_layer = nn.Conv1d(
        in_channels=in_channels,
        out_channels=out_channels,
        kernel_size=kernel_size,
        dilation=dilation,
        stride=stride,
        padding=0,
    )
    return nn.Sequential(padding, conv_layer)


def WaveNetSRLikeAggregator(
    in_channels: int,
    num_timepoints: int,
    output_dim: int,
    activation: nn.Module = nn.LeakyReLU(inplace=True),
    kernel_sizes: Optional[List] = None,
    intermediate_channel_sizes: Optional[List] = None,
    stride_sizes: Union[int, List] = 1,
) -> nn.Module:
    """
    Creates a non-causal 1D CNN aggregator based on the WaveNet speach recognition task

    By default this function creates an aggregator with two CNN layers,
    after every convolution a maxpooling operation halves the number of timepoints.
    The final CNN will have as many channels as the desired output dimension.
    A global average pooling operation is applied in the end. The dimension of the
    output will thus be (batch_size, output_dim, 1) regardless of the input size.

    Args:
        in_channels: number of channels at input.
        num_timepoints: length of the input.
        output dim: wanted number of features as output.
        activation: activation to apply after the convolution.
        kernel_sizes: (optional) alter the kernel size used and the number of CNN layers
            (through the length of the kernel size vector).
        intermediate_channel_sizes: (optional) alter the intermediate channel sizes
            used, should have length = len(kernel_sizes) - 1.
        stride_sizes = Optional alter the stride used, either a vector of
            len = len(kernel_sizes) or a single integer, in which case the same stride
            is used in every convolution.

    Returns:
        nn.Module object that contains a sequence of CNN and max_pool layer
            and finally a global average pooling layer.
    """
    aggregator_out_shape = (
        in_channels,
        num_timepoints,
    )
    if kernel_sizes is None:
        kernel_sizes = [
            min(9, aggregator_out_shape[-1]),
            min(5, int(aggregator_out_shape[-1] / 2)),
        ]
    if intermediate_channel_sizes is None:
        intermediate_channel_sizes = [64]
    assert len(intermediate_channel_sizes) == len(kernel_sizes) - 1, (
        "Provided kernel size list should be exactly one element longer "
        "than channel size list."
    )
    intermediate_channel_sizes += [output_dim]
    if isinstance(stride_sizes, List):
        assert len(stride_sizes) == len(kernel_sizes), (
            "Provided stride size list should be have the same size as"
            "the kernel size list."
        )
    else:
        stride_sizes = [stride_sizes] * len(kernel_sizes)

    non_causal_layers = []
    for ll in range(len(kernel_sizes)):
        print(aggregator_out_shape)
        conv_layer = nn.Conv1d(
            in_channels=in_channels if ll == 0 else intermediate_channel_sizes[ll - 1],
            out_channels=intermediate_channel_sizes[ll],
            kernel_size=kernel_sizes[ll],
            stride=stride_sizes[ll],
            padding='same',
        )
        maxpool = nn.MaxPool1d(kernel_size=2 if aggregator_out_shape[-1] > 2 else 1)
        non_causal_layers += [conv_layer, activation, maxpool]
        aggregator_out_shape = (
            intermediate_channel_sizes[ll],
            int(
                calculate_filter_output_size(
                    aggregator_out_shape[-1],
                    (kernel_sizes[ll] - 1) / 2,
                    1,
                    kernel_sizes[ll],
                    stride_sizes[ll],
                )
                / 2
            ),
        )
    print(aggregator_out_shape)
    aggregator = nn.Sequential(*non_causal_layers, nn.AdaptiveAvgPool1d(1))
    return aggregator


[docs] class CausalCNNEmbedding(nn.Module): """Embedding network that uses 1D causal convolutions.""" def __init__( self, input_shape: Tuple, in_channels: int = 1, out_channels_per_layer: Optional[List] = None, dilation: Union[str, List] = "exponential_cyclic", num_conv_layers: int = 5, activation: nn.Module = nn.LeakyReLU(inplace=True), pool_kernel_size: int = 160, kernel_size: int = 2, aggregator: Optional[nn.Module] = None, output_dim: int = 20, ): """Intitialize embedding network that uses 1D causal convolutions. This is a simplified version of the architecture introduced for the speech recognition task in the WaveNet paper (van den Oord, et al. (2016)) After several dilated causal convolutions (that maintain the dimensionality of the input), an aggregator network is used to bring down the dimensionality. You can provide an aggregator network that you deem reasonable for your data. If you do not provide an aggregator network yourself, a default aggregator is used. This default aggregator is based on the WaveNet paper's description of their Speech Recognition Task, and uses non-causal convolutions and pooling layers, and global average poolingg to obtain a final low dimensional embedding. Args: input_shape: Dimensionality of the input e.g. (num_timepoints,), currently only 1D is supported. in_channels: Number of input channels, default = 1. out_channels_per_layer: number of out_channels for each layer, number of entries should correspond with num_conv_layers passed below. Default = 16 in every convolutional layer. dilation: type of dilation to use either one of "none" (dilation = 1 in every layer), "exponential" (increase dilation by a factor of 2 every layer), "exponential_cyclic" (as exponential, but reset to 1 after dilation = 2**9) or pass a list with dilation size per layer. By default the cyclic, exponential scheme from WaveNet is used. num_conv_layers: the number of causal convolutional layers kernel_size: size of the kernels in the causal convolutional layers. activation: activation function to use between convolutions, default = LeakyReLU. pool_kernel_size: pool size to use for the AvgPool1d operation after the causal convolutional layers. aggregator: aggregation net that reduces the dimensionality of the data to a low-dimensional embedding. output_dim: number of output units in the final layer when using the default aggregation """ super(CausalCNNEmbedding, self).__init__() assert isinstance(input_shape, Tuple), ( "input_shape must be a Tuple of size 1, e.g. (timepoints,)." ) assert len(input_shape) == 1, "Currently only 1D causal CNNs are supported." self.input_shape = (in_channels, *input_shape) total_timepoints = input_shape[0] assert total_timepoints >= pool_kernel_size, ( "Please ensure that the pool kernel size is not " "larger than the number of observed timepoints." ) if isinstance(dilation, str): match dilation.lower(): case "exponential_cyclic": max_dil_exp = 10 ## Use dilation scheme as in WaveNet paper dilation_per_layer = [ 2 ** (i % max_dil_exp) for i in range(num_conv_layers) ] case "exponential": dilation_per_layer = [2**i for i in range(num_conv_layers)] case "none": dilation_per_layer = [1] * num_conv_layers case _: raise ValueError( f"{dilation} is not a valid option, please use \"none\"," "\"exponential\",or \"exponential_cyclic\", or pass a list " "of dilation sizes." ) else: assert isinstance(dilation, List), ( "Please pass dilation size as list or a string option." ) dilation_per_layer = dilation assert max(dilation_per_layer) < total_timepoints, ( "Your maximal dilations size used is larger than the number of " "timepoints in your input, please provide a list with smaller dilations." ) if out_channels_per_layer is None: out_channels_per_layer = [16] * num_conv_layers causal_layers = [] for ll in range(num_conv_layers): causal_layers += [ causalConv1d( in_channels if ll == 0 else out_channels_per_layer[ll - 1], out_channels_per_layer[ll], kernel_size, dilation_per_layer[ll], 1, ), activation, ] self.causal_cnns = nn.Sequential(*causal_layers) self.pooling_layer = nn.AvgPool1d(kernel_size=pool_kernel_size) if aggregator is None: aggregator_out_shape = ( out_channels_per_layer[-1], int(total_timepoints / pool_kernel_size), ) assert aggregator_out_shape[-1] > 1, ( "Your dimensionality is already small," "Please ensure a larger input size or use a custom aggregator." ) aggregator = WaveNetSRLikeAggregator( aggregator_out_shape[0], aggregator_out_shape[-1], output_dim=output_dim, ) self.aggregation = aggregator
[docs] def forward(self, x: Tensor) -> Tensor: batch_size = x.size(0) x = x.view(batch_size, *self.input_shape) x = self.causal_cnns(x) x = self.pooling_layer(x) x = self.aggregation(x) # ensure flattening when aggregator uses global average pooling x = x.view(batch_size, -1) return x