# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
from typing import Dict, Optional, Sequence, Tuple, Union
import torch
from torch import Tensor
from torch.distributions import Distribution
from torch.utils.tensorboard.writer import SummaryWriter
from sbi.inference.trainers._contracts import LossArgs, LossArgsNRE_C
from sbi.inference.trainers.nre.nre_base import (
RatioEstimatorTrainer,
)
from sbi.neural_nets.estimators.base import ConditionalEstimatorBuilder
from sbi.neural_nets.ratio_estimators import RatioEstimator
from sbi.sbi_types import Tracker
from sbi.utils.sbiutils import del_entries
from sbi.utils.torchutils import assert_all_finite
[docs]
class NRE_C(RatioEstimatorTrainer):
r"""Neural Ratio Estimation (NRE-C / CNRE) as in Miller et al. (2022) [1].
NRE-C generalizes NRE-A and NRE-B using a "multi-class sigmoid" loss that
ensures the estimated ratio $p(\theta,x)/p(\theta)p(x)$ is exact at optimum
in the first round. This addresses the issue that NRE-B's ratio is only defined
up to an arbitrary function of $x$. NRE-C provides more accurate ratio estimates
while maintaining the benefits of contrastive learning.
[1] *Contrastive Neural Ratio Estimation*, Benjamin Kurt Miller, et al.,
NeurIPS 2022, https://arxiv.org/abs/2210.06170
Example:
--------
::
import torch
from sbi.inference import NRE_C
from sbi.utils import BoxUniform
# 1. Setup prior and simulate data
prior = BoxUniform(low=torch.zeros(3), high=torch.ones(3))
theta = prior.sample((100,))
x = theta + torch.randn_like(theta) * 0.1
# 2. Train ratio estimator
inference = NRE_C(prior=prior)
ratio_estimator = inference.append_simulations(theta, x).train(num_classes=5)
# 3. Build posterior
posterior = inference.build_posterior(ratio_estimator)
# 4. Sample from posterior
x_o = torch.randn(1, 3)
samples = posterior.sample((1000,), x=x_o)
"""
def __init__(
self,
prior: Optional[Distribution] = None,
classifier: Union[str, ConditionalEstimatorBuilder[RatioEstimator]] = "resnet",
device: str = "cpu",
logging_level: Union[int, str] = "warning",
summary_writer: Optional[SummaryWriter] = None,
tracker: Optional[Tracker] = None,
show_progress_bars: bool = True,
):
r"""Initialize NRE-C.
Args:
prior: A probability distribution that expresses prior knowledge about the
parameters, e.g. which ranges are meaningful for them. If `None`, the
prior must be passed to `.build_posterior()`.
classifier: Classifier trained to approximate likelihood ratios. If it is
a string, use a pre-configured network of the provided type (one of
linear, mlp, resnet), or a callable that implements the
`ConditionalEstimatorBuilder` protocol. The callable will
be called with the first batch of simulations (theta, x), which can thus
be used for shape inference and potentially for z-scoring. It returns a
`RatioEstimator`.
device: Training device, e.g., "cpu", "cuda" or "cuda:{0, 1, ...}".
logging_level: Minimum severity of messages to log. One of the strings
INFO, WARNING, DEBUG, ERROR and CRITICAL.
summary_writer: Deprecated alias for the TensorBoard summary writer.
Use ``tracker`` instead.
tracker: Tracking adapter used to log training metrics. If None, a
TensorBoard tracker is used with a default log directory.
show_progress_bars: Whether to show a progressbar during simulation and
sampling.
"""
kwargs = del_entries(locals(), entries=("self", "__class__"))
super().__init__(**kwargs)
[docs]
def train(
self,
num_classes: int = 5,
gamma: float = 1.0,
training_batch_size: int = 200,
learning_rate: float = 5e-4,
validation_fraction: float = 0.1,
stop_after_epochs: int = 20,
max_num_epochs: int = 2**31 - 1,
clip_max_norm: Optional[float] = 5.0,
resume_training: bool = False,
discard_prior_samples: bool = False,
retrain_from_scratch: bool = False,
show_train_summary: bool = False,
dataloader_kwargs: Optional[Dict] = None,
) -> RatioEstimator:
r"""Return classifier that approximates the ratio $p(\theta,x)/p(\theta)p(x)$.
Args:
num_classes: Number of theta to classify against, corresponds to $K$ in
_Contrastive Neural Ratio Estimation_. Minimum value is 1. Similar to
`num_atoms` for SNRE_B except SNRE_C has an additional independently
drawn sample. The total number of alternative parameters `NRE-C` "sees"
is $2K-1$ or `2 * num_classes - 1` divided between two loss terms.
gamma: Determines the relative weight of the sum of all $K$ dependently
drawn classes against the marginally drawn one. Specifically,
$p(y=k) :=p_K$, $p(y=0) := p_0$, $p_0 = 1 - K p_K$, and finally
$\gamma := K p_K / p_0$.
training_batch_size: Training batch size.
learning_rate: Learning rate for Adam optimizer.
validation_fraction: The fraction of data to use for validation.
stop_after_epochs: The number of epochs to wait for improvement on the
validation set before terminating training.
max_num_epochs: Maximum number of epochs to run. If reached, we stop
training even when the validation loss is still decreasing. Otherwise,
we train until validation loss increases (see also `stop_after_epochs`).
clip_max_norm: Value at which to clip the total gradient norm in order to
prevent exploding gradients. Use None for no clipping.
exclude_invalid_x: Whether to exclude simulation outputs `x=NaN` or `x=±∞`
during training. Expect errors, silent or explicit, when `False`.
resume_training: Can be used in case training time is limited, e.g. on a
cluster. If `True`, the split between train and validation set, the
optimizer, the number of epochs, and the best validation log-prob will
be restored from the last time `.train()` was called.
discard_prior_samples: Whether to discard samples simulated in round 1, i.e.
from the prior. Training may be sped up by ignoring such less targeted
samples.
retrain_from_scratch: Whether to retrain the conditional density
estimator for the posterior from scratch each round.
show_train_summary: Whether to print the number of epochs and validation
loss and leakage after the training.
dataloader_kwargs: Additional or updated kwargs to be passed to the training
and validation dataloaders (like, e.g., a collate_fn)
Returns:
Classifier that approximates the ratio $p(\theta,x)/p(\theta)p(x)$.
"""
kwargs = del_entries(locals(), entries=("self", "__class__"))
kwargs["loss_kwargs"] = LossArgsNRE_C(
num_atoms=kwargs.pop("num_classes") + 1, gamma=kwargs.pop("gamma")
)
return super().train(**kwargs)
def _loss(
self, theta: Tensor, x: Tensor, num_atoms: int, gamma: float
) -> torch.Tensor:
r"""Return cross-entropy loss (via ''multi-class sigmoid'' activation) for
1-out-of-`K + 1` classification.
At optimum, this loss function returns the exact likelihood-to-evidence ratio
in the first round.
Details of loss computation are described in Contrastive Neural Ratio
Estimation[1]. The paper does not discuss the sequential case.
[1] _Contrastive Neural Ratio Estimation_, Benajmin Kurt Miller, et. al.,
NeurIPS 2022, https://arxiv.org/abs/2210.06170
"""
# Reminder: K = num_classes
# The algorithm is written with K, so we convert back to K format rather than
# reasoning in num_atoms.
num_classes = num_atoms - 1
assert num_classes >= 1, f"num_classes = {num_classes} must be greater than 1."
assert theta.shape[0] == x.shape[0], "Batch sizes for theta and x must match."
batch_size = theta.shape[0]
# We append a contrastive theta to the marginal case because we will remove
# the jointly drawn
# sample in the logits_marginal[:, 0] position. That makes the remaining sample
# marginally drawn.
# We have a batch of `batch_size` datapoints.
logits_marginal = self._classifier_logits(theta, x, num_classes + 1).reshape(
batch_size, num_classes + 1
)
logits_joint = self._classifier_logits(theta, x, num_classes).reshape(
batch_size, num_classes
)
dtype = logits_marginal.dtype
device = logits_marginal.device
# Index 0 is the theta-x-pair sampled from the joint p(theta,x) and hence
# we remove the jointly drawn sample from the logits_marginal
logits_marginal = logits_marginal[:, 1:]
# ... and retain it in the logits_joint. Now we have two arrays with K choices.
# To use logsumexp, we extend the denominator logits with loggamma
loggamma = torch.tensor(gamma, dtype=dtype, device=device).log()
logK = torch.tensor(num_classes, dtype=dtype, device=device).log()
denominator_marginal = torch.concat(
[loggamma + logits_marginal, logK.expand((batch_size, 1))],
dim=-1,
)
denominator_joint = torch.concat(
[loggamma + logits_joint, logK.expand((batch_size, 1))],
dim=-1,
)
# Compute the contributions to the loss from each term in the classification.
log_prob_marginal = logK - torch.logsumexp(denominator_marginal, dim=-1)
log_prob_joint = (
loggamma + logits_joint[:, 0] - torch.logsumexp(denominator_joint, dim=-1)
)
# relative weights. p_marginal := p_0, and p_joint := p_K * K from the notation.
p_marginal, p_joint = self._get_prior_probs_marginal_and_joint(gamma)
loss = -torch.mean(p_marginal * log_prob_marginal + p_joint * log_prob_joint)
assert_all_finite(loss, "NRE-C loss")
return loss
@staticmethod
def _get_prior_probs_marginal_and_joint(gamma: float) -> Tuple[float, float]:
r"""Return a tuple (p_marginal, p_joint) where `p_marginal := `$p_0$,
`p_joint := `$p_K \cdot K$.
We let the joint (dependently drawn) class to be equally likely across K
options. The marginal class is therefore restricted to get the remaining
probability.
"""
p_joint = gamma / (1 + gamma)
p_marginal = 1 / (1 + gamma)
return p_marginal, p_joint
def _get_losses(self, batch: Sequence[Tensor], loss_args: LossArgs) -> Tensor:
"""Overrides the parent class method to check the type of loss_args."""
if not isinstance(loss_args, LossArgsNRE_C):
raise TypeError(
"Expected type of loss_args to be LossArgsNRE_C,"
f" but got {type(loss_args)}"
)
return super()._get_losses(batch=batch, loss_args=loss_args)