# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
import warnings
from typing import Dict, Optional, Sequence, Union
import torch
from torch import Tensor, float32
from torch.distributions import (
Bernoulli,
Binomial,
Categorical,
Distribution,
Multinomial,
MultivariateNormal,
constraints,
)
from sbi.utils.torchutils import move_all_tensor_to_device, process_device
def get_distribution_parameters(
dist: Distribution, device: Union[str, torch.device]
) -> Dict:
"""Used to get the tensors of the parameters in torch distributions.
Returns the tensors relocated to device.
"""
params = {param: getattr(dist, param).to(device) for param in dist.arg_constraints}
# MultivariateNormal calculates precision matrix from covariance, and stores it in
# the arg_constraints. When reinstantiating, we must provide only one of them.
if isinstance(dist, MultivariateNormal):
params["precision_matrix"] = None
params["scale_tril"] = None
# MultivariateNormal calculates logits from probabilities, and stores it in the
# arg_constraints. When reinstantiating, we must provide only one of them.
elif isinstance(dist, (Binomial, Bernoulli, Categorical, Multinomial)):
params["logits"] = None
return params
def move_distribution_to_device(
dist: Distribution, device: Union[str, torch.device]
) -> Distribution:
"""Move a distribution to the specified device.
If the distribution has a `.to()` method (e.g. sbi prior wrappers like
``PytorchReturnTypeWrapper`` or ``BoxUniform``), it is called directly.
Otherwise, the distribution is reconstructed on the target device by
extracting its parameters via ``get_distribution_parameters()``.
Args:
dist: The distribution to move.
device: The target device.
Returns:
The distribution on the target device. If `.to()` was used, this is the
same object; otherwise it is a newly constructed instance.
"""
if hasattr(dist, "to"):
dist.to(device) # type: ignore
return dist
else:
try:
params = get_distribution_parameters(dist, device)
return type(dist)(**params)
except Exception:
move_all_tensor_to_device(dist, device)
return dist
class CustomPriorWrapper(Distribution):
def __init__(
self,
custom_prior,
return_type: Optional[torch.dtype] = float32,
batch_shape=torch.Size(),
event_shape=torch.Size(),
validate_args=None,
arg_constraints: Optional[Dict[str, constraints.Constraint]] = None,
lower_bound: Optional[Tensor] = None,
upper_bound: Optional[Tensor] = None,
):
self.custom_arg_constraints = arg_constraints or {}
self.custom_prior = custom_prior
self.return_type = return_type
self.custom_support = build_support(lower_bound, upper_bound)
self._set_mean_and_variance()
super().__init__(
batch_shape=batch_shape,
event_shape=event_shape,
validate_args=validate_args,
)
def log_prob(self, value) -> Tensor:
return torch.as_tensor(
self.custom_prior.log_prob(value),
dtype=self.return_type, # type: ignore
)
def sample(self, sample_shape=torch.Size()) -> Tensor:
return torch.as_tensor(
self.custom_prior.sample(sample_shape),
dtype=self.return_type, # type: ignore
)
@property
def arg_constraints(self) -> Dict[str, constraints.Constraint]:
return self.custom_arg_constraints
@property
def support(self) -> constraints.Constraint:
return self.custom_support
def _set_mean_and_variance(self):
"""Set mean and variance if available, else estimate from samples."""
if hasattr(self.custom_prior, "mean"):
pass
else:
self.custom_prior.mean = torch.mean(
torch.as_tensor(self.custom_prior.sample((1000,))), dim=0
)
warnings.warn(
"Prior is lacking mean attribute, estimating prior mean from samples.",
UserWarning,
stacklevel=2,
)
if hasattr(self.custom_prior, "variance"):
pass
else:
self.custom_prior.variance = (
torch.std(torch.as_tensor(self.custom_prior.sample((1000,))), dim=0)
** 2
)
warnings.warn(
"Prior is lacking variance attribute, estimating prior variance from "
"samples.",
UserWarning,
stacklevel=2,
)
def to(self, device: Union[str, torch.device]) -> None:
"""
Move the distribution to the specified device. Not implemented for this class.
Raises:
NotImplementedError.
"""
raise NotImplementedError(
"This class is not supported on the GPU. Use on cpu or use "
"any of `PytorchReturnTypeWrapper`, `BoxUniform`, or `MultipleIndependent`."
)
@property
def mean(self):
return torch.as_tensor(
self.custom_prior.mean,
dtype=self.return_type, # type: ignore
)
@property
def variance(self):
return torch.as_tensor(
self.custom_prior.variance,
dtype=self.return_type, # type: ignore
)
class PytorchReturnTypeWrapper(Distribution):
"""Wrap PyTorch Distribution to return a given return type."""
def __init__(
self,
prior: Distribution, # type: ignore
return_type: Optional[torch.dtype] = float32,
batch_shape=torch.Size(),
event_shape=torch.Size(),
validate_args=None,
):
super().__init__(
batch_shape=batch_shape,
event_shape=event_shape,
validate_args=(
prior._validate_args if validate_args is None else validate_args
),
)
self.prior = prior
self.device = None
self.return_type = return_type
def log_prob(self, value) -> Tensor:
return torch.as_tensor(
self.prior.log_prob(value),
dtype=self.return_type, # type: ignore
)
def sample(self, sample_shape=torch.Size()) -> Tensor:
return torch.as_tensor(
self.prior.sample(sample_shape),
dtype=self.return_type, # type: ignore
)
@property
def mean(self):
return torch.as_tensor(
self.prior.mean,
dtype=self.return_type, # type: ignore
)
@property
def variance(self):
return torch.as_tensor(
self.prior.variance,
dtype=self.return_type, # type: ignore
)
@property
def support(self):
return self.prior.support
def to(self, device: Union[str, torch.device]) -> None:
"""
Move the distribution to the specified device.
Moves the distribution parameters to the specific device
and updates the device attribute.
Args:
device: device to move the distribution to.
"""
self.prior = move_distribution_to_device(self.prior, device)
self.device = device
[docs]
class MultipleIndependent(Distribution):
"""Wrap a sequence of PyTorch distributions into a joint PyTorch distribution."""
def __init__(
self,
dists: list[Distribution],
validate_args: Optional[bool] = None,
arg_constraints: Optional[Dict[str, constraints.Constraint]] = None,
device: Optional[str] = None,
):
"""Joint distribution of multiple independent :class:`torch.distributions`.
Every element of the sequence is treated as independent from the \
other elements. Single elements can be multivariate with dependent dimensions.
Args:
dists: Sequence of PyTorch distributions.
validate_args (Optional): If True, the distribution checks its parameters.
arg_constraints (Optional): Dictionary of constraints for the parameters \
of the distribution.
device (Optional): Device to move the distribution to. If None, \
the distribution is moved to the CPU.
Example:
--------
::
import torch
from torch.distributions import Gamma, Beta, MultivariateNormal
from sbi.utils.user_input_checks_utils import MultipleIndependent
prior = MultipleIndependent([
Gamma(torch.zeros(1), torch.ones(1)),
Beta(torch.zeros(1), torch.ones(1)),
MultivariateNormal(torch.ones(2), torch.tensor([[1, .1], [.1, 1.]]))
])
"""
self._check_distributions(dists)
if validate_args is not None:
[d.set_default_validate_args(validate_args) for d in dists]
self.dists = dists
self.device = process_device(device or "cpu")
self.to(self.device)
# numel() instead of event_shape because for all dists both is possible,
# event_shape=[1] or batch_shape=[1]
self.dims_per_dist = [d.sample().numel() for d in self.dists]
self.ndims = int(torch.sum(torch.as_tensor(self.dims_per_dist)).item())
self.custom_arg_constraints = arg_constraints or {}
self.validate_args = validate_args
super().__init__(
batch_shape=torch.Size([]), # batch size was ensured to be <= 1 above.
event_shape=torch.Size([
self.ndims
]), # Event shape is the sum of all ndims.
validate_args=validate_args,
)
@property
def arg_constraints(self) -> Dict[str, constraints.Constraint]:
"""Return argument constraints."""
return self.custom_arg_constraints
def _check_distributions(self, dists):
"""Check if dists is Sequence and longer 1 and check every member."""
assert isinstance(
dists, Sequence
), f"""The combination of independent priors must be of type Sequence, is
{type(dists)}."""
assert len(dists) > 1, "Provide at least 2 distributions to combine."
# Check every element of the sequence.
[self._check_distribution(d) for d in dists]
def _check_distribution(self, dist: Distribution):
"""Check type and shape of a single input distribution."""
assert not isinstance(dist, (MultipleIndependent, Sequence)), (
"Nesting of combined distributions is not possible."
)
assert isinstance(
dist, Distribution
), """priors passed to MultipleIndependent must be PyTorch distributions. Make \
sure to process custom priors individually using :func:`process_prior` \
before passing them in a list to :func:`process_prior`."""
# Make sure batch shape is smaller or equal to 1.
assert dist.batch_shape in (
torch.Size([1]),
torch.Size([0]),
torch.Size([]),
), "The batch shape of every distribution must be smaller or equal to 1."
assert (
len(dist.batch_shape) > 0 or len(dist.event_shape) > 0
), """One of the distributions you passed is defined over a scalar only. Make
sure pass distributions with one of event_shape or batch_shape > 0: For example
- instead of Uniform(0.0, 1.0) pass Uniform(torch.zeros(1), torch.ones(1))
- instead of Beta(1.0, 2.0) pass Beta(tensor([1.0]), tensor([2.0])).
"""
[docs]
def sample(self, sample_shape=torch.Size()) -> Tensor:
# Sample from every sub distribution and concatenate samples.
sample = torch.cat([d.sample(sample_shape) for d in self.dists], dim=-1)
# This reshape is needed to cover the case .sample() vs. .sample((n, )).
if sample_shape == torch.Size():
sample = sample.reshape(self.ndims)
else:
sample = sample.reshape(-1, self.ndims)
return sample
[docs]
def log_prob(self, value) -> Tensor:
value = self._prepare_value(value)
# Evaluate value per distribution, taking into account that individual
# distributions can be multivariate.
num_samples = value.shape[0]
log_probs = []
dims_covered = 0
for idx, d in enumerate(self.dists):
ndims = int(self.dims_per_dist[idx])
v = value[:, dims_covered : dims_covered + ndims]
# Reshape here to ensure all returned log_probs are 2D for concatenation.
log_probs.append(d.log_prob(v).reshape(num_samples, 1))
dims_covered += ndims
# Sum accross last dimension to get joint log prob over all distributions.
return torch.cat(log_probs, dim=1).sum(-1)
def _prepare_value(self, value) -> Tensor:
"""Return input value with fixed shape.
Raises:
AssertionError: if value has more than 2 dimensions or invalid size in
2nd dimension.
"""
if value.ndim < 2:
value = value.unsqueeze(0)
assert value.ndim == 2, (
f"value in log_prob must have ndim <= 2, it is {value.ndim}."
)
batch_shape, num_value_dims = value.shape
assert num_value_dims == self.ndims, (
f"Number of dimensions must match dimensions of this joint: {self.ndims}."
)
return value
@property
def mean(self) -> Tensor:
return torch.cat([d.mean for d in self.dists])
@property
def variance(self) -> Tensor:
return torch.cat([d.variance for d in self.dists])
@property
def support(self):
# First, we remove all `independent` constraints. This applies to e.g.
# `MultivariateNormal`. An `independent` constraint returns a 1D `[True]`
# when `.support.check(sample)` is called, whereas distributions that are
# not `independent` (e.g. `Gamma`), return a 2D `[[True]]`. When such
# constraints would be combined with the `constraint.cat(..., dim=1)`, it
# fails because the `independent` constraint returned only a 1D `[True]`.
supports = []
for d in self.dists:
if isinstance(d.support, constraints.independent):
supports.append(d.support.base_constraint)
else:
supports.append(d.support)
# Wrap as `independent` in order to have the correct shape of the
# `log_abs_det`, i.e. summed over the parameter dimensions.
return constraints.independent(
constraints.cat(supports, dim=-1, lengths=self.dims_per_dist),
reinterpreted_batch_ndims=1,
)
[docs]
def to(self, device: Union[str, torch.device]) -> None:
"""Move the distribution to the specified device.
If the distribution has the `to` method, it is used. Otherwise, the
parameters of the distribution are moved to the specified device.
Args:
device: device to move the distribution to.
"""
for i in range(len(self.dists)):
self.dists[i] = move_distribution_to_device(self.dists[i], device)
self.device = device
def build_support(
lower_bound: Optional[Tensor] = None, upper_bound: Optional[Tensor] = None
) -> constraints.Constraint:
"""Return support for prior distribution, depending on available bounds.
Args:
lower_bound: lower bound of the prior support, can be None
upper_bound: upper bound of the prior support, can be None
Returns:
support: Pytorch constraint object.
"""
# Support is real if no bounds are passed.
if lower_bound is None and upper_bound is None:
support = constraints.real
warnings.warn(
"No prior bounds were passed, consider passing lower_bound "
"and / or upper_bound if your prior has bounded support.",
stacklevel=2,
)
# Only lower bound is specified.
elif upper_bound is None:
num_dimensions = lower_bound.numel() # type: ignore
if num_dimensions > 1:
support = constraints._IndependentConstraint(
constraints.greater_than(lower_bound),
1,
)
else:
support = constraints.greater_than(lower_bound)
# Only upper bound is specified.
elif lower_bound is None:
num_dimensions = upper_bound.numel()
if num_dimensions > 1:
support = constraints._IndependentConstraint(
constraints.less_than(upper_bound),
1,
)
else:
support = constraints.less_than(upper_bound)
# Both are specified.
else:
num_dimensions = lower_bound.numel()
assert num_dimensions == upper_bound.numel(), (
"There must be an equal number of independent bounds."
)
if num_dimensions > 1:
support = constraints._IndependentConstraint(
constraints.interval(lower_bound, upper_bound),
1,
)
else:
support = constraints.interval(lower_bound, upper_bound)
return support
class OneDimPriorWrapper(Distribution):
"""Wrap batched 1D distributions to get rid of the batch dim of `.log_prob()`.
1D pytorch distributions such as `torch.distributions.Exponential`, `.Uniform`, or
`.Normal` do not, by default return __any__ sample or batch dimension. E.g.:
```python
dist = torch.distributions.Exponential(torch.tensor(3.0))
dist.sample((10,)).shape # (10,)
```
`sbi` will raise an error that the sample dimension is missing. A simple solution is
to add a batch dimension to `dist` as follows:
```python
dist = torch.distributions.Exponential(torch.tensor([3.0]))
dist.sample((10,)).shape # (10, 1)
```
Unfortunately, this `dist` will return the batch dimension also for `.log_prob():
```python
dist = torch.distributions.Exponential(torch.tensor([3.0]))
samples = dist.sample((10,))
dist.log_prob(samples).shape # (10, 1)
```
This will lead to unexpected errors in `sbi`. The point of this class is to wrap
those batched 1D distributions to get rid of their batch dimension in `.log_prob()`.
"""
def __init__(self, prior: Distribution, validate_args=None) -> None:
super().__init__(
batch_shape=prior.batch_shape,
event_shape=prior.event_shape,
validate_args=(
prior._validate_args if validate_args is None else validate_args
),
)
self.prior = prior
self.device = None
def to(self, device: Union[str, torch.device]) -> None:
"""
Move the distribution to the specified device.
Moves the distribution parameters to the specific device
and updates the device attribute.
Args:
device: device to move the distribution to.
"""
self.prior = move_distribution_to_device(self.prior, device)
self.device = device
def sample(self, *args, **kwargs) -> Tensor:
return self.prior.sample(*args, **kwargs)
def log_prob(self, *args, **kwargs) -> Tensor:
"""Override the log_prob method to get rid of the additional batch dimension."""
return self.prior.log_prob(*args, **kwargs)[..., 0]
@property
def arg_constraints(self) -> Dict[str, constraints.Constraint]:
return self.prior.arg_constraints
@property
def support(self):
return self.prior.support
@property
def mean(self) -> Tensor:
return self.prior.mean
@property
def variance(self) -> Tensor:
return self.prior.variance