# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
from typing import Dict, Optional, Sequence, Union
import torch
from torch import Tensor, nn, ones
from torch.distributions import Distribution
from torch.utils.tensorboard.writer import SummaryWriter
from sbi.inference.trainers._contracts import LossArgs, LossArgsBNRE
from sbi.inference.trainers.nre.nre_a import NRE_A
from sbi.neural_nets.estimators.base import ConditionalEstimatorBuilder
from sbi.neural_nets.ratio_estimators import RatioEstimator
from sbi.sbi_types import Tracker
from sbi.utils.sbiutils import del_entries
from sbi.utils.torchutils import assert_all_finite
[docs]
class BNRE(NRE_A):
r"""Balanced Neural Ratio Estimation (BNRE) as in Delaunoy et al. (2022) [1].
BNRE is a variation of NRE-A that adds a balancing regularizer to the binary
cross-entropy loss. This regularizer encourages the classifier to predict equal
probabilities for joint and marginal samples on average, which can lead to more
conservative and reliable posterior approximations. BNRE is particularly useful
when robustness is prioritized over tightness of the posterior.
NRE can be run multi-round without need for correction, but requires running
potentially expensive posterior sampling in each round.
[1] Towards Reliable Simulation-Based Inference with Balanced Neural Ratio
Estimation, Delaunoy, A., Hermans, J., Rozet, F., Wehenkel, A., & Louppe, G.,
NeurIPS 2022. https://arxiv.org/abs/2208.13624
Example:
--------
::
import torch
from sbi.inference import BNRE
from sbi.utils import BoxUniform
# 1. Setup prior and simulate data
prior = BoxUniform(low=torch.zeros(3), high=torch.ones(3))
theta = prior.sample((100,))
x = theta + torch.randn_like(theta) * 0.1
# 2. Train balanced ratio estimator
inference = BNRE(prior=prior)
# Note: regularization_strength needs to be tuned carefully for your problem
ratio_estimator = inference.append_simulations(theta, x).train(
regularization_strength=100.0
)
# 3. Build posterior
posterior = inference.build_posterior(ratio_estimator)
# 4. Sample from posterior
x_o = torch.randn(1, 3)
samples = posterior.sample((1000,), x=x_o)
"""
def __init__(
self,
prior: Optional[Distribution] = None,
classifier: Union[str, ConditionalEstimatorBuilder[RatioEstimator]] = "resnet",
device: str = "cpu",
logging_level: Union[int, str] = "warning",
summary_writer: Optional[SummaryWriter] = None,
tracker: Optional[Tracker] = None,
show_progress_bars: bool = True,
):
r"""Balanced neural ratio estimation (BNRE).
Args:
prior: A probability distribution that expresses prior knowledge about the
parameters, e.g. which ranges are meaningful for them. If `None`, the
prior must be passed to `.build_posterior()`.
classifier: Classifier trained to approximate likelihood ratios. If it is
a string, use a pre-configured network of the provided type (one of
linear, mlp, resnet), or a callable that implements the
`ConditionalEstimatorBuilder` protocol. The callable will
be called with the first batch of simulations (theta, x), which can thus
be used for shape inference and potentially for z-scoring. It returns a
`RatioEstimator`.
device: Training device, e.g., "cpu", "cuda" or "cuda:{0, 1, ...}".
logging_level: Minimum severity of messages to log. One of the strings
INFO, WARNING, DEBUG, ERROR and CRITICAL.
summary_writer: Deprecated alias for the TensorBoard summary writer.
Use ``tracker`` instead.
tracker: Tracking adapter used to log training metrics. If None, a
TensorBoard tracker is used with a default log directory.
show_progress_bars: Whether to show a progressbar during simulation and
sampling.
"""
kwargs = del_entries(locals(), entries=("self", "__class__"))
super().__init__(**kwargs)
[docs]
def train(
self,
regularization_strength: float = 100.0,
training_batch_size: int = 200,
learning_rate: float = 5e-4,
validation_fraction: float = 0.1,
stop_after_epochs: int = 20,
max_num_epochs: int = 2**31 - 1,
clip_max_norm: Optional[float] = 5.0,
resume_training: bool = False,
discard_prior_samples: bool = False,
retrain_from_scratch: bool = False,
show_train_summary: bool = False,
dataloader_kwargs: Optional[Dict] = None,
) -> RatioEstimator:
r"""Return classifier that approximates the ratio $p(\theta,x)/p(\theta)p(x)$.
Args:
regularization_strength: The multiplicative coefficient applied to the
balancing regularizer ($\lambda$).
training_batch_size: Training batch size.
learning_rate: Learning rate for Adam optimizer.
validation_fraction: The fraction of data to use for validation.
stop_after_epochs: The number of epochs to wait for improvement on the
validation set before terminating training.
max_num_epochs: Maximum number of epochs to run. If reached, we stop
training even when the validation loss is still decreasing. Otherwise,
we train until validation loss increases (see also `stop_after_epochs`).
clip_max_norm: Value at which to clip the total gradient norm in order to
prevent exploding gradients. Use None for no clipping.
exclude_invalid_x: Whether to exclude simulation outputs `x=NaN` or `x=±∞`
during training. Expect errors, silent or explicit, when `False`.
resume_training: Can be used in case training time is limited, e.g. on a
cluster. If `True`, the split between train and validation set, the
optimizer, the number of epochs, and the best validation log-prob will
be restored from the last time `.train()` was called.
discard_prior_samples: Whether to discard samples simulated in round 1, i.e.
from the prior. Training may be sped up by ignoring such less targeted
samples.
retrain_from_scratch: Whether to retrain the conditional density
estimator for the posterior from scratch each round.
show_train_summary: Whether to print the number of epochs and validation
loss and leakage after the training.
dataloader_kwargs: Additional or updated kwargs to be passed to the training
and validation dataloaders (like, e.g., a collate_fn)
Returns:
Classifier that approximates the ratio $p(\theta,x)/p(\theta)p(x)$.
"""
kwargs = del_entries(locals(), entries=("self", "__class__"))
# Configure _loss function parameters by initializing LossArgsBNRE
# with the given regularization strength.
kwargs["loss_kwargs"] = LossArgsBNRE(
regularization_strength=kwargs.pop("regularization_strength"),
)
return super().train(**kwargs)
def _loss(
self, theta: Tensor, x: Tensor, num_atoms: int, regularization_strength: float
) -> Tensor:
"""Returns the binary cross-entropy loss for the trained classifier.
The classifier takes as input a $(\theta,x)$ pair. It is trained to predict 1
if the pair was sampled from the joint $p(\theta,x)$, and to predict 0 if the
pair was sampled from the marginals $p(\theta)p(x)$.
"""
assert theta.shape[0] == x.shape[0], "Batch sizes for theta and x must match."
batch_size = theta.shape[0]
logits = self._classifier_logits(theta, x, num_atoms)
likelihood = torch.sigmoid(logits).squeeze()
# Alternating pairs where there is one sampled from the joint and one
# sampled from the marginals. The first element is sampled from the
# joint p(theta, x) and is labelled 1. The second element is sampled
# from the marginals p(theta)p(x) and is labelled 0. And so on.
labels = ones(2 * batch_size, device=self._device) # two atoms
labels[1::2] = 0.0
# Binary cross entropy to learn the likelihood (AALR-specific)
bce = nn.BCELoss()(likelihood, labels)
# Balancing regularizer
regularizer = (
(torch.sigmoid(logits[0::2]) + torch.sigmoid(logits[1::2]) - 1)
.mean()
.square()
)
loss = bce + regularization_strength * regularizer
assert_all_finite(loss, "BNRE loss")
return loss
def _get_losses(self, batch: Sequence[Tensor], loss_args: LossArgs) -> Tensor:
"""Overrides the parent class method to check the type of loss_args."""
if not isinstance(loss_args, LossArgsBNRE):
raise TypeError(
"Expected type of loss_args to be LossArgsBNRE,"
f" but got {type(loss_args)}"
)
return super()._get_losses(batch=batch, loss_args=loss_args)