How to specify a custom prior (e.g., multiple independent priors)

How to specify a custom prior (e.g., multiple independent priors)#

sbi works with torch distributions only, so we recommend using those whenever possible. If you want a different prior in each parameter dimension, then you can use the sbi utility MultipleIndependent:

from torch.distributions import MultivariateNormal, Exponential
from sbi.utils import MultipleIndependent, BoxUniform

prior = MultipleIndependent([
    MultivariateNormal(torch.ones(2,), torch.eye(2,)),
    BoxUniform(-torch.ones(3,), torch.ones(3,)),
    Exponential(torch.ones(1,))
])

This will create a prior for 6 parameters, of which the first two follow a Normal distribution, the second three follow a uniform distribution, and the last one follows an exponential distribution.

Wrapping non-torch distributions#

In case you want to use a custom prior that is not in the set of common distributions that’s possible as well: You need to write a prior class that mimicks the behaviour of a torch.distributions.Distribution class. sbi will wrap this class to make it a fully functional torch Distribution.

Essentially, the class needs two methods:

  • .sample(sample_shape), where sample_shape is a shape tuple, e.g., (n,), and returns a batch of n samples, e.g., of shape (n, 2)` for a two dimenional prior.

  • .log_prob(value) method that returns the “log probs” of parameters under the prior, e.g., for a batches of n parameters with shape (n, ndims) it should return a log probs array of shape (n,).

For sbi > 0.17.2 this could look like the following:

class CustomUniformPrior:
    """User defined numpy uniform prior.

    Custom prior with user-defined valid .sample and .log_prob methods.
    """

    def __init__(self, lower: Tensor, upper: Tensor, return_numpy: bool = False):
        self.lower = lower
        self.upper = upper
        self.dist = BoxUniform(lower, upper)
        self.return_numpy = return_numpy

    def sample(self, sample_shape=torch.Size([])):
        samples = self.dist.sample(sample_shape)
        return samples.numpy() if self.return_numpy else samples

    def log_prob(self, values):
        if self.return_numpy:
            values = torch.as_tensor(values)
        log_probs = self.dist.log_prob(values)
        return log_probs.numpy() if self.return_numpy else log_probs

Once you have such a class, you can wrap it into a Distribution using the process_prior function sbi provides:

from sbi.utils import process_prior

custom_prior = CustomUniformPrior(torch.zeros(2), torch.ones(2))
prior, *_ = process_prior(custom_prior)  # Keeping only the first return.
# use this wrapped prior in sbi...

In sbi it is sometimes required to check the support of the prior, e.g., when the prior support is bounded and one wants to reject samples from the posterior density estimator that lie outside the prior support. In torch Distributions this is handled automatically. However, when using a custom prior, it is not. Thus, if your prior has bounded support (like the one above), it makes sense to pass the bounds to the wrapper function such that sbi can pass them to torch Distributions:

from sbi.utils import process_prior

custom_prior = CustomUniformPrior(torch.zeros(2), torch.ones(2))
prior = process_prior(custom_prior,
                      custom_prior_wrapper_kwargs=dict(lower_bound=torch.zeros(2),
                                                       upper_bound=torch.ones(2)))
# use this wrapped prior in sbi...

Note that in custom_prior_wrapper_kwargs you can pass additinal arguments for the wrapper, e.g., validate_args or arg_constraints see the Distribution documentation for more details.

If you are using sbi < 0.17.2 and use NLE the code above will produce a NotImplementedError (see #581). In this case, you need to update to a newer version of sbi or use NPE instead.