Source code for sbi.neural_nets.embedding_nets.SC_embedding

# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

# This code is based on on the following three papers:

# Lingsch et al. (2024) FUSE: Fast Unified Simulation and Estimation for PDEs
# (https://proceedings.neurips.cc/paper_files/paper/2024/file/266c0f191b04cbbbe529016d0edc847e-Paper-Conference.pdf)
#
# Lingsch et al. (2024) Beyond Regular Grids: Fourier-Based Neural Operators
# on Arbitrary Domains
# (https://arxiv.org/pdf/2305.19663)

# Li et al. (2021) Fourier Neural Operator for Parametric Partial Differential Equations
# (https://openreview.net/pdf?id=c8P9NQVtmnO)

# and partially adapted from the following repository:
# https://github.com/camlab-ethz/FUSE

from typing import Optional, Tuple

import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor, nn


class VFT:
    """Class for performing Fourier transformations for non-equally
    and equally spaced 1d grids.

    It provides a function for creating grid-dependent operator V to compute the
    Forward Fourier transform X of data x with X = V*x.
    The inverse Fourier transform can then be computed by x = V_inv*X with
    V_inv = transpose(conjugate(V)).

    Adapted from: Lingsch et al. (2024) Beyond Regular Grids: Fourier-Based
    Neural Operators on Arbitrary Domains

    Args:
        batch_size: Training batch size
        n_points: Number of 1d grid points
        modes: number of Fourier modes that should be used
            (maximal floor(n_points/2) + 1)
        point_positions: Grid point positions of shape (batch_size, n_points).
            If not provided, equispaced points are used. Positions have to be
            normalized with domain length.
    """

    def __init__(
        self,
        batch_size: int,
        n_points: int,
        modes: int,
        point_positions: Optional[Tensor] = None,
    ):
        self.number_points = n_points
        self.batch_size = batch_size
        self.modes = modes

        if point_positions is not None:
            new_times = point_positions[:, None, :]
        else:
            new_times = (
                (torch.arange(self.number_points) / self.number_points).repeat(
                    self.batch_size, 1
                )
            )[:, None, :]

        self.new_times = new_times * 2 * np.pi

        self.X_ = torch.arange(modes).repeat(self.batch_size, 1)[:, :, None].float()
        # V_fwd: (batch, modes, points) V_inf: (batch, points, modes)
        self.V_fwd, self.V_inv = self.make_matrix()

    def make_matrix(self) -> Tuple[Tensor, Tensor]:
        """Create matrix operators V and V_inf for forward and backward
        Fourier transformation on arbitrary grids
        """

        X_mat = torch.bmm(self.X_, self.new_times)
        forward_mat = torch.exp(-1j * (X_mat))

        inverse_mat = torch.conj(forward_mat.clone()).permute(0, 2, 1)

        return forward_mat, inverse_mat

    def forward(self, data: Tensor, norm: str = 'forward') -> Tensor:
        """Perform forward Fourier transformation
        Args:
            data: Input data with shape (batch_size, n_points, conv_channel)
        """
        if norm == 'forward':
            data_fwd = torch.bmm(self.V_fwd, data) / self.number_points
        elif norm == 'ortho':
            data_fwd = torch.bmm(self.V_fwd, data) / np.sqrt(self.number_points)
        elif norm == 'backward':
            data_fwd = torch.bmm(self.V_fwd, data)

        return data_fwd  # (batch, modes, conv_channels)

    def inverse(self, data: Tensor, norm: str = 'forward') -> Tensor:
        """Perform inverse Fourier transformation
        Args:
            data: Input data with shape (batch_size, modes, conv_channel)
        """
        if norm == 'backward':
            data_inv = torch.bmm(self.V_inv, data) / self.number_points
        elif norm == 'ortho':
            data_inv = torch.bmm(self.V_inv, data) / np.sqrt(self.number_points)
        elif norm == 'forward':
            data_inv = torch.bmm(self.V_inv, data)

        return data_inv  # (batch, n_points, conv_channels)


class SpectralConv1d_SMM(nn.Module):
    """
    A 1D spectral convolutional layer using the Fourier transform.
    This layer applies a learned complex multiplication in the frequency domain.

    Adapted from:
    - Lingsch et al. (2024) FUSE: Fast Unified Simulation and Estimation for PDEs
    - Li et al. (2021) Fourier Neural Operator for Parametric Partial Differential
                        Equations

    Args:
        in_channels: Number of input channels.
        out_channels: Number of output channels.
        modes: Number of Fourier modes to multiply,
            at most floor(N/2) + 1.
    """

    def __init__(self, in_channels: int, out_channels: int, modes: int):
        super(SpectralConv1d_SMM, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes = modes

        self.scale = 1 / (in_channels * out_channels)
        self.weights1 = nn.Parameter(
            self.scale
            * torch.rand(in_channels, out_channels, self.modes, dtype=torch.cfloat)
        )

    def compl_mul1d(self, input: Tensor, weights: Tensor) -> Tensor:
        """
        Performs complex multiplication in the Fourier domain.

        Args:
            input: Input tensor of shape (batch, in_channels, modes).
            weights: Weight tensor of shape (in_channels, out_channels, modes).

        Returns:
            torch.Tensor: Output tensor of shape (batch, out_channels, modes).
        """

        return torch.einsum("bix,iox->box", input, weights)

    def forward(self, x: Tensor, transform: VFT) -> Tensor:
        """
        Forward pass of the spectral convolution layer.

        Args:
            x: Input tensor of shape (batch, n_points, in_channels).
            transform: Fourier transform operator with forward and inverse methods.

        Returns:
            The real part of the transformed output tensor
            with shape (batch, points, out_channels).
        """
        # Compute Fourier coefficients
        x_ft = transform.forward(x.to(torch.complex64), norm='forward')
        x_ft = x_ft.permute(0, 2, 1)
        out_ft = self.compl_mul1d(x_ft, self.weights1)
        x_ft = out_ft.permute(0, 2, 1)

        # Return to physical space
        x = transform.inverse(x_ft, norm='forward')

        return x.real

    def last_layer(self, x: Tensor, transform: VFT) -> Tensor:
        """
        Last convolutional layer returning Fourier coefficients to be used as embedding

        Args:
            x: Input tensor of shape (batch, points, in_channels).
            transform: Fourier transform operator with forward and inverse methods.

        Returns:
            Transformed output tensor of shape (batch, 2*modes, out_channels).
        """

        # Compute Fourier coeffcients
        x_ft = transform.forward(x.to(torch.complex64), norm='forward')
        x_ft = x_ft.permute(0, 2, 1)
        x_ft = self.compl_mul1d(x_ft, self.weights1)  # (batch, conv_channels, modes)
        x_ft = x_ft.permute(0, 2, 1)  # (batch, modes, conv_channels)
        x_ft = torch.view_as_real(x_ft)  # (batch, modes, conv_channels, 2)
        x_ft = x_ft.permute(0, 1, 3, 2)
        x_ft = x_ft.reshape(x.shape[0], 2 * self.modes, self.out_channels)

        return x_ft


[docs] class SpectralConvEmbedding(nn.Module): def __init__( self, in_channels: int, modes: int = 10, out_channels: int = 1, conv_channels: int = 5, num_layers: int = 3, ): """SpectralConvEmbedding is a neural network module that performs convolution in Fourier space for 1D input data (that can have multiple channels). It uses a series of spectral convolution layers and pointwise convolution layers to transform the input tensor. Adapted from: Lingsch et al. (2024) Beyond Regular Grids: Fourier-Based Neural Operators on Arbitrary Domains Args: in_channels: Number of channels in the input data. modes: Number of modes considered in the spectral convolution, at most floor(n_points/2) + 1. out_channels: number of channels for final output. conv_channels: Number of going in and out convolutional layer. num_layers: Number of convolution layers. """ super().__init__() self.modes = modes self.in_channels = in_channels self.out_channels = out_channels self.conv_channels = conv_channels self.num_layers = num_layers # Initialize fully connected layer to raise number of # input channels to number of convolutional channels self.fc0 = nn.Linear(self.in_channels, self.conv_channels) # Inititalize layers performing convolution in Fourier space self.conv_layers = nn.ModuleList([ SpectralConv1d_SMM(self.conv_channels, self.conv_channels, self.modes) for _ in range(self.num_layers) ]) # Initialize layer performing pointwise convolution self.w_layers = nn.ModuleList([ nn.Conv1d(self.conv_channels, self.conv_channels, 1) for _ in range(self.num_layers) ]) # Initialize last convolutional layer with output in Fourier space self.conv_last = SpectralConv1d_SMM( self.conv_channels, self.conv_channels, self.modes ) # Initialize fully connected layer to reduce number of output channels self.fc_last = nn.Linear(self.conv_channels, self.out_channels)
[docs] def forward(self, x: Tensor) -> Tensor: """Network forward pass. Args: x: 3D input tensor (batch_size, in_channels, n_points) for equi-spaced data or 4D tensor (batch_size, 2, in_channels, n_points) for non-equispaced data, where we additionally pass the point positions in the second dimension, repeating the same point positions for each channel. For non-equispaced data, the positions have to be normalized with physical domain length. Exemplary code: # Example for equispaced grid data with batch size of 256, 3 channels and # sequence length of 500 data_equispaced = torch.rand(256, 3, 500) embedding_net = SpectralConvEmbedding(modes=15, in_channels=3, out_channels=1, conv_channels=5, num_layers=4) neural_posterior = posterior_nn(model="nsf", embedding_net=embedding_net) inference = SNPE(prior=sbi_prior, density_estimator=neural_posterior) _ = inference.append_simulations(theta, data_equispaced) # Example for non-equispaced data with batch size of 256, 3 channels and # sequence length of 500 irregular_positions = torch.rand(500) # non-equally spaced positions in [0;1] irregular_positions, indices = torch.sort(irregular_positions, 0) irregular_positions = irregular_positions.repeat(256, 3, 1) random_data = torch.rand(256, 3, 500) data_nonequispaced = torch.zeros(256, 2, 3, 500) data_nonequispaced[:, 0, :, :] = random_data data_nonequispaced[:, 1, :, :] = irregular_positions embedding_net = SpectralConvEmbedding(modes=15, in_channels=3, out_channels=1, conv_channels=5, num_layers=4) neural_posterior = posterior_nn(model="nsf", embedding_net=embedding_net) inference = SNPE(prior=sbi_prior, density_estimator=neural_posterior) _ = inference.append_simulations(theta, data_nonequispaced) Returns: Network output (batch_size, out_channels * 2 * modes). """ batch_size = x.shape[0] # Check dimension of input data and reshape it if x.ndim == 3: x = x.permute(0, 2, 1) # (batch, n_points, in_channels) point_positions = None elif x.ndim == 4: point_positions = x[:, 1, 0, :] x = x[:, 0, :, :].permute(0, 2, 1) else: raise ValueError( 'Input tensor should be 3D (batch_size, channels, n_points) ' 'or 4D (batch_size, 2, channels, n_points). ', f'The tensor that was passed has shape {x.shape}.', ) n_points = x.shape[1] assert self.modes <= n_points // 2 + 1, ( "Modes should be at most floor(n_points/2) + 1" ) x = self.fc0(x) # (batch_size, n_points, in_channels) # Initialize Fourier transform for arbitrarily spaced points fourier_transform = VFT(batch_size, n_points, self.modes, point_positions) # Send the data through Fourier layers, output in original space for conv, w in zip(self.conv_layers, self.w_layers, strict=False): x1 = conv(x, fourier_transform) x2 = w(x.permute(0, 2, 1)) x = x1 + x2.permute(0, 2, 1) x = F.gelu(x) # Send data through last convolutional layer which returns data in Fourier space x_spec = self.conv_last.last_layer( x, fourier_transform ) # (batch, 2*modes, out_channels) # Reduce the number of channels with last layer x_spec = self.fc_last(x_spec) # (batch, 2*modes, out_channels) return x_spec.reshape(batch_size, -1)