Source code for sbi.inference.trainers.marginal.marginal_base

# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

import time
from copy import deepcopy
from datetime import datetime
from pathlib import Path
from typing import Any, Callable, Optional, Tuple, Union
from warnings import warn

import torch
from torch import Tensor
from torch.nn.utils.clip_grad import clip_grad_norm_
from torch.optim.adam import Adam
from torch.utils import data
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.tensorboard.writer import SummaryWriter

from sbi.neural_nets.estimators import UnconditionalDensityEstimator
from sbi.neural_nets.estimators.shape_handling import (
    reshape_to_batch_event,
)
from sbi.neural_nets.factory import ZukoFlowType, marginal_nn
from sbi.sbi_types import Tracker
from sbi.utils import check_estimator_arg, get_log_root
from sbi.utils.torchutils import assert_all_finite, process_device
from sbi.utils.tracking import TensorBoardTracker

DensityEstimatorType = Union[ZukoFlowType, str, Callable[[Tensor], Any]]


[docs] class MarginalTrainer: r"""Utility class for training a marginal density estimator $p(x)$. The marginal density estimator learns the distribution of simulation outputs $x$ without conditioning on parameters. In the sbi toolbox, it is primarily used for misspecification diagnostics by comparing the estimated marginal $p(x)$ to the observed data distribution. Example: -------- :: import torch from sbi.inference import MarginalTrainer # 1. Simulate data theta = torch.randn(100, 3) x = theta + torch.randn_like(theta) * 0.1 # 2. Train marginal density estimator marginal_trainer = MarginalTrainer(density_estimator="nsf") marginal_estimator = marginal_trainer.append_samples(x).train() # 3. Evaluate log probability of new observations x_new = torch.randn(10, 3) log_prob = marginal_estimator.log_prob(x_new) """ def __init__( self, density_estimator: DensityEstimatorType = ZukoFlowType.NSF, device: str = "cpu", summary_writer: Optional[SummaryWriter] = None, tracker: Optional[Tracker] = None, show_progress_bars: bool = True, ): """Initialize the marginal trainer. Args: density_estimator: Density estimator to use. Can be a string or a callable. If a string, it must be one of the following: - "bpf": Bijector Polynomial Flow - "maf": Masked Autoregressive Flow - "naf": Neural Autoregressive Flow - "ncsf": Neural Conditional Spline Flow - "nsf": Neural Spline Flow - "sospf": Sum-of-Squares Polynomial Flow - "unaf": Unconditional Neural Autoregressive Flow If a callable, it must be a function that returns a neural network that inherits from `UnconditionalDensityEstimator`. device: Device to use for training. Can be "cpu" or "cuda". summary_writer: Deprecated alias for the TensorBoard summary writer. Use ``tracker`` instead. tracker: Tracking adapter used to log training progress. If None, a TensorBoard tracker is created. show_progress_bars: Whether to show progress bars during training. """ self._device = process_device(device) self._neural_net = None self._show_progress_bars = show_progress_bars self._val_loss = float("Inf") if summary_writer is not None: warn( "summary_writer is deprecated. Use tracker instead.", FutureWarning, stacklevel=2, ) if tracker is not None: raise ValueError("Pass only one of summary_writer or tracker.") tracker = TensorBoardTracker(summary_writer) self._tracker = self._default_tracker() if tracker is None else tracker # Logging during training. self._summary = dict( epochs_trained=[], best_validation_loss=[], validation_loss=[], training_loss=[], epoch_durations_sec=[], ) if isinstance(density_estimator, ZukoFlowType): check_estimator_arg(density_estimator.value) self._build_neural_net = marginal_nn(model=density_estimator) elif isinstance(density_estimator, str): check_estimator_arg(density_estimator) self._build_neural_net = marginal_nn( model=ZukoFlowType(density_estimator.lower()) ) elif callable(density_estimator): check_estimator_arg(density_estimator) self._build_neural_net = density_estimator else: raise ValueError( "density_estimator must be either a DensityEstimator, str, or a " "Callable[[Tensor], Any]." )
[docs] def get_dataloaders( self, training_batch_size: int = 200, validation_fraction: float = 0.1, dataloader_kwargs: Optional[dict] = None, ) -> Tuple[data.DataLoader, data.DataLoader]: """Return training and validation dataloaders.""" x = self.get_samples() dataset = data.TensorDataset(x) # Get total number of training examples. num_examples = x.size(0) # Select random train and validation splits from (theta, x) pairs. num_training_examples = int((1 - validation_fraction) * num_examples) num_validation_examples = num_examples - num_training_examples # Separate indices for training and validation permuted_indices = torch.randperm(num_examples) self.train_indices, self.val_indices = ( permuted_indices[:num_training_examples], permuted_indices[num_training_examples:], ) train_loader_kwargs = { "batch_size": min(training_batch_size, num_training_examples), "drop_last": True, "sampler": SubsetRandomSampler(self.train_indices.tolist()), } val_loader_kwargs = { "batch_size": min(training_batch_size, num_validation_examples), "shuffle": False, "drop_last": True, "sampler": SubsetRandomSampler(self.val_indices.tolist()), } if dataloader_kwargs is not None: train_loader_kwargs = dict(train_loader_kwargs, **dataloader_kwargs) val_loader_kwargs = dict(val_loader_kwargs, **dataloader_kwargs) train_loader = data.DataLoader(dataset, **train_loader_kwargs) val_loader = data.DataLoader(dataset, **val_loader_kwargs) return train_loader, val_loader
[docs] def append_samples(self, x) -> "MarginalTrainer": self._x = x return self
[docs] def get_samples(self) -> Tensor: return self._x
[docs] def loss(self, x: Tensor) -> Tensor: """Return loss. The loss is the negative log prob Returns: Negative log prob. """ if self._neural_net is None: raise ValueError( "Neural network has not been initialized. Please call `train` first." ) else: x = reshape_to_batch_event(x, event_shape=self._neural_net.input_shape) loss = self._neural_net.loss(x) assert_all_finite(loss, "loss") return loss
[docs] def train( self, training_batch_size: int = 200, learning_rate: float = 5e-4, validation_fraction: float = 0.1, stop_after_epochs: int = 20, max_num_epochs: int = 2**31 - 1, clip_max_norm: Optional[float] = 5.0, dataloader_kwargs: Optional[dict] = None, ) -> UnconditionalDensityEstimator: r"""Return density estimator that approximates the distribution $p(x)$. Args: training_batch_size: Training batch size. learning_rate: Learning rate for Adam optimizer. validation_fraction: The fraction of data to use for validation. stop_after_epochs: The number of epochs to wait for improvement on the validation set before terminating training. max_num_epochs: Maximum number of epochs to run. If reached, we stop training even when the validation loss is still decreasing. Otherwise, we train until validation loss increases (see also `stop_after_epochs`). clip_max_norm: Value at which to clip the total gradient norm in order to prevent exploding gradients. Use None for no clipping. show_train_summary: Whether to print the number of epochs and validation loss after the training. dataloader_kwargs: Additional or updated kwargs to be passed to the training and validation dataloaders (like, e.g., a collate_fn) Returns: Density estimator that approximates the distribution $p(\theta|x)$. """ # fake round setting just for compatibility with NeuralInference self._round = 0 train_loader, val_loader = self.get_dataloaders( training_batch_size, validation_fraction, dataloader_kwargs=dataloader_kwargs, ) if self._neural_net is None: # Get x to initialize NN x = self.get_samples() # Use only training data for building the neural net (z-scoring transforms) self._neural_net = self._build_neural_net( x[self.train_indices].to("cpu"), ) self.optimizer = Adam(list(self._neural_net.parameters()), lr=learning_rate) self.epoch, self._val_loss = 0, float("Inf") self._neural_net.to(self._device) while self.epoch <= max_num_epochs and not self._converged( self.epoch, stop_after_epochs ): # Train for a single epoch. self._neural_net.train() train_loss_sum = 0 epoch_start_time = time.time() for batch in train_loader: self.optimizer.zero_grad() # Get batches on current device. x_batch = batch[0].to(self._device) train_losses = self.loss(x_batch) train_loss = torch.mean(train_losses) train_loss_sum += train_losses.sum().item() train_loss.backward() if clip_max_norm is not None: clip_grad_norm_( self._neural_net.parameters(), max_norm=clip_max_norm ) self.optimizer.step() self.epoch += 1 train_loss_average = train_loss_sum / ( len(train_loader) * train_loader.batch_size # type: ignore ) self._summary["training_loss"].append(train_loss_average) # Calculate validation performance. self._neural_net.eval() val_loss_sum = 0 with torch.no_grad(): for batch in val_loader: x_batch = batch[0].to(self._device) # Take negative loss here to get validation log_prob. val_losses = self.loss(x_batch) val_loss_sum += val_losses.sum().item() # Take mean over all validation samples. self._val_loss = val_loss_sum / ( len(val_loader) * val_loader.batch_size # type: ignore ) # Log validation loss for every epoch. self._summary["validation_loss"].append(self._val_loss) self._summary["epoch_durations_sec"].append(time.time() - epoch_start_time) self._maybe_show_progress(self._show_progress_bars, self.epoch) # Update summary. self._summary["epochs_trained"].append(self.epoch) self._summary["best_validation_loss"].append(self._best_val_loss) # Update tensorboard and summary dict. self._summarize(round_=self._round) # Avoid keeping the gradients in the resulting network, which can # cause memory leakage when benchmarking. self._neural_net.zero_grad(set_to_none=True) return self._neural_net
def _default_tracker(self) -> Tracker: """Return default tracker logging to a TensorBoard directory.""" method = self.__class__.__name__ logdir = Path( get_log_root(), method, datetime.now().isoformat().replace(":", "_") ) return TensorBoardTracker(SummaryWriter(logdir)) def _converged(self, epoch: int, stop_after_epochs: int) -> bool: """Return whether the training converged yet and save best model state so far. Checks for improvement in validation performance over previous epochs. Args: epoch: Current epoch in training. stop_after_epochs: How many fruitless epochs to let pass before stopping. Returns: Whether the training has stopped improving, i.e. has converged. """ converged = False assert self._neural_net is not None neural_net = self._neural_net # (Re)-start the epoch count with the first epoch or any improvement. if epoch == 0 or self._val_loss < self._best_val_loss: self._best_val_loss = self._val_loss self._epochs_since_last_improvement = 0 self._best_model_state_dict = deepcopy(neural_net.state_dict()) else: self._epochs_since_last_improvement += 1 # If no validation improvement over many epochs, stop training. if self._epochs_since_last_improvement > stop_after_epochs - 1: neural_net.load_state_dict(self._best_model_state_dict) converged = True return converged def _summarize( self, round_: int, ) -> None: """Update the tracker with statistics for a given round. During training several performance statistics are added to the summary, e.g., using `self._summary['key'].append(value)`. This function writes these values into the tracker. Args: round: index of round Scalar tags: - epochs_trained: number of epochs trained - best_validation_loss: best validation loss (for each round). - validation_loss: validation loss for every epoch (for each round). - training_loss training loss for every epoch (for each round). - epoch_durations_sec epoch duration for every epoch (for each round) """ # Add most recent training stats to tracker. self._tracker.log_metric( name="epochs_trained", value=self._summary["epochs_trained"][-1], step=round_ + 1, ) self._tracker.log_metric( name="best_validation_loss", value=self._summary["best_validation_loss"][-1], step=round_ + 1, ) # Add validation loss for every epoch. # Offset with all previous epochs. offset = ( torch.tensor(self._summary["epochs_trained"][:-1], dtype=torch.int) .sum() .item() ) for i, vlp in enumerate(self._summary["validation_loss"][offset:]): self._tracker.log_metric( name="validation_loss", value=vlp, step=int(offset + i), ) for i, tlp in enumerate(self._summary["training_loss"][offset:]): self._tracker.log_metric( name="training_loss", value=tlp, step=int(offset + i), ) for i, eds in enumerate(self._summary["epoch_durations_sec"][offset:]): self._tracker.log_metric( name="epoch_durations_sec", value=eds, step=int(offset + i), ) self._tracker.flush() @staticmethod def _maybe_show_progress(show: bool, epoch: int) -> None: if show: # end="\r" deletes the print statement when a new one appears. # https://stackoverflow.com/questions/3419984/. `\r` in the beginning due # to #330. print("\r", f"Training neural network. Epochs trained: {epoch}", end="")