Source code for sbi.inference.posteriors.posterior_parameters

# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from abc import ABC, abstractmethod
from dataclasses import dataclass, replace
from typing import (
    Any,
    Callable,
    Dict,
    Literal,
    Optional,
    Union,
    cast,
    get_args,
    get_origin,
)

from torch.distributions import Distribution

from sbi.inference.posteriors.vi_posterior import VIPosterior
from sbi.sbi_types import TorchTransform
from sbi.utils.typechecks import (
    is_nonnegative_int,
    is_positive_float,
    is_positive_int,
)


@dataclass(frozen=True)
class PosteriorParameters(ABC):
    @abstractmethod
    def validate(self):
        """
        Method for subclasses to override and implement
        custom validation logic. Called at the end of __post_init__.
        """
        ...

    def with_param(self, **kwargs):
        """
        Create a new instance of the class with updated field values.

        Only allows updates to fields defined in the dataclass. Raises an error if
        any unknown or invalid field names are passed.

        Args:
            **kwargs: Field-value pairs to override in the new instance.

        Returns:
            A new instance of the same class with updated values.

        Raises:
            ValueError: If any of the provided keys are not valid dataclass fields.
        """

        valid_fields = set(self.__dataclass_fields__)
        for key in kwargs:
            if key not in valid_fields:
                raise ValueError(
                    f"Invalid parameter: '{key}' is not a valid field"
                    f" of {self.__class__.__name__}"
                )
        return replace(self, **kwargs)

    def __post_init__(self) -> None:
        """
        Performs runtime validation and type enforcement after dataclass initialization.

        - Enforces that fields annotated with `Literal[...]` contain valid values.
        - Attempts to cast fields annotated as primitive types (int, float, bool) to
          their expected types if not already correctly typed.
        - Calls the `validate()` method at the end for additional custom checks.

        Raises:
            ValueError: If any field fails its Literal constraint or cannot be cast to
                        the expected primitive type.
        """

        for field in self.__dataclass_fields__.values():
            field_name = field.name
            raw_value = getattr(self, field_name)
            annotation = field.type
            target_type = cast(type, annotation)

            # Check if the value is among the valid choices
            # defined by a Literal annotation
            if get_origin(annotation) is Literal:
                allowed = get_args(annotation)
                if raw_value not in allowed:
                    raise ValueError(
                        f"Field '{field_name}' must be one of {allowed},"
                        f" got {raw_value}"
                    )
            # Attempt to cast primitive type values to ensure type correctness
            elif target_type in (int, float, bool):
                try:
                    value = target_type(raw_value)
                except Exception as e:
                    raise ValueError(
                        f"Could not convert the value of the field {field} to the "
                        f"expected type {target_type}."
                    ) from e

                # Overwrite the original field value with the converted value
                object.__setattr__(self, field_name, value)

        # Run additional validations specified in subclasses
        self.validate()


[docs] @dataclass(frozen=True) class DirectPosteriorParameters(PosteriorParameters): """ Parameters for initializing DirectPosterior. Fields: max_sampling_batch_size: Batchsize of samples being drawn from the proposal at every iteration. enable_transform: Whether to transform parameters to unconstrained space during MAP optimization. When False, an identity transform will be returned for `theta_transform`. """ max_sampling_batch_size: int = 10_000 enable_transform: bool = True
[docs] def validate(self): """Validate DirectPosteriorParameters fields.""" if not is_positive_int(self.max_sampling_batch_size): raise ValueError("max_sampling_batch_size must be greater than 0.")
@dataclass(frozen=True) class FilteredDirectPosteriorParameters(PosteriorParameters): """Parameters for initializing `FilteredDirectPosterior`. Fields: max_sampling_batch_size: Batchsize of samples drawn from the proposal at every iteration. enable_transform: Whether to transform parameters to unconstrained space during MAP optimization. When False, an identity transform will be returned for `theta_transform`. filter_size: Number of context simulations retained after filtering. filter_type: Filtering strategy. Either `"knn"`, `"first"`, or a callable returning context indices. """ max_sampling_batch_size: int = 10_000 enable_transform: bool = True filter_size: int = 2048 filter_type: Union[Literal["knn", "first"], Callable] = "knn" def validate(self): """Validate `FilteredDirectPosteriorParameters` fields.""" if not is_positive_int(self.max_sampling_batch_size): raise ValueError("max_sampling_batch_size must be greater than 0.") if not is_positive_int(self.filter_size - 1): raise ValueError("filter_size must be greater than 1.") if not ( (isinstance(self.filter_type, str) and self.filter_type in {"knn", "first"}) or callable(self.filter_type) ): raise ValueError( "filter_type must be one of ['knn', 'first'] or a callable." )
[docs] @dataclass(frozen=True) class ImportanceSamplingPosteriorParameters(PosteriorParameters): """ Parameters for initializing ImportanceSamplingPosterior. Fields: theta_transform: Transformation that is applied to parameters. Is not used during but only when calling `.map()`. method: Either of [`sir`|`importance`]. This sets the behavior of the `.sample()` method. With `sir`, approximate posterior samples are generated with sampling importance resampling (SIR). With `importance`, the `.sample()` method returns a tuple of samples and corresponding importance weights. oversampling_factor: Number of proposed samples from which only one is selected based on its importance weight. max_sampling_batch_size: The batch size of samples being drawn from the proposal at every iteration. """ theta_transform: Optional[TorchTransform] = None method: Literal["sir", "importance"] = "sir" oversampling_factor: int = 32 max_sampling_batch_size: int = 10_000
[docs] def validate(self): """Validate ImportanceSamplingPosteriorParameters fields.""" if not ( self.theta_transform is None or isinstance(self.theta_transform, TorchTransform) ): raise TypeError( "theta_transform must be either None or of type TorchTransform" ) if not is_positive_int(self.oversampling_factor): raise ValueError("oversampling_factor must be greater than 0.") if not is_positive_int(self.max_sampling_batch_size): raise ValueError("max_sampling_batch_size must be greater than 0.")
[docs] @dataclass(frozen=True) class MCMCPosteriorParameters(PosteriorParameters): """ Parameters for initializing MCMCPosterior. Fields: method: Method used for MCMC sampling, one of `slice_np`, `slice_np_vectorized`, `hmc_pyro`, `nuts_pyro`, `slice_pymc`, `hmc_pymc`, `nuts_pymc`. `slice_np` is a custom numpy implementation of slice sampling. `slice_np_vectorized` is identical to `slice_np`, but if `num_chains>1`, the chains are vectorized for `slice_np_vectorized` whereas they are run sequentially for `slice_np`. The samplers ending on `_pyro` are using Pyro, and likewise the samplers ending on `_pymc` are using PyMC. thin: The thinning factor for the chain, default 1 (no thinning). warmup_steps: The initial number of samples to discard. num_chains: The number of chains. Should generally be at most `num_workers - 1`. init_strategy: The initialisation strategy for chains; `proposal` will draw init locations from `proposal`, whereas `sir` will use Sequential- Importance-Resampling (SIR). SIR initially samples `init_strategy_num_candidates` from the `proposal`, evaluates all of them under the `potential_fn` and `proposal`, and then resamples the initial locations with weights proportional to `exp(potential_fn - proposal.log_prob`. `resample` is the same as `sir` but uses `exp(potential_fn)` as weights. init_strategy_parameters: Dictionary of keyword arguments passed to the init strategy, e.g., for `init_strategy=sir` this could be `num_candidate_samples`, i.e., the number of candidates to find init locations (internal default is `1000`), or `device`. num_workers: number of cpu cores used to parallelize mcmc mp_context: Multiprocessing start method, either `"fork"` or `"spawn"` (default), used by Pyro and PyMC samplers. `"fork"` can be significantly faster than `"spawn"` but is only supported on POSIX-based systems (e.g. Linux and macOS, not Windows). """ method: Literal[ "slice_np", "slice_np_vectorized", "hmc_pyro", "nuts_pyro", "slice_pymc", "hmc_pymc", "nuts_pymc", ] = "slice_np_vectorized" thin: int = -1 warmup_steps: int = 200 num_chains: int = 20 init_strategy: Literal["proposal", "sir", "resample"] = "resample" init_strategy_parameters: Optional[Dict[str, Any]] = None num_workers: int = 1 mp_context: Literal["fork", "spawn"] = "spawn"
[docs] def validate(self): """Validate MCMCPosteriorParameters fields.""" if not ( self.init_strategy_parameters is None or isinstance(self.init_strategy_parameters, Dict) ): raise TypeError( "init_strategy_parameters must be either None or of type Dict" ) if self.thin != -1 and not (1 <= self.thin <= 10): raise ValueError("thin must be a value between 10 to 1, or -1.") if not is_nonnegative_int(self.warmup_steps): raise ValueError("warmup_steps must be greater than or equal to 0.") if not is_positive_int(self.num_chains): raise ValueError("num_chains must be greater than 0.") if not is_positive_int(self.num_workers): raise ValueError("num_workers must be greater than 0.")
[docs] @dataclass(frozen=True) class RejectionPosteriorParameters(PosteriorParameters): """ Parameters for initializing RejectionPosterior. Fields: max_sampling_batch_size: The batchsize of samples being drawn from the proposal at every iteration. num_samples_to_find_max: The number of samples that are used to find the maximum of the `potential_fn / proposal` ratio. num_iter_to_find_max: The number of gradient ascent iterations to find the maximum of the `potential_fn / proposal` ratio. m: Multiplier to the `potential_fn / proposal` ratio. """ max_sampling_batch_size: int = 10_000 num_samples_to_find_max: int = 10_000 num_iter_to_find_max: int = 100 m: float = 1.2
[docs] def validate(self): """Validate RejectionPosteriorParameters fields.""" if not is_positive_int(self.max_sampling_batch_size): raise ValueError("max_sampling_batch_size must be greater than 0.") if not is_positive_int(self.num_samples_to_find_max): raise ValueError("num_samples_to_find_max must be greater than 0.") if not is_nonnegative_int(self.num_iter_to_find_max): raise ValueError("num_iter_to_find_max must be greater than or equal to 0.") if not is_positive_float(self.m): raise ValueError("m must be greater than 0.")
[docs] @dataclass(frozen=True) class VectorFieldPosteriorParameters(PosteriorParameters): """ Parameters for initializing VectorFieldPosterior. Fields: max_sampling_batch_size: Batchsize of samples being drawn from the proposal at every iteration. enable_transform: Whether to transform parameters to unconstrained space during MAP optimization. When False, an identity transform will be returned for `theta_transform`. True is not supported yet. iid_method: Which method to use for computing the score in the iid setting. We currently support "fnpe", "gauss", "auto_gauss", "jac_gauss". iid_params: Parameters for the iid method, for arguments see `IIDScoreFunction`. neural_ode_backend: The backend to use for the neural ODE. Currently, only "zuko" is supported. neural_ode_kwargs: Additional keyword arguments for the neural ODE. """ max_sampling_batch_size: int = 10_000 enable_transform: bool = True # fields passed from VectorfieldPosterior as keyword arguments # to VectorFieldBasedPotential __init__ method iid_method: Literal["fnpe", "gauss", "auto_gauss", "jac_gauss"] = "auto_gauss" iid_params: Optional[Dict[str, Any]] = None neural_ode_backend: Literal["zuko"] = "zuko" neural_ode_kwargs: Optional[Dict[str, Any]] = None
[docs] def validate(self): """Validate VectorFieldPosteriorParameters fields.""" if not (self.iid_params is None or isinstance(self.iid_params, Dict)): raise TypeError("iid_params must be either None or of type Dict") if not ( self.neural_ode_kwargs is None or isinstance(self.neural_ode_kwargs, Dict) ): raise TypeError("neural_ode_kwargs must be either None or of type Dict") if not is_positive_int(self.max_sampling_batch_size): raise ValueError("max_sampling_batch_size must be greater than 0.")
[docs] @dataclass(frozen=True) class VIPosteriorParameters(PosteriorParameters): """ Parameters for VIPosterior, supporting both single-x and amortized VI. Fields: q: Variational distribution. Either a string specifying the flow type [maf, nsf, naf, unaf, nice, sospf, gf, gaussian, gaussian_diag], a `Distribution`, a `VIPosterior` object, or a `Callable` builder function. For amortized VI, only string flow types are supported. If q is already a `VIPosterior`, arguments are copied from it (relevant for multi-round training). Note: For 1D problems, prefer "gf" (mixture of Gaussians) or "gaussian" as autoregressive flows may be unstable. vi_method: Variational method for fitting q to the posterior. Options: [rKL, fKL, IW, alpha]. Some are "mode seeking" (rKL, alpha > 1) and some are "mass covering" (fKL, IW, alpha < 1). Currently only used for single-x VI; amortized VI uses ELBO (rKL). num_transforms: Number of transforms in the normalizing flow. Used for both single-x VI (via set_q/train) and amortized VI (via train_amortized). hidden_features: Hidden layer size in the flow networks. Used for both single-x VI and amortized VI. z_score_theta: Method for z-scoring θ (the parameters being sampled). One of "none", "independent", "structured". Used for both single-x VI and amortized VI. Use "structured" for parameters with correlations. z_score_x: Method for z-scoring x (the conditioning observation). One of "none", "independent", "structured". Only used for amortized VI (train_amortized). Use "structured" for structured data like images. Note: For custom distributions that lack `parameters()` and `modules()` methods, pass these via `VIPosterior.set_q(q, parameters=..., modules=...)` instead. """ q: Union[ Literal[ "maf", "nsf", "naf", "unaf", "nice", "sospf", "gf", "gaussian", "gaussian_diag", ], Distribution, "VIPosterior", Callable, ] = "maf" vi_method: Literal["rKL", "fKL", "IW", "alpha"] = "rKL" num_transforms: int = 5 hidden_features: int = 50 z_score_theta: Literal["none", "independent", "structured"] = "independent" z_score_x: Literal["none", "independent", "structured"] = "independent"
[docs] def validate(self): """Validate VIPosteriorParameters fields.""" valid_q = { "nsf", "maf", "naf", "unaf", "nice", "sospf", "gf", "gaussian", "gaussian_diag", } if isinstance(self.q, str) and self.q not in valid_q: raise ValueError(f"If `q` is a string, it must be one of {valid_q}") elif not isinstance(self.q, (Distribution, VIPosterior, Callable, str)): raise TypeError( "q must be either of type Distribution, VIPosterior, or Callable" ) if self.num_transforms < 1: raise ValueError(f"num_transforms must be >= 1, got {self.num_transforms}") if self.hidden_features < 1: raise ValueError( f"hidden_features must be >= 1, got {self.hidden_features}" )