# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
import warnings
from typing import Any, Dict, Literal, Optional, Union
from torch.distributions import Distribution
from torch.utils.tensorboard.writer import SummaryWriter
from sbi import utils as utils
from sbi.inference.posteriors.base_posterior import NeuralPosterior
from sbi.inference.posteriors.posterior_parameters import VectorFieldPosteriorParameters
from sbi.inference.trainers.vfpe.base_vf_inference import (
VectorFieldTrainer,
)
from sbi.neural_nets.estimators.base import (
ConditionalEstimatorBuilder,
ConditionalVectorFieldEstimator,
)
from sbi.neural_nets.factory import posterior_flow_nn
from sbi.sbi_types import Tracker
[docs]
class FMPE(VectorFieldTrainer):
r"""Flow Matching Posterior Estimation (FMPE) [1].
FMPE trains a continuous normalizing flow (CNF) to transform samples from the
prior distribution to the posterior distribution using flow matching. Instead of
maximum likelihood, it trains a vector field to match the marginal vector field
of a conditional flow that interpolates between the prior and posterior. The
neural network architecture for the vector field is not constrained like for
flows and can be any expressive network. Sampling is performed by solving an ODE,
which can be slower than flow-based NPE, but log_prob evaluation can also be slower.
NOTE: FMPE does not support multi-round inference with flexible proposals yet.
You can try multi-round with truncated proposals, but this is not tested.
[1] Flow Matching for Generative Modeling, Lipman et al., ICLR 2023,
https://arxiv.org/abs/2210.02747
Example:
--------
::
import torch
from sbi.inference import FMPE
from sbi.utils import BoxUniform
# 1. Setup prior and simulate data
prior = BoxUniform(low=torch.zeros(3), high=torch.ones(3))
theta = prior.sample((100,))
x = theta + torch.randn_like(theta) * 0.1
# 2. Train flow matching estimator
inference = FMPE(prior=prior)
flow_estimator = inference.append_simulations(theta, x).train()
# 3. Build posterior (uses ODE solver for sampling)
posterior = inference.build_posterior(flow_estimator)
# 4. Sample from posterior
x_o = torch.randn(1, 3)
samples = posterior.sample((1000,), x=x_o)
"""
def __init__(
self,
prior: Optional[Distribution] = None,
vf_estimator: Union[
Literal["mlp", "ada_mlp", "transformer", "transformer_cross_attn"],
ConditionalEstimatorBuilder[ConditionalVectorFieldEstimator],
] = "mlp",
density_estimator: Optional[
ConditionalEstimatorBuilder[ConditionalVectorFieldEstimator]
] = None,
device: str = "cpu",
logging_level: Union[int, str] = "WARNING",
summary_writer: Optional[SummaryWriter] = None,
tracker: Optional[Tracker] = None,
show_progress_bars: bool = True,
) -> None:
"""Initialization method for the FMPE class.
Args:
prior: Prior distribution.
vf_estimator: Neural network architecture used to learn the
vector field estimator. Can be a string (e.g. 'mlp', 'ada_mlp',
'transformer' or 'transformer_cross_attn') or a callable that implements
the `ConditionalEstimatorBuilder` protocol with `__call__` that receives
`theta` and `x` and returns a `ConditionalVectorFieldEstimator`.
To configure estimator-level options, use `posterior_flow_nn` to
build a custom callable and pass it here.
density_estimator: Deprecated. Use `vf_estimator` instead.
device: Device to use for training.
logging_level: Logging level.
summary_writer: Deprecated alias for the TensorBoard summary writer.
Use ``tracker`` instead.
tracker: Tracking adapter used to log training metrics. If None, a
TensorBoard tracker is used with a default log directory.
show_progress_bars: Whether to show progress bars.
"""
if density_estimator is not None:
warnings.warn(
"`density_estimator` is deprecated and will be removed in a future "
"release. Use `vf_estimator` instead.",
FutureWarning,
stacklevel=2,
)
vf_estimator = density_estimator
super().__init__(
prior=prior,
device=device,
logging_level=logging_level,
summary_writer=summary_writer,
tracker=tracker,
show_progress_bars=show_progress_bars,
vector_field_estimator_builder=vf_estimator,
)
# When vf_estimator is a string, build the default neural net.
if isinstance(vf_estimator, str):
self._build_neural_net = self._build_default_nn_fn(model=vf_estimator)
[docs]
def build_posterior(
self,
vector_field_estimator: Optional[ConditionalVectorFieldEstimator] = None,
prior: Optional[Distribution] = None,
sample_with: Literal["ode", "sde"] = "ode",
vectorfield_sampling_parameters: Optional[Dict[str, Any]] = None,
posterior_parameters: Optional[VectorFieldPosteriorParameters] = None,
) -> NeuralPosterior:
r"""Build posterior from the flow matching estimator.
Note that this is the same as the NPSE posterior, but the sample_with method
is set to "ode" by default.
For FMPE, the posterior distribution that is returned here implements
the following functionality over the raw neural density estimator:
- correct the calculation of the log probability such that
samples outside of the prior bounds have log probability -inf.
- reject samples that lie outside of the prior bounds.
Args:
vector_field_estimator: The flow matching estimator that
the posterior is based on. If `None`, use the latest neural
flow matching estimator that was trained.
prior: Prior distribution.
sample_with: Method to use for sampling from the posterior.
Can be one of 'ode' (default) or 'sde'. The 'ode' method uses
the velocity field to define a probabilistic ODE and solves it
with a numerical ODE solver. The 'sde' method uses the score to
do a Langevin diffusion step.
vectorfield_sampling_parameters: Additional keyword arguments passed to
`VectorFieldPosterior`.
posterior_parameters: Configuration passed to the init method for
VectorFieldPosterior.
Returns:
Posterior $p(\theta|x)$ with `.sample()` and `.log_prob()` methods.
"""
return super().build_posterior(
estimator=vector_field_estimator,
prior=prior,
sample_with=sample_with,
vectorfield_sampling_parameters=vectorfield_sampling_parameters,
posterior_parameters=posterior_parameters,
)
def _build_default_nn_fn(
self,
model: Literal["mlp", "ada_mlp", "transformer", "transformer_cross_attn"],
) -> ConditionalEstimatorBuilder[ConditionalVectorFieldEstimator]:
return posterior_flow_nn(model=model)