Source code for sbi.inference.trainers.nre.bnre

# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from typing import Dict, Optional, Sequence, Union

import torch
from torch import Tensor, nn, ones
from torch.distributions import Distribution
from torch.utils.tensorboard.writer import SummaryWriter

from sbi.inference.trainers._contracts import LossArgs, LossArgsBNRE
from sbi.inference.trainers.nre.nre_a import NRE_A
from sbi.neural_nets.estimators.base import ConditionalEstimatorBuilder
from sbi.neural_nets.ratio_estimators import RatioEstimator
from sbi.sbi_types import Tracker
from sbi.utils.sbiutils import del_entries
from sbi.utils.torchutils import assert_all_finite


[docs] class BNRE(NRE_A): r"""Balanced Neural Ratio Estimation (BNRE) as in Delaunoy et al. (2022) [1]. BNRE is a variation of NRE-A that adds a balancing regularizer to the binary cross-entropy loss. This regularizer encourages the classifier to predict equal probabilities for joint and marginal samples on average, which can lead to more conservative and reliable posterior approximations. BNRE is particularly useful when robustness is prioritized over tightness of the posterior. NRE can be run multi-round without need for correction, but requires running potentially expensive posterior sampling in each round. [1] Towards Reliable Simulation-Based Inference with Balanced Neural Ratio Estimation, Delaunoy, A., Hermans, J., Rozet, F., Wehenkel, A., & Louppe, G., NeurIPS 2022. https://arxiv.org/abs/2208.13624 Example: -------- :: import torch from sbi.inference import BNRE from sbi.utils import BoxUniform # 1. Setup prior and simulate data prior = BoxUniform(low=torch.zeros(3), high=torch.ones(3)) theta = prior.sample((100,)) x = theta + torch.randn_like(theta) * 0.1 # 2. Train balanced ratio estimator inference = BNRE(prior=prior) # Note: regularization_strength needs to be tuned carefully for your problem ratio_estimator = inference.append_simulations(theta, x).train( regularization_strength=100.0 ) # 3. Build posterior posterior = inference.build_posterior(ratio_estimator) # 4. Sample from posterior x_o = torch.randn(1, 3) samples = posterior.sample((1000,), x=x_o) """ def __init__( self, prior: Optional[Distribution] = None, classifier: Union[str, ConditionalEstimatorBuilder[RatioEstimator]] = "resnet", device: str = "cpu", logging_level: Union[int, str] = "warning", summary_writer: Optional[SummaryWriter] = None, tracker: Optional[Tracker] = None, show_progress_bars: bool = True, ): r"""Balanced neural ratio estimation (BNRE). Args: prior: A probability distribution that expresses prior knowledge about the parameters, e.g. which ranges are meaningful for them. If `None`, the prior must be passed to `.build_posterior()`. classifier: Classifier trained to approximate likelihood ratios. If it is a string, use a pre-configured network of the provided type (one of linear, mlp, resnet), or a callable that implements the `ConditionalEstimatorBuilder` protocol. The callable will be called with the first batch of simulations (theta, x), which can thus be used for shape inference and potentially for z-scoring. It returns a `RatioEstimator`. device: Training device, e.g., "cpu", "cuda" or "cuda:{0, 1, ...}". logging_level: Minimum severity of messages to log. One of the strings INFO, WARNING, DEBUG, ERROR and CRITICAL. summary_writer: Deprecated alias for the TensorBoard summary writer. Use ``tracker`` instead. tracker: Tracking adapter used to log training metrics. If None, a TensorBoard tracker is used with a default log directory. show_progress_bars: Whether to show a progressbar during simulation and sampling. """ kwargs = del_entries(locals(), entries=("self", "__class__")) super().__init__(**kwargs)
[docs] def train( self, regularization_strength: float = 100.0, training_batch_size: int = 200, learning_rate: float = 5e-4, validation_fraction: float = 0.1, stop_after_epochs: int = 20, max_num_epochs: int = 2**31 - 1, clip_max_norm: Optional[float] = 5.0, resume_training: bool = False, discard_prior_samples: bool = False, retrain_from_scratch: bool = False, show_train_summary: bool = False, dataloader_kwargs: Optional[Dict] = None, ) -> RatioEstimator: r"""Return classifier that approximates the ratio $p(\theta,x)/p(\theta)p(x)$. Args: regularization_strength: The multiplicative coefficient applied to the balancing regularizer ($\lambda$). training_batch_size: Training batch size. learning_rate: Learning rate for Adam optimizer. validation_fraction: The fraction of data to use for validation. stop_after_epochs: The number of epochs to wait for improvement on the validation set before terminating training. max_num_epochs: Maximum number of epochs to run. If reached, we stop training even when the validation loss is still decreasing. Otherwise, we train until validation loss increases (see also `stop_after_epochs`). clip_max_norm: Value at which to clip the total gradient norm in order to prevent exploding gradients. Use None for no clipping. exclude_invalid_x: Whether to exclude simulation outputs `x=NaN` or `x=±∞` during training. Expect errors, silent or explicit, when `False`. resume_training: Can be used in case training time is limited, e.g. on a cluster. If `True`, the split between train and validation set, the optimizer, the number of epochs, and the best validation log-prob will be restored from the last time `.train()` was called. discard_prior_samples: Whether to discard samples simulated in round 1, i.e. from the prior. Training may be sped up by ignoring such less targeted samples. retrain_from_scratch: Whether to retrain the conditional density estimator for the posterior from scratch each round. show_train_summary: Whether to print the number of epochs and validation loss and leakage after the training. dataloader_kwargs: Additional or updated kwargs to be passed to the training and validation dataloaders (like, e.g., a collate_fn) Returns: Classifier that approximates the ratio $p(\theta,x)/p(\theta)p(x)$. """ kwargs = del_entries(locals(), entries=("self", "__class__")) # Configure _loss function parameters by initializing LossArgsBNRE # with the given regularization strength. kwargs["loss_kwargs"] = LossArgsBNRE( regularization_strength=kwargs.pop("regularization_strength"), ) return super().train(**kwargs)
def _loss( self, theta: Tensor, x: Tensor, num_atoms: int, regularization_strength: float ) -> Tensor: """Returns the binary cross-entropy loss for the trained classifier. The classifier takes as input a $(\theta,x)$ pair. It is trained to predict 1 if the pair was sampled from the joint $p(\theta,x)$, and to predict 0 if the pair was sampled from the marginals $p(\theta)p(x)$. """ assert theta.shape[0] == x.shape[0], "Batch sizes for theta and x must match." batch_size = theta.shape[0] logits = self._classifier_logits(theta, x, num_atoms) likelihood = torch.sigmoid(logits).squeeze() # Alternating pairs where there is one sampled from the joint and one # sampled from the marginals. The first element is sampled from the # joint p(theta, x) and is labelled 1. The second element is sampled # from the marginals p(theta)p(x) and is labelled 0. And so on. labels = ones(2 * batch_size, device=self._device) # two atoms labels[1::2] = 0.0 # Binary cross entropy to learn the likelihood (AALR-specific) bce = nn.BCELoss()(likelihood, labels) # Balancing regularizer regularizer = ( (torch.sigmoid(logits[0::2]) + torch.sigmoid(logits[1::2]) - 1) .mean() .square() ) loss = bce + regularization_strength * regularizer assert_all_finite(loss, "BNRE loss") return loss def _get_losses(self, batch: Sequence[Tensor], loss_args: LossArgs) -> Tensor: """Overrides the parent class method to check the type of loss_args.""" if not isinstance(loss_args, LossArgsBNRE): raise TypeError( "Expected type of loss_args to be LossArgsBNRE," f" but got {type(loss_args)}" ) return super()._get_losses(batch=batch, loss_args=loss_args)