Source code for sbi.inference.trainers.nre.nre_b

# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from typing import Dict, Optional, Union

import torch
from torch import Tensor
from torch.distributions import Distribution
from torch.utils.tensorboard.writer import SummaryWriter

from sbi.inference.trainers._contracts import LossArgsNRE
from sbi.inference.trainers.nre.nre_base import (
    RatioEstimatorTrainer,
)
from sbi.neural_nets.estimators.base import ConditionalEstimatorBuilder
from sbi.neural_nets.ratio_estimators import RatioEstimator
from sbi.sbi_types import Tracker
from sbi.utils.sbiutils import del_entries
from sbi.utils.torchutils import assert_all_finite


[docs] class NRE_B(RatioEstimatorTrainer): r"""Neural Ratio Estimation (NRE-B / SRE) as in Durkan et al. (2020) [1]. NRE-B is an extension of NRE-A that trains a neural classifier using a contrastive (1-out-of-K) loss to estimate the likelihood-to-evidence ratio. Instead of binary classification, it contrasts one sample from the joint $p(\theta, x)$ against $K-1$ samples from the marginals $p(\theta)p(x)$. This multi-class formulation improves training stability compared to NRE-A. NRE can be run multi-round without need for correction, but requires running potentially expensive posterior sampling in each round. [1] *On Contrastive Learning for Likelihood-free Inference*, Durkan et al., ICML 2020, https://arxiv.org/pdf/2002.03712 Example: -------- :: import torch from sbi.inference import NRE_B from sbi.utils import BoxUniform # 1. Setup prior and simulate data prior = BoxUniform(low=torch.zeros(3), high=torch.ones(3)) theta = prior.sample((100,)) x = theta + torch.randn_like(theta) * 0.1 # 2. Train ratio estimator with contrastive loss inference = NRE_B(prior=prior) ratio_estimator = inference.append_simulations(theta, x).train(num_atoms=10) # 3. Build posterior posterior = inference.build_posterior(ratio_estimator) # 4. Sample from posterior x_o = torch.randn(1, 3) samples = posterior.sample((1000,), x=x_o) """ def __init__( self, prior: Optional[Distribution] = None, classifier: Union[str, ConditionalEstimatorBuilder[RatioEstimator]] = "resnet", device: str = "cpu", logging_level: Union[int, str] = "warning", summary_writer: Optional[SummaryWriter] = None, tracker: Optional[Tracker] = None, show_progress_bars: bool = True, ): r"""Initialize NRE_B. Args: prior: A probability distribution that expresses prior knowledge about the parameters, e.g. which ranges are meaningful for them. If `None`, the prior must be passed to `.build_posterior()`. classifier: Classifier trained to approximate likelihood ratios. If it is a string, use a pre-configured network of the provided type (one of linear, mlp, resnet), or a callable that implements the `ConditionalEstimatorBuilder` protocol. The callable will be called with the first batch of simulations (theta, x), which can thus be used for shape inference and potentially for z-scoring. It returns a `RatioEstimator`. device: Training device, e.g., "cpu", "cuda" or "cuda:{0, 1, ...}". logging_level: Minimum severity of messages to log. One of the strings INFO, WARNING, DEBUG, ERROR and CRITICAL. summary_writer: Deprecated alias for the TensorBoard summary writer. Use ``tracker`` instead. tracker: Tracking adapter used to log training metrics. If None, a TensorBoard tracker is used with a default log directory. show_progress_bars: Whether to show a progressbar during simulation and sampling. """ kwargs = del_entries(locals(), entries=("self", "__class__")) super().__init__(**kwargs)
[docs] def train( self, num_atoms: int = 10, training_batch_size: int = 200, learning_rate: float = 5e-4, validation_fraction: float = 0.1, stop_after_epochs: int = 20, max_num_epochs: int = 2**31 - 1, clip_max_norm: Optional[float] = 5.0, resume_training: bool = False, discard_prior_samples: bool = False, retrain_from_scratch: bool = False, show_train_summary: bool = False, dataloader_kwargs: Optional[Dict] = None, ) -> RatioEstimator: r"""Return classifier that approximates the ratio $p(\theta,x)/p(\theta)p(x)$. Args: num_atoms: Number of atoms to use for classification. training_batch_size: Training batch size. learning_rate: Learning rate for Adam optimizer. validation_fraction: The fraction of data to use for validation. stop_after_epochs: The number of epochs to wait for improvement on the validation set before terminating training. max_num_epochs: Maximum number of epochs to run. If reached, we stop training even when the validation loss is still decreasing. Otherwise, we train until validation loss increases (see also `stop_after_epochs`). clip_max_norm: Value at which to clip the total gradient norm in order to prevent exploding gradients. Use None for no clipping. resume_training: Can be used in case training time is limited, e.g. on a cluster. If `True`, the split between train and validation set, the optimizer, the number of epochs, and the best validation log-prob will be restored from the last time `.train()` was called. discard_prior_samples: Whether to discard samples simulated in round 1, i.e. from the prior. Training may be sped up by ignoring such less targeted samples. retrain_from_scratch: Whether to retrain the conditional density estimator for the posterior from scratch each round. show_train_summary: Whether to print the number of epochs and validation loss and leakage after the training. dataloader_kwargs: Additional or updated kwargs to be passed to the training and validation dataloaders (like, e.g., a collate_fn) Returns: Classifier that approximates the ratio $p(\theta,x)/p(\theta)p(x)$. """ kwargs = del_entries(locals(), entries=("self", "__class__")) kwargs["loss_kwargs"] = LossArgsNRE(num_atoms=kwargs.pop("num_atoms")) return super().train(**kwargs)
def _loss(self, theta: Tensor, x: Tensor, num_atoms: int) -> Tensor: r"""Return cross-entropy (via softmax activation) loss for 1-out-of-`num_atoms` classification. The classifier takes as input `num_atoms` $(\theta,x)$ pairs. Out of these pairs, one pair was sampled from the joint $p(\theta,x)$ and all others from the marginals $p(\theta)p(x)$. The classifier is trained to predict which of the pairs was sampled from the joint $p(\theta,x)$. """ assert theta.shape[0] == x.shape[0], "Batch sizes for theta and x must match." batch_size = theta.shape[0] logits = self._classifier_logits(theta, x, num_atoms) # For 1-out-of-`num_atoms` classification each datapoint consists # of `num_atoms` points, with one of them being the correct one. # We have a batch of `batch_size` such datapoints. logits = logits.reshape(batch_size, num_atoms) # Index 0 is the theta-x-pair sampled from the joint p(theta,x) and hence the # "correct" one for the 1-out-of-N classification. log_prob = logits[:, 0] - torch.logsumexp(logits, dim=-1) loss = -torch.mean(log_prob) assert_all_finite(loss, "NRE-B loss") return loss