# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
from typing import Dict, Optional, Union
import torch
from torch import Tensor, nn, ones
from torch.distributions import Distribution
from torch.utils.tensorboard.writer import SummaryWriter
from sbi.inference.trainers._contracts import LossArgsNRE_A
from sbi.inference.trainers.nre.nre_base import (
RatioEstimatorTrainer,
)
from sbi.neural_nets.estimators.base import ConditionalEstimatorBuilder
from sbi.neural_nets.ratio_estimators import RatioEstimator
from sbi.sbi_types import Tracker
from sbi.utils.sbiutils import del_entries
from sbi.utils.torchutils import assert_all_finite
[docs]
class NRE_A(RatioEstimatorTrainer):
r"""Neural Ratio Estimation (NRE-A / AALR) as in Hermans et al. (2020) [1].
NRE-A trains a neural classifier to estimate the likelihood-to-evidence ratio
$r(\theta, x) = p(x|\theta) / p(x)$ by distinguishing between samples from the
joint distribution $p(\theta, x)$ and samples from the marginals $p(\theta)p(x)$.
Posterior sampling is then performed via MCMC, rejection sampling, or variational
inference using the estimated ratio.
NRE can be run multi-round without need for correction, but requires running
potentially expensive posterior sampling in each round.
[1] *Likelihood-free MCMC with Amortized Approximate Likelihood Ratios*, Hermans
et al., ICML 2020, https://arxiv.org/abs/1903.04057
Example:
--------
::
import torch
from sbi.inference import NRE_A
from sbi.utils import BoxUniform
# 1. Setup prior and simulate data
prior = BoxUniform(low=torch.zeros(3), high=torch.ones(3))
theta = prior.sample((100,))
x = theta + torch.randn_like(theta) * 0.1
# 2. Train ratio estimator
inference = NRE_A(prior=prior)
ratio_estimator = inference.append_simulations(theta, x).train()
# 3. Build posterior (uses MCMC or rejection sampling)
posterior = inference.build_posterior(ratio_estimator)
# 4. Sample from posterior
x_o = torch.randn(1, 3)
samples = posterior.sample((1000,), x=x_o)
"""
def __init__(
self,
prior: Optional[Distribution] = None,
classifier: Union[str, ConditionalEstimatorBuilder[RatioEstimator]] = "resnet",
device: str = "cpu",
logging_level: Union[int, str] = "warning",
summary_writer: Optional[SummaryWriter] = None,
tracker: Optional[Tracker] = None,
show_progress_bars: bool = True,
):
r"""Initialize NRE_A.
Args:
prior: A probability distribution that expresses prior knowledge about the
parameters, e.g. which ranges are meaningful for them. If `None`, the
prior must be passed to `.build_posterior()`.
classifier: Classifier trained to approximate likelihood ratios. If it is
a string, use a pre-configured network of the provided type (one of
linear, mlp, resnet), or a callable that implements the
`ConditionalEstimatorBuilder` protocol. The callable will
be called with the first batch of simulations (theta, x), which can
thus be used for shape inference and potentially for z-scoring. It
returns a `RatioEstimator`.
device: Training device, e.g., "cpu", "cuda" or "cuda:{0, 1, ...}".
logging_level: Minimum severity of messages to log. One of the strings
INFO, WARNING, DEBUG, ERROR and CRITICAL.
summary_writer: Deprecated alias for the TensorBoard summary writer.
Use ``tracker`` instead.
tracker: Tracking adapter used to log training metrics. If None, a
TensorBoard tracker is used with a default log directory.
show_progress_bars: Whether to show a progressbar during simulation and
sampling.
"""
kwargs = del_entries(locals(), entries=("self", "__class__"))
super().__init__(**kwargs)
[docs]
def train(
self,
training_batch_size: int = 200,
learning_rate: float = 5e-4,
validation_fraction: float = 0.1,
stop_after_epochs: int = 20,
max_num_epochs: int = 2**31 - 1,
clip_max_norm: Optional[float] = 5.0,
resume_training: bool = False,
discard_prior_samples: bool = False,
retrain_from_scratch: bool = False,
show_train_summary: bool = False,
dataloader_kwargs: Optional[Dict] = None,
loss_kwargs: Optional[LossArgsNRE_A] = None,
) -> RatioEstimator:
r"""Return classifier that approximates the ratio $p(\theta,x)/p(\theta)p(x)$.
Args:
training_batch_size: Training batch size.
learning_rate: Learning rate for Adam optimizer.
validation_fraction: The fraction of data to use for validation.
stop_after_epochs: The number of epochs to wait for improvement on the
validation set before terminating training.
max_num_epochs: Maximum number of epochs to run. If reached, we stop
training even when the validation loss is still decreasing. Otherwise,
we train until validation loss increases (see also `stop_after_epochs`).
clip_max_norm: Value at which to clip the total gradient norm in order to
prevent exploding gradients. Use None for no clipping.
resume_training: Can be used in case training time is limited, e.g. on a
cluster. If `True`, the split between train and validation set, the
optimizer, the number of epochs, and the best validation log-prob will
be restored from the last time `.train()` was called.
discard_prior_samples: Whether to discard samples simulated in round 1, i.e.
from the prior. Training may be sped up by ignoring such less targeted
samples.
retrain_from_scratch: Whether to retrain the conditional density
estimator for the posterior from scratch each round.
show_train_summary: Whether to print the number of epochs and validation
loss and leakage after the training.
dataloader_kwargs: Additional or updated kwargs to be passed to the training
and validation dataloaders (like, e.g., a collate_fn)
loss_kwargs: Additional or updated kwargs to be passed to the self._loss fn.
Returns:
Classifier that approximates the ratio $p(\theta,x)/p(\theta)p(x)$.
"""
kwargs = del_entries(locals(), entries=("self", "__class__"))
if loss_kwargs is None:
kwargs["loss_kwargs"] = LossArgsNRE_A()
elif not issubclass(type(loss_kwargs), LossArgsNRE_A):
raise TypeError(
"Expected loss_kwargs to be a subclass of LossArgsNRE_A,"
f" but got {type(loss_kwargs)}"
)
return super().train(**kwargs)
def _loss(self, theta: Tensor, x: Tensor, num_atoms: int) -> Tensor:
"""Returns the binary cross-entropy loss for the trained classifier.
The classifier takes as input a $(\theta,x)$ pair. It is trained to predict 1
if the pair was sampled from the joint $p(\theta,x)$, and to predict 0 if the
pair was sampled from the marginals $p(\theta)p(x)$.
"""
assert theta.shape[0] == x.shape[0], "Batch sizes for theta and x must match."
batch_size = theta.shape[0]
logits = self._classifier_logits(theta, x, num_atoms)
likelihood = torch.sigmoid(logits).squeeze()
# Alternating pairs where there is one sampled from the joint and one
# sampled from the marginals. The first element is sampled from the
# joint p(theta, x) and is labelled 1. The second element is sampled
# from the marginals p(theta)p(x) and is labelled 0. And so on.
labels = ones(2 * batch_size, device=self._device) # two atoms
labels[1::2] = 0.0
# Binary cross entropy to learn the likelihood (AALR-specific)
loss = nn.BCELoss()(likelihood, labels)
assert_all_finite(loss, "NRE-A loss")
return loss