MarginalTrainer#
- class MarginalTrainer(density_estimator=ZukoFlowType.NSF, device='cpu', summary_writer=None, tracker=None, show_progress_bars=True)[source]#
Bases:
objectUtility class for training a marginal density estimator \(p(x)\).
The marginal density estimator learns the distribution of simulation outputs \(x\) without conditioning on parameters. In the sbi toolbox, it is primarily used for misspecification diagnostics by comparing the estimated marginal \(p(x)\) to the observed data distribution.
Example:#
import torch from sbi.inference import MarginalTrainer # 1. Simulate data theta = torch.randn(100, 3) x = theta + torch.randn_like(theta) * 0.1 # 2. Train marginal density estimator marginal_trainer = MarginalTrainer(density_estimator="nsf") marginal_estimator = marginal_trainer.append_samples(x).train() # 3. Evaluate log probability of new observations x_new = torch.randn(10, 3) log_prob = marginal_estimator.log_prob(x_new)
- get_dataloaders(training_batch_size=200, validation_fraction=0.1, dataloader_kwargs=None)[source]#
Return training and validation dataloaders.
- Parameters:
- Return type:
- train(training_batch_size=200, learning_rate=0.0005, validation_fraction=0.1, stop_after_epochs=20, max_num_epochs=2147483647, clip_max_norm=5.0, dataloader_kwargs=None)[source]#
Return density estimator that approximates the distribution \(p(x)\).
- Parameters:
training_batch_size (int) – Training batch size.
learning_rate (float) – Learning rate for Adam optimizer.
validation_fraction (float) – The fraction of data to use for validation.
stop_after_epochs (int) – The number of epochs to wait for improvement on the validation set before terminating training.
max_num_epochs (int) – Maximum number of epochs to run. If reached, we stop training even when the validation loss is still decreasing. Otherwise, we train until validation loss increases (see also stop_after_epochs).
clip_max_norm (float | None) – Value at which to clip the total gradient norm in order to prevent exploding gradients. Use None for no clipping.
show_train_summary – Whether to print the number of epochs and validation loss after the training.
dataloader_kwargs (dict | None) – Additional or updated kwargs to be passed to the training and validation dataloaders (like, e.g., a collate_fn)
- Returns:
Density estimator that approximates the distribution \(p(\theta|x)\).
- Return type:
UnconditionalDensityEstimator