MarginalTrainer#

class MarginalTrainer(density_estimator=ZukoFlowType.NSF, device='cpu', summary_writer=None, tracker=None, show_progress_bars=True)[source]#

Bases: object

Utility class for training a marginal density estimator \(p(x)\).

The marginal density estimator learns the distribution of simulation outputs \(x\) without conditioning on parameters. In the sbi toolbox, it is primarily used for misspecification diagnostics by comparing the estimated marginal \(p(x)\) to the observed data distribution.

Example:#

import torch
from sbi.inference import MarginalTrainer

# 1. Simulate data
theta = torch.randn(100, 3)
x = theta + torch.randn_like(theta) * 0.1

# 2. Train marginal density estimator
marginal_trainer = MarginalTrainer(density_estimator="nsf")
marginal_estimator = marginal_trainer.append_samples(x).train()

# 3. Evaluate log probability of new observations
x_new = torch.randn(10, 3)
log_prob = marginal_estimator.log_prob(x_new)
get_dataloaders(training_batch_size=200, validation_fraction=0.1, dataloader_kwargs=None)[source]#

Return training and validation dataloaders.

Parameters:
  • training_batch_size (int)

  • validation_fraction (float)

  • dataloader_kwargs (dict | None)

Return type:

Tuple[DataLoader, DataLoader]

append_samples(x)[source]#
Return type:

MarginalTrainer

get_samples()[source]#
Return type:

Tensor

loss(x)[source]#

Return loss.

The loss is the negative log prob

Returns:

Negative log prob.

Parameters:

x (Tensor)

Return type:

Tensor

train(training_batch_size=200, learning_rate=0.0005, validation_fraction=0.1, stop_after_epochs=20, max_num_epochs=2147483647, clip_max_norm=5.0, dataloader_kwargs=None)[source]#

Return density estimator that approximates the distribution \(p(x)\).

Parameters:
  • training_batch_size (int) – Training batch size.

  • learning_rate (float) – Learning rate for Adam optimizer.

  • validation_fraction (float) – The fraction of data to use for validation.

  • stop_after_epochs (int) – The number of epochs to wait for improvement on the validation set before terminating training.

  • max_num_epochs (int) – Maximum number of epochs to run. If reached, we stop training even when the validation loss is still decreasing. Otherwise, we train until validation loss increases (see also stop_after_epochs).

  • clip_max_norm (float | None) – Value at which to clip the total gradient norm in order to prevent exploding gradients. Use None for no clipping.

  • show_train_summary – Whether to print the number of epochs and validation loss after the training.

  • dataloader_kwargs (dict | None) – Additional or updated kwargs to be passed to the training and validation dataloaders (like, e.g., a collate_fn)

Returns:

Density estimator that approximates the distribution \(p(\theta|x)\).

Return type:

UnconditionalDensityEstimator

Parameters: