# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
from typing import Any, Literal, Optional, Union
import torch
from torch import Tensor
from torch.distributions import Distribution
from torch.utils.tensorboard.writer import SummaryWriter
import sbi.utils as utils
from sbi.inference.trainers.npe.npe_base import (
PosteriorEstimatorTrainer,
)
from sbi.neural_nets.estimators.base import (
ConditionalDensityEstimator,
ConditionalEstimatorBuilder,
)
from sbi.neural_nets.estimators.shape_handling import reshape_to_sample_batch_event
from sbi.sbi_types import Tracker
from sbi.utils.sbiutils import del_entries
[docs]
class NPE_B(PosteriorEstimatorTrainer):
r"""Neural Posterior Estimation algorithm (NPE-B) as in Lueckmann et al. (2017) [1].
NPE-B (also known as SNPE-B) trains a neural network to directly approximate
the posterior $p(\theta|x)$ using an importance-weighted loss. Unlike NPE-A,
this importance weighting ensures convergence to the true posterior in multi-round
inference, and it is not limited to Gaussian proposals. NPE-B can use flexible
density estimators like normalizing flows.
For single-round inference, NPE-A, NPE-B, and NPE-C are equivalent and use
plain NLL loss.
[1] *Flexible statistical inference for mechanistic models of neural
dynamics*, Lueckmann, Gonçalves et al., NeurIPS 2017.
https://arxiv.org/abs/1711.01861
Example:
--------
::
import torch
from sbi.inference import NPE_B
from sbi.utils import BoxUniform
# 1. Setup simulator, prior, and observation
prior = BoxUniform(low=torch.zeros(3), high=torch.ones(3))
x_o = torch.randn(1, 3) # Observed data
def simulator(theta):
return theta + torch.randn_like(theta) * 0.1
# 2. Multi-round inference
inference = NPE_B(prior=prior)
proposal = prior
for round_idx in range(5):
theta = proposal.sample((100,))
x = simulator(theta)
density_estimator = inference.append_simulations(theta, x).train()
posterior = inference.build_posterior(density_estimator)
proposal = posterior.set_default_x(x_o)
# 3. Sample from final posterior
samples = posterior.sample((1000,), x=x_o)
"""
def __init__(
self,
prior: Optional[Distribution] = None,
density_estimator: Union[
Literal["nsf", "maf", "mdn", "made"],
ConditionalEstimatorBuilder[ConditionalDensityEstimator],
] = "maf",
device: str = "cpu",
logging_level: Union[int, str] = "WARNING",
summary_writer: Optional[SummaryWriter] = None,
tracker: Optional[Tracker] = None,
show_progress_bars: bool = True,
):
r"""Initialize NPE-B.
Args:
prior: A probability distribution that expresses prior knowledge about the
parameters, e.g. which ranges are meaningful for them.
density_estimator: If it is a string, use a pre-configured network of the
provided type (one of nsf, maf, mdn, made). Alternatively, a function
that builds a custom neural network, which adheres to
`ConditionalEstimatorBuilder` protocol can be provided. The function
will be called with the first batch of simulations (theta, x), which can
thus be used for shape inference and potentially for z-scoring. The
density estimator needs to provide the methods `.log_prob` and
`.sample()` and must return a `ConditionalDensityEstimator`.
device: Training device, e.g., "cpu", "cuda" or "cuda:{0, 1, ...}".
logging_level: Minimum severity of messages to log. One of the strings
INFO, WARNING, DEBUG, ERROR and CRITICAL.
summary_writer: Deprecated alias for the TensorBoard summary writer.
Use ``tracker`` instead.
tracker: Tracking adapter used to log training metrics. If None, a
TensorBoard tracker is used with a default log directory.
show_progress_bars: Whether to show a progressbar during training.
"""
kwargs = del_entries(locals(), entries=("self", "__class__"))
super().__init__(**kwargs)
def _log_prob_proposal_posterior(
self,
theta: Tensor,
x: Tensor,
masks: Tensor,
proposal: Optional[Any],
) -> Tensor:
"""
Return importance-weighted log probability (Lueckmann, Goncalves et al., 2017).
Args:
theta: Batch of parameters θ.
x: Batch of data.
masks: Mask that is True for prior samples in the batch in order to train
them with prior loss.
proposal: Proposal distribution.
Returns:
Importance-weighted log-probability of the proposal posterior.
"""
# Evaluate prior
# we accept prior log prob to be -Inf at theta
# meaning that theta is out of the prior range (the weight is thus 0)
utils.assert_not_nan_or_plus_inf(
self._prior.log_prob(theta), "prior log probs of proposal samples"
)
prior = torch.exp(self._prior.log_prob(theta))
# Evaluate proposal
# (as theta comes from prior and proposal from previous rounds,
# the last proposal is actually a mixture of the prior
# and of all the previous proposals with coefficients representing
# the proportion of the new theta added at each round)
prop = torch.zeros(self._round + 1, device=theta.device)
nb_samples = 0 # total number of theta from all the rounds
for k in range(self._round + 1):
nb_samples += self._theta_roundwise[k].size(0)
# the number of new theta sampled in the round k
prop[k] = self._theta_roundwise[k].size(0)
prop /= nb_samples
log_prop = torch.log(prop).repeat(theta.size(0), 1)
log_previous_proposals = torch.zeros(
(theta.size(0), self._round + 1), device=theta.device
)
for k, density in enumerate(self._proposal_roundwise):
# we accept the k th proposal log prob to be -Inf at theta
# meaning that theta is out of the k th proposal range
log_previous_proposals[:, k] = density.log_prob(theta)
utils.assert_not_nan_or_plus_inf(
log_previous_proposals[:, k], "proposal log probs of proposal samples"
)
log_proposal = torch.logsumexp(log_prop + log_previous_proposals, dim=1)
proposal = torch.exp(log_proposal)
# Construct the importance weights and normalize them
importance_weights = prior / proposal
importance_weights /= importance_weights.sum()
theta = reshape_to_sample_batch_event(theta, theta.shape[1:])
# Reshape the density estimator log probs
# from (sample_shape, batch_shape) to (batch_shape)
posterior_log_probs = self._neural_net.log_prob(theta, x).squeeze(dim=0)
return importance_weights * posterior_log_probs