Source code for sbi.inference.trainers.npe.npe_a

# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from copy import deepcopy
from functools import partial
from typing import Any, Callable, Dict, Literal, Optional, Union

import torch
from torch import Tensor
from torch.distributions import Distribution, MultivariateNormal
from torch.utils.tensorboard.writer import SummaryWriter

from sbi.inference.posteriors.npe_a_posterior import NPE_A_Posterior
from sbi.inference.trainers.npe.npe_base import (
    PosteriorEstimatorTrainer,
)
from sbi.neural_nets.estimators.base import (
    ConditionalDensityEstimator,
    ConditionalEstimatorBuilder,
)
from sbi.neural_nets.estimators.mixture_density_estimator import (
    MixtureDensityEstimator,
)
from sbi.neural_nets.estimators.mog import MoG
from sbi.sbi_types import Tracker
from sbi.utils.sbiutils import del_entries
from sbi.utils.torchutils import BoxUniform

# Constant for numerical stability in matrix operations.
_CORRECTION_EPSILON: float = 1e-6


[docs] class NPE_A(PosteriorEstimatorTrainer): r"""Neural Posterior Estimation algorithm as in Papamakarios et al. (2016) [1]. [1] *Fast epsilon-free Inference of Simulation Models with Bayesian Conditional Density Estimation*, Papamakarios et al., NeurIPS 2016. https://arxiv.org/abs/1605.06376 Like all NPE methods, this method trains a deep neural density estimator to directly approximate the posterior. Also like all other NPE methods, in the first round, this density estimator is trained with a maximum-likelihood loss. This class implements NPE-A. NPE-A trains across multiple rounds with a maximum-likelihood loss. This will make training converge to the proposal posterior instead of the true posterior. To correct for this, SNPE-A applies a post-hoc correction after training. This correction is performed analytically and requires Mixture of Gaussians (MoG) density estimators. Note: In multi-round SNPE-A, the number of MoG components grows multiplicatively with each round: if the proposal has L components and the density estimator has K components, the corrected posterior has L×K components. For many rounds, consider using SNPE-C (APT) instead, which handles multi-round inference more efficiently. Example: -------- :: import torch from sbi.inference import NPE_A from sbi.utils import BoxUniform # 1. Setup simulator, prior, and observation prior = BoxUniform(low=torch.zeros(3), high=torch.ones(3)) x_o = torch.randn(1, 3) # Observed data def simulator(theta): return theta + torch.randn_like(theta) * 0.1 # 2. Multi-round inference inference = NPE_A(prior=prior, num_components=5) proposal = prior for round_idx in range(5): theta = proposal.sample((100,)) x = simulator(theta) density_estimator = inference.append_simulations(theta, x).train() posterior = inference.build_posterior(density_estimator) proposal = posterior.set_default_x(x_o) # 3. Sample from final posterior samples = posterior.sample((1000,), x=x_o) """ def __init__( self, prior: Optional[Distribution] = None, density_estimator: Union[ Literal["mdn_snpe_a"], ConditionalEstimatorBuilder[ConditionalDensityEstimator], ] = "mdn_snpe_a", num_components: int = 10, device: str = "cpu", logging_level: Union[int, str] = "WARNING", summary_writer: Optional[SummaryWriter] = None, tracker: Optional[Tracker] = None, show_progress_bars: bool = True, ): r"""Initialize NPE-A [1]. Args: prior: A probability distribution that expresses prior knowledge about the parameters, e.g. which ranges are meaningful for them. Any object with `.log_prob()`and `.sample()` (for example, a PyTorch distribution) can be used. density_estimator: If it is a string (only "mdn_snpe_a" is valid), use a pre-configured mixture of densities network. Alternatively, a function that builds a custom neural network, which adheres to `ConditionalEstimatorBuilder` protocol can be provided. The function will be called with the first batch of simulations (theta, x), which can thus be used for shape inference and potentially for z-scoring. The density estimator needs to provide the methods `.log_prob` and `.sample()` and must return a `ConditionalDensityEstimator`. num_components: Number of components of the mixture of Gaussians. Note: In multi-round SNPE-A, the number of components grows multiplicatively with each round due to the analytical correction (L components in proposal × K components in density = L*K posterior components). device: Training device, e.g., "cpu", "cuda" or "cuda:{0, 1, ...}". logging_level: Minimum severity of messages to log. One of the strings INFO, WARNING, DEBUG, ERROR and CRITICAL. summary_writer: Deprecated alias for the TensorBoard summary writer. Use ``tracker`` instead. tracker: Tracking adapter used to log training metrics. If None, a TensorBoard tracker is used with a default log directory. show_progress_bars: Whether to show a progressbar during training. """ # Catch invalid inputs. if not ((density_estimator == "mdn_snpe_a") or callable(density_estimator)): raise TypeError( "The `density_estimator` passed to SNPE_A needs to be a " "callable or the string 'mdn_snpe_a'!" ) self._num_components = num_components # WARNING: sneaky trick ahead. We proxy the parent's `train` here, # requiring the signature to have `num_components`, save it for use below, and # continue. It's sneaky because we are using the object (self) as a namespace # to pass arguments between functions, and that's implicit state management. kwargs = del_entries( locals(), entries=("self", "__class__", "num_components"), ) super().__init__(**kwargs)
[docs] def train( self, training_batch_size: int = 200, learning_rate: float = 5e-4, validation_fraction: float = 0.1, stop_after_epochs: int = 20, max_num_epochs: int = 2**31 - 1, clip_max_norm: Optional[float] = 5.0, calibration_kernel: Optional[Callable] = None, resume_training: bool = False, retrain_from_scratch: bool = False, show_train_summary: bool = False, dataloader_kwargs: Optional[Dict] = None, ) -> ConditionalDensityEstimator: r"""Return density estimator that approximates the proposal posterior. [1] _Fast epsilon-free Inference of Simulation Models with Bayesian Conditional Density Estimation_, Papamakarios et al., NeurIPS 2016, https://arxiv.org/abs/1605.06376. Training is performed with maximum likelihood on samples from the latest round, which leads the algorithm to converge to the proposal posterior. Args: training_batch_size: Training batch size. learning_rate: Learning rate for Adam optimizer. validation_fraction: The fraction of data to use for validation. stop_after_epochs: The number of epochs to wait for improvement on the validation set before terminating training. max_num_epochs: Maximum number of epochs to run. If reached, we stop training even when the validation loss is still decreasing. Otherwise, we train until validation loss increases (see also `stop_after_epochs`). clip_max_norm: Value at which to clip the total gradient norm in order to prevent exploding gradients. Use None for no clipping. calibration_kernel: A function to calibrate the loss with respect to the simulations `x`. See Lueckmann, Gonçalves et al., NeurIPS 2017. resume_training: Can be used in case training time is limited, e.g. on a cluster. If `True`, the split between train and validation set, the optimizer, the number of epochs, and the best validation log-prob will be restored from the last time `.train()` was called. retrain_from_scratch: Whether to retrain the conditional density estimator for the posterior from scratch each round. Not supported for SNPE-A. show_train_summary: Whether to print the number of epochs and validation loss and leakage after the training. dataloader_kwargs: Additional or updated kwargs to be passed to the training and validation dataloaders (like, e.g., a collate_fn) Returns: Density estimator that approximates the distribution $p(\theta|x)$. """ assert not retrain_from_scratch, """Retraining from scratch is not supported in SNPE-A yet. The reason for this is that, if we reininitialized the density estimator, the z-scoring would change, which would break the posthoc correction. This is a pure implementation issue.""" kwargs = del_entries( locals(), entries=( "self", "__class__", ), ) # SNPE-A always discards the prior samples. kwargs["discard_prior_samples"] = True kwargs["force_first_round_loss"] = True if len(self._data_round_index) == 0: raise RuntimeError( "No simulations found. You must call .append_simulations() " "before calling .train()." ) self._round = max(self._data_round_index) # Always use the specified number of components self._build_neural_net = partial( self._build_neural_net, num_components=self._num_components ) density_estimator = super().train(**kwargs) return density_estimator
def _get_proposal_mog( self, proposal: Union["NPE_A_Posterior", MultivariateNormal, MoG, Any], ) -> MoG: """Extract MoG parameters from a proposal distribution. Supports multiple proposal types: - NPE_A_Posterior: extracts corrected MoG via get_mog_params() - MultivariateNormal: converts to single-component MoG - MoG: uses directly - Any object with get_mog_params(x) method Args: proposal: The proposal distribution from the previous round. Returns: MoG parameters from the proposal. Raises: ValueError: If NPE_A_Posterior proposal doesn't have default_x set. TypeError: If proposal type is not supported. """ if isinstance(proposal, NPE_A_Posterior): default_x = proposal.default_x if default_x is None: raise ValueError( "Proposal posterior must have a default_x set for SNPE-A " "correction. Call posterior.set_default_x(x_o) before using " "as proposal." ) if default_x.shape[0] != 1: raise ValueError( f"SNPE-A requires default_x batch size of 1, got " f"{default_x.shape[0]}. SNPE-A only supports single " "observations for correction." ) return proposal.get_mog_params(default_x) if isinstance(proposal, MultivariateNormal): mean: Tensor = proposal.mean.to(self._device) # type: ignore[assignment] cov: Tensor = proposal.covariance_matrix.to( # type: ignore[assignment] self._device ) return MoG.from_gaussian(mean.unsqueeze(0), cov.unsqueeze(0)) if isinstance(proposal, MoG): return proposal.to(self._device) # Case 4: Any object with get_mog_params method if hasattr(proposal, "get_mog_params"): # Try to get default_x if available default_x = getattr(proposal, "default_x", None) if default_x is None: raise ValueError( "Proposal has get_mog_params() but no default_x set. " "Call proposal.set_default_x(x_o) before using as proposal." ) if default_x.shape[0] != 1: raise ValueError( f"SNPE-A requires default_x batch size of 1, got " f"{default_x.shape[0]}." ) mog = proposal.get_mog_params(default_x) if not isinstance(mog, MoG): raise TypeError( f"Proposal's get_mog_params() must return MoG, " f"got {type(mog).__name__}." ) return mog.to(self._device) # Unsupported type raise TypeError( f"For multi-round SNPE-A, proposal must be one of: NPE_A_Posterior, " f"MultivariateNormal, MoG, or an object with get_mog_params() method. " f"Got {type(proposal).__name__}. For custom proposals, construct " f"NPE_A_Posterior directly with your proposal_mog parameter." ) def _compute_z_scored_prior_mog( self, density_estimator: MixtureDensityEstimator ) -> Optional[MoG]: """Compute the prior as a MoG in z-scored space (if applicable). For SNPE-A correction, the prior needs to be in the same coordinate system as the density estimator's output. When z-scoring is applied to inputs, the density estimator outputs MoG parameters in z-scored space, so the prior must also be transformed to z-scored space. For uniform priors (BoxUniform), returns None since uniform priors have zero precision (infinite covariance) and are handled specially in the correction formula. Mathematical background: For z-score transform: z = (theta - shift) / scale If theta ~ N(mu, Sigma), then: z ~ N((mu - shift) / scale, Sigma / (scale ⊗ scale)) where (scale ⊗ scale)_ij = scale_i * scale_j Args: density_estimator: The MixtureDensityEstimator (to get z-score parameters). Returns: MoG representation of the z-scored prior for Gaussian priors, or None for uniform priors. """ # Uniform priors have zero precision, return None if isinstance(self._prior, BoxUniform): return None if not isinstance(self._prior, MultivariateNormal): raise TypeError( f"Prior must be MultivariateNormal or BoxUniform, " f"got {type(self._prior).__name__}" ) # Get prior parameters prior_mean = self._prior.mean prior_cov = self._prior.covariance_matrix # Apply z-score transform if enabled if density_estimator.has_input_transform: shift = density_estimator._transform_shift scale = density_estimator._transform_scale # Z-scored mean: (mu - shift) / scale z_mean = (prior_mean - shift) / scale # Z-scored covariance: Sigma_z[i,j] = Sigma[i,j] / (scale_i * scale_j) scale_outer = scale.unsqueeze(-1) * scale.unsqueeze(-2) z_cov = prior_cov / scale_outer else: z_mean = prior_mean z_cov = prior_cov # Validate covariance is positive definite try: torch.linalg.cholesky(z_cov) except RuntimeError as e: raise ValueError( "Z-scored prior covariance is not positive definite. " "This may indicate numerical issues with the z-score transform. " f"Original error: {e}" ) from e # Convert to MoG return MoG.from_gaussian(z_mean, z_cov)
[docs] def build_posterior( self, density_estimator: Optional[ConditionalDensityEstimator] = None, prior: Optional[Distribution] = None, sample_with: Literal["direct"] = "direct", **kwargs, ) -> NPE_A_Posterior: r"""Build posterior from the neural density estimator. Returns an NPE_A_Posterior that applies the SNPE-A correction formula: p(θ|x) ∝ q(θ|x) × prior(θ) / proposal(θ) Note: NPE_A only supports `sample_with="direct"`. The corrected posterior is a Mixture of Gaussians (MoG) which can be sampled directly and efficiently. MCMC, VI, rejection, and importance sampling methods do not provide benefits over direct MoG sampling and are therefore not supported. Args: density_estimator: The density estimator that the posterior is based on. If `None`, use the latest neural density estimator that was trained. prior: Prior distribution. sample_with: Must be "direct". Other sampling methods are not supported. **kwargs: Additional arguments passed to NPE_A_Posterior. Returns: NPE_A_Posterior with the SNPE-A correction applied. Raises: ValueError: If sample_with is not "direct". """ if sample_with != "direct": raise ValueError( f"NPE_A only supports sample_with='direct', got '{sample_with}'. " "The corrected posterior is a Mixture of Gaussians which can be " "sampled directly and efficiently. MCMC, VI, rejection, and " "importance sampling do not provide benefits over direct MoG sampling." ) if prior is None: assert self._prior is not None, ( "You did not pass a prior. You have to pass the prior either at " "initialization `inference = NPE_A(prior)` or to " "`.build_posterior(prior=prior)`." ) prior = self._prior # Resolve and validate density estimator if density_estimator is None: density_estimator = deepcopy(self._neural_net) if not isinstance(density_estimator, MixtureDensityEstimator): raise TypeError( "NPE_A requires MixtureDensityEstimator, " f"got {type(density_estimator).__name__}. " "Use density_estimator='mdn_snpe_a' when initializing NPE_A." ) # Compute correction parameters proposal = self._proposal_roundwise[-1] is_first_round = proposal is self._prior or proposal is None if is_first_round: proposal_mog = None prior_mog = None else: proposal_mog = self._get_proposal_mog(proposal) prior_mog = self._compute_z_scored_prior_mog(density_estimator) # Build the posterior self._posterior = NPE_A_Posterior( posterior_estimator=density_estimator, prior=prior, proposal_mog=proposal_mog, prior_mog=prior_mog, device=self._device, **kwargs, ) return self._posterior
def _log_prob_proposal_posterior( self, theta: Tensor, x: Tensor, masks: Tensor, proposal: Optional[Any], ) -> Tensor: """Return the log-probability of the proposal posterior. For SNPE-A this is the same as `self._neural_net.log_prob(theta, x)` in `_loss()` to be found in `snpe_base.py`. Args: theta: Batch of parameters θ. x: Batch of data. masks: Mask that is True for prior samples in the batch in order to train them with prior loss. proposal: Proposal distribution. Returns: Log-probability of the proposal posterior. """ return self._neural_net.log_prob(theta, x)
def _correct_for_proposal( density_mog: MoG, proposal_mog: MoG, prior_mog: Optional[MoG] = None, ) -> MoG: """Compute SNPE-A corrected posterior from density estimator output. Implements Appendix C of Papamakarios et al. 2016 (SNPE-A paper). The true posterior is computed as: posterior = density_estimator * prior / proposal Since all distributions are MoGs, this can be done analytically. If the proposal has L components and the density has K components, the posterior has L*K components. For uniform priors, pass `prior_mog=None`. The prior term is then omitted from the correction (uniform has zero precision). Warning: Component count grows multiplicatively across rounds. In multi-round SNPE-A, if round N has L components, round N+1 will have L*K components (where K is the density estimator's component count). This can lead to memory issues for many rounds. However, the original SNPE-A algorithm uses a single Gaussian (K=1) in intermediate rounds, expanding to multiple components only in the final round, which avoids this issue. Args: density_mog: MoG from the density estimator for current observation. proposal_mog: MoG from previous round's proposal distribution. prior_mog: MoG representation of the prior. Use `MoG.from_gaussian()` for Gaussian priors. Pass None for uniform priors. Returns: Corrected MoG representing the true posterior. Raises: ValueError: If posterior precision is not positive definite. """ num_comps_proposal = proposal_mog.num_components num_comps_density = density_mog.num_components dim = density_mog.dim # Compute posterior precisions (Eq. 23) # prec_post = prec_density - prec_proposal + prec_prior # For uniform priors, prec_prior = 0 prec_proposal_rep = proposal_mog.precisions.repeat_interleave( num_comps_density, dim=1 ) prec_density_rep = density_mog.precisions.repeat(1, num_comps_proposal, 1, 1) prec_post = prec_density_rep - prec_proposal_rep # Add prior precision term only for Gaussian priors if prior_mog is not None: prec_prior_rep = prior_mog.precisions.repeat( 1, num_comps_proposal * num_comps_density, 1, 1 ) prec_post = prec_post + prec_prior_rep # Add small epsilon to diagonal for numerical stability eye = torch.eye(dim, device=prec_post.device, dtype=prec_post.dtype) prec_post_stabilized = prec_post + _CORRECTION_EPSILON * eye # Compute precision factors via Cholesky (also validates positive definiteness) try: precf_post = torch.linalg.cholesky(prec_post_stabilized, upper=True) except torch.linalg.LinAlgError as e: raise ValueError( "Posterior precision matrix is not positive definite. " "This is a known issue with SNPE-A when the proposal and density " "estimator don't align well. Try different hyperparameters. " f"Original error: {e}" ) from e # Compute posterior covariances using solve for numerical stability batch_shape = prec_post_stabilized.shape[:-2] eye_expanded = eye.expand(*batch_shape, dim, dim) cov_post = torch.linalg.solve(prec_post_stabilized, eye_expanded) # Compute posterior means (Eq. 24) # mean_post = cov_post @ (prec_density @ mean_density # - prec_proposal @ mean_proposal # + prec_prior @ mean_prior) prec_mean_proposal = _batched_mv(proposal_mog.precisions, proposal_mog.means) prec_mean_density = _batched_mv(density_mog.precisions, density_mog.means) prec_mean_proposal_rep = prec_mean_proposal.repeat_interleave( num_comps_density, dim=1 ) prec_mean_density_rep = prec_mean_density.repeat(1, num_comps_proposal, 1) summed_prec_mean = prec_mean_density_rep - prec_mean_proposal_rep # Add prior mean term only for Gaussian priors if prior_mog is not None: prec_mean_prior = _batched_mv(prior_mog.precisions, prior_mog.means) prec_mean_prior_rep = prec_mean_prior.repeat( 1, num_comps_proposal * num_comps_density, 1 ) summed_prec_mean = summed_prec_mean + prec_mean_prior_rep mean_post = _batched_mv(cov_post, summed_prec_mean) # Compute posterior logits (Eqs. 25-26) logits_post = _compute_posterior_logits( mean_post, prec_post, cov_post, proposal_mog.logits, proposal_mog.means, proposal_mog.precisions, density_mog.logits, density_mog.means, density_mog.precisions, num_comps_proposal, num_comps_density, ) return MoG( logits=logits_post, means=mean_post, precisions=prec_post_stabilized, precision_factors=precf_post, ) def _compute_posterior_logits( mean_post: Tensor, prec_post: Tensor, cov_post: Tensor, logits_proposal: Tensor, mean_proposal: Tensor, prec_proposal: Tensor, logits_density: Tensor, mean_density: Tensor, prec_density: Tensor, num_comps_proposal: int, num_comps_density: int, ) -> Tensor: """Compute posterior logits using Eqs. 25-26 from SNPE-A paper. Computes unnormalized log weights for the posterior MoG components. The formula combines logit differences, log-determinant ratios, and quadratic form differences from the proposal, density, and posterior. Args: mean_post: Posterior means, shape (batch, L*K, dim). prec_post: Posterior precisions, shape (batch, L*K, dim, dim). cov_post: Posterior covariances, shape (batch, L*K, dim, dim). logits_proposal: Proposal logits, shape (batch, L). mean_proposal: Proposal means, shape (batch, L, dim). prec_proposal: Proposal precisions, shape (batch, L, dim, dim). logits_density: Density logits, shape (batch, K). mean_density: Density means, shape (batch, K, dim). prec_density: Density precisions, shape (batch, K, dim, dim). num_comps_proposal: Number of proposal components (L). num_comps_density: Number of density components (K). Returns: Posterior logits of shape (batch, L*K). """ # Compute logit factors logits_proposal_rep = logits_proposal.repeat_interleave(num_comps_density, dim=1) logits_density_rep = logits_density.repeat(1, num_comps_proposal) logit_factors = logits_density_rep - logits_proposal_rep # Compute log-determinant terms using slogdet for numerical stability _, logdet_cov_post = torch.linalg.slogdet(cov_post) _, logdet_prec_proposal = torch.linalg.slogdet(prec_proposal) _, logdet_prec_density = torch.linalg.slogdet(prec_density) logdet_cov_proposal = -logdet_prec_proposal logdet_cov_density = -logdet_prec_density logdet_cov_proposal_rep = logdet_cov_proposal.repeat_interleave( num_comps_density, dim=1 ) logdet_cov_density_rep = logdet_cov_density.repeat(1, num_comps_proposal) log_sqrt_det_ratio = 0.5 * ( logdet_cov_post + logdet_cov_proposal_rep - logdet_cov_density_rep ) # Compute quadratic form terms (m^T P m) exponent_proposal = _batched_vmv(prec_proposal, mean_proposal) exponent_density = _batched_vmv(prec_density, mean_density) exponent_post = _batched_vmv(prec_post, mean_post) exponent_proposal_rep = exponent_proposal.repeat_interleave( num_comps_density, dim=1 ) exponent_density_rep = exponent_density.repeat(1, num_comps_proposal) exponent = -0.5 * (exponent_density_rep - exponent_proposal_rep - exponent_post) return logit_factors + log_sqrt_det_ratio + exponent def _batched_mv(matrix: Tensor, vector: Tensor) -> Tensor: """Batched matrix-vector product with component dimension. Args: matrix: Shape (batch, num_components, dim, dim). vector: Shape (batch, num_components, dim). Returns: Product of shape (batch, num_components, dim). """ return torch.einsum("bcij,bcj->bci", matrix, vector) def _batched_vmv(matrix: Tensor, vector: Tensor) -> Tensor: """Batched vector-matrix-vector product (quadratic form). Args: matrix: Shape (batch, num_components, dim, dim). vector: Shape (batch, num_components, dim). Returns: Quadratic form v^T M v of shape (batch, num_components). """ mv = torch.einsum("bcij,bcj->bci", matrix, vector) return torch.einsum("bci,bci->bc", vector, mv)