How to track experiments#

Experiment tracking helps compare model variants and keep a record of hyperparameters and training metrics. By default, sbi logs to TensorBoard. You can also bring your own tracker by implementing the lightweight Tracker protocol and passing it as tracker=....

If using your own tracker (e.g., wandb, mlflow or trackio), note that the run lifecycle (e.g., wandb.init, mlflow.start_run) is handled on the user side.

Define a minimal training setup#

import torch

from sbi.inference import NPE
from sbi.neural_nets import posterior_nn
from sbi.neural_nets.embedding_nets import FCEmbedding
from sbi.utils import BoxUniform

torch.manual_seed(0)

def simulator(theta):
    return theta + 0.1 * torch.randn_like(theta)

prior = BoxUniform(low=-2 * torch.ones(2), high=2 * torch.ones(2))

theta = prior.sample((5000,))
x = simulator(theta)

embedding_net = FCEmbedding(input_dim=x.shape[1], output_dim=32)
density_estimator = posterior_nn(
    model="nsf",
    embedding_net=embedding_net,
    num_transforms=4,
)

train_kwargs = dict(
    max_num_epochs=50,
    training_batch_size=128,
    validation_fraction=0.1,
    show_train_summary=False,
)

Train with a tracker#

By default, sbi uses a TensorBoard tracker to log training loss, validation loss, number of epochs and more.

When you want to track additional quantities, you instantiate the tracker yourself and pass it to the inference class:

from torch.utils.tensorboard.writer import SummaryWriter

from sbi.utils.tracking import TensorBoardTracker

tracker = TensorBoardTracker(SummaryWriter("sbi-logs"))
tracker.log_params({"embedding_dim": 32, "num_transforms": 4})

inference = NPE(prior=prior, tracker=tracker)
inference.append_simulations(theta, x)
estimator = inference.train(**train_kwargs)
posterior = inference.build_posterior(estimator)

View TensorBoard results#

You can then view your tracked run(s) on a TensorBoard shown on your localhost in the browser. By default, sbi will create a log directory sbi-logs at the location the training script was called.

tensorboard --logdir=sbi-logs

Using other trackers#

To enable usage of other trackers, we provide a lightweight Protocol that trackers need to follow. You can implement a small adapter that satisfies the Tracker protocol and pass it to tracker=. Below are minimal examples for common tools.

# W&B adapter (requires `wandb.init()` before training)
class WandBAdapter:
    log_dir = None

    def __init__(self, run):
        self._run = run

    def log_metric(self, name, value, step=None):
        self._run.log({name: value}, step=step)

    def log_metrics(self, metrics, step=None):
        self._run.log(metrics, step=step)

    def log_params(self, params):
        self._run.config.update(params)

    def add_figure(self, name, figure, step=None):
        import wandb
        self._run.log({name: wandb.Image(figure)}, step=step)

    def flush(self):
        pass
# MLflow adapter (configure tracking URI separately)
class MLflowAdapter:
    log_dir = None

    def __init__(self, mlflow):
        self._mlflow = mlflow

    def log_metric(self, name, value, step=None):
        self._mlflow.log_metric(name, value, step=step)

    def log_metrics(self, metrics, step=None):
        for name, value in metrics.items():
            self.log_metric(name, value, step=step)

    def log_params(self, params):
        self._mlflow.log_params(params)

    def add_figure(self, name, figure, step=None):
        self._mlflow.log_figure(figure, f"{name}.png")

    def flush(self):
        pass
# Trackio adapter (requires `trackio.init()` before training)
class TrackioAdapter:
    log_dir = None

    def __init__(self, trackio):
        self._trackio = trackio

    def log_metric(self, name, value, step=None):
        self._trackio.log({name: value}, step=step)

    def log_metrics(self, metrics, step=None):
        self._trackio.log(metrics, step=step)

    def log_params(self, params):
        self._trackio.log(params)

    def add_figure(self, name, figure, step=None):
        self._trackio.log_image(figure, name=name, step=step)

    def flush(self):
        pass

When using external trackers, create an adapter instance and pass it to tracker=:

# wandb.init(...)
tracker = WandBAdapter(wandb.run)
inference = NPE(prior=prior, density_estimator=density_estimator, tracker=tracker)

Log figures#

Trackers can also store matplotlib figures. For example, after training you can log a pairplot:

from sbi.analysis import pairplot

x_o = x[:1]
samples = posterior.sample((1000,), x=x_o)
fig, _ = pairplot(samples)
tracker.add_figure("posterior_pairplot", fig, step=0)

Figure logging depends on the tracker implementation (e.g., wandb.Image, mlflow.log_figure).

Custom training loop (optional)#

If you want to log custom diagnostics per epoch, use the training interface tutorial: https://sbi.readthedocs.io/en/latest/advanced_tutorials/18_training_interface.html.

Notes#

  • Each tool supports richer logging (artifacts, checkpoints, plots), but the patterns above are enough to track hyperparameters, epoch-wise losses, and validation metrics.

  • If you already use Optuna or other sweep tools, you can call the logger inside the objective function to log each trial.