Source code for sbi.inference.posteriors.importance_posterior

# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from typing import Any, Callable, Literal, Optional, Tuple, Union

import torch
from torch import Tensor

from sbi.inference.posteriors.base_posterior import NeuralPosterior
from sbi.inference.potentials.base_potential import BasePotential
from sbi.samplers.importance.importance_sampling import importance_sample
from sbi.samplers.importance.sir import sampling_importance_resampling
from sbi.sbi_types import Shape, TorchTransform
from sbi.utils.sbiutils import mcmc_transform
from sbi.utils.torchutils import ensure_theta_batched


[docs] class ImportanceSamplingPosterior(NeuralPosterior): r"""Provides importance sampling to sample from the posterior. SNLE or SNRE train neural networks to approximate the likelihood(-ratios). `ImportanceSamplingPosterior` allows to estimate the posterior log-probability by estimating the normlalization constant with importance sampling. It also allows to perform importance sampling (with `.sample()`) and to draw approximate samples with sampling-importance-resampling (SIR) (with `.sir_sample()`) """ def __init__( self, potential_fn: Union[Callable, BasePotential], proposal: Any, theta_transform: Optional[TorchTransform] = None, method: Literal["sir", "importance"] = "sir", oversampling_factor: int = 32, max_sampling_batch_size: int = 10_000, device: Optional[Union[str, torch.device]] = None, x_shape: Optional[torch.Size] = None, ): """ Args: potential_fn: The potential function from which to draw samples. Must be a `BasePotential` or a `Callable` which takes `theta` and `x_o` as inputs. proposal: The proposal distribution. theta_transform: Transformation that is applied to parameters. Is not used during but only when calling `.map()`. method: Either of [`sir`|`importance`]. This sets the behavior of the `.sample()` method. With `sir`, approximate posterior samples are generated with sampling importance resampling (SIR). With `importance`, the `.sample()` method returns a tuple of samples and corresponding importance weights. oversampling_factor: Number of proposed samples from which only one is selected based on its importance weight. max_sampling_batch_size: The batch size of samples being drawn from the proposal at every iteration. device: Device on which to sample, e.g., "cpu", "cuda" or "cuda:0". If None, `potential_fn.device` is used. x_shape: Deprecated, should not be passed. """ super().__init__( potential_fn, theta_transform=theta_transform, device=device, x_shape=x_shape, ) self.proposal = proposal self._normalization_constant = None self.method = method self.theta_transform = theta_transform self.oversampling_factor = oversampling_factor self.max_sampling_batch_size = max_sampling_batch_size self._purpose = ( "It provides sampling-importance resampling (SIR) to .sample() from the " "posterior and can evaluate the _unnormalized_ posterior density with " ".log_prob()." ) self.x_shape = x_shape
[docs] def to(self, device: Union[str, torch.device]) -> None: """ Move the potential, the proposal and x_o to a new device. It also reinstantiates the posterior with the new device. Args: device: Device on which to move the posterior to. """ self.device = device self.potential_fn.to(device) # type: ignore self.proposal.to(device) x_o = None if hasattr(self, "_x") and (self._x is not None): x_o = self._x.to(device) self.theta_transform = mcmc_transform(self.proposal, device=device) super().__init__( self.potential_fn, theta_transform=self.theta_transform, device=device, x_shape=self.x_shape, ) # super().__init__ erases the self._x, so we need to set it again if x_o is not None: self.set_default_x(x_o)
[docs] def log_prob( self, theta: Tensor, x: Optional[Tensor] = None, track_gradients: bool = False, normalization_constant_params: Optional[dict] = None, ) -> Tensor: r"""Returns the log-probability of theta under the posterior. The normalization constant is estimated with importance sampling. Args: theta: Parameters $\theta$. track_gradients: Whether the returned tensor supports tracking gradients. This can be helpful for e.g. sensitivity analysis, but increases memory consumption. normalization_constant_params: Parameters passed on to `estimate_normalization_constant()`. Returns: `len($\theta$)`-shaped log-probability. """ x = self._x_else_default_x(x) self.potential_fn.set_x(x) theta = ensure_theta_batched(torch.as_tensor(theta)) with torch.set_grad_enabled(track_gradients): potential_values = self.potential_fn( theta.to(self._device), track_gradients=track_gradients ) if normalization_constant_params is None: normalization_constant_params = dict() # use defaults normalization_constant = self.estimate_normalization_constant( x, **normalization_constant_params ) return (potential_values - torch.log(normalization_constant)).to( self._device )
[docs] @torch.no_grad() def estimate_normalization_constant( self, x: Tensor, num_samples: int = 10_000, force_update: bool = False ) -> Tensor: """Returns the normalization constant via importance sampling. Args: num_samples: Number of importance samples used for the estimate. force_update: Whether to re-calculate the normlization constant when x is unchanged and have a cached value. """ # Check if the provided x matches the default x (short-circuit on identity). is_new_x = self.default_x is None or ( x is not self.default_x and (x != self.default_x).any() ) not_saved_at_default_x = self._normalization_constant is None if is_new_x: # Calculate at x; don't save. _, log_importance_weights = importance_sample( self.potential_fn, proposal=self.proposal, num_samples=num_samples, ) return torch.mean(torch.exp(log_importance_weights)) elif not_saved_at_default_x or force_update: # Calculate at default_x; save. assert self.default_x is not None _, log_importance_weights = importance_sample( self.potential_fn, proposal=self.proposal, num_samples=num_samples, ) self._normalization_constant = torch.mean(torch.exp(log_importance_weights)) return self._normalization_constant.to(self._device) # type: ignore
[docs] def sample( self, sample_shape: Shape = torch.Size(), x: Optional[Tensor] = None, method: Optional[str] = None, oversampling_factor: int = 32, max_sampling_batch_size: int = 10_000, show_progress_bars: bool = False, ) -> Union[Tensor, Tuple[Tensor, Tensor]]: """Draw samples from the approximate posterior distribution $p(\theta|x)$. Args: sample_shape: Shape of samples that are drawn from posterior. x: Conditioning observation $x_o$. If not provided, uses the default `x` set via `.set_default_x()`. method: Either of [`sir`|`importance`]. This sets the behavior of the `.sample()` method. With `sir`, approximate posterior samples are generated with sampling importance resampling (SIR). With `importance`, the `.sample()` method returns a tuple of samples and corresponding importance weights. oversampling_factor: Number of proposed samples from which only one is selected based on its importance weight. max_sampling_batch_size: The batch size of samples being drawn from the proposal at every iteration. show_progress_bars: Whether to show a progressbar during sampling. """ method = self.method if method is None else method self.potential_fn.set_x(self._x_else_default_x(x)) if method == "sir": return self._sir_sample( sample_shape, oversampling_factor=oversampling_factor, max_sampling_batch_size=max_sampling_batch_size, show_progress_bars=show_progress_bars, ) elif method == "importance": return self._importance_sample(sample_shape) else: raise NameError
[docs] def sample_batched( self, sample_shape: Shape, x: Tensor, max_sampling_batch_size: int = 10000, show_progress_bars: bool = True, ) -> Tensor: raise NotImplementedError( "Batched sampling is not implemented for ImportanceSamplingPosterior. \ Alternatively you can use `sample` in a loop \ [posterior.sample(theta, x_o) for x_o in x]." )
def _importance_sample( self, sample_shape: Shape = torch.Size(), show_progress_bars: bool = False, ) -> Tuple[Tensor, Tensor]: """Returns samples from the proposal and log of their importance weights. Args: sample_shape: Desired shape of samples that are drawn from posterior. show_progress_bars: Whether to show sampling progress monitor. Returns: Samples and logarithm of corresponding importance weights. """ num_samples = torch.Size(sample_shape).numel() samples, log_importance_weights = importance_sample( self.potential_fn, proposal=self.proposal, num_samples=num_samples, show_progress_bars=show_progress_bars, ) samples = samples.reshape((*sample_shape, -1)).to(self._device) return samples, log_importance_weights.to(self._device) def _sir_sample( self, sample_shape: Shape = torch.Size(), oversampling_factor: int = 32, max_sampling_batch_size: int = 10_000, show_progress_bars: bool = False, ): r"""Returns approximate samples from posterior $p(\theta|x)$ via SIR. Args: sample_shape: Desired shape of samples that are drawn from posterior. If sample_shape is multidimensional we simply draw `sample_shape.numel()` samples and then reshape into the desired shape. oversampling_factor: Number of proposed samples from which only one is selected based on its importance weight. max_sampling_batch_size: The batch size of samples being drawn from the proposal at every iteration. show_progress_bars: Whether to show sampling progress monitor. Returns: Samples from posterior. """ # Replace arguments that were not passed with their default. oversampling_factor = ( self.oversampling_factor if oversampling_factor is None else oversampling_factor ) max_sampling_batch_size = ( self.max_sampling_batch_size if max_sampling_batch_size is None else max_sampling_batch_size ) num_samples = torch.Size(sample_shape).numel() samples = sampling_importance_resampling( self.potential_fn, proposal=self.proposal, num_samples=num_samples, num_candidate_samples=oversampling_factor, show_progress_bars=show_progress_bars, max_sampling_batch_size=max_sampling_batch_size, device=self._device, ) return samples.reshape((*sample_shape, -1)).to(self._device)
[docs] def map( self, x: Optional[Tensor] = None, num_iter: int = 1_000, num_to_optimize: int = 100, learning_rate: float = 0.01, init_method: Union[str, Tensor] = "proposal", num_init_samples: int = 1_000, save_best_every: int = 10, show_progress_bars: bool = False, force_update: bool = False, ) -> Tensor: r"""Returns the maximum-a-posteriori estimate (MAP). The method can be interrupted (Ctrl-C) when the user sees that the log-probability converges. The best estimate will be saved in `self._map` and can be accessed with `self.map()`. The MAP is obtained by running gradient ascent from a given number of starting positions (samples from the posterior with the highest log-probability). After the optimization is done, we select the parameter set that has the highest log-probability after the optimization. Warning: The default values used by this function are not well-tested. They might require hand-tuning for the problem at hand. For developers: if the prior is a `BoxUniform`, we carry out the optimization in unbounded space and transform the result back into bounded space. Args: x: Deprecated - use `.set_default_x()` prior to `.map()`. num_iter: Number of optimization steps that the algorithm takes to find the MAP. learning_rate: Learning rate of the optimizer. init_method: How to select the starting parameters for the optimization. If it is a string, it can be either [`posterior`, `prior`], which samples the respective distribution `num_init_samples` times. If it is a tensor, the tensor will be used as init locations. num_init_samples: Draw this number of samples from the posterior and evaluate the log-probability of all of them. num_to_optimize: From the drawn `num_init_samples`, use the `num_to_optimize` with highest log-probability as the initial points for the optimization. save_best_every: The best log-probability is computed, saved in the `map`-attribute, and printed every `save_best_every`-th iteration. Computing the best log-probability creates a significant overhead (thus, the default is `10`.) show_progress_bars: Whether to show a progressbar during sampling from the posterior. force_update: Whether to re-calculate the MAP when x is unchanged and have a cached value. log_prob_kwargs: Will be empty for SNLE and SNRE. Will contain {'norm_posterior': True} for SNPE. Returns: The MAP estimate. """ return super().map( x=x, num_iter=num_iter, num_to_optimize=num_to_optimize, learning_rate=learning_rate, init_method=init_method, num_init_samples=num_init_samples, save_best_every=save_best_every, show_progress_bars=show_progress_bars, force_update=force_update, )