RestrictionEstimator#

class RestrictionEstimator(prior, model='resnet', decision_criterion='nan', hidden_features=100, num_blocks=2, dropout_probability=0.5, z_score='independent', embedding_net=Identity())[source]#

Bases: object

Classifier to estimate regions of the prior that give good simulation results.

Parameters:
build_resnet(theta)[source]#
Return type:

Module

build_mlp(theta)[source]#
Return type:

Module

append_simulations(theta, x)[source]#

Store parameters and simulation outputs to use them for training later. Data ar stored as entries in lists for each type of variable (parameter/data).

Parameters:
  • theta (Tensor) – Parameter sets.

  • x (Tensor) – Simulation outputs.

Returns:

RestrictionEstimator object (returned so that this function is chainable).

Return type:

RestrictionEstimator

get_simulations(starting_round=0)[source]#

Return all \((\theta, x, label)\) pairs that have been passed to this object.

The label had been inferred from the valid_or_invalid_criterion.

Parameters:

starting_round (int)

Return type:

Tuple[Tensor, Tensor, Tensor]

train(training_batch_size=200, learning_rate=0.0005, validation_fraction=0.1, stop_after_epochs=20, max_num_epochs=2147483647, clip_max_norm=5.0, loss_importance_weights=False, subsample_invalid_sims=1.0)[source]#

Train the classifier to distinguish parameters with valid`|`invalid outputs.

Parameters:
  • training_batch_size (int) – Training batch size.

  • learning_rate (float) – Learning rate for Adam optimizer.

  • validation_fraction (float) – The fraction of data to use for validation.

  • stop_after_epochs (int) – The number of epochs to wait for improvement on the validation set before terminating training.

  • max_num_epochs (int) – Maximum number of epochs to run. If reached, we stop training even when the validation loss is still decreasing. If None, we train until validation loss increases (see also stop_after_epochs).

  • clip_max_norm (float | None) – Value at which to clip the total gradient norm in order to prevent exploding gradients. Use None for no clipping.

  • loss_importance_weights (bool | float) – If bool: whether or not to reweigh the loss such that the prior between valid and invalid simulations is uniform. This is one way to deal with imbalanced data (e.g. 99% invalid simulations). If you want to reweigh with a custom weight, pass a float. The value assigned will be the reweighing factor for invalid simulations, (1-reweigh_factor) will be the factor for good simulations.

  • subsample_invalid_sims (float | str) – Sampling weight of invalid simulations. This can be useful when the fraction of invalid simulations is extremely high and one wants to train on a larger fraction of valid simulations. This factor has to be in [0, 1]. If it is auto, automatically infer subsample weights such that the data is balanced.

Return type:

Module

restrict_prior(classifier=None, allowed_false_negatives=0.0, reweigh_factor=None)[source]#

Return the restricted prior.

The restricted prior (Deistler et al. 2020, in preparation) is the part of the prior that can produce valid simulations. More formally, the restricted prior \(p_r(\theta)\) is:

\(p_r(\theta) = c \cdot p(\theta) if \theta \in support(p(\theta|x=`valid`))\) \(p_r(\theta) = 0 otherwise\).

We sample from the restricted prior by sampling from the prior and then rejecting if the classifier predicts that the simulation output can not be valid.

Parameters:
  • classifier (Module | None) – Classifier that is used to predict whether parameter sets are valid or invalid.

  • allowed_false_negatives (float) – Fraction of false-negative predictions the classifier is allowed to make. The threshold of the classifier will be tuned such that this criterion is fulfilled. A high value will lead to the classifier rejecting more parameter sets, which will give many valid parameter sets. However, a high value also means that some potentially valid parameter sets will be missed. Inference is only exact for allowed_false_negatives=0.0. The value specified here corresponds approximately to the fraction of parameter sets that will be systematically missed by the inference procedure.

  • reweigh_factor (float | None) – Post-hoc correction factor. Should be in [0, 1]. A large reweigh factor will increase the probability of predicting a invalid simulation.

Returns:

Restricted prior with .sample() and .predict() methods.

Return type:

RestrictedPrior