# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
from __future__ import annotations
from typing import Optional, Tuple, Union
import torch
from torch import Tensor
from torch.distributions import Distribution
from sbi.inference.potentials.base_potential import BasePotential
from sbi.neural_nets.estimators import ConditionalDensityEstimator
from sbi.neural_nets.estimators.shape_handling import (
reshape_to_batch_event,
reshape_to_sample_batch_event,
)
from sbi.sbi_types import TorchTransform
from sbi.utils.sbiutils import (
mcmc_transform,
within_support,
)
from sbi.utils.torchutils import ensure_theta_batched
[docs]
def posterior_estimator_based_potential(
posterior_estimator: ConditionalDensityEstimator,
prior: Distribution,
x_o: Optional[Tensor],
enable_transform: bool = True,
) -> Tuple[PosteriorBasedPotential, TorchTransform]:
r"""Returns the potential for posterior-based methods.
It also returns a transformation that can be used to transform the potential into
unconstrained space.
The potential is the same as the log-probability of the `posterior_estimator`, but
it is set to $-\inf$ outside of the prior bounds.
Args:
posterior_estimator: The neural network modelling the posterior.
prior: The prior distribution.
x_o: The observed data at which to evaluate the posterior.
enable_transform: Whether to transform parameters to unconstrained space.
When False, an identity transform will be returned for `theta_transform`.
Returns:
The potential function and a transformation that maps
to unconstrained space.
"""
device = str(next(posterior_estimator.parameters()).device)
potential_fn = PosteriorBasedPotential(
posterior_estimator, prior, x_o, device=device
)
theta_transform = mcmc_transform(
prior, device=device, enable_transform=enable_transform
)
return potential_fn, theta_transform
class PosteriorBasedPotential(BasePotential):
def __init__(
self,
posterior_estimator: ConditionalDensityEstimator,
prior: Distribution, # type: ignore
x_o: Optional[Tensor] = None,
device: Union[str, torch.device] = "cpu",
):
r"""Returns the potential for posterior-based methods.
The potential is the same as the log-probability of the `posterior_estimator`,
but it is set to $-\inf$ outside of the prior bounds.
Args:
posterior_estimator: The neural network modelling the posterior.
prior: The prior distribution.
x_o: The observed data at which to evaluate the posterior.
Returns:
The potential function.
"""
super().__init__(prior, x_o, device)
self.posterior_estimator = posterior_estimator
self.posterior_estimator.eval()
def to(self, device: Union[str, torch.device]) -> "PosteriorBasedPotential":
"""Move posterior estimator, prior and x_o to the given device.
Args:
device: Device to move the posterior_estimator, prior and x_o to.
Returns:
Self for method chaining.
"""
super().to(device)
self.posterior_estimator.to(device)
return self
def set_x(self, x_o: Optional[Tensor], x_is_iid: Optional[bool] = False):
"""
Check the shape of the observed data and, if valid, set it.
"""
super().set_x(x_o, x_is_iid=x_is_iid)
def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor:
r"""Returns the potential for posterior-based methods.
Args:
theta: The parameter set at which to evaluate the potential function.
track_gradients: Whether to track the gradients.
Returns:
The potential.
"""
if self._x_o is None:
raise ValueError(
"No observed data x_o is available. Please reinitialize \
the potential or manually set self._x_o."
)
with torch.set_grad_enabled(track_gradients):
# Force probability to be zero outside prior support.
in_prior_support = within_support(self.prior, theta)
theta = ensure_theta_batched(torch.as_tensor(theta)).to(self.device)
if self.x_is_iid and self.x_o.shape[0] > 1:
if self.prior is None:
raise ValueError(
"A proper prior is required for evaluating the "
"posterior potential with iid observations."
)
num_iid = self.x_o.shape[0]
theta_sbe = reshape_to_sample_batch_event(
theta, event_shape=theta.shape[1:], leading_is_sample=True
)
x_iid = reshape_to_batch_event(
self.x_o,
event_shape=self.posterior_estimator.condition_shape,
)
theta_expanded = theta_sbe.expand(-1, num_iid, *theta_sbe.shape[2:])
iid_log_probs = self.posterior_estimator.log_prob(
theta_expanded, condition=x_iid
)
posterior_log_prob = iid_log_probs.sum(dim=1)
posterior_log_prob = posterior_log_prob - (
num_iid - 1
) * self.prior.log_prob(theta)
else:
x = reshape_to_batch_event(
self.x_o,
event_shape=self.posterior_estimator.condition_shape,
)
theta_batch_size = theta.shape[0]
x_batch_size = x.shape[0]
assert theta_batch_size == x_batch_size or x_batch_size == 1, (
f"Batch size mismatch: {theta_batch_size} and {x_batch_size}.\
When performing batched sampling for multiple `x`, the batch size\
of `theta` must match the batch size of `x`."
)
if x_batch_size == 1:
# If a single `x` is passed (i.e. batchsize==1), we squeeze
# the batch dimension of the log-prob with `.squeeze(dim=1)`.
theta = reshape_to_sample_batch_event(
theta, event_shape=theta.shape[1:], leading_is_sample=True
)
posterior_log_prob = self.posterior_estimator.log_prob(
theta, condition=x
)
posterior_log_prob = posterior_log_prob.squeeze(1)
else:
# If multiple `x` are passed, we return the log-probs for each
# (x,theta) pair, and do not squeeze the batch dimension.
theta = theta.unsqueeze(0)
posterior_log_prob = self.posterior_estimator.log_prob(
theta, condition=x
)
posterior_log_prob = torch.where(
in_prior_support,
posterior_log_prob,
torch.tensor(float("-inf"), dtype=torch.float32, device=self.device),
)
return posterior_log_prob