RestrictionEstimator#
- class RestrictionEstimator(prior, model='resnet', decision_criterion='nan', hidden_features=100, num_blocks=2, dropout_probability=0.5, z_score='independent', embedding_net=Identity())[source]#
Bases:
objectClassifier to estimate regions of the prior that give good simulation results.
- Parameters:
- append_simulations(theta, x)[source]#
Store parameters and simulation outputs to use them for training later. Data ar stored as entries in lists for each type of variable (parameter/data).
- Parameters:
- Returns:
RestrictionEstimator object (returned so that this function is chainable).
- Return type:
- get_simulations(starting_round=0)[source]#
Return all \((\theta, x, label)\) pairs that have been passed to this object.
The label had been inferred from the valid_or_invalid_criterion.
- train(training_batch_size=200, learning_rate=0.0005, validation_fraction=0.1, stop_after_epochs=20, max_num_epochs=2147483647, clip_max_norm=5.0, loss_importance_weights=False, subsample_invalid_sims=1.0)[source]#
Train the classifier to distinguish parameters with valid`|`invalid outputs.
- Parameters:
training_batch_size (int) – Training batch size.
learning_rate (float) – Learning rate for Adam optimizer.
validation_fraction (float) – The fraction of data to use for validation.
stop_after_epochs (int) – The number of epochs to wait for improvement on the validation set before terminating training.
max_num_epochs (int) – Maximum number of epochs to run. If reached, we stop training even when the validation loss is still decreasing. If None, we train until validation loss increases (see also stop_after_epochs).
clip_max_norm (float | None) – Value at which to clip the total gradient norm in order to prevent exploding gradients. Use None for no clipping.
loss_importance_weights (bool | float) – If bool: whether or not to reweigh the loss such that the prior between valid and invalid simulations is uniform. This is one way to deal with imbalanced data (e.g. 99% invalid simulations). If you want to reweigh with a custom weight, pass a float. The value assigned will be the reweighing factor for invalid simulations, (1-reweigh_factor) will be the factor for good simulations.
subsample_invalid_sims (float | str) – Sampling weight of invalid simulations. This can be useful when the fraction of invalid simulations is extremely high and one wants to train on a larger fraction of valid simulations. This factor has to be in [0, 1]. If it is auto, automatically infer subsample weights such that the data is balanced.
- Return type:
- restrict_prior(classifier=None, allowed_false_negatives=0.0, reweigh_factor=None)[source]#
Return the restricted prior.
The restricted prior (Deistler et al. 2020, in preparation) is the part of the prior that can produce valid simulations. More formally, the restricted prior \(p_r(\theta)\) is:
\(p_r(\theta) = c \cdot p(\theta) if \theta \in support(p(\theta|x=`valid`))\) \(p_r(\theta) = 0 otherwise\).
We sample from the restricted prior by sampling from the prior and then rejecting if the classifier predicts that the simulation output can not be valid.
- Parameters:
classifier (Module | None) – Classifier that is used to predict whether parameter sets are valid or invalid.
allowed_false_negatives (float) – Fraction of false-negative predictions the classifier is allowed to make. The threshold of the classifier will be tuned such that this criterion is fulfilled. A high value will lead to the classifier rejecting more parameter sets, which will give many valid parameter sets. However, a high value also means that some potentially valid parameter sets will be missed. Inference is only exact for allowed_false_negatives=0.0. The value specified here corresponds approximately to the fraction of parameter sets that will be systematically missed by the inference procedure.
reweigh_factor (float | None) – Post-hoc correction factor. Should be in [0, 1]. A large reweigh factor will increase the probability of predicting a invalid simulation.
- Returns:
Restricted prior with .sample() and .predict() methods.
- Return type: