# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
import warnings
from typing import Optional, Union
import torch
from torch import Tensor, log
from torch.distributions import Distribution
from sbi.inference.posteriors.base_posterior import NeuralPosterior
from sbi.inference.potentials.posterior_based_potential import (
posterior_estimator_based_potential,
)
from sbi.neural_nets.estimators.base import ConditionalDensityEstimator
from sbi.neural_nets.estimators.shape_handling import (
reshape_to_batch_event,
reshape_to_sample_batch_event,
)
from sbi.samplers.rejection import rejection
from sbi.sbi_types import Shape
from sbi.utils.sbiutils import warn_if_outside_prior_support, within_support
from sbi.utils.torchutils import ensure_theta_batched
from sbi.utils.user_input_checks import check_prior
[docs]
class DirectPosterior(NeuralPosterior):
r"""Posterior based on neural networks that directly estimate the posterior (NPE).
NPE trains a neural network to directly approximate the posterior distribution.
However, for bounded priors, the neural network can have leakage: it puts non-zero
mass in regions where the prior is zero. The `DirectPosterior` class wraps the
trained network to deal with these cases.
Specifically, this class offers the following functionality:
- correct the calculation of the log probability such that it compensates for the
leakage.
- reject samples that lie outside of the prior bounds.
This class can not be used in combination with NLE or NRE.
"""
def __init__(
self,
posterior_estimator: ConditionalDensityEstimator,
prior: Distribution,
max_sampling_batch_size: int = 10_000,
device: Optional[Union[str, torch.device]] = None,
x_shape: Optional[torch.Size] = None,
enable_transform: bool = True,
):
"""
Args:
prior: Prior distribution with `.log_prob()` and `.sample()`.
posterior_estimator: The trained neural posterior.
max_sampling_batch_size: Batchsize of samples being drawn from
the proposal at every iteration.
device: Training device, e.g., "cpu", "cuda" or "cuda:0". If None,
`potential_fn.device` is used.
x_shape: Deprecated, should not be passed.
enable_transform: Whether to transform parameters to unconstrained space
during MAP optimization. When False, an identity transform will be
returned for `theta_transform`.
"""
# Because `DirectPosterior` does not take the `potential_fn` as input, it
# builds it itself. The `potential_fn` and `theta_transform` are used only for
# obtaining the MAP.
check_prior(prior)
self.enable_transform = enable_transform
self.x_shape = x_shape
potential_fn, theta_transform = posterior_estimator_based_potential(
posterior_estimator,
prior,
x_o=None,
enable_transform=enable_transform,
)
super().__init__(
potential_fn=potential_fn,
theta_transform=theta_transform,
device=device,
x_shape=x_shape,
)
self.device = device
self.prior = prior
self.posterior_estimator = posterior_estimator
self.max_sampling_batch_size = max_sampling_batch_size
self._leakage_density_correction_factor = None
self._purpose = """It samples the posterior network and rejects samples that
lie outside of the prior bounds."""
[docs]
def to(self, device: Union[str, torch.device]) -> None:
"""Move posterior_estimator, prior and x_o to device.
Changes the device attribute, reinstanciates the
posterior, and resets the default x.
Args:
device: device where to move the posterior to.
"""
self.device = device
if hasattr(self.prior, "to"):
self.prior.to(device) # type: ignore
else:
raise ValueError("""Prior has no attribute to(device).""")
if hasattr(self.posterior_estimator, "to"):
self.posterior_estimator.to(device)
else:
raise ValueError("""Posterior estimator has no attribute to(device).""")
potential_fn, theta_transform = posterior_estimator_based_potential(
self.posterior_estimator,
self.prior,
x_o=None,
enable_transform=self.enable_transform,
)
x_o = None
if hasattr(self, "_x") and (self._x is not None):
x_o = self._x.to(device)
super().__init__(
potential_fn=potential_fn,
theta_transform=theta_transform,
device=device,
x_shape=self.x_shape,
)
# super().__init__ erases the self._x, so we need to set it again
if x_o is not None:
self.set_default_x(x_o)
[docs]
def sample(
self,
sample_shape: Shape = torch.Size(),
x: Optional[Tensor] = None,
max_sampling_batch_size: int = 10_000,
show_progress_bars: bool = True,
reject_outside_prior: bool = True,
max_sampling_time: Optional[float] = None,
return_partial_on_timeout: bool = False,
) -> Tensor:
r"""Draw samples from the approximate posterior distribution $p(\theta|x)$.
Args:
sample_shape: Desired shape of samples that are drawn from posterior. If
sample_shape is multidimensional we simply draw `sample_shape.numel()`
samples and then reshape into the desired shape.
x: Conditioning observation $x_o$. If not provided, uses the default `x`
set via `.set_default_x()`.
max_sampling_batch_size: Maximum batch size for rejection sampling.
show_progress_bars: Whether to show sampling progress monitor.
reject_outside_prior: If True (default), rejection sampling is used to
ensure samples lie within the prior support. If False, samples are drawn
directly from the neural density estimator without rejection, which is
faster but may include samples outside the prior support.
max_sampling_time: Optional maximum allowed sampling time in seconds.
If exceeded, sampling is aborted and a RuntimeError is raised. Only
applies when `reject_outside_prior=True` (no effect otherwise since
direct sampling is fast).
return_partial_on_timeout: If True and `max_sampling_time` is exceeded,
return the samples collected so far instead of raising a RuntimeError.
A warning will be issued. Only applies when `reject_outside_prior=True`
(default).
"""
num_samples = torch.Size(sample_shape).numel()
x = self._x_else_default_x(x)
x = reshape_to_batch_event(
x, event_shape=self.posterior_estimator.condition_shape
)
if x.shape[0] > 1:
raise ValueError(
".sample() supports only `batchsize == 1`. If you intend "
"to sample multiple observations, use `.sample_batched()`. "
"If you intend to sample i.i.d. observations, set up the "
"posterior density estimator with an appropriate permutation "
"invariant embedding net."
)
max_sampling_batch_size = (
self.max_sampling_batch_size
if max_sampling_batch_size is None
else max_sampling_batch_size
)
if reject_outside_prior:
# Normal rejection behavior.
samples = rejection.accept_reject_sample(
proposal=self.posterior_estimator.sample,
accept_reject_fn=lambda theta: within_support(self.prior, theta),
num_samples=num_samples,
show_progress_bars=show_progress_bars,
max_sampling_batch_size=max_sampling_batch_size,
proposal_sampling_kwargs={"condition": x},
alternative_method="build_posterior(..., sample_with='mcmc')",
max_sampling_time=max_sampling_time,
return_partial_on_timeout=return_partial_on_timeout,
)[0]
else:
# Bypass rejection sampling entirely.
samples = self.posterior_estimator.sample(
torch.Size([num_samples]),
condition=x,
)
warn_if_outside_prior_support(self.prior, samples[:, 0])
return samples[:, 0] # Remove batch dimension.
[docs]
def sample_batched(
self,
sample_shape: Shape,
x: Tensor,
max_sampling_batch_size: int = 10_000,
show_progress_bars: bool = True,
reject_outside_prior: bool = True,
max_sampling_time: Optional[float] = None,
return_partial_on_timeout: bool = False,
) -> Tensor:
r"""Draw samples from the posteriors for a batch of different xs.
Given a batch of observations `[x_1, ..., x_B]`, this method samples from
posteriors $p(\theta|x_1), \ldots, p(\theta|x_B)$ in a vectorized manner.
Args:
sample_shape: Desired shape of samples that are drawn from the posterior
given every observation.
x: A batch of observations, of shape `(batch_dim, event_shape_x)`.
`batch_dim` corresponds to the number of observations to be drawn.
max_sampling_batch_size: Maximum batch size for rejection sampling.
show_progress_bars: Whether to show sampling progress monitor.
reject_outside_prior: If True (default), rejection sampling is used to
ensure samples lie within the prior support. If False, samples are drawn
directly from the neural density estimator without rejection, which is
faster but may include samples outside the prior support.
max_sampling_time: Optional maximum allowed sampling time in seconds.
If exceeded, sampling is aborted and a RuntimeError is raised. Only
applies when `reject_outside_prior=True`.
return_partial_on_timeout: If True and `max_sampling_time` is exceeded,
return the samples collected so far instead of raising a RuntimeError.
A warning will be issued. Only applies when `reject_outside_prior=True`.
Returns:
Samples from the posteriors of shape (*sample_shape, B, *input_shape)
"""
num_samples = torch.Size(sample_shape).numel()
condition_shape = self.posterior_estimator.condition_shape
x = reshape_to_batch_event(x, event_shape=condition_shape)
num_xos = x.shape[0]
# throw warning if num_x * num_samples is too large
if num_xos * num_samples > 2**21: # 2 million-ish
warnings.warn(
f"Note that for batched sampling, the direct posterior sampling "
f"generates {num_xos} * {num_samples} = {num_xos * num_samples} "
"samples. This can be slow and memory-intensive. Consider "
"reducing the number of samples or batch size.",
stacklevel=2,
)
max_sampling_batch_size = (
self.max_sampling_batch_size
if max_sampling_batch_size is None
else max_sampling_batch_size
)
# Adjust max_sampling_batch_size to avoid excessive memory usage
if max_sampling_batch_size * num_xos > 100_000:
capped = max(1, 100_000 // num_xos)
warnings.warn(
f"Capping max_sampling_batch_size from {max_sampling_batch_size} "
f"to {capped} to avoid excessive memory usage.",
stacklevel=2,
)
max_sampling_batch_size = capped
if reject_outside_prior:
# Normal rejection behavior.
samples = rejection.accept_reject_sample(
proposal=self.posterior_estimator.sample,
accept_reject_fn=lambda theta: within_support(self.prior, theta),
num_samples=num_samples,
show_progress_bars=show_progress_bars,
max_sampling_batch_size=max_sampling_batch_size,
proposal_sampling_kwargs={"condition": x},
alternative_method="build_posterior(..., sample_with='mcmc')",
max_sampling_time=max_sampling_time,
return_partial_on_timeout=return_partial_on_timeout,
)[0]
else:
# Bypass rejection sampling entirely.
samples = self.posterior_estimator.sample(
torch.Size([num_samples]),
condition=x,
)
warn_if_outside_prior_support(self.prior, samples)
return samples
[docs]
def log_prob(
self,
theta: Tensor,
x: Optional[Tensor] = None,
norm_posterior: bool = True,
track_gradients: bool = False,
leakage_correction_params: Optional[dict] = None,
) -> Tensor:
r"""Returns the log-probability of the posterior $p(\theta|x)$.
Args:
theta: Parameters $\theta$.
norm_posterior: Whether to enforce a normalized posterior density.
Renormalization of the posterior is useful when some
probability falls out or leaks out of the prescribed prior support.
The normalizing factor is calculated via rejection sampling, so if you
need speedier but unnormalized log posterior estimates set here
`norm_posterior=False`. The returned log posterior is set to
-∞ outside of the prior support regardless of this setting.
track_gradients: Whether the returned tensor supports tracking gradients.
This can be helpful for e.g. sensitivity analysis, but increases memory
consumption.
leakage_correction_params: A `dict` of keyword arguments to override the
default values of `leakage_correction()`. Possible options are:
`num_rejection_samples`, `force_update`, `show_progress_bars`, and
`rejection_sampling_batch_size`.
These parameters only have an effect if `norm_posterior=True`.
Returns:
`(len(θ),)`-shaped log posterior probability $\log p(\theta|x)$ for θ in the
support of the prior, -∞ (corresponding to 0 probability) outside.
"""
x = self._x_else_default_x(x)
theta = ensure_theta_batched(torch.as_tensor(theta))
theta_density_estimator = reshape_to_sample_batch_event(
theta, theta.shape[1:], leading_is_sample=True
)
x_density_estimator = reshape_to_batch_event(
x, event_shape=self.posterior_estimator.condition_shape
)
if x_density_estimator.shape[0] > 1:
raise ValueError(
".log_prob() supports only `batchsize == 1`. If you intend "
"to evaluate given multiple observations, use `.log_prob_batched()`. "
"If you intend to evaluate given i.i.d. observations, set up the "
"posterior density estimator with an appropriate permutation "
"invariant embedding net."
)
self.posterior_estimator.eval()
with torch.set_grad_enabled(track_gradients):
# Evaluate on device, move back to cpu for comparison with prior.
unnorm_log_prob = self.posterior_estimator.log_prob(
theta_density_estimator, condition=x_density_estimator
)
# `log_prob` supports only a single observation (i.e. `batchsize==1`).
# We now remove this additional dimension.
unnorm_log_prob = unnorm_log_prob.squeeze(dim=1)
# Force probability to be zero outside prior support.
in_prior_support = within_support(self.prior, theta)
masked_log_prob = torch.where(
in_prior_support,
unnorm_log_prob,
torch.tensor(float("-inf"), dtype=torch.float32, device=self._device),
)
if leakage_correction_params is None:
leakage_correction_params = dict() # use defaults
log_factor = (
log(self.leakage_correction(x=x, **leakage_correction_params))
if norm_posterior
else 0
)
return masked_log_prob - log_factor
[docs]
def log_prob_batched(
self,
theta: Tensor,
x: Tensor,
norm_posterior: bool = True,
track_gradients: bool = False,
leakage_correction_params: Optional[dict] = None,
) -> Tensor:
"""Given a batch of observations [x_1, ..., x_B] and a batch of parameters \
[$\theta_1$,..., $\theta_B$] this function evalautes the log-probabilities \
of the posteriors $p(\theta_1|x_1)$, ..., $p(\theta_B|x_B)$ in a batched \
(i.e. vectorized) manner.
Args:
theta: Batch of parameters $\theta$ of shape \
`(*sample_shape, batch_dim, *theta_shape)`.
x: Batch of observations $x$ of shape \
`(batch_dim, *condition_shape)`.
norm_posterior: Whether to enforce a normalized posterior density.
Renormalization of the posterior is useful when some
probability falls out or leaks out of the prescribed prior support.
The normalizing factor is calculated via rejection sampling, so if you
need speedier but unnormalized log posterior estimates set here
`norm_posterior=False`. The returned log posterior is set to
-∞ outside of the prior support regardless of this setting.
track_gradients: Whether the returned tensor supports tracking gradients.
This can be helpful for e.g. sensitivity analysis, but increases memory
consumption.
leakage_correction_params: A `dict` of keyword arguments to override the
default values of `leakage_correction()`. Possible options are:
`num_rejection_samples`, `force_update`, `show_progress_bars`, and
`rejection_sampling_batch_size`.
These parameters only have an effect if `norm_posterior=True`.
Returns:
`(len(θ), B)`-shaped log posterior probability $\\log p(\theta|x)$\\ for θ \
in the support of the prior, -∞ (corresponding to 0 probability) outside.
"""
theta = ensure_theta_batched(torch.as_tensor(theta))
event_shape = self.posterior_estimator.input_shape
# If theta has 1 leading dim (batch, event), treat it as batch (matching x).
# overwise, the leading is sample.
num_leading = len(theta.shape) - len(event_shape)
theta_density_estimator = reshape_to_sample_batch_event(
theta, event_shape, leading_is_sample=(num_leading > 1)
)
x_density_estimator = reshape_to_batch_event(
x, event_shape=self.posterior_estimator.condition_shape
)
self.posterior_estimator.eval()
with torch.set_grad_enabled(track_gradients):
# Evaluate on device, move back to cpu for comparison with prior.
unnorm_log_prob = self.posterior_estimator.log_prob(
theta_density_estimator, condition=x_density_estimator
)
# Force probability to be zero outside prior support.
in_prior_support = within_support(self.prior, theta)
masked_log_prob = torch.where(
in_prior_support,
unnorm_log_prob,
torch.tensor(float("-inf"), dtype=torch.float32, device=self._device),
)
if leakage_correction_params is None:
leakage_correction_params = dict() # use defaults
log_factor = (
log(self.leakage_correction(x=x, **leakage_correction_params))
if norm_posterior
else 0
)
return masked_log_prob - log_factor
[docs]
@torch.no_grad()
def leakage_correction(
self,
x: Tensor,
num_rejection_samples: int = 10_000,
force_update: bool = False,
show_progress_bars: bool = False,
rejection_sampling_batch_size: int = 10_000,
) -> Tensor:
r"""Return leakage correction factor for a leaky posterior density estimate.
The factor is estimated from the acceptance probability during rejection
sampling from the posterior.
This is to avoid re-estimating the acceptance probability from scratch
whenever `log_prob` is called and `norm_posterior=True`. Here, it
is estimated only once for `self.default_x` and saved for later. We
re-evaluate only whenever a new `x` is passed.
Arguments:
num_rejection_samples: Number of samples used to estimate correction factor.
show_progress_bars: Whether to show a progress bar during sampling.
rejection_sampling_batch_size: Batch size for rejection sampling.
Returns:
Saved or newly-estimated correction factor (as a scalar `Tensor`).
"""
def acceptance_at(x: Tensor) -> Tensor:
# [1:] to remove batch-dimension for `reshape_to_batch_event`.
return rejection.accept_reject_sample(
proposal=self.posterior_estimator.sample,
accept_reject_fn=lambda theta: within_support(self.prior, theta),
num_samples=num_rejection_samples,
show_progress_bars=show_progress_bars,
sample_for_correction_factor=True,
max_sampling_batch_size=rejection_sampling_batch_size,
proposal_sampling_kwargs={
"condition": reshape_to_batch_event(
x, event_shape=self.posterior_estimator.condition_shape
)
},
)[1]
# Check if the provided x matches the default x (short-circuit on identity).
is_new_x = self.default_x is None or (
x is not self.default_x and (x != self.default_x).any()
)
not_saved_at_default_x = self._leakage_density_correction_factor is None
if is_new_x: # Calculate at x; don't save.
return acceptance_at(x)
elif not_saved_at_default_x or force_update: # Calculate at default_x; save.
assert self.default_x is not None
self._leakage_density_correction_factor = acceptance_at(self.default_x)
return self._leakage_density_correction_factor # type: ignore
[docs]
def map(
self,
x: Optional[Tensor] = None,
num_iter: int = 1_000,
num_to_optimize: int = 100,
learning_rate: float = 0.01,
init_method: Union[str, Tensor] = "posterior",
num_init_samples: int = 1_000,
save_best_every: int = 10,
show_progress_bars: bool = False,
force_update: bool = False,
) -> Tensor:
r"""Returns the maximum-a-posteriori estimate (MAP).
The method can be interrupted (Ctrl-C) when the user sees that the
log-probability converges. The best estimate will be saved in `self._map` and
can be accessed with `self.map()`. The MAP is obtained by running gradient
ascent from a given number of starting positions (samples from the posterior
with the highest log-probability). After the optimization is done, we select the
parameter set that has the highest log-probability after the optimization.
Warning: The default values used by this function are not well-tested. They
might require hand-tuning for the problem at hand.
For developers: if the prior is a `BoxUniform`, we carry out the optimization
in unbounded space and transform the result back into bounded space.
Args:
x: Deprecated - use `.set_default_x()` prior to `.map()`.
num_iter: Number of optimization steps that the algorithm takes
to find the MAP.
learning_rate: Learning rate of the optimizer.
init_method: How to select the starting parameters for the optimization. If
it is a string, it can be either [`posterior`, `prior`], which samples
the respective distribution `num_init_samples` times. If it is a
tensor, the tensor will be used as init locations.
num_init_samples: Draw this number of samples from the posterior and
evaluate the log-probability of all of them.
num_to_optimize: From the drawn `num_init_samples`, use the
`num_to_optimize` with highest log-probability as the initial points
for the optimization.
save_best_every: The best log-probability is computed, saved in the
`map`-attribute, and printed every `save_best_every`-th iteration.
Computing the best log-probability creates a significant overhead
(thus, the default is `10`.)
show_progress_bars: Whether to show a progressbar during sampling from the
posterior.
force_update: Whether to re-calculate the MAP when x is unchanged and
have a cached value.
log_prob_kwargs: Will be empty for SNLE and SNRE. Will contain
{'norm_posterior': True} for SNPE.
Returns:
The MAP estimate.
"""
return super().map(
x=x,
num_iter=num_iter,
num_to_optimize=num_to_optimize,
learning_rate=learning_rate,
init_method=init_method,
num_init_samples=num_init_samples,
save_best_every=save_best_every,
show_progress_bars=show_progress_bars,
force_update=force_update,
)