sbi.neural_nets.likelihood_nn

Navigation

sbi.neural_nets.likelihood_nn#

likelihood_nn(model, z_score_theta='independent', z_score_x='independent', hidden_features=50, num_transforms=5, num_bins=10, embedding_net=Identity(), num_components=10, **kwargs)[source]#

Returns a function that builds a density estimator for learning the likelihood.

This function will usually be used for SNLE. The returned function is to be passed to the inference class when using the flexible interface.

Parameters:
  • model (str) – The type of density estimator that will be created. One of [mdn, made, maf, maf_rqs, nsf].

  • z_score_theta (Literal['independent', 'structured', 'transform_to_unconstrained', 'none'] | None) – Whether to z-score parameters \(\theta\) before passing them into the network, can take one of the following: - none, or None: do not z-score. - independent: z-score each dimension independently. - structured: treat dimensions as related, therefore compute mean and std over the entire batch, instead of per-dimension. Should be used when each sample is, for example, a time series or an image.

  • z_score_x (Literal['independent', 'structured', 'transform_to_unconstrained', 'none'] | None) – Whether to z-score simulation outputs \(x\) before passing them into the network, same options as z_score_theta.

  • hidden_features (int) – Number of hidden features.

  • num_transforms (int) – Number of transforms when a flow is used. Only relevant if density estimator is a normalizing flow (i.e. currently either a maf or a nsf). Ignored if density estimator is a mdn or made.

  • num_bins (int) – Number of bins used for the splines in nsf. Ignored if density estimator not nsf.

  • embedding_net (Module) – Optional embedding network for parameters \(\theta\).

  • num_components (int) – Number of mixture components for a mixture of Gaussians. Ignored if density estimator is not an mdn.

  • **kwargs (Any) – Additional estimator arguments. Valid keys are defined by ConditionalFlowConfig; unknown keys trigger a warning and are forwarded to the builder.

Return type:

Callable