How to track experiments#
Experiment tracking helps compare model variants and keep a record of hyperparameters and training metrics. By default, sbi logs to TensorBoard. You can also bring your own tracker by implementing the lightweight Tracker protocol and passing it as tracker=....
If using your own tracker (e.g., wandb, mlflow or trackio), note that the run lifecycle (e.g., wandb.init, mlflow.start_run) is handled on the user side.
Define a minimal training setup#
import torch
from sbi.inference import NPE
from sbi.neural_nets import posterior_nn
from sbi.neural_nets.embedding_nets import FCEmbedding
from sbi.utils import BoxUniform
torch.manual_seed(0)
def simulator(theta):
return theta + 0.1 * torch.randn_like(theta)
prior = BoxUniform(low=-2 * torch.ones(2), high=2 * torch.ones(2))
theta = prior.sample((5000,))
x = simulator(theta)
embedding_net = FCEmbedding(input_dim=x.shape[1], output_dim=32)
density_estimator = posterior_nn(
model="nsf",
embedding_net=embedding_net,
num_transforms=4,
)
train_kwargs = dict(
max_num_epochs=50,
training_batch_size=128,
validation_fraction=0.1,
show_train_summary=False,
)
Train with a tracker#
By default, sbi uses a TensorBoard tracker to log training loss, validation loss,
number of epochs and more.
When you want to track additional quantities, you instantiate the tracker yourself and pass it to the inference class:
from torch.utils.tensorboard.writer import SummaryWriter
from sbi.utils.tracking import TensorBoardTracker
tracker = TensorBoardTracker(SummaryWriter("sbi-logs"))
tracker.log_params({"embedding_dim": 32, "num_transforms": 4})
inference = NPE(prior=prior, tracker=tracker)
inference.append_simulations(theta, x)
estimator = inference.train(**train_kwargs)
posterior = inference.build_posterior(estimator)
View TensorBoard results#
You can then view your tracked run(s) on a TensorBoard shown on your localhost in the
browser. By default, sbi will create a log directory sbi-logs at the location the
training script was called.
tensorboard --logdir=sbi-logs
Using other trackers#
To enable usage of other trackers, we provide a lightweight Protocol that trackers
need to follow. You can implement a small adapter that satisfies the Tracker protocol
and pass it to tracker=. Below are minimal examples for common tools.
# W&B adapter (requires `wandb.init()` before training)
class WandBAdapter:
log_dir = None
def __init__(self, run):
self._run = run
def log_metric(self, name, value, step=None):
self._run.log({name: value}, step=step)
def log_metrics(self, metrics, step=None):
self._run.log(metrics, step=step)
def log_params(self, params):
self._run.config.update(params)
def add_figure(self, name, figure, step=None):
import wandb
self._run.log({name: wandb.Image(figure)}, step=step)
def flush(self):
pass
# MLflow adapter (configure tracking URI separately)
class MLflowAdapter:
log_dir = None
def __init__(self, mlflow):
self._mlflow = mlflow
def log_metric(self, name, value, step=None):
self._mlflow.log_metric(name, value, step=step)
def log_metrics(self, metrics, step=None):
for name, value in metrics.items():
self.log_metric(name, value, step=step)
def log_params(self, params):
self._mlflow.log_params(params)
def add_figure(self, name, figure, step=None):
self._mlflow.log_figure(figure, f"{name}.png")
def flush(self):
pass
# Trackio adapter (requires `trackio.init()` before training)
class TrackioAdapter:
log_dir = None
def __init__(self, trackio):
self._trackio = trackio
def log_metric(self, name, value, step=None):
self._trackio.log({name: value}, step=step)
def log_metrics(self, metrics, step=None):
self._trackio.log(metrics, step=step)
def log_params(self, params):
self._trackio.log(params)
def add_figure(self, name, figure, step=None):
self._trackio.log_image(figure, name=name, step=step)
def flush(self):
pass
When using external trackers, create an adapter instance and pass it to tracker=:
# wandb.init(...)
tracker = WandBAdapter(wandb.run)
inference = NPE(prior=prior, density_estimator=density_estimator, tracker=tracker)
Log figures#
Trackers can also store matplotlib figures. For example, after training you can log a pairplot:
from sbi.analysis import pairplot
x_o = x[:1]
samples = posterior.sample((1000,), x=x_o)
fig, _ = pairplot(samples)
tracker.add_figure("posterior_pairplot", fig, step=0)
Figure logging depends on the tracker implementation (e.g., wandb.Image, mlflow.log_figure).
Custom training loop (optional)#
If you want to log custom diagnostics per epoch, use the training interface tutorial: https://sbi.readthedocs.io/en/latest/advanced_tutorials/18_training_interface.html.
Notes#
Each tool supports richer logging (artifacts, checkpoints, plots), but the patterns above are enough to track hyperparameters, epoch-wise losses, and validation metrics.
If you already use Optuna or other sweep tools, you can call the logger inside the objective function to log each trial.