EP-01: Pluggable Training Infrastructure for sbi#
Status: Discussion Feedback: See GitHub Discussion → EP-01 Discussion
Summary#
This enhancement proposal describes the refactoring of sbi neural network training
infrastructure to address technical debt, improve maintainability, and provide users
with better experiment tracking capabilities. The refactoring will introduce typed
configuration objects, extract a unified training loop, and implement pluggable
interfaces for logging and early stopping while maintaining full backward compatibility.
Motivation#
The sbi package has evolved organically over five years to serve researchers in
simulation-based inference. This growth has resulted in several architectural issues in
the training infrastructure:
Code duplication: Training logic is duplicated across NPE, NLE, NRE, and other inference methods, spanning over 3,000 lines of code
Type safety issues: Unstructured dictionary parameters lead to runtime errors and poor IDE support
Limited logging options: Users are restricted to TensorBoard with no easy way to integrate their preferred experiment tracking tools
Missing features: No built-in early stopping, leading to wasted compute and potential overfitting
Integration barriers: Community contributions like PR #1629 cannot be easily integrated due to architectural constraints
These issues affect both users and maintainers, slowing down development and making the codebase harder to work with.
Goals#
For Users#
Seamless experiment tracking: Support for TensorBoard, WandB, MLflow, and stdout without changing existing code
Clear API for early stopping strategies: Better docs for patience-based (implemented internally), plus possibly a lightweight interface to external early stopping strategies.
Better debugging: Typed configurations with clear error messages
Zero migration effort: Full backward compatibility with existing code
For Maintainers#
Reduced code duplication: Extract shared training logic into reusable components
Type safety: Prevent entire classes of bugs through static typing
Easier feature integration: Clean interfaces for community contributions
Future extensibility: Lightweight interfaces for external tools, e.g., logging and early stoppping.
Non-Goals#
We want to avoid removing code duplication just for the sake of reduced LOC. Sometimes, code duplication is required and preferred in favor of overcomplicated large class structures with unclear separation of concerns. In other words, having clear API interfaces is more important than reducing code duplication.
We also want to avoid adding complexity by aiming implementing all possible features. E.g., we probably should not implement all kinds of early stopping tools internally because this will add maintainance and documentation burden for us, and overhead for the user to understand the API and the docs. Instead, we should implement either a lightweight interface that allows to plug external early stopping tool, or implement just a basic version in our internal training (like we do now), and refer to the flexible training interface when a user wants to use other approaches.
Design#
These are very rough sketches of how this could look like. They should be open for discussion and can be changed substantially when we implement this (s.t. to the non-goals defined above).
Current API#
# Current: Each method has its own training implementation, mixing general training options with method-specific loss options.
inference = NPE(prior=prior)
inference.train(
training_batch_size=50,
learning_rate=5e-4,
validation_fraction=0.1,
stop_after_epochs=20,
max_num_epochs=2**31-1,
clip_max_norm=5.0,
exclude_invalid_x=True,
resume_training=False,
show_train_summary=False,
)
Proposed API#
# Proposed: Cleaner API with typed configurations
from sbi.training import TrainConfig, LossArgs
# Configure training (with IDE autocomplete and validation)
train_config = TrainConfig(
batch_size=50,
learning_rate=5e-4,
max_epochs=1000,
device="cuda"
)
# Method-specific loss configuration
loss_args = LossArgsNPE(exclude_invalid_x=True)
# Train with clean API
inference = NPE(prior=prior)
inference.train(train_config, loss_args)
Logging Interface#
Users can seamlessly switch between logging backends:
from sbi.training import LoggingConfig
# Choose your backend - no other code changes needed
logging = LoggingConfig(backend="wandb", project="my-experiment")
# or: LoggingConfig(backend="tensorboard", log_dir="./runs")
# or: LoggingConfig(backend="mlflow", experiment_name="sbi-run")
# or: LoggingConfig(backend="stdout") # default
inference.train(train_config, loss_args, logging=logging)
Early Stopping#
Multiple strategies available out of the box:
from sbi.training import EarlyStopping
# Stop when validation loss plateaus
early_stop = EarlyStopping.validation_loss(patience=20, min_delta=1e-4)
# Stop when learning rate drops too low
early_stop = EarlyStopping.lr_threshold(min_lr=1e-6)
inference.train(train_config, loss_args, early_stopping=early_stop)
Backward Compatibility#
All existing code continues to work:
# Old API still supported - no breaking changes
inference.train(training_batch_size=100, learning_rate=1e-3)
# Mix old and new as needed during migration
inference.train(
training_batch_size=100, # old style
logging=LoggingConfig(backend="wandb") # new feature
)
Unified Backend#
All inference methods share the same training infrastructure:
# NPE, NLE, NRE all use the same configuration
npe = NPE(prior=prior)
npe.train(train_config, loss_args)
nle = NLE(prior=prior)
nle.train(train_config, loss_args)
Example: Complete Training Pipeline#
from sbi import utils
from sbi.inference import NPE
from sbi.training import TrainConfig, LossArgsNPE, LoggingConfig, EarlyStopping
# Setup simulation
prior = utils.BoxUniform(low=-2*torch.ones(2), high=2*torch.ones(2))
simulator = lambda theta: theta + 0.1 * torch.randn_like(theta)
# Configure training with type safety and autocomplete
config = TrainConfig(
batch_size=100,
learning_rate=1e-3,
max_epochs=1000
)
# Setup logging and early stopping
logging = LoggingConfig(backend="wandb", project="sbi-experiment")
early_stop = EarlyStopping.validation_loss(patience=20)
# Train with new features
inference = NPE(prior=prior)
theta, x = utils.simulate_for_sbi(simulator, prior, num_simulations=10000)
inference.append_simulations(theta, x)
neural_net = inference.train(
config,
LossArgsNPE(exclude_invalid_x=True),
logging=logging,
early_stopping=early_stop
)
Next steps#
Centralizing training logic in base.py has historically increased the size and
responsibilities of the NeuralInference “god class”. As a natural next step, we
propose extracting the entire training loop into a standalone function that takes the
configured options and training components, and returns the trained network (plus
optional artifacts), e.g., something like:
def run_training(
config: TrainConfig,
model: torch.nn.Module,
loss_fn: Callable[..., torch.Tensor],
train_loader: DataLoader,
val_loader: DataLoader | None = None,
optimizer: torch.optim.Optimizer | None = None,
scheduler: torch.optim.lr_scheduler._LRScheduler | None = None,
callbacks: Sequence[Callback] | None = None, # logging, early stopping, etc.
device: str | torch.device | None = None,
) -> tuple[torch.nn.Module, TrainingSummary]:
"""Runs the unified training loop and returns the trained model and summary."""
Benefits:
Shrinks
NeuralInferenceand makes responsibilities explicit.Improves testability (train loop covered independently; inference classes can be tested with lightweight mocks).
Enables pluggable logging/early-stopping via callbacks without entangling method- specific logic.
Keeps backward compatibility: inference classes compose
run_training()internally while still exposing the existing.train(...)entry point.
This should be tackled in a follow-up EP or PR that would introduce run_training()
(and a minimal Callback protocol), migrate NPE/NLE/NRE to call it, and add focused
unit tests for the training runner.
Feedback Wanted#
We welcome feedback and implementation interest in GitHub Discussions:
Which logging backends are most important?
What early stopping strategies would be useful?
Any concerns about the proposed API?
What do you think about the external training function?
Discussion thread: EP-01 Discussion
References#
PR #1629: Community early stopping implementation
NUMFOCUS SDG Proposal: Related funding proposal