# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
from copy import deepcopy
from math import floor
from typing import Any, Callable, Literal, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from nflows.nn import nets
from torch import Tensor, nn, relu
from torch.distributions import Distribution
from torch.nn.utils.clip_grad import clip_grad_norm_
from torch.optim.adam import Adam
from torch.utils import data
from torch.utils.data.sampler import SubsetRandomSampler, WeightedRandomSampler
from sbi.samplers import rejection
from sbi.samplers.importance.sir import sampling_importance_resampling
from sbi.sbi_types import Shape
from sbi.utils.sbiutils import (
get_simulations_since_round,
handle_invalid_x,
standardizing_net,
z_score_parser,
)
from sbi.utils.torchutils import ensure_theta_batched
from sbi.utils.user_input_checks import validate_theta_and_x
def build_input_layer(
batch_theta: Tensor,
z_score_theta: Optional[str] = "independent",
embedding_net_theta: nn.Module = nn.Identity(),
) -> nn.Module:
r"""Builds input layer for the `RestrictionEstimator` with option to z-score.
The classifier used in the `RestrictionEstimator` will receive batches of $\theta$s.
Args:
batch_theta: Batch of $\theta$s, used to infer dimensionality and (optional)
z-scoring.
z_score_theta: Whether to z-score parameters $\theta$ before passing them into
the network, can take one of the following:
- `none`, or None: do not z-score.
- `independent`: z-score each dimension independently.
- `structured`: treat dimensions as related, therefore compute mean and std
over the entire batch, instead of per-dimension. Should be used when each
sample is, for example, a time series or an image.
embedding_net_theta: Optional embedding network for $\theta$s.
Returns:
Input layer with optional embedding net and z-scoring.
"""
z_score_theta_bool, structured_theta = z_score_parser(z_score_theta)
if z_score_theta_bool:
input_layer = nn.Sequential(
standardizing_net(batch_theta, structured_theta), embedding_net_theta
)
else:
input_layer = embedding_net_theta
return input_layer
[docs]
class RestrictionEstimator:
"""Classifier to estimate regions of the prior that give good simulation results."""
def __init__(
self,
prior: Distribution,
model: Union[str, Callable] = "resnet",
decision_criterion: Union[str, Callable] = "nan",
hidden_features: int = 100,
num_blocks: int = 2,
dropout_probability: float = 0.5,
z_score: Optional[
Literal["independent", "structured", "transform_to_unconstrained", "none"]
] = "independent",
embedding_net: nn.Module = nn.Identity(),
) -> None:
r"""
Estimator that trains a classifier to restrict the prior.
The classifier learns to distinguish `valid` simulation outputs from `invalid`
simulation outputs.
Reference:
Deistler et al. (2022): "Energy-efficient network activity from disparate
circuit parameters"
Args:
prior: Prior distribution.
model: Neural network used to distinguish valid from invalid samples. If it
is a string, use a pre-configured network of the provided type (either
mlp or resnet). Alternatively, a function that builds a custom
neural network can be provided. The function will be called with the
first batch of parameters (theta,), which can thus be used for shape
inference and potentially for z-scoring. It needs to return a PyTorch
`nn.Module` implementing the classifier.
decision_criterion: Callable that takes in the simulation output $x$ and
outputs whether $x$ is counted as `valid` simulation (output 1) or as a
`invalid` simulation (output 0). By default, the function checks
whether a simulation output $x$ contains at least one `nan` or `inf`.
hidden_features: Number of hidden units of the classifier if `model` is a
string.
num_blocks: Number of hidden layers of the classifier if `model` is a
string.
dropout_probability: Dropout probability of the classifier if `model` is
`resnet`.
z_score: Whether to z-score the parameters $\theta$ used to train the
classifier.
embedding_net: Neural network used to encode the parameters before they are
passed to the classifier.
"""
self._prior = prior
self._classifier = None
self._device = "cpu" # TODO hot fix to prevent the tests from crashing
if isinstance(model, str):
self._model = model
self._hidden_features = hidden_features
self._num_blocks = num_blocks
self._dropout_probability = dropout_probability
self._z_score = z_score
self._embedding_net = embedding_net
if model == "resnet":
self._build_nn = self.build_resnet
elif model == "mlp":
self._build_nn = self.build_mlp
else:
raise NameError(
f"The `model` must be either of [resnet|mlp]. You passed {model}."
)
else:
self._build_nn = model
self._valid_or_invalid_criterion = decision_criterion
self._theta_roundwise = []
self._x_roundwise = []
self._label_roundwise = []
self._data_round_index = []
self._validation_log_probs = []
[docs]
def build_resnet(self, theta) -> nn.Module:
classifier = nets.ResidualNet(
in_features=theta.shape[1],
out_features=2,
hidden_features=self._hidden_features,
context_features=None,
num_blocks=self._num_blocks,
activation=relu,
dropout_probability=self._dropout_probability,
use_batch_norm=True,
)
z_score_theta = self._z_score
embedding_net_theta = self._embedding_net
input_layer = build_input_layer(theta, z_score_theta, embedding_net_theta)
classifier = nn.Sequential(input_layer, classifier)
return classifier
[docs]
def build_mlp(self, theta) -> nn.Module:
classifier = nn.Sequential(
nn.Linear(theta.shape[1], self._hidden_features),
nn.LayerNorm(self._hidden_features),
nn.ReLU(),
nn.Linear(self._hidden_features, self._hidden_features),
nn.LayerNorm(self._hidden_features),
nn.ReLU(),
nn.Linear(self._hidden_features, 2),
)
z_score_theta = self._z_score
embedding_net_theta = self._embedding_net
input_layer = build_input_layer(theta, z_score_theta, embedding_net_theta)
classifier = nn.Sequential(input_layer, classifier)
return classifier
[docs]
def append_simulations(self, theta: Tensor, x: Tensor) -> "RestrictionEstimator":
r"""
Store parameters and simulation outputs to use them for training later.
Data ar stored as entries in lists for each type of variable (parameter/data).
Args:
theta: Parameter sets.
x: Simulation outputs.
Returns:
`RestrictionEstimator` object (returned so that this function is chainable).
"""
theta, x = validate_theta_and_x(theta, x, training_device=self._device)
if self._valid_or_invalid_criterion == "nan":
label, _, _ = handle_invalid_x(x)
else:
assert isinstance(self._valid_or_invalid_criterion, Callable)
label = self._valid_or_invalid_criterion(x)
label = label.long()
if self._data_round_index:
self._data_round_index.append(self._data_round_index[-1] + 1)
else:
self._data_round_index.append(0)
self._theta_roundwise.append(theta)
self._x_roundwise.append(x)
self._label_roundwise.append(label)
return self
[docs]
def get_simulations(self, starting_round: int = 0) -> Tuple[Tensor, Tensor, Tensor]:
r"""
Return all $(\theta, x, label)$ pairs that have been passed to this object.
The label had been inferred from the `valid_or_invalid_criterion`.
"""
theta = get_simulations_since_round(
self._theta_roundwise, self._data_round_index, starting_round
)
x = get_simulations_since_round(
self._x_roundwise, self._data_round_index, starting_round
)
label = get_simulations_since_round(
self._label_roundwise, self._data_round_index, starting_round
)
return theta, x, label
[docs]
def train(
self,
training_batch_size: int = 200,
learning_rate: float = 5e-4,
validation_fraction: float = 0.1,
stop_after_epochs: int = 20,
max_num_epochs: int = 2**31 - 1,
clip_max_norm: Optional[float] = 5.0,
loss_importance_weights: Union[bool, float] = False,
subsample_invalid_sims: Union[float, str] = 1.0,
) -> torch.nn.Module:
r"""
Train the classifier to distinguish parameters with `valid`|`invalid` outputs.
Args:
training_batch_size: Training batch size.
learning_rate: Learning rate for Adam optimizer.
validation_fraction: The fraction of data to use for validation.
stop_after_epochs: The number of epochs to wait for improvement on the
validation set before terminating training.
max_num_epochs: Maximum number of epochs to run. If reached, we stop
training even when the validation loss is still decreasing. If None, we
train until validation loss increases (see also `stop_after_epochs`).
clip_max_norm: Value at which to clip the total gradient norm in order to
prevent exploding gradients. Use None for no clipping.
loss_importance_weights: If `bool`: whether or not to reweigh the loss such
that the prior between `valid` and `invalid` simulations is uniform.
This is one way to deal with imbalanced data (e.g. 99% invalid
simulations). If you want to reweigh with a custom weight, pass a
`float`. The value assigned will be the reweighing factor for invalid
simulations, (1-reweigh_factor) will be the factor for good simulations.
subsample_invalid_sims: Sampling weight of invalid simulations. This can be
useful when the fraction of invalid simulations is extremely high and
one wants to train on a larger fraction of valid simulations. This
factor has to be in [0, 1]. If it is `auto`, automatically infer
subsample weights such that the data is balanced.
"""
theta: Tensor = torch.cat(self._theta_roundwise)
label: Tensor = torch.cat(self._label_roundwise)
# Get indices for permutation of the data.
num_examples = len(theta)
permuted_indices = torch.randperm(num_examples)
num_training_examples = int((1 - validation_fraction) * num_examples)
num_validation_examples = num_examples - num_training_examples
train_indices, val_indices = (
permuted_indices[:num_training_examples],
permuted_indices[num_training_examples:],
)
# The ratio of `valid` and `invalid` simulation outputs might not be balanced.
# E.g. if there are fewer `valid` datapoints, one might want to sample them
# more often (i.e. show them to the neural network more often). Such a sampler
# is implemented below. Also see: https://discuss.pytorch.org/t/29907
subsample_weights: Tensor = torch.ones(num_examples)
if subsample_invalid_sims == "auto":
subsample_invalid = float(label.sum()) / float(theta.shape[0] - label.sum())
else:
assert isinstance(subsample_invalid_sims, float)
subsample_invalid = subsample_invalid_sims
subsample_weights[torch.logical_not(label.bool())] = subsample_invalid
subsample_weights = deepcopy(subsample_weights)
subsample_weights[val_indices] = 0.0
# Dataset is shared for training and validation loaders.
dataset = data.TensorDataset(theta, label)
# Create neural_net and validation loaders using a subset sampler.
train_loader = data.DataLoader(
dataset,
batch_size=training_batch_size,
drop_last=True,
sampler=WeightedRandomSampler(
subsample_weights.tolist(),
int(subsample_weights.sum()),
replacement=False,
),
)
val_loader = data.DataLoader(
dataset,
batch_size=min(
max(200, training_batch_size), num_examples - num_training_examples
),
shuffle=False,
drop_last=True,
sampler=SubsetRandomSampler(val_indices.tolist()),
)
if self._classifier is None:
self._classifier = self._build_nn(theta[train_indices])
# If we are in the first round, save the validation data in order to be able to
# tune the classifier threshold.
if max(self._data_round_index) == 0:
self._first_round_validation_theta = theta[val_indices]
self._first_round_validation_label = label[val_indices]
optimizer = Adam(
list(self._classifier.parameters()),
lr=learning_rate,
)
# Compute the fraction of good simulations in dataset.
if loss_importance_weights:
if isinstance(loss_importance_weights, bool):
good_sim_fraction = torch.sum(label, dtype=torch.float) / label.shape[0]
importance_weights = good_sim_fraction
else:
importance_weights = loss_importance_weights
else:
importance_weights = 0.5
# Factor of two such that the average learning rate remains the same.
# Needed because the average of reweigh_factor and 1-reweigh_factor will be 0.5
# only.
importance_weights = 2 * torch.tensor([
importance_weights,
1 - importance_weights,
])
criterion = nn.CrossEntropyLoss(importance_weights, reduction="none")
epoch, self._val_log_prob = 0, float("-Inf")
while epoch <= max_num_epochs and not self._converged(epoch, stop_after_epochs):
self._classifier.train()
for parameters, observations in train_loader:
optimizer.zero_grad()
outputs = self._classifier(parameters)
loss = criterion(outputs, observations).mean()
loss.backward()
if clip_max_norm is not None:
clip_grad_norm_(
self._classifier.parameters(),
max_norm=clip_max_norm,
)
optimizer.step()
epoch += 1
# calculate validation performance
self._classifier.eval()
val_loss = 0.0
with torch.no_grad():
for parameters, observations in val_loader:
outputs = self._classifier(parameters)
loss = criterion(outputs, observations)
loss[~observations.bool()] *= subsample_invalid_sims
val_loss += loss.sum().item()
self._val_log_prob = -val_loss / num_validation_examples
self._validation_log_probs.append(self._val_log_prob)
print("Training neural network. Epochs trained: ", epoch, end="\r")
return deepcopy(self._classifier)
[docs]
def restrict_prior(
self,
classifier: Optional[nn.Module] = None,
allowed_false_negatives: float = 0.0,
reweigh_factor: Optional[float] = None,
) -> "RestrictedPrior":
r"""
Return the restricted prior.
The restricted prior (Deistler et al. 2020, in preparation) is the part of the
prior that can produce `valid` simulations. More formally, the restricted prior
$p_r(\theta)$ is:
$p_r(\theta) = c \cdot p(\theta) if \theta \in support(p(\theta|x=`valid`))$
$p_r(\theta) = 0 otherwise$.
We sample from the restricted prior by sampling from the prior and then
rejecting if the classifier predicts that the simulation output can not be
`valid`.
Args:
classifier: Classifier that is used to predict whether parameter sets are
`valid` or `invalid`.
allowed_false_negatives: Fraction of false-negative predictions the
classifier is allowed to make. The threshold of the classifier will be
tuned such that this criterion is fulfilled. A high value will lead to
the classifier rejecting more parameter sets, which will give many
`valid` parameter sets. However, a high value also means that some
potentially `valid` parameter sets will be missed. Inference is only
**exact** for `allowed_false_negatives=0.0`. The value specified here
corresponds approximately to the fraction of parameter sets that will
be systematically missed by the inference procedure.
reweigh_factor: Post-hoc correction factor. Should be in [0, 1]. A large
reweigh factor will increase the probability of predicting a `invalid`
simulation.
Returns:
Restricted prior with `.sample()` and `.predict()` methods.
"""
if classifier is None:
assert self._classifier is not None, "Classifier must be trained first."
classifier_ = self._classifier
else:
classifier_ = classifier
classifier_.zero_grad(set_to_none=True)
accept_reject_fn = AcceptRejectFunction(
self._classifier,
self._first_round_validation_theta,
self._first_round_validation_label,
allowed_false_negatives=allowed_false_negatives,
reweigh_factor=reweigh_factor,
)
return RestrictedPrior(self._prior, accept_reject_fn)
def _converged(self, epoch: int, stop_after_epochs: int) -> bool:
r"""
Return whether the training converged yet and save best model state so far.
Checks for improvement in validation performance over previous epochs.
Args:
epoch: Current epoch in training.
stop_after_epochs: How many fruitless epochs to let pass before stopping.
Returns:
Whether the training has stopped improving, i.e. has converged.
"""
converged = False
assert self._classifier is not None, "Classifier must be trained first."
posterior_nn = self._classifier
# (Re)-start the epoch count with the first epoch or any improvement.
if epoch == 0 or self._val_log_prob > self._best_val_log_prob:
self._best_val_log_prob = self._val_log_prob
self._epochs_since_last_improvement = 0
self._best_model_state_dict = deepcopy(posterior_nn.state_dict())
else:
self._epochs_since_last_improvement += 1
# If no validation improvement over many epochs, stop training.
if self._epochs_since_last_improvement > stop_after_epochs - 1:
posterior_nn.load_state_dict(self._best_model_state_dict)
converged = True
return converged
[docs]
def get_density_thresholder(
dist: Any,
quantile: float = 1e-4,
num_samples_to_estimate_support: int = 1_000_000,
) -> Callable:
"""Returns function that thresholds a density at a particular `1-quantile`.
Reference:
Deistler et al. (2022): "Truncated proposals for scalable and hassle-free
simulation-based inference"
Args:
dist: Probability distribution to be thresholded, must have `.sample()` and
`.log_prob()`.
quantile: The returned function will be `True` for $\theta$ within the
`1-quantile` high-probability region of the distribution. In other words:
`quantile` is the fraction of mass that is excluded from `dist`.
num_samples_to_estimate_support: The number of samples that are drawn from
`dist` in order to obtain the threshold. Higher values are more accurate
but slower.
Returns:
Callabe which is true for $\theta$ in the `1-quantile` high-probability region
of the `dist`.
"""
samples = dist.sample((num_samples_to_estimate_support,))
log_probs = dist.log_prob(samples)
sorted_log_probs, _ = torch.sort(log_probs)
log_prob_threshold = sorted_log_probs[
int(quantile * num_samples_to_estimate_support)
]
def density_thresholder(theta: Tensor) -> Tensor:
theta_log_probs = dist.log_prob(theta)
predictions = theta_log_probs > log_prob_threshold
return predictions.bool()
return density_thresholder
class AcceptRejectFunction:
def __init__(
self,
classifier: Any,
validation_theta: Tensor,
validation_label: Tensor,
allowed_false_negatives: Optional[float] = None,
reweigh_factor: Optional[float] = None,
print_fp_rate: bool = False,
safety_margin: Optional[Union[str, float]] = "frequentist",
) -> None:
self._classifier = classifier
self._validation_theta = validation_theta
self._validation_label = validation_label
self._allowed_false_negatives = allowed_false_negatives
self._reweigh_factor = reweigh_factor
self._print_fp_rate = print_fp_rate
self._safety_margin = safety_margin
assert (
allowed_false_negatives is None or reweigh_factor is None
), """Both the `allowed_false_negatives` and the `reweigh_factor` are set. You
can only set one of them."""
valid_val_theta = validation_theta[validation_label.bool()]
num_valid = valid_val_theta.shape[0]
clf_probs = F.softmax(classifier.forward(valid_val_theta), dim=1)[:, 1]
if allowed_false_negatives == 0.0:
if safety_margin is None:
self._classifier_thr = torch.min(clf_probs)
elif isinstance(safety_margin, float):
self._classifier_thr = torch.min(clf_probs) - safety_margin
elif safety_margin == "frequentist":
# We seek the minimum classifier output, not the maximum, as it usually
# is in the `German Tank Problem`. Hence, we transform the outputs with
# (1-output), apply the estimator, and then transform back.
tf_min_val = torch.max(1.0 - clf_probs)
tf_estimate = tf_min_val + tf_min_val / num_valid
self._classifier_thr = 1.0 - tf_estimate
else:
raise NameError(f"`safety_margin` {safety_margin} not supported.")
else:
assert allowed_false_negatives is not None, (
"`allowed_false_negatives` must be set."
)
quantile_index = floor(num_valid * allowed_false_negatives)
self._classifier_thr, _ = torch.kthvalue(clf_probs, quantile_index + 1)
if self._print_fp_rate:
self.print_false_positive_rate(
self.__call__, self._validation_theta, self._validation_label
)
def __call__(self, theta):
pred = F.softmax(self._classifier.forward(theta), dim=1)[:, 1]
if self._reweigh_factor is None:
threshold = self._classifier_thr
predictions = pred > threshold
else:
probs_invalid = pred * self._reweigh_factor
probs_valid = (1 - pred) * (1 - self._reweigh_factor)
predictions = probs_valid > probs_invalid
return predictions.bool()
def print_false_positive_rate(
self,
accept_reject_fn: Callable,
validation_theta: Tensor,
validation_label: Tensor,
) -> float:
r"""
Print and return the rate of false positive predictions on the validation set.
Returns:
The false positive rate.
"""
invalid_val_theta = validation_theta[~validation_label.bool()]
predictions = accept_reject_fn(invalid_val_theta)
num_false_positives = int(predictions.sum())
fraction_false_positives = num_false_positives / invalid_val_theta.shape[0]
print(
f"Fraction of false positives: "
f"{num_false_positives} / {invalid_val_theta.shape[0]} = "
f"{fraction_false_positives:.3f}"
)
return fraction_false_positives
[docs]
class RestrictedPrior(Distribution):
"""Distribution that restricts the prior distribution to a smaller region."""
def __init__(
self,
prior: Distribution,
accept_reject_fn: Callable,
posterior: Optional[Any] = None,
sample_with: str = "rejection",
device: str = "cpu",
) -> None:
r"""Initialize the simulation-informed prior.
References:
- Deistler et al. (2022): *Energy-efficient network activity from disparate
circuit parameters*
- Deistler et al. (2022): *Truncated proposals for scalable and hassle-free
simulation-based inference*
Args:
prior: Prior distribution, will be used as proposal distribution whose
samples will be evaluated by the classifier.
accept_reject_fn: Callable that returns `True` inside the restricted region
and `False` outside of it.
posterior: Posterior distribution. Only used as proposal for `sir`.
sample_with: Either of [`rejection`|`sir`]. Sets that method that is used
to sample from the restricted prior. If `sir`, youu must have passed
a `posterior` at initialization.
device: Device used for sampling and evaluating.
"""
super().__init__(validate_args=False)
self._prior = prior
self._accept_reject_fn = accept_reject_fn
self._posterior = posterior # Only used for SIR.
self._sample_with = sample_with
self._device = device
self.acceptance_rate = None # Only defined for rejection sampling.
[docs]
def sample(
self,
sample_shape: Shape = torch.Size(),
sample_with: Optional[str] = None,
max_sampling_batch_size: int = 10_000,
oversampling_factor: int = 1024,
save_acceptance_rate: bool = False,
show_progress_bars: bool = False,
print_rejected_frac: bool = True,
) -> Tensor:
"""
Return samples from the `RestrictedPrior`.
Samples are obtained by sampling from the prior, evaluating them under the
trained classifier (`RestrictionEstimator`) and using only those that were
accepted.
Args:
sample_shape: Shape of the returned samples.
sample_with: Either of [`rejection`|`sir`]. Sets that method that is used
to sample from the restricted prior. If `sir`, youu must have passed
a `posterior` at initialization.
max_sampling_batch_size: Batch size for drawing samples from the posterior.
Takes effect only in the second iteration of the loop below, i.e., in
case of leakage or `num_samples>max_sampling_batch_size`. Larger batch
size speeds up sampling.
oversampling_factor: Number of proposed samples for `sir` from which only
one is selected based on its importance weight.
save_acceptance_rate: If `True`, the acceptance rate is saved and such that
it can potentially be used later in `log_prob()`.
show_progress_bars: Whether to show a progressbar during sampling.
print_rejected_frac: Whether to print the rejection rate of the restriction
estimator during sampling.
Returns:
Samples from the `RestrictedPrior`.
"""
num_samples = torch.Size(sample_shape).numel()
sample_with = self._sample_with if sample_with is None else sample_with
if sample_with == "rejection":
samples, acceptance_rate = rejection.accept_reject_sample(
proposal=lambda sample_shape, **kwargs: self._prior.sample(
sample_shape
),
accept_reject_fn=self._accept_reject_fn,
num_samples=num_samples,
show_progress_bars=show_progress_bars,
max_sampling_batch_size=max_sampling_batch_size,
alternative_method="sample_with='sir'",
)
# NOTE: This currently requires a float acceptance rate. A previous version
# of accept_reject_sample returned a float. In favour to batched sampling
# it now returns a tensor.
acceptance_rate = acceptance_rate.min().item()
if save_acceptance_rate:
self.acceptance_rate = torch.as_tensor(acceptance_rate)
if print_rejected_frac:
print(
f"The `RestrictedPrior` rejected "
f"{(1.0 - acceptance_rate) * 100:.1f}% of prior samples. You will "
f"get a speed-up of {(1.0 / acceptance_rate - 1.0) * 100:.1f}%."
)
elif sample_with == "sir":
assert self._posterior is not None, (
"In order to use SIR sampling, you must provide a `posterior`: "
"`RestrictionEstimator(..., posterior=posterior)`."
)
num_samples = torch.Size(sample_shape).numel()
accept_reject_fn = lambda theta: self._accept_reject_fn(theta).type(
torch.float32
)
samples = sampling_importance_resampling(
accept_reject_fn,
proposal=self._posterior,
num_samples=num_samples,
oversampling_factor=oversampling_factor,
show_progress_bars=show_progress_bars,
max_sampling_batch_size=max_sampling_batch_size,
device=self._device,
)
else:
raise ValueError("Only [rejection | sir] implemented as `method`")
return samples.reshape((*sample_shape, -1)).to(self._device)
[docs]
def log_prob(
self,
theta: Tensor,
norm_restricted_prior: bool = True,
track_gradients: bool = False,
prior_acceptance_params: Optional[dict] = None,
) -> Tensor:
r"""Returns the log-probability of the restricted prior.
Args:
theta: Parameters $\theta$.
norm_restricted_prior: Whether to enforce a normalized restricted prior
density. The normalizing factor is calculated via rejection sampling,
so if you need speedier but unnormalized log probability estimates set
here `norm_restricted_prior=False`. The returned log probability is set
to -∞ outside of the restriceted prior support regardless of this
setting.
track_gradients: Whether the returned tensor supports tracking gradients.
This can be helpful for e.g. sensitivity analysis, but increases memory
consumption.
prior_acceptance_params: A `dict` of keyword arguments to override the
default values of `prior_acceptance()`. Possible options are:
`num_rejection_samples`, `force_update`, `show_progress_bars`, and
`rejection_sampling_batch_size`. These parameters only have an effect
if `norm_restricted_prior=True`.
Returns:
`(len(θ),)`-shaped log probability for θ in the support of the restricted
prior, -∞ (corresponding to 0 probability) outside.
"""
theta = ensure_theta_batched(torch.as_tensor(theta))
with torch.set_grad_enabled(track_gradients):
# Evaluate on device, move back to cpu for comparison with prior.
prior_log_prob = self._prior.log_prob(theta)
accepted = self._accept_reject_fn(theta).bool()
masked_log_prob = torch.where(
accepted,
prior_log_prob,
torch.tensor(float("-inf"), dtype=torch.float32),
)
if prior_acceptance_params is None:
prior_acceptance_params = dict() # use defaults
log_factor = (
torch.log(self.prior_acceptance(**prior_acceptance_params))
if norm_restricted_prior
else 0
)
return masked_log_prob - log_factor
[docs]
@torch.no_grad()
def prior_acceptance(
self,
num_rejection_samples: int = 10_000,
force_update: bool = False,
show_progress_bars: bool = False,
rejection_sampling_batch_size: int = 10_000,
) -> Tensor:
"""
Return the fraction of prior samples accepted by the classifier.
The factor is estimated from the acceptance probability during rejection
sampling from the prior.
Args:
num_rejection_samples: Number of samples to estimate the acceptance rate.
force_update: Whether to force update the acceptance rate.
show_progress_bars: Whether to show progress bars during sampling.
rejection_sampling_batch_size: Batch size for rejection sampling.
Returns:
Tensor of the estimated acceptance rate.
"""
if self.acceptance_rate is None or force_update:
self.sample(
sample_shape=torch.Size((num_rejection_samples,)),
sample_with="rejection",
show_progress_bars=show_progress_bars,
max_sampling_batch_size=rejection_sampling_batch_size,
save_acceptance_rate=True,
)
# after calling sample, self.acceptance_rate will be a Tensor.
return self.acceptance_rate # type: ignore
@property
def mean(self) -> Tensor:
"""Mean of the restricted prior (not implemented)."""
raise NotImplementedError("Mean is not implemented for RestrictedPrior.")
@property
def variance(self) -> Tensor:
"""Variance of the restricted prior (not implemented)."""
raise NotImplementedError("Variance is not implemented for RestrictedPrior.")
@property
def support(self):
# Return base prior's support or raise an error
try:
return self._prior.support
except AttributeError as e:
raise NotImplementedError(
"Support is not implemented for this RestrictedPrior."
) from e