EnsemblePosterior#

class EnsemblePosterior(posteriors, weights=None, theta_transform=None, device=None)[source]#

Bases: NeuralPosterior

Wrapper for bundling together different posterior instances into an ensemble.

This class creates a posterior ensemble from a set of \(N\) different, already trained posterior estimators \(p_{i}(\theta \mid x_o)\), where \(i \in \{1, \ldots, N\}\).

It can wrap all posterior classes available in sbi and even a mixture of different posteriors, i.e. obtained via SNLE and SNPE at the same time, since it only provides a pass-through to the class methods of each posterior in the ensemble. The only constraint is that the individual posteriors have the same prior.

So far, log_prob(), sample() and map() functionality are supported.

Example:#

import torch
from sbi.inference import NPE, EnsemblePosterior

theta = prior.sample((100,))
x = simulate(theta)

n_ensembles = 10
posteriors = []
for _ in range(n_ensembles):
    inference = NPE()
    inference.append_simulations(theta, x).train()
    posteriors.append(inference.build_posterior())

ensemble = EnsemblePosterior(posteriors)
ensemble.set_default_x(torch.zeros((3,)))
ensemble.sample((1,))
posteriors#

List of the posterior estimators making up the ensemble.

num_components#

Number of posterior estimators.

weights#

Weight of each posterior distribution. If none are provided each posterior is weighted with 1/N.

priors#

Prior distributions of all posterior components.

theta_transform#

If passed, this transformation will be applied during the optimization performed when obtaining the map. It does not affect the .sample() and .log_prob() methods.

device#

device to host the posterior distribution.

to(device)[source]#

Moves each posterior to device.

Prior and weights are also moved to the specified device.

Parameters:

device (str | device) – The device to move the ensemble posterior to.

Return type:

None

ensure_same_device(posteriors)[source]#

Ensures that all posteriors in the ensemble are on the same device.

Parameters:

posteriors (List) – List containing the trained posterior instances that will make up the ensemble.

Raises:

AssertionError if ensemble components have different device variables.

Returns:

A device string, that is the same for all posteriors.

Return type:

str

property weights: Tensor#
sample(sample_shape=(), x=None, **kwargs)[source]#

Return samples from posterior ensemble.

The samples are drawn according to their assigned weight. The number of samples for each distributino is drawn from a corresponding multinomial distribution. Then each component posterior is sampled individually and all samples are aggregated afterwards.

All kwargs are passed directly through to posterior.sample().

Parameters:
  • sample_shape (Size | Tuple[int, ...]) – Desired shape of samples that are drawn from posterior ensemble. If sample_shape is multidimensional we simply draw sample_shape.numel() samples and then reshape into the desired shape.

  • x (Tensor | None) – Conditioning context. If none is provided and no default context is set, an error will be raised.

Returns:

Samples drawn from the ensemble distribution.

Return type:

Tensor

sample_batched(sample_shape, x, **kwargs)[source]#

Draw samples from the posteriors for a batch of different xs.

Given a batch of observations [x_1, …, x_B], this method samples from posteriors \(p(\theta|x_1), \ldots, p(\theta|x_B)\) in a vectorized manner.

Parameters:
  • sample_shape (Size | Tuple[int, ...]) – Shape of samples to draw for each observation.

  • x (Tensor) – Batch of observations with shape (batch_dim, *event_shape_x).

  • show_progress_bars – Whether to show a progress bar during sampling.

  • **kwargs – Additional keyword arguments passed to the specific posterior’s sampling method.

Returns:

Samples with shape (*sample_shape, batch_dim, *theta_shape).

Return type:

Tensor

log_prob(theta, x=None, individually=False, **kwargs)[source]#

Returns the average log-probability of the posterior ensemble

\(\sum_{i}^{N} w_{i} p_i(\theta|x)\).

All kwargs are passed directly through to posterior.log_prob().

Parameters:
  • theta (Tensor) – Parameters \(\theta\).

  • x (Tensor | None) – Conditioning context.If none is provided and no default context is set, an error will be raised.

  • individually (bool) – If true, returns log weights and log_probs individually.

Raises:

AssertionError if posterior estimators are a mixture of different methods.

Returns:

(len(θ),)-shaped average log posterior probability \(\log p(\theta|x)\) for θ in the support of the prior, -∞ (corresponding to 0 probability) outside.

Return type:

Tensor | Tuple[Tensor, Tensor]

set_default_x(x)[source]#

Set new default x for .sample(), .log_prob() as conditioning context.

This is a pure convenience to avoid having to repeatedly specify x in calls to .sample() and .log_prob() - only θ needs to be passed.

This convenience is particularly useful when the posterior ensemble is focused, i.e. has been trained over multiple rounds to be accurate in the vicinity of a particular x=x_o (you can check if your posterior object is focused by printing one exemplary component of the ensemble).

NOTE: this method is chainable, i.e. will return the EnsemblePosterior object so that calls like `posterior_enemble.set_default_x(my_x).sample(mytheta) ` are possible.

Parameters:
  • x (Tensor) – The default observation to set for every posterior \(p_i(theta|x)\) in the

  • ensemble.

Returns:

EnsemblePosterior that will use a default x when not explicitly passed.

Return type:

NeuralPosterior

potential(theta, x=None, track_gradients=False)[source]#

Evaluates \(\theta\) under the potential that is used to sample the posterior. The potential is the unnormalized log-probability of \(\theta\) under the posterior. :param theta: Parameters \(\theta\). :param track_gradients: Whether the returned tensor supports tracking gradients.

This can be helpful for e.g. sensitivity analysis, but increases memory consumption.

Parameters:
Return type:

Tensor

map(x=None, num_iter=1000, num_to_optimize=100, learning_rate=0.01, init_method='posterior', num_init_samples=1000, save_best_every=10, show_progress_bars=False, individually=False)[source]#

Returns the average maximum-a-posteriori estimate (MAP).

Computes MAP estimate across the whole ensemble or for each component individually. All args and kwargs are passed directly through to gradient_ascent.

The routine can be interrupted (individually) with [Ctrl-C], when the user sees that the log-probability converges. The best estimate will be saved in self. posteriors[idx].map_.

For more details of how the MAP estimate is obtained see .map() docstring of self.posteriors[idx].

Parameters:
  • x (Tensor | None) – Observed data at which to evaluate the MAP.

  • num_iter (int) – Number of optimization steps that the algorithm takes to find the MAP.

  • num_to_optimize (int) – From the drawn num_init_samples, use the num_to_optimize with highest log-probability as the initial points

  • learning_rate (float) – Learning rate of the optimizer.

  • init_method (str | Tensor) – How to select the starting parameters for the optimization. If it is a string, it can be either [posterior, prior], which samples the respective distribution num_init_samples times. If it is a tensor, the tensor will be used as init locations.

  • num_init_samples (int) – Draw this number of samples from the posterior and evaluate the log-probability of all of them. for the optimization.

  • save_best_every (int) – The best log-probability is computed, saved in the map-attribute, and printed every save_best_every-th iteration. Computing the best log-probability creates a significant overhead (thus, the default is 10.)

  • show_progress_bars (bool) – Whether to show a progressbar during sampling from the posterior.

  • individually (bool) – If true, returns log weights and MAPs individually.

Returns:

The ensemble MAP estimate or individual log_weigths and component MAP estimate if individually == True.

Return type:

Tensor | Tuple[Tensor, Tensor]

property default_x: Tensor | None#

Return default x used by .sample(), .log_prob as conditioning context.

Parameters: