sbi.utils.process_prior#
- process_prior(prior, custom_prior_wrapper_kwargs=None)[source]#
Return PyTorch distribution-like prior from user-provided prior.
NOTE: If the prior argument is a sequence of PyTorch distributions, they will be interpreted as independent prior dimensions wrapped in a
MultipleIndependentPyTorch Distribution. In case the elements are not PyTorch distributions, make sure to useprocess_prior()on each element in the list beforehand.NOTE: Returns a tuple (processed_prior, num_params, whether_prior_returns_numpy). The last two entries in the tuple can be passed on to process_simulator to prepare the simulator. For example, it ensures parameters are cast to numpy or adds a batch dimension to the simulator output, if needed.
- Parameters:
prior (
Distributionor Sequence[Distribution]) – Prior object with .sample() and .log_prob(), or a sequence of such objects.custom_prior_wrapper_kwargs (dict, optional) – Additional arguments passed to the wrapper class that processes the prior into a PyTorch Distribution, such as bounds (lower_bound, upper_bound) or argument constraints (arg_constraints).
- Raises:
AttributeError – If prior objects lack .sample() or .log_prob().
- Returns:
prior: A PyTorch-compatible prior.
theta_numel: Dimensionality of a single sample from the prior.
prior_returns_numpy: Whether the prior originally returned NumPy arrays.
- Return type:
Example:#
import torch from torch.distributions import Uniform from sbi.utils.user_input_checks import process_prior prior = Uniform(torch.zeros(1), torch.ones(1)) prior, theta_numel, prior_returns_numpy = process_prior(prior)