# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
from typing import Any, Dict, Literal, Optional, Union
from torch.distributions import Distribution
from torch.utils.tensorboard.writer import SummaryWriter
from sbi.inference.posteriors.base_posterior import NeuralPosterior
from sbi.inference.posteriors.posterior_parameters import (
ImportanceSamplingPosteriorParameters,
MCMCPosteriorParameters,
RejectionPosteriorParameters,
VIPosteriorParameters,
)
from sbi.inference.trainers.nle.nle_base import LikelihoodEstimatorTrainer
from sbi.neural_nets.estimators import MixedDensityEstimator
from sbi.neural_nets.estimators.base import ConditionalEstimatorBuilder
from sbi.sbi_types import Tracker
from sbi.utils.sbiutils import del_entries
[docs]
class MNLE(LikelihoodEstimatorTrainer):
r"""Mixed Neural Likelihood Estimation for discrete and continuous data [1].
MNLE extends NLE to handle data with mixed types, such as continuous reaction
times and discrete choices in decision-making experiments. It trains a neural
network to approximate the likelihood $p(x|\theta)$ where $x$ contains both
discrete and continuous components.
[1] Flexible and efficient simulation-based inference for models of
decision-making, Boelts et al., eLife 2022,
https://www.biorxiv.org/content/10.1101/2021.12.22.473472v2
Example:
--------
::
import torch
from sbi.inference import MNLE
from sbi.utils import BoxUniform
# 1. Setup prior and simulate mixed-type data
prior = BoxUniform(low=torch.zeros(3), high=torch.ones(3))
theta = prior.sample((100,))
# First 5 dims continuous, last 3 dims discrete
x_continuous = torch.randn(100, 5)
x_discrete = torch.randint(0, 3, (100, 3))
x = torch.cat([x_continuous, x_discrete.float()], dim=1)
# 2. Train likelihood estimator
inference = MNLE(prior=prior)
likelihood_estimator = inference.append_simulations(theta, x).train()
# 3. Build posterior
posterior = inference.build_posterior(likelihood_estimator)
# 4. Sample from posterior
x_o = torch.cat([torch.randn(1, 5), torch.tensor([[1., 0., 2.]])], dim=1)
samples = posterior.sample((1000,), x=x_o)
"""
def __init__(
self,
prior: Optional[Distribution] = None,
density_estimator: Union[
Literal["mnle"],
ConditionalEstimatorBuilder[MixedDensityEstimator],
] = "mnle",
device: str = "cpu",
logging_level: Union[int, str] = "WARNING",
summary_writer: Optional[SummaryWriter] = None,
tracker: Optional[Tracker] = None,
show_progress_bars: bool = True,
):
r"""Initialize MNLE.
Args:
prior: A probability distribution that expresses prior knowledge about the
parameters, e.g. which ranges are meaningful for them. If `None`, the
prior must be passed to `.build_posterior()`.
density_estimator: If it is a string, it must be "mnle" to use the
preconfiugred neural nets for MNLE. Alternatively, a function
that builds a custom neural network, which adheres to
`ConditionalEstimatorBuilder` protocol can be provided. The function
will be called with the first batch of simulations (theta, x), which can
thus be used for shape inference and potentially for z-scoring. The
density estimator needs to provide the methods `.log_prob` and
`.sample()` and must return a `MixedDensityEstimator`.
device: Training device, e.g., "cpu", "cuda" or "cuda:{0, 1, ...}".
logging_level: Minimum severity of messages to log. One of the strings
INFO, WARNING, DEBUG, ERROR and CRITICAL.
summary_writer: Deprecated alias for the TensorBoard summary writer.
Use ``tracker`` instead.
tracker: Tracking adapter used to log training metrics. If None, a
TensorBoard tracker is used with a default log directory.
show_progress_bars: Whether to show a progressbar during simulation and
sampling.
"""
if isinstance(density_estimator, str):
assert (
density_estimator == "mnle"
), f"""MNLE can be used with preconfigured 'mnle' density estimator only,
not with {density_estimator}."""
kwargs = del_entries(locals(), entries=("self", "__class__"))
super().__init__(**kwargs)
[docs]
def train(
self,
training_batch_size: int = 200,
learning_rate: float = 5e-4,
validation_fraction: float = 0.1,
stop_after_epochs: int = 20,
max_num_epochs: int = 2**31 - 1,
clip_max_norm: Optional[float] = 5.0,
resume_training: bool = False,
discard_prior_samples: bool = False,
retrain_from_scratch: bool = False,
show_train_summary: bool = False,
dataloader_kwargs: Optional[Dict] = None,
) -> MixedDensityEstimator:
density_estimator = super().train(
**del_entries(locals(), entries=("self", "__class__"))
)
assert isinstance(
density_estimator, MixedDensityEstimator
), f"""Internal net must be of type
MixedDensityEstimator but is {type(density_estimator)}."""
return density_estimator
[docs]
def build_posterior(
self,
density_estimator: Optional[MixedDensityEstimator] = None,
prior: Optional[Distribution] = None,
sample_with: Literal["mcmc", "rejection", "vi"] = "mcmc",
mcmc_method: Literal[
"slice_np",
"slice_np_vectorized",
"hmc_pyro",
"nuts_pyro",
"slice_pymc",
"hmc_pymc",
"nuts_pymc",
] = "slice_np_vectorized",
vi_method: Literal["rKL", "fKL", "IW", "alpha"] = "rKL",
mcmc_parameters: Optional[Dict[str, Any]] = None,
vi_parameters: Optional[Dict[str, Any]] = None,
rejection_sampling_parameters: Optional[Dict[str, Any]] = None,
importance_sampling_parameters: Optional[Dict[str, Any]] = None,
posterior_parameters: Optional[
Union[
MCMCPosteriorParameters,
VIPosteriorParameters,
RejectionPosteriorParameters,
ImportanceSamplingPosteriorParameters,
]
] = None,
) -> NeuralPosterior:
r"""Build posterior from the neural density estimator.
SNLE trains a neural network to approximate the likelihood $p(x|\theta)$. The
posterior wraps the trained network such that one can directly evaluate the
unnormalized posterior log probability $p(\theta|x) \propto p(x|\theta) \cdot
p(\theta)$ and draw samples from the posterior with MCMC or rejection sampling.
Args:
density_estimator: The density estimator that the posterior is based on.
If `None`, use the latest neural density estimator that was trained.
prior: Prior distribution.
sample_with: Method to use for sampling from the posterior. Must be one of
[`mcmc` | `rejection` | `vi`].
mcmc_method: Method used for MCMC sampling, one of `slice_np`,
`slice_np_vectorized`, `hmc_pyro`, `nuts_pyro`, `slice_pymc`,
`hmc_pymc`, `nuts_pymc`. `slice_np` is a custom
numpy implementation of slice sampling. `slice_np_vectorized` is
identical to `slice_np`, but if `num_chains>1`, the chains are
vectorized for `slice_np_vectorized` whereas they are run sequentially
for `slice_np`. The samplers ending on `_pyro` are using Pyro, and
likewise the samplers ending on `_pymc` are using PyMC.
vi_method: Method used for VI, one of [`rKL`, `fKL`, `IW`, `alpha`]. Note
some of the methods admit a `mode seeking` property (e.g. rKL) whereas
some admit a `mass covering` one (e.g fKL).
mcmc_parameters: Additional kwargs passed to `MCMCPosterior`.
vi_parameters: Additional kwargs passed to `VIPosterior`.
rejection_sampling_parameters: Additional kwargs passed to
`RejectionPosterior`.
importance_sampling_parameters: Additional kwargs passed to
`ImportanceSamplingPosterior`
posterior_parameters: Configuration passed to the init method for the
posterior. Must be one of the following
- `VIPosteriorParameters`
- `MCMCPosteriorParameters`
- `RejectionPosteriorParameters`
- `ImportanceSamplingPosteriorParameters`
Returns:
Posterior $p(\theta|x)$ with `.sample()` and `.log_prob()` methods
(the returned log-probability is unnormalized).
"""
if density_estimator is not None:
assert isinstance(
density_estimator, MixedDensityEstimator
), f"""net must be of type MixedDensityEstimator but is {
type(density_estimator)
}."""
return super().build_posterior(
density_estimator=density_estimator,
prior=prior,
sample_with=sample_with,
posterior_parameters=posterior_parameters,
mcmc_method=mcmc_method,
vi_method=vi_method,
mcmc_parameters=mcmc_parameters,
vi_parameters=vi_parameters,
rejection_sampling_parameters=rejection_sampling_parameters,
importance_sampling_parameters=importance_sampling_parameters,
)