sbi.utils.process_prior

Navigation

sbi.utils.process_prior#

process_prior(prior, custom_prior_wrapper_kwargs=None)[source]#

Return PyTorch distribution-like prior from user-provided prior.

NOTE: If the prior argument is a sequence of PyTorch distributions, they will be interpreted as independent prior dimensions wrapped in a MultipleIndependent PyTorch Distribution. In case the elements are not PyTorch distributions, make sure to use process_prior() on each element in the list beforehand.

NOTE: Returns a tuple (processed_prior, num_params, whether_prior_returns_numpy). The last two entries in the tuple can be passed on to process_simulator to prepare the simulator. For example, it ensures parameters are cast to numpy or adds a batch dimension to the simulator output, if needed.

Parameters:
  • prior (Distribution or Sequence[Distribution]) – Prior object with .sample() and .log_prob(), or a sequence of such objects.

  • custom_prior_wrapper_kwargs (dict, optional) – Additional arguments passed to the wrapper class that processes the prior into a PyTorch Distribution, such as bounds (lower_bound, upper_bound) or argument constraints (arg_constraints).

Raises:

AttributeError – If prior objects lack .sample() or .log_prob().

Returns:

  • prior: A PyTorch-compatible prior.

  • theta_numel: Dimensionality of a single sample from the prior.

  • prior_returns_numpy: Whether the prior originally returned NumPy arrays.

Return type:

Tuple[torch.distributions.Distribution, int, bool]

Example:#

import torch
from torch.distributions import Uniform
from sbi.utils.user_input_checks import process_prior

prior = Uniform(torch.zeros(1), torch.ones(1))
prior, theta_numel, prior_returns_numpy = process_prior(prior)