sbi.utils.process_simulator

Navigation

sbi.utils.process_simulator#

process_simulator(user_simulator, prior, is_numpy_simulator)[source]#

Returns a simulator that meets the requirements for usage in sbi.

Parameters:
  • user_simulator (Callable) – simulator provided by the user, possibly written in numpy.

  • prior (torch.distributions.Distribution) – prior as pytorch distribution or processed with process_prior().

  • is_numpy_simulator (bool) – whether the simulator needs theta in numpy types, returned from process_prior.

Returns:

simulator: processed simulator that returns torch.Tensor and can handle batches of parameters.

Return type:

Callable

Example:#

import torch
from sbi.utils.user_input_checks import process_simulator
from torch.distributions import Uniform
from sbi.utils.user_input_checks import process_prior

prior = Uniform(torch.zeros(1), torch.ones(1))
prior, theta_numel, prior_returns_numpy = process_prior(prior)
simulator = lambda theta: theta + 1
simulator = process_simulator(simulator, prior, prior_returns_numpy)