# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
from typing import List, Optional, Tuple, Union
import torch
from torch import Tensor
from torch.distributions import Distribution
from sbi.inference.posteriors.base_posterior import NeuralPosterior
from sbi.inference.potentials.base_potential import BasePotential
from sbi.inference.potentials.posterior_based_potential import PosteriorBasedPotential
from sbi.sbi_types import Shape, TorchTransform
from sbi.utils.sbiutils import gradient_ascent, mcmc_transform
from sbi.utils.torchutils import ensure_theta_batched
from sbi.utils.user_input_checks import process_x
[docs]
class EnsemblePosterior(NeuralPosterior):
r"""Wrapper for bundling together different posterior instances into an ensemble.
This class creates a posterior ensemble from a set of :math:`N` different, already
trained posterior estimators :math:`p_{i}(\theta \mid x_o)`, where
:math:`i \in \{1, \ldots, N\}`.
It can wrap all posterior classes available in ``sbi`` and even a mixture of
different posteriors, i.e. obtained via SNLE and SNPE at the same time, since it
only provides a pass-through to the class methods of each posterior in the
ensemble. The only constraint is that the individual posteriors have the same
prior.
So far, ``log_prob()``, ``sample()`` and ``map()`` functionality are supported.
Example:
--------
::
import torch
from sbi.inference import NPE, EnsemblePosterior
theta = prior.sample((100,))
x = simulate(theta)
n_ensembles = 10
posteriors = []
for _ in range(n_ensembles):
inference = NPE()
inference.append_simulations(theta, x).train()
posteriors.append(inference.build_posterior())
ensemble = EnsemblePosterior(posteriors)
ensemble.set_default_x(torch.zeros((3,)))
ensemble.sample((1,))
Attributes:
posteriors: List of the posterior estimators making up the ensemble.
num_components: Number of posterior estimators.
weights: Weight of each posterior distribution. If none are provided each
posterior is weighted with 1/N.
priors: Prior distributions of all posterior components.
theta_transform: If passed, this transformation will be applied during the
optimization performed when obtaining the map. It does not affect the
.sample() and .log_prob() methods.
device: device to host the posterior distribution.
"""
def __init__(
self,
posteriors: List,
weights: Optional[Union[List[float], Tensor]] = None,
theta_transform: Optional[TorchTransform] = None,
device: Optional[Union[str, torch.device]] = None,
):
r"""
Args:
posteriors: List containing the trained posterior instances that will make
up the ensemble.
weights: Assign weights to posteriors manually, otherwise they will be
weighted with 1/N.
theta_transform: If passed, this transformation will be applied during the
optimization performed when obtaining the map. It does not affect the
`.sample()` and `.log_prob()` methods.
"""
self.posteriors = posteriors
self.num_components = len(posteriors)
self.weights = weights
self.theta_transform = theta_transform
# Take first prior as reference
self.prior = posteriors[0].potential_fn.prior
self.device = device
if self.device is None:
self.device = self.ensure_same_device(posteriors)
self._build_potential_fns()
[docs]
def to(self, device: Union[str, torch.device]) -> None:
"""Moves each posterior to device.
Prior and weights are also moved to
the specified device.
Args:
device: The device to move the ensemble posterior to.
"""
self.device = device
self._device = device
for i in range(len(self.posteriors)):
self.posteriors[i].to(device)
self.prior.to(device)
self.theta_transform = mcmc_transform(self.prior, device=device)
self._weights.to(device)
self._build_potential_fns()
def _build_potential_fns(self):
potential_fns = []
for posterior in self.posteriors:
potential = posterior.potential_fn
potential_fns.append(potential)
# make sure all prior are the same
assert isinstance(potential.prior, type(self.prior)), (
"All posteriors in ensemble must have the same prior: "
f"{potential.prior} {self.prior}"
)
potential_fn = EnsemblePotential(potential_fns, self._weights, self.prior, None)
super().__init__(
potential_fn=potential_fn,
theta_transform=self.theta_transform,
device=self.device,
)
[docs]
def ensure_same_device(self, posteriors: List) -> str:
"""Ensures that all posteriors in the ensemble are on the same device.
Args:
posteriors: List containing the trained posterior instances that will make
up the ensemble.
Raises:
AssertionError if ensemble components have different device variables.
Returns:
A device string, that is the same for all posteriors.
"""
devices = [posterior._device for posterior in posteriors]
assert all(device == devices[0] for device in devices), (
"Only supported if all posteriors are on the same device."
)
return devices[0]
@property
def weights(self) -> Tensor:
return self._weights
@weights.setter
def weights(self, weights: Optional[Union[List[float], Tensor]]) -> None:
"""Set relative weight for each posterior in the ensemble.
Weights are normalised.
Raises:
TypeError if weights are provided in an unsupported format.
Args:
weights: Assignes weight to each posterior distribution.
"""
if weights is None:
self._weights = torch.tensor([
1.0 / self.num_components for _ in range(self.num_components)
])
elif isinstance(weights, (Tensor, List)):
self._weights = torch.tensor(weights) / sum(weights)
else:
raise TypeError
[docs]
def sample(
self, sample_shape: Shape = torch.Size(), x: Optional[Tensor] = None, **kwargs
) -> Tensor:
r"""Return samples from posterior ensemble.
The samples are drawn according to their assigned weight. The number of samples
for each distributino is drawn from a corresponding multinomial distribution.
Then each component posterior is sampled individually and all samples are
aggregated afterwards.
All kwargs are passed directly through to `posterior.sample()`.
Args:
sample_shape: Desired shape of samples that are drawn from posterior
ensemble. If sample_shape is multidimensional we simply draw
`sample_shape.numel()` samples and then reshape into the desired shape.
x: Conditioning context. If none is provided and no default context is set,
an error will be raised.
Returns:
Samples drawn from the ensemble distribution.
"""
num_samples = torch.Size(sample_shape).numel()
posterior_indizes = torch.multinomial(
self._weights, num_samples, replacement=True
)
samples = []
for posterior_index, sample_size in torch.vstack(
posterior_indizes.unique(return_counts=True)
).T:
sample_shape_c = torch.Size((int(sample_size),))
samples.append(
self.posteriors[posterior_index].sample(sample_shape_c, x=x, **kwargs)
)
return torch.vstack(samples).reshape(*sample_shape, -1)
[docs]
def sample_batched(
self,
sample_shape: Shape,
x: Tensor,
**kwargs,
) -> Tensor:
num_samples = torch.Size(sample_shape).numel()
posterior_indices = torch.multinomial(
self._weights, num_samples, replacement=True
)
samples = []
for posterior_index, sample_size in torch.vstack(
posterior_indices.unique(return_counts=True)
).T:
sample_shape_c = torch.Size((int(sample_size),))
samples.append(
self.posteriors[posterior_index].sample_batched(
sample_shape_c, x=x, **kwargs
)
)
samples = torch.vstack(samples)
return samples.reshape(sample_shape + samples.shape[1:])
[docs]
def log_prob(
self,
theta: Tensor,
x: Optional[Tensor] = None,
individually: bool = False,
**kwargs,
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
r"""Returns the average log-probability of the posterior ensemble
$\sum_{i}^{N} w_{i} p_i(\theta|x)$.
All kwargs are passed directly through to `posterior.log_prob()`.
Args:
theta: Parameters $\theta$.
x: Conditioning context.If none is provided and no default context is set,
an error will be raised.
individually: If true, returns log weights and log_probs individually.
Raises:
AssertionError if posterior estimators are a mixture of different methods.
Returns:
`(len(θ),)`-shaped average log posterior probability $\log p(\theta|x)$ for
θ in the support of the prior, -∞ (corresponding to 0 probability) outside.
"""
assert all(
isinstance(posterior, type(self.posteriors[0]))
for posterior in self.posteriors
), "`log_prob()` only works for ensembles of the same type of posterior."
log_probs = torch.stack([
posterior.log_prob(theta, x=x, **kwargs) for posterior in self.posteriors
])
log_weights = torch.log(self._weights).reshape(-1, 1)
if individually:
return log_weights, log_probs
else:
return torch.logsumexp(log_weights.expand_as(log_probs) + log_probs, dim=0)
[docs]
def set_default_x(self, x: Tensor) -> "NeuralPosterior":
r"""Set new default x for `.sample(), .log_prob()` as conditioning context.
This is a pure convenience to avoid having to repeatedly specify `x` in calls to
`.sample()` and `.log_prob()` - only θ needs to be passed.
This convenience is particularly useful when the posterior ensemble is focused,
i.e. has been trained over multiple rounds to be accurate in the vicinity of a
particular `x=x_o` (you can check if your posterior object is focused by
printing one exemplary component of the ensemble).
NOTE: this method is chainable, i.e. will return the EnsemblePosterior
object so that calls like `posterior_enemble.set_default_x(my_x).sample(mytheta)
` are possible.
Args:
x: The default observation to set for every posterior $p_i(theta|x)$ in the
ensemble.
Returns:
`EnsemblePosterior` that will use a default `x` when not explicitly
passed.
"""
self._x = process_x(x, x_event_shape=None).to(self._device)
for posterior in self.posteriors:
posterior.set_default_x(x)
return self
[docs]
def potential(
self, theta: Tensor, x: Optional[Tensor] = None, track_gradients: bool = False
) -> Tensor:
r"""Evaluates $\theta$ under the potential that is used to sample the posterior.
The potential is the unnormalized log-probability of $\theta$ under the
posterior.
Args:
theta: Parameters $\theta$.
track_gradients: Whether the returned tensor supports tracking gradients.
This can be helpful for e.g. sensitivity analysis, but increases memory
consumption.
"""
self.potential_fn.set_x(self._x_else_default_x(x))
theta = ensure_theta_batched(torch.as_tensor(theta))
return self.potential_fn(
theta.to(self._device), track_gradients=track_gradients
)
[docs]
def map(
self,
x: Optional[Tensor] = None,
num_iter: int = 1_000,
num_to_optimize: int = 100,
learning_rate: float = 0.01,
init_method: Union[str, Tensor] = "posterior",
num_init_samples: int = 1_000,
save_best_every: int = 10,
show_progress_bars: bool = False,
individually: bool = False,
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
r"""Returns the average maximum-a-posteriori estimate (MAP).
Computes MAP estimate across the whole ensemble or for each component
individually. All args and kwargs are passed directly through to
`gradient_ascent`.
The routine can be interrupted (individually) with [Ctrl-C], when the user sees
that the log-probability converges. The best estimate will be saved in `self.
posteriors[idx].map_`.
For more details of how the MAP estimate is obtained see `.map()` docstring of
self.posteriors[idx].
Args:
x: Observed data at which to evaluate the MAP.
num_iter: Number of optimization steps that the algorithm takes
to find the MAP.
num_to_optimize: From the drawn `num_init_samples`, use the
`num_to_optimize` with highest log-probability as the initial points
learning_rate: Learning rate of the optimizer.
init_method: How to select the starting parameters for the optimization. If
it is a string, it can be either [`posterior`, `prior`], which samples
the respective distribution `num_init_samples` times. If it is a
tensor, the tensor will be used as init locations.
num_init_samples: Draw this number of samples from the posterior and
evaluate the log-probability of all of them.
for the optimization.
save_best_every: The best log-probability is computed, saved in the
`map`-attribute, and printed every `save_best_every`-th iteration.
Computing the best log-probability creates a significant overhead
(thus, the default is `10`.)
show_progress_bars: Whether to show a progressbar during sampling from
the posterior.
individually: If true, returns log weights and MAPs individually.
Returns:
The ensemble MAP estimate or individual log_weigths and component MAP
estimate if individually == True.
"""
if individually:
maps = []
log_weights = torch.log(self._weights).reshape(-1, 1)
for posterior in self.posteriors:
maps.append(
posterior.map(
x=x,
num_iter=num_iter,
num_to_optimize=num_to_optimize,
learning_rate=learning_rate,
init_method=init_method,
num_init_samples=num_init_samples,
save_best_every=save_best_every,
show_progress_bars=show_progress_bars,
)
)
maps = torch.stack(maps)
return log_weights, maps
else:
self.potential_fn.set_x(self._x_else_default_x(x))
if init_method == "posterior":
inits = self.sample((num_init_samples,), self._x_else_default_x(x))
elif isinstance(init_method, Tensor):
inits = init_method
else:
raise ValueError
return gradient_ascent(
potential_fn=self.potential_fn,
inits=inits,
theta_transform=self.theta_transform,
num_iter=num_iter,
num_to_optimize=num_to_optimize,
learning_rate=learning_rate,
save_best_every=save_best_every,
show_progress_bars=show_progress_bars,
)[0]
class EnsemblePotential(BasePotential):
r"""Provides `EnsemblePosterior` with a `potential_fn` attribute.
The potential is the same as the sum of the weighted log-probabilities of each
component posterior.
This class was modelled off of `PosteriorBasedPotential` and should provide similar
functionality.
Attributes:
potential_fns: List of the potential_fns from each posterior component.
weights: Weights of each posterior distribution.
"""
def __init__(
self,
potential_fns: List,
weights: Tensor,
prior: Distribution,
x_o: Optional[Tensor],
device: Union[str, torch.device] = "cpu",
):
r"""
Args:
potential_fns: List of the potential_fns from each posterior component.
weights: Weights of each posterior distribution.
priors: List of prior distributions.
x_o: Used as conditioning context for `potential_fn`.
device: Device which the component distributions sit on.
"""
self._weights = weights
self.potential_fns = potential_fns
super().__init__(prior, x_o, device)
def to(self, device: Union[str, torch.device]) -> None:
"""
Moves the ensemble potentials, the prior, the weights and x_o to
the specified device.
Args:
device: The device to move the ensemble potential to.
"""
self.device = device
for i in range(len(self.potential_fns)):
self.potential_fns[i].to(device)
self._weights = self._weights.to(device)
self.prior.to(device) # type: ignore
if self._x_o is not None:
self._x_o = self._x_o.to(device)
def allow_iid_x(self) -> bool:
# in case there is different kinds of posteriors, this will produce an error
# in `set_x()`
return any(
isinstance(potential, PosteriorBasedPotential)
for potential in self.potential_fns
)
def set_x(self, x_o: Optional[Tensor]):
"""Check the shape of the observed data and, if valid, set it."""
if x_o is not None:
x_o = process_x(x_o).to( # type: ignore
self.device
)
self._x_o = x_o
for comp_potential in self.potential_fns:
comp_potential.set_x(x_o)
def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor:
r"""Returns the potential for posterior-based methods.
Args:
theta: The parameter set at which to evaluate the potential function.
track_gradients: Whether to track the gradients.
Returns:
The potential.
"""
theta = ensure_theta_batched(torch.as_tensor(theta))
theta = theta.to(self.device)
log_probs = [
fn(theta, track_gradients=track_gradients) for fn in self.potential_fns
]
log_probs = torch.vstack(log_probs)
ensemble_log_probs = torch.logsumexp(
torch.log(self._weights.reshape(-1, 1)).expand_as(log_probs) + log_probs,
dim=0,
)
return ensemble_log_probs