# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
import warnings
from typing import Callable, List, Optional, Tuple, Union
import torch
from torch import Tensor
from torch.distributions import Distribution
from sbi.inference.potentials.base_potential import BasePotential
from sbi.neural_nets.estimators import (
ConditionalDensityEstimator,
MixedDensityEstimator,
)
from sbi.neural_nets.estimators.shape_handling import (
reshape_to_batch_event,
reshape_to_sample_batch_event,
)
from sbi.sbi_types import TorchTransform
from sbi.utils.sbiutils import mcmc_transform
[docs]
def likelihood_estimator_based_potential(
likelihood_estimator: ConditionalDensityEstimator,
prior: Distribution, # type: ignore
x_o: Optional[Tensor],
enable_transform: bool = True,
) -> Tuple["LikelihoodBasedPotential", TorchTransform]:
r"""Returns potential :math:`\log(p(x_o|\theta)p(\theta))` for likelihood estimator.
It also returns a transformation that can be used to transform the potential into
unconstrained space.
Args:
likelihood_estimator: The density estimator modelling the likelihood.
prior: The prior distribution.
x_o: The observed data at which to evaluate the likelihood.
enable_transform: Whether to transform parameters to unconstrained space.
When False, an identity transform will be returned for `theta_transform`.
Returns:
The potential function $p(x_o|\theta)p(\theta)$ and a transformation that maps
to unconstrained space.
"""
device = str(next(likelihood_estimator.parameters()).device)
potential_fn = LikelihoodBasedPotential(
likelihood_estimator, prior, x_o, device=device
)
theta_transform = mcmc_transform(
prior, device=device, enable_transform=enable_transform
)
return potential_fn, theta_transform
class LikelihoodBasedPotential(BasePotential):
def __init__(
self,
likelihood_estimator: ConditionalDensityEstimator,
prior: Distribution,
x_o: Optional[Tensor] = None,
device: str = "cpu",
):
r"""Returns the potential function for likelihood-based methods.
Args:
likelihood_estimator: The density estimator modelling the likelihood.
prior: The prior distribution.
x_o: The observed data at which to evaluate the likelihood.
device: The device to which parameters and data are moved before evaluating
the `likelihood_nn`.
Returns:
The potential function $p(x_o|\theta)p(\theta)$.
"""
super().__init__(prior, x_o, device)
self.likelihood_estimator = likelihood_estimator
self.likelihood_estimator.eval()
def to(self, device: Union[str, torch.device]) -> "LikelihoodBasedPotential":
"""Move likelihood estimator, prior and x_o to the given device.
Args:
device: Device to move the likelihood_estimator, prior and x_o to.
Returns:
Self for method chaining.
"""
super().to(device)
self.likelihood_estimator.to(device)
return self
def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor:
r"""Returns the potential $\log(p(x_o|\theta)p(\theta))$.
Args:
theta: The parameter set at which to evaluate the potential function.
track_gradients: Whether to track the gradients.
Returns:
The potential $\log(p(x_o|\theta)p(\theta))$.
"""
if self.x_is_iid:
# For each theta, calculate the likelihood sum over all x in batch.
log_likelihood_trial_sum = _log_likelihoods_over_trials(
x=self.x_o,
theta=theta.to(self.device),
estimator=self.likelihood_estimator,
track_gradients=track_gradients,
)
return log_likelihood_trial_sum + self.prior.log_prob(theta) # type: ignore
else:
# Calculate likelihood for each (theta,x) pair separately
theta_batch_size = theta.shape[0]
x_batch_size = self.x_o.shape[0]
assert theta_batch_size == x_batch_size, (
f"Batch size mismatch: {theta_batch_size} and {x_batch_size}.\
When performing batched sampling for multiple `x`, the batch size of\
`theta` must match the batch size of `x`."
)
x = self.x_o.unsqueeze(0)
with torch.set_grad_enabled(track_gradients):
log_likelihood_batches = self.likelihood_estimator.log_prob(
x, condition=theta
)
return log_likelihood_batches + self.prior.log_prob(theta) # type: ignore
def condition_on_theta(
self, local_theta: Tensor, dims_global_theta: List[int]
) -> Callable:
r"""Returns a potential function conditioned on a subset of theta dimensions.
The goal of this function is to divide the original `theta` into a
`global_theta` we do inference over, and a `local_theta` we condition on (in
addition to conditioning on `x_o`). Thus, the returned potential function will
calculate $\prod_{i=1}^{N}p(x_i | local_theta_i, \global_theta)$, where `x_i`
and `local_theta_i` are fixed and `global_theta` varies at inference time.
Args:
local_theta: The condition values to be conditioned.
dims_global_theta: The indices of the columns in `theta` that will be
sampled, i.e., that *not* conditioned. For example, if original theta
has shape `(batch_dim, 3)`, and `dims_global_theta=[0, 1]`, then the
potential will set `theta[:, 3] = local_theta` at inference time.
Returns:
A potential function conditioned on the `local_theta`.
"""
assert self.x_is_iid, "Conditioning is only supported for iid data."
def conditioned_potential(
theta: Tensor, x_o: Optional[Tensor] = None, track_gradients: bool = True
) -> Tensor:
assert len(dims_global_theta) == theta.shape[-1], (
"dims_global_theta must match the number of parameters to sample."
)
if theta.dim() > 2:
assert theta.shape[0] == 1, (
"condition_on_theta does not support sample shape for theta."
)
theta = theta.squeeze(0)
global_theta = theta[:, dims_global_theta]
x_o = x_o if x_o is not None else self.x_o
# x needs shape (sample_dim (iid), batch_dim (xs), *event_shape)
if x_o.dim() < 3:
x_o = reshape_to_sample_batch_event(
x_o, event_shape=x_o.shape[1:], leading_is_sample=self.x_is_iid
)
return _log_likelihood_over_iid_trials_and_local_theta(
x=x_o.to(self.device),
global_theta=global_theta.to(self.device),
local_theta=local_theta.to(self.device),
estimator=self.likelihood_estimator,
track_gradients=track_gradients,
)
return conditioned_potential
def _log_likelihoods_over_trials(
x: Tensor,
theta: Tensor,
estimator: ConditionalDensityEstimator,
track_gradients: bool = False,
) -> Tensor:
r"""Return log likelihoods summed over iid trials of `x`.
Note: `x` can be a batch with batch size larger 1. Batches in `x` are assumed
to be iid trials, i.e., data generated based on the same paramters /
experimental conditions.
Repeats `x` and $\theta$ to cover all their combinations of batch entries.
Args:
x: Batch of iid data of shape `(iid_dim, *event_shape)`.
theta: Batch of parameters of shape `(batch_dim, *event_shape)`.
estimator: DensityEstimator.
track_gradients: Whether to track gradients.
Returns:
log_likelihood_trial_sum: log likelihood for each parameter, summed over all
batch entries (iid trials) in `x`.
"""
# Shape of `x` is (iid_dim, *event_shape).
x = reshape_to_sample_batch_event(
x, event_shape=x.shape[1:], leading_is_sample=True
)
# Match the number of `x` to the number of conditions (`theta`). This is important
# if the potential is simulataneously evaluated at multiple `theta` (e.g.
# multi-chain MCMC).
theta_batch_size = theta.shape[0]
trailing_minus_ones = [-1 for _ in range(x.dim() - 2)]
x = x.expand(-1, theta_batch_size, *trailing_minus_ones)
assert (
next(estimator.parameters()).device == x.device and x.device == theta.device
), f"""device mismatch: estimator, x, theta: \
{next(estimator.parameters()).device}, {x.device},
{theta.device}."""
# Shape of `theta` is (batch_dim, *event_shape). Therefore, the call below should
# not change anything, and we just have it as "best practice" before calling
# `DensityEstimator.log_prob`.
theta = reshape_to_batch_event(theta, event_shape=theta.shape[1:])
# Calculate likelihood in one batch.
with torch.set_grad_enabled(track_gradients):
log_likelihood_trial_batch = estimator.log_prob(x, condition=theta)
# Sum over trial-log likelihoods.
log_likelihood_trial_sum = log_likelihood_trial_batch.sum(0)
return log_likelihood_trial_sum
def _log_likelihood_over_iid_trials_and_local_theta(
x: Tensor,
global_theta: Tensor,
local_theta: Tensor,
estimator: ConditionalDensityEstimator,
track_gradients: bool = False,
) -> Tensor:
"""Returns $\\prod_{i=1}^N \\log(p(x_i|\theta, local_theta_i)$.
`x` is a batch of iid data, and `local_theta` is a matching batch of condition
values that were part of `theta` but are treated as local iid variables at inference
time.
This function is different from `_log_likelihoods_over_trials` in that it moves the
iid batch dimension of `x` onto the batch dimension of `theta`. This is needed when
the likelihood estimator is conditioned on a batch of conditions that are iid with
the batch of `x`. It avoids the evaluation of the likelihood for every combination
of `x` and `local_theta`.
Args:
x: data with shape `(sample_dim, x_batch_dim, *x_event_shape)`, where sample_dim
holds the i.i.d. trials and batch_dim holds a batch of xs, e.g., non-iid
observations.
global_theta: Batch of parameters `(theta_batch_dim,
num_parameters)`.
local_theta: Batch of conditions of shape `(sample_dim, num_local_thetas)`, must
match x's `sample_dim`.
estimator: DensityEstimator.
track_gradients: Whether to track gradients.
Returns:
log_likelihood: log likelihood for each x in x_batch_dim, for each theta in
theta_batch_dim, summed over all iid trials. Shape `(x_batch_dim,
theta_batch_dim)`.
"""
assert x.dim() > 2, "x must have shape (sample_dim, batch_dim, *event_shape)."
assert local_theta.dim() == 2, (
"condition must have shape (sample_dim, num_conditions)."
)
assert global_theta.dim() == 2, "theta must have shape (batch_dim, num_parameters)."
num_trials, num_xs = x.shape[:2]
num_thetas = global_theta.shape[0]
assert local_theta.shape[0] == num_trials, (
"Condition batch size must match the number of iid trials in x."
)
if num_xs > 1:
raise NotImplementedError(
"Batched sampling for multiple `x` is not supported for iid conditions."
)
# move the iid batch dimension onto the batch dimension of theta and repeat it there
x_repeated = torch.transpose(x, 0, 1).repeat_interleave(num_thetas, dim=1)
# construct theta and condition to cover all trial-theta combinations
theta_with_condition = torch.cat(
[
global_theta.repeat(num_trials, 1), # repeat ABAB
local_theta.repeat_interleave(num_thetas, dim=0), # repeat AABB
],
dim=-1,
)
with torch.set_grad_enabled(track_gradients):
# Calculate likelihood in one batch. Returns (1, num_trials * num_theta)
log_likelihood_trial_batch = estimator.log_prob(
x_repeated, condition=theta_with_condition
)
# Reshape to (x-trials x parameters), sum over trial-log likelihoods.
log_likelihood_trial_sum = log_likelihood_trial_batch.reshape(
num_xs, num_trials, num_thetas
).sum(1)
# remove xs batch dimension
return log_likelihood_trial_sum.squeeze(0)
[docs]
def mixed_likelihood_estimator_based_potential(
likelihood_estimator: MixedDensityEstimator,
prior: Distribution,
x_o: Optional[Tensor],
) -> Tuple[Callable, TorchTransform]:
r"""Returns $\log(p(x_o|\theta)p(\theta))$ for mixed likelihood-based methods.
It also returns a transformation that can be used to transform the potential into
unconstrained space.
Args:
likelihood_estimator: The neural network modelling the likelihood.
prior: The prior distribution.
x_o: The observed data at which to evaluate the likelihood.
Returns:
The potential function $p(x_o|\theta)p(\theta)$ and a transformation that maps
to unconstrained space.
"""
warnings.warn(
"This function is deprecated and will be removed in a future release. Use "
"`likelihood_estimator_based_potential` instead.",
DeprecationWarning,
stacklevel=2,
)
device = str(next(likelihood_estimator.discrete_net.parameters()).device)
potential_fn = MixedLikelihoodBasedPotential(
likelihood_estimator, prior, x_o, device=device
)
theta_transform = mcmc_transform(prior, device=device)
return potential_fn, theta_transform
class MixedLikelihoodBasedPotential(LikelihoodBasedPotential):
def __init__(
self,
likelihood_estimator: MixedDensityEstimator,
prior: Distribution,
x_o: Optional[Tensor],
device: str = "cpu",
):
super().__init__(likelihood_estimator, prior, x_o, device)
warnings.warn(
"This function is deprecated and will be removed in a future release. Use "
"`LikelihoodBasedPotential` instead.",
DeprecationWarning,
stacklevel=2,
)
def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor:
prior_log_prob = self.prior.log_prob(theta) # type: ignore
# Shape of `x` is (iid_dim, *event_shape)
theta = reshape_to_batch_event(theta, event_shape=theta.shape[1:])
x = reshape_to_sample_batch_event(
self.x_o, event_shape=self.x_o.shape[1:], leading_is_sample=True
)
theta_batch_dim = theta.shape[0]
# Match the number of `x` to the number of conditions (`theta`). This is
# importantif the potential is simulataneously evaluated at multiple `theta`
# (e.g. multi-chain MCMC).
trailing_minus_ones = [-1 for _ in range(x.dim() - 2)]
x = x.expand(-1, theta_batch_dim, *trailing_minus_ones)
# Calculate likelihood in one batch.
with torch.set_grad_enabled(track_gradients):
# Call the specific log prob method of the mixed likelihood estimator as
# this optimizes the evaluation of the discrete data part.
log_likelihood_trial_batch = self.likelihood_estimator.log_prob(
input=x,
condition=theta.to(self.device),
)
# Reshape to (x-trials x parameters), sum over trial-log likelihoods.
log_likelihood_trial_sum = log_likelihood_trial_batch.reshape(
self.x_o.shape[0], -1
).sum(0)
return log_likelihood_trial_sum + prior_log_prob