Source code for sbi.inference.posteriors.mcmc_posterior

# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
import inspect
import warnings
from copy import deepcopy
from functools import partial
from math import ceil
from typing import Any, Callable, Dict, Literal, Optional, Union
from warnings import warn

import torch
import torch.distributions.transforms as torch_tf
from joblib import Parallel, delayed
from numpy import ndarray
from torch import Tensor
from torch import multiprocessing as mp
from tqdm.auto import tqdm

from sbi.inference.posteriors.base_posterior import NeuralPosterior
from sbi.inference.potentials.base_potential import BasePotential
from sbi.neural_nets.estimators.shape_handling import reshape_to_batch_event
from sbi.samplers.mcmc import (
    IterateParameters,
    SliceSamplerSerial,
    SliceSamplerVectorized,
    proposal_init,
    resample_given_potential_fn,
    sir_init,
)
from sbi.sbi_types import Shape, TorchTransform
from sbi.utils import mcmc_transform
from sbi.utils.potentialutils import pyro_potential_wrapper, transformed_potential
from sbi.utils.torchutils import ensure_theta_batched, tensor2numpy


[docs] class MCMCPosterior(NeuralPosterior): r"""Provides MCMC to sample from the posterior. SNLE or SNRE train neural networks to approximate the likelihood(-ratios). `MCMCPosterior` allows to sample from the posterior with MCMC. """ def __init__( self, potential_fn: Union[Callable, BasePotential], proposal: Any, theta_transform: Optional[TorchTransform] = None, method: Literal[ "slice_np", "slice_np_vectorized", "hmc_pyro", "nuts_pyro", "slice_pymc", "hmc_pymc", "nuts_pymc", ] = "slice_np_vectorized", thin: int = -1, warmup_steps: int = 200, num_chains: int = 20, init_strategy: Literal["proposal", "sir", "resample"] = "resample", init_strategy_parameters: Optional[Dict[str, Any]] = None, init_strategy_num_candidates: Optional[int] = None, num_workers: int = 1, mp_context: Literal["fork", "spawn"] = "spawn", device: Optional[Union[str, torch.device]] = None, x_shape: Optional[torch.Size] = None, ): """ Args: potential_fn: The potential function from which to draw samples. Must be a `BasePotential` or a `Callable` which takes `theta` and `x_o` as inputs. proposal: Proposal distribution that is used to initialize the MCMC chain. theta_transform: Transformation that will be applied during sampling. Allows to perform MCMC in unconstrained space. method: Method used for MCMC sampling, one of `slice_np`, `slice_np_vectorized`, `hmc_pyro`, `nuts_pyro`, `slice_pymc`, `hmc_pymc`, `nuts_pymc`. `slice_np` is a custom numpy implementation of slice sampling. `slice_np_vectorized` is identical to `slice_np`, but if `num_chains>1`, the chains are vectorized for `slice_np_vectorized` whereas they are run sequentially for `slice_np`. The samplers ending on `_pyro` are using Pyro, and likewise the samplers ending on `_pymc` are using PyMC. thin: The thinning factor for the chain, default 1 (no thinning). warmup_steps: The initial number of samples to discard. num_chains: The number of chains. Should generally be at most `num_workers - 1`. init_strategy: The initialisation strategy for chains; `proposal` will draw init locations from `proposal`, whereas `sir` will use Sequential- Importance-Resampling (SIR). SIR initially samples `init_strategy_num_candidates` from the `proposal`, evaluates all of them under the `potential_fn` and `proposal`, and then resamples the initial locations with weights proportional to `exp(potential_fn - proposal.log_prob`. `resample` is the same as `sir` but uses `exp(potential_fn)` as weights. init_strategy_parameters: Dictionary of keyword arguments passed to the init strategy, e.g., for `init_strategy=sir` this could be `num_candidate_samples`, i.e., the number of candidates to find init locations (internal default is `1000`), or `device`. init_strategy_num_candidates: Number of candidates to find init locations in `init_strategy=sir` (deprecated, use init_strategy_parameters instead). num_workers: number of cpu cores used to parallelize mcmc mp_context: Multiprocessing start method, either `"fork"` or `"spawn"` (default), used by Pyro and PyMC samplers. `"fork"` can be significantly faster than `"spawn"` but is only supported on POSIX-based systems (e.g. Linux and macOS, not Windows). device: Training device, e.g., "cpu", "cuda" or "cuda:0". If None, `potential_fn.device` is used. x_shape: Deprecated, should not be passed. """ if method == "slice": warn( "The Pyro-based slice sampler is deprecated, and the method `slice` " "has been changed to `slice_np`, i.e., the custom " "numpy-based slice sampler.", DeprecationWarning, stacklevel=2, ) method = "slice_np" thin = _process_thin_default(thin) super().__init__( potential_fn, theta_transform=theta_transform, device=device, x_shape=x_shape, ) self.proposal = proposal self.method = method self.thin = thin self.warmup_steps = warmup_steps self.num_chains = num_chains self.init_strategy = init_strategy self.init_strategy_parameters = init_strategy_parameters or {} self.num_workers = num_workers self.mp_context = mp_context self._posterior_sampler = None # Hardcode parameter name to reduce clutter kwargs. self.param_name = "theta" self.x_shape = x_shape if init_strategy_num_candidates is not None: warn( "Passing `init_strategy_num_candidates` is deprecated as of sbi " "v0.19.0. Instead, use e.g., `init_strategy_parameters " f"={'num_candidate_samples': 1000}`", stacklevel=2, ) self.init_strategy_parameters["num_candidate_samples"] = ( init_strategy_num_candidates ) self.potential_ = self._prepare_potential(method) self._purpose = ( "It provides MCMC to .sample() from the posterior and " "can evaluate the _unnormalized_ posterior density with .log_prob()." )
[docs] def to(self, device: Union[str, torch.device]) -> None: """Moves potential_fn, proposal, x_o and theta_transform to the specified device. Reinstantiates the posterior and resets the default x_o. Args: device: Device to move the posterior to. """ self.device = device self.potential_fn.to(device) # type: ignore self.proposal.to(device) x_o = None if hasattr(self, "_x") and (self._x is not None): x_o = self._x.to(device) self.theta_transform = mcmc_transform(self.proposal, device=device) super().__init__( self.potential_fn, theta_transform=self.theta_transform, device=device, x_shape=self.x_shape, ) # super().__init__ erases the self._x, so we need to set it again if x_o is not None: self.set_default_x(x_o) self.potential_ = self._prepare_potential(self.method)
@property def mcmc_method(self) -> str: """Returns MCMC method.""" return self._mcmc_method @mcmc_method.setter def mcmc_method(self, method: str) -> None: """See `set_mcmc_method`.""" self.set_mcmc_method(method) @property def posterior_sampler(self): """Returns sampler created by `sample`.""" return self._posterior_sampler
[docs] def set_mcmc_method(self, method: str) -> "NeuralPosterior": """Sets sampling method to for MCMC and returns `NeuralPosterior`. Args: method: Method to use. Returns: `NeuralPosterior` for chainable calls. """ self._mcmc_method = method return self
[docs] def log_prob( self, theta: Tensor, x: Optional[Tensor] = None, track_gradients: bool = False ) -> Tensor: r"""Returns the log-probability of theta under the posterior. Args: theta: Parameters $\theta$. track_gradients: Whether the returned tensor supports tracking gradients. This can be helpful for e.g. sensitivity analysis, but increases memory consumption. Returns: `len($\theta$)`-shaped log-probability. """ warn( "`.log_prob()` is deprecated for methods that can only evaluate the " "log-probability up to a normalizing constant. Use `.potential()` instead.", stacklevel=2, ) warn("The log-probability is unnormalized!", stacklevel=2) x = self._x_else_default_x(x) self.potential_fn.set_x(x, x_is_iid=True) theta = ensure_theta_batched(torch.as_tensor(theta)) return self.potential_fn( theta.to(self._device), track_gradients=track_gradients )
[docs] def sample( self, sample_shape: Shape = torch.Size(), x: Optional[Tensor] = None, method: Optional[str] = None, thin: Optional[int] = None, warmup_steps: Optional[int] = None, num_chains: Optional[int] = None, init_strategy: Optional[str] = None, init_strategy_parameters: Optional[Dict[str, Any]] = None, num_workers: Optional[int] = None, mp_context: Optional[str] = None, show_progress_bars: bool = True, ) -> Tensor: r"""Draw samples from the approximate posterior distribution $p(\theta|x)$. Args: sample_shape: Desired shape of samples that are drawn from posterior. If sample_shape is multidimensional we simply draw `sample_shape.numel()` samples and then reshape into the desired shape. x: Conditioning observation $x_o$. If not provided, uses the default `x` set via `.set_default_x()`. method: MCMC method to use. One of `slice_np`, `slice_np_vectorized`, `hmc_pyro`, `nuts_pyro`, `slice_pymc`, `hmc_pymc`, `nuts_pymc`. If not provided, uses the method specified at initialization. thin: Thinning factor for the chain. If not provided, uses the value specified at initialization. warmup_steps: Number of warmup steps to discard. If not provided, uses the value specified at initialization. num_chains: Number of MCMC chains to run. If not provided, uses the value specified at initialization. init_strategy: Initialization strategy for chains (`proposal`, `sir`, or `resample`). If not provided, uses the value specified at initialization. init_strategy_parameters: Parameters for the initialization strategy. If not provided, uses the value specified at initialization. num_workers: Number of CPU cores for parallelization. If not provided, uses the value specified at initialization. mp_context: Multiprocessing context (`fork` or `spawn`). If not provided, uses the value specified at initialization. show_progress_bars: Whether to show sampling progress monitor. Returns: Samples from posterior. """ x = self._x_else_default_x(x) self.potential_fn.set_x(x, x_is_iid=True) # Replace arguments that were not passed with their default. method = self.method if method is None else method thin = self.thin if thin is None else thin warmup_steps = self.warmup_steps if warmup_steps is None else warmup_steps num_chains = self.num_chains if num_chains is None else num_chains init_strategy = self.init_strategy if init_strategy is None else init_strategy num_workers = self.num_workers if num_workers is None else num_workers mp_context = self.mp_context if mp_context is None else mp_context init_strategy_parameters = ( self.init_strategy_parameters if init_strategy_parameters is None else init_strategy_parameters ) self.potential_ = self._prepare_potential(method) # type: ignore initial_params = self._get_initial_params( init_strategy, # type: ignore num_chains, # type: ignore num_workers, show_progress_bars, **init_strategy_parameters, ) num_samples = torch.Size(sample_shape).numel() track_gradients = method in ("hmc_pyro", "nuts_pyro", "hmc_pymc", "nuts_pymc") with torch.set_grad_enabled(track_gradients): if method in ("slice_np", "slice_np_vectorized"): transformed_samples = self._slice_np_mcmc( num_samples=num_samples, potential_function=self.potential_, initial_params=initial_params, thin=thin, # type: ignore warmup_steps=warmup_steps, # type: ignore vectorized=(method == "slice_np_vectorized"), interchangeable_chains=True, num_workers=num_workers, show_progress_bars=show_progress_bars, ) elif method in ("hmc_pyro", "nuts_pyro"): transformed_samples = self._pyro_mcmc( num_samples=num_samples, potential_function=self.potential_, initial_params=initial_params, mcmc_method=method, # type: ignore thin=thin, # type: ignore warmup_steps=warmup_steps, # type: ignore num_chains=num_chains, show_progress_bars=show_progress_bars, mp_context=mp_context, ) elif method in ("hmc_pymc", "nuts_pymc", "slice_pymc"): transformed_samples = self._pymc_mcmc( num_samples=num_samples, potential_function=self.potential_, initial_params=initial_params, mcmc_method=method, # type: ignore thin=thin, # type: ignore warmup_steps=warmup_steps, # type: ignore num_chains=num_chains, show_progress_bars=show_progress_bars, mp_context=mp_context, ) else: raise NameError(f"The sampling method {method} is not implemented!") samples = self.theta_transform.inv(transformed_samples) # NOTE: Currently MCMCPosteriors will require a single dimension for the # parameter dimension. With recent ConditionalDensity(Ratio) estimators, we # can have multiple dimensions for the parameter dimension. samples = samples.reshape((*sample_shape, -1)) # type: ignore return samples
[docs] def sample_batched( self, sample_shape: Shape, x: Tensor, method: Optional[str] = None, thin: Optional[int] = None, warmup_steps: Optional[int] = None, num_chains: Optional[int] = None, init_strategy: Optional[str] = None, init_strategy_parameters: Optional[Dict[str, Any]] = None, num_workers: Optional[int] = None, mp_context: Optional[str] = None, show_progress_bars: bool = True, ) -> Tensor: r"""Draw samples from the posteriors for a batch of different xs. Given a batch of observations `[x_1, ..., x_B]`, this method samples from posteriors $p(\theta|x_1), \ldots, p(\theta|x_B)$ in a vectorized manner. Check the `__init__()` method for a description of all arguments as well as their default values. Args: sample_shape: Desired shape of samples that are drawn from the posterior given every observation. x: A batch of observations, of shape `(batch_dim, event_shape_x)`. `batch_dim` corresponds to the number of observations to be drawn. method: Method used for MCMC sampling, e.g., "slice_np_vectorized". thin: The thinning factor for the chain, default 1 (no thinning). warmup_steps: The initial number of samples to discard. num_chains: The number of chains used for each `x` passed in the batch. init_strategy: The initialisation strategy for chains. init_strategy_parameters: Dictionary of keyword arguments passed to the init strategy. num_workers: number of cpu cores used to parallelize initial parameter generation and mcmc sampling. mp_context: Multiprocessing start method, either `"fork"` or `"spawn"` show_progress_bars: Whether to show sampling progress monitor. Returns: Samples from the posteriors of shape (*sample_shape, B, *input_shape) """ # Replace arguments that were not passed with their default. method = self.method if method is None else method thin = self.thin if thin is None else thin warmup_steps = self.warmup_steps if warmup_steps is None else warmup_steps num_chains = self.num_chains if num_chains is None else num_chains init_strategy = self.init_strategy if init_strategy is None else init_strategy num_workers = self.num_workers if num_workers is None else num_workers mp_context = self.mp_context if mp_context is None else mp_context init_strategy_parameters = ( self.init_strategy_parameters if init_strategy_parameters is None else init_strategy_parameters ) assert method == "slice_np_vectorized", ( "Batched sampling only supported for vectorized samplers!" ) # warn if num_chains is larger than num requested samples if num_chains > torch.Size(sample_shape).numel(): warnings.warn( "The passed number of MCMC chains is larger than the number of " f"requested samples: {num_chains} > {torch.Size(sample_shape).numel()}," f" resetting it to {torch.Size(sample_shape).numel()}.", stacklevel=2, ) num_chains = torch.Size(sample_shape).numel() # custom shape handling to make sure to match the batch size of x and theta # without unnecessary combinations. if len(x.shape) == 1: x = x.unsqueeze(0) batch_size = x.shape[0] x = reshape_to_batch_event(x, event_shape=x.shape[1:]) # For batched sampling, we want `num_chains` for each observation in the batch. # Here we repeat the observations ABC -> AAABBBCCC, so that the chains are # in the order of the observations. x_ = x.repeat_interleave(num_chains, dim=0) self.potential_fn.set_x(x_, x_is_iid=False) self.potential_ = self._prepare_potential(method) # type: ignore # For each observation in the batch, we have num_chains independent chains. num_chains_extended = batch_size * num_chains if num_chains_extended > 100: warnings.warn( "Note that for batched sampling, we use num_chains many chains " "for each x in the batch. With the given settings, this results " f"in a large number of chains ({num_chains_extended}), which can " "be slow and memory-intensive for vectorized MCMC. Consider " "reducing the number of chains or batch size.", stacklevel=2, ) init_strategy_parameters["num_return_samples"] = num_chains_extended initial_params = self._get_initial_params_batched( x, init_strategy, # type: ignore num_chains, # type: ignore num_workers, show_progress_bars, **init_strategy_parameters, ) # We need num_samples from each posterior in the batch num_samples = torch.Size(sample_shape).numel() * batch_size with torch.set_grad_enabled(False): transformed_samples = self._slice_np_mcmc( num_samples=num_samples, potential_function=self.potential_, initial_params=initial_params, thin=thin, # type: ignore warmup_steps=warmup_steps, # type: ignore vectorized=(method == "slice_np_vectorized"), interchangeable_chains=False, num_workers=num_workers, show_progress_bars=show_progress_bars, ) # (num_chains_extended, samples_per_chain, *input_shape) samples_per_chain: Tensor = self.theta_transform.inv(transformed_samples) # type: ignore dim_theta = samples_per_chain.shape[-1] # We need to collect samples for each x from the respective chains. # However, using samples.reshape(*sample_shape, batch_size, dim_theta) # does not combine the samples in the right order, since this mixes # samples that belong to different `x`. The following permute is a # workaround to reshape the samples in the right order. samples_per_x = samples_per_chain.reshape(( batch_size, # We are flattening the sample shape here using -1 because we might have # generated more samples than requested (more chains, or multiple of # chains not matching sample_shape) -1, dim_theta, )).permute(1, 0, -1) # Shape is now (-1, batch_size, dim_theta) # We can now select the number of requested samples samples = samples_per_x[: torch.Size(sample_shape).numel()] # and reshape into (*sample_shape, batch_size, dim_theta) samples = samples.reshape((*sample_shape, batch_size, dim_theta)) return samples
def _build_mcmc_init_fn( self, proposal: Any, potential_fn: Callable, transform: torch_tf.Transform, init_strategy: str, **kwargs, ) -> Callable: """Return function that, when called, creates an initial parameter set for MCMC. Args: proposal: Proposal distribution. potential_fn: Potential function that the candidate samples are weighted with. init_strategy: Specifies the initialization method. Either of [`proposal`|`sir`|`resample`|`latest_sample`]. kwargs: Passed on to init function. This way, init specific keywords can be set through `mcmc_parameters`. Unused arguments will be absorbed by the intitialization method. Returns: Initialization function. """ if init_strategy == "proposal" or init_strategy == "prior": if init_strategy == "prior": warn( "You set `init_strategy=prior`. As of sbi v0.18.0, this is " "deprecated and it will be removed in a future release. Use " "`init_strategy=proposal` instead.", stacklevel=2, ) return lambda: proposal_init(proposal, transform=transform, **kwargs) elif init_strategy == "sir": warn( "As of sbi v0.19.0, the behavior of the SIR initialization for MCMC " "has changed. If you wish to restore the behavior of sbi v0.18.0, set " "`init_strategy='resample'.`", stacklevel=2, ) return lambda: sir_init( proposal, potential_fn, transform=transform, **kwargs ) elif init_strategy == "resample": return lambda: resample_given_potential_fn( proposal, potential_fn, transform=transform, **kwargs ) elif init_strategy == "latest_sample": latest_sample = IterateParameters(self._mcmc_init_params, **kwargs) return latest_sample else: raise NotImplementedError def _get_initial_params( self, init_strategy: str, num_chains: int, num_workers: int, show_progress_bars: bool, **kwargs, ) -> Tensor: """Return initial parameters for MCMC obtained with given init strategy. Parallelizes across CPU cores only for resample and SIR. Args: init_strategy: Specifies the initialization method. Either of [`proposal`|`sir`|`resample`|`latest_sample`]. num_chains: number of MCMC chains, generates initial params for each num_workers: number of CPU cores for parallization show_progress_bars: whether to show progress bars for SIR init kwargs: Passed on to `_build_mcmc_init_fn`. Returns: Tensor: initial parameters, one for each chain """ # Build init function init_fn = self._build_mcmc_init_fn( self.proposal, self.potential_fn, transform=self.theta_transform, init_strategy=init_strategy, # type: ignore **kwargs, ) # Parallelize inits for resampling only. if num_workers > 1 and (init_strategy == "resample" or init_strategy == "sir"): def seeded_init_fn(seed): torch.manual_seed(seed) return init_fn() seeds = torch.randint(high=2**31, size=(num_chains,)) # Generate initial params parallelized over num_workers. initial_params = list( tqdm( Parallel(return_as="generator", n_jobs=num_workers)( delayed(seeded_init_fn)(seed) for seed in seeds ), total=len(seeds), desc=f"Generating {num_chains} MCMC inits via {init_strategy} " "strategy", disable=not show_progress_bars, ) ) initial_params = torch.cat(initial_params) # type: ignore else: initial_params = torch.cat( [ init_fn() for _ in tqdm( range(num_chains), desc=f"Generating {num_chains} MCMC inits via {init_strategy} " "strategy", disable=not show_progress_bars, ) ] # type: ignore ) assert initial_params.shape[0] == num_chains, "Initial params shape mismatch." return initial_params def _get_initial_params_batched( self, x: torch.Tensor, init_strategy: str, num_chains_per_x: int, num_workers: int, show_progress_bars: bool, **kwargs, ) -> Tensor: """Return initial parameters for MCMC for a batch of `x`, obtained with given init strategy. Parallelizes across CPU cores only for resample and SIR. Args: x: Batch of observations to create different initial parameters for. init_strategy: Specifies the initialization method. Either of [`proposal`|`sir`|`resample`|`latest_sample`]. num_chains_per_x: number of MCMC chains for each x, generates initial params for each x num_workers: number of CPU cores for parallization show_progress_bars: whether to show progress bars for SIR init kwargs: Passed on to `_build_mcmc_init_fn`. Returns: Tensor: initial parameters, one for each chain """ potential_ = deepcopy(self.potential_fn) initial_params = [] init_fn = self._build_mcmc_init_fn( self.proposal, potential_fn=potential_, transform=self.theta_transform, init_strategy=init_strategy, # type: ignore **kwargs, ) for xi in x: # Build init function potential_.set_x(xi) # Parallelize inits for resampling or sir. if num_workers > 1 and ( init_strategy == "resample" or init_strategy == "sir" ): def seeded_init_fn(seed): torch.manual_seed(seed) return init_fn() seeds = torch.randint(high=2**31, size=(num_chains_per_x,)) # Generate initial params parallelized over num_workers. initial_params = initial_params + list( tqdm( Parallel(return_as="generator", n_jobs=num_workers)( delayed(seeded_init_fn)(seed) for seed in seeds ), total=len(seeds), desc=f"""Generating {num_chains_per_x} MCMC inits with {num_workers} workers.""", disable=not show_progress_bars, ) ) else: initial_params = initial_params + [ init_fn() for _ in range(num_chains_per_x) ] # type: ignore initial_params = torch.cat(initial_params) return initial_params def _slice_np_mcmc( self, num_samples: int, potential_function: Callable, initial_params: Tensor, thin: int, warmup_steps: int, vectorized: bool = False, interchangeable_chains=True, num_workers: int = 1, init_width: Union[float, ndarray] = 0.01, show_progress_bars: bool = True, ) -> Tensor: """Custom implementation of slice sampling using Numpy. Args: num_samples: Desired number of samples. potential_function: A callable **class**. initial_params: Initial parameters for MCMC chain. thin: Thinning (subsampling) factor, default 1 (no thinning). warmup_steps: Initial number of samples to discard. vectorized: Whether to use a vectorized implementation of the `SliceSampler`. interchangeable_chains: Whether chains are interchangeable, i.e., whether we can mix samples between chains. num_workers: Number of CPU cores to use. init_width: Inital width of brackets. show_progress_bars: Whether to show a progressbar during sampling; can only be turned off for vectorized sampler. Returns: Tensor of shape (num_samples, shape_of_single_theta). """ num_chains, dim_samples = initial_params.shape if not vectorized: SliceSamplerMultiChain = SliceSamplerSerial else: SliceSamplerMultiChain = SliceSamplerVectorized def multi_obs_potential(params): # Params are of shape (num_chains * num_obs, event). all_potentials = potential_function(params) # Shape: (num_chains, num_obs) return all_potentials.flatten() posterior_sampler = SliceSamplerMultiChain( init_params=tensor2numpy(initial_params), log_prob_fn=multi_obs_potential, num_chains=num_chains, thin=thin, verbose=show_progress_bars, num_workers=num_workers, init_width=init_width, ) warmup_ = warmup_steps * thin num_samples_ = ceil((num_samples * thin) / num_chains) # Run mcmc including warmup samples = posterior_sampler.run(warmup_ + num_samples_) samples = samples[:, warmup_steps:, :] # discard warmup steps samples = torch.from_numpy(samples) # chains x samples x dim # Save posterior sampler. self._posterior_sampler = posterior_sampler # Save sample as potential next init (if init_strategy == 'latest_sample'). self._mcmc_init_params = samples[:, -1, :].reshape(num_chains, dim_samples) # Update: If chains are interchangeable, return concatenated samples. Otherwise # return samples per chain. if interchangeable_chains: # Collect samples from all chains. samples = samples.reshape(-1, dim_samples)[:num_samples] return samples.type(torch.float32).to(self._device) def _pyro_mcmc( self, num_samples: int, potential_function: Callable, initial_params: Tensor, mcmc_method: str = "nuts_pyro", thin: int = -1, warmup_steps: int = 200, num_chains: Optional[int] = 1, show_progress_bars: bool = True, mp_context: str = "spawn", ) -> Tensor: r"""Return samples obtained using Pyro's HMC or NUTS sampler. Args: num_samples: Desired number of samples. potential_function: A callable **class**. A class, but not a function, is picklable for Pyro MCMC to use it across chains in parallel, even when the potential function requires evaluating a neural network. initial_params: Initial parameters for MCMC chain. mcmc_method: Pyro MCMC method to use, either `"hmc_pyro"` or `"nuts_pyro"` (default). thin: Thinning (subsampling) factor, default 1 (no thinning). warmup_steps: Initial number of samples to discard. num_chains: Whether to sample in parallel. If None, use all but one CPU. show_progress_bars: Whether to show a progressbar during sampling. Returns: Tensor of shape (num_samples, shape_of_single_theta). """ try: from pyro.infer.mcmc import HMC, NUTS from pyro.infer.mcmc.api import MCMC except ImportError: raise ImportError( "pyro-ppl is required for Pyro-based MCMC. " "Install it with: pip install 'sbi[pyro]'" ) from None thin = _process_thin_default(thin) num_chains = mp.cpu_count() - 1 if num_chains is None else num_chains kernels = dict(hmc_pyro=HMC, nuts_pyro=NUTS) sampler = MCMC( kernel=kernels[mcmc_method](potential_fn=potential_function), num_samples=ceil((thin * num_samples) / num_chains), warmup_steps=warmup_steps, initial_params={self.param_name: initial_params}, num_chains=num_chains, mp_context=mp_context, disable_progbar=not show_progress_bars, transforms={}, ) sampler.run() samples = next(iter(sampler.get_samples().values())).reshape( -1, initial_params.shape[1], # .shape[1] = dim of theta ) # Save posterior sampler. self._posterior_sampler = sampler samples = samples[::thin][:num_samples] return samples.detach() def _pymc_mcmc( self, num_samples: int, potential_function: Callable, initial_params: Tensor, mcmc_method: str = "nuts_pymc", thin: int = -1, warmup_steps: int = 200, num_chains: Optional[int] = 1, show_progress_bars: bool = True, mp_context: str = "spawn", ) -> Tensor: r"""Return samples obtained using PyMC's HMC, NUTS or slice samplers. Args: num_samples: Desired number of samples. potential_function: A callable **class**. A class, but not a function, is picklable for PyMC MCMC to use it across chains in parallel, even when the potential function requires evaluating a neural network. initial_params: Initial parameters for MCMC chain. mcmc_method: mcmc_method: Pyro MCMC method to use, either `"hmc_pymc"` or `"slice_pymc"`, or `"nuts_pymc"` (default). thin: Thinning (subsampling) factor, default 1 (no thinning). warmup_steps: Initial number of samples to discard. num_chains: Whether to sample in parallel. If None, use all but one CPU. show_progress_bars: Whether to show a progressbar during sampling. Returns: Tensor of shape (num_samples, shape_of_single_theta). """ try: from sbi.samplers.mcmc.pymc_wrapper import PyMCSampler except ImportError: raise ImportError( "pymc is required for PyMC-based MCMC. " "Install it with: pip install 'sbi[pymc]'" ) from None thin = _process_thin_default(thin) num_chains = mp.cpu_count() - 1 if num_chains is None else num_chains steps = dict(slice_pymc="slice", hmc_pymc="hmc", nuts_pymc="nuts") sampler = PyMCSampler( potential_fn=potential_function, step=steps[mcmc_method], initvals=tensor2numpy(initial_params), draws=ceil((thin * num_samples) / num_chains), tune=warmup_steps, chains=num_chains, mp_ctx=mp_context, progressbar=show_progress_bars, param_name=self.param_name, device=self._device, ) samples = sampler.run() samples = torch.from_numpy(samples).to(dtype=torch.float32, device=self._device) samples = samples.reshape(-1, initial_params.shape[1]) # Save posterior sampler. self._posterior_sampler = sampler samples = samples[::thin][:num_samples] return samples def _prepare_potential(self, method: str) -> Callable: """Combines potential and transform and takes care of gradients and pyro. Args: method: Which MCMC method to use. Returns: A potential function that is ready to be used in MCMC. """ if method in ("hmc_pyro", "nuts_pyro"): track_gradients = True pyro = True elif method in ("hmc_pymc", "nuts_pymc"): track_gradients = True pyro = False elif method in ("slice_np", "slice_np_vectorized", "slice_pymc"): track_gradients = False pyro = False else: if "hmc" in method or "nuts" in method: warn( "The kwargs 'hmc' and 'nuts' are deprecated. Use 'hmc_pyro', " "'nuts_pyro', 'hmc_pymc', or 'nuts_pymc' instead.", DeprecationWarning, stacklevel=2, ) raise NotImplementedError(f"MCMC method {method} is not implemented.") prepared_potential = partial( transformed_potential, potential_fn=self.potential_fn, theta_transform=self.theta_transform, device=self._device, track_gradients=track_gradients, ) if pyro: prepared_potential = partial( pyro_potential_wrapper, potential=prepared_potential ) return prepared_potential
[docs] def map( self, x: Optional[Tensor] = None, num_iter: int = 1_000, num_to_optimize: int = 100, learning_rate: float = 0.01, init_method: Union[str, Tensor] = "proposal", num_init_samples: int = 1_000, save_best_every: int = 10, show_progress_bars: bool = False, force_update: bool = False, ) -> Tensor: r"""Returns the maximum-a-posteriori estimate (MAP). The method can be interrupted (Ctrl-C) when the user sees that the log-probability converges. The best estimate will be saved in `self._map` and can be accessed with `self.map()`. The MAP is obtained by running gradient ascent from a given number of starting positions (samples from the posterior with the highest log-probability). After the optimization is done, we select the parameter set that has the highest log-probability after the optimization. Warning: The default values used by this function are not well-tested. They might require hand-tuning for the problem at hand. For developers: if the prior is a `BoxUniform`, we carry out the optimization in unbounded space and transform the result back into bounded space. Args: x: Deprecated - use `.set_default_x()` prior to `.map()`. num_iter: Number of optimization steps that the algorithm takes to find the MAP. learning_rate: Learning rate of the optimizer. init_method: How to select the starting parameters for the optimization. If it is a string, it can be either [`posterior`, `prior`], which samples the respective distribution `num_init_samples` times. If it is a tensor, the tensor will be used as init locations. num_init_samples: Draw this number of samples from the posterior and evaluate the log-probability of all of them. num_to_optimize: From the drawn `num_init_samples`, use the `num_to_optimize` with highest log-probability as the initial points for the optimization. save_best_every: The best log-probability is computed, saved in the `map`-attribute, and printed every `save_best_every`-th iteration. Computing the best log-probability creates a significant overhead (thus, the default is `10`.) show_progress_bars: Whether to show a progressbar during sampling from the posterior. force_update: Whether to re-calculate the MAP when x is unchanged and have a cached value. log_prob_kwargs: Will be empty for SNLE and SNRE. Will contain {'norm_posterior': True} for SNPE. Returns: The MAP estimate. """ return super().map( x=x, num_iter=num_iter, num_to_optimize=num_to_optimize, learning_rate=learning_rate, init_method=init_method, num_init_samples=num_init_samples, save_best_every=save_best_every, show_progress_bars=show_progress_bars, force_update=force_update, )
def __getstate__(self) -> Dict: """Get state of MCMCPosterior. Removes the posterior sampler from the state, as it may not be picklable. Returns: Dict: State of MCMCPosterior. """ state = self.__dict__.copy() state["_posterior_sampler"] = None return state
def _process_thin_default(thin: int) -> int: """ Check if the user did use the default thinning value and raise a warning if so. Args: thin: Thinning (subsampling) factor, setting 1 disables thinning. Returns: The corrected thinning factor. """ if thin == -1: thin = 1 return thin def _num_required_args(func): """ Utility for counting the number of positional args in a function. This function counts each parameter in the signature that are positional -- ie. (1) cannot only be passed in as keyword arguments (2) do not have a default value Args: func: A callable function. Returns: Number of required positional arguments. """ sig = inspect.signature(func) return sum( 1 for param in sig.parameters.values() if param.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD) and param.default is inspect._empty ) def build_from_potential( potential_fn: Callable, prior: Any, x: Optional[Tensor] = None, **kwargs ) -> MCMCPosterior: """ Returns a sampler from a MCMCPosterior object, given user-defined potential function and prior. The user-defined potential can be conditional (accepts theta and x as positional arguments) or unconditional (accepting only theta). Args: potential_fn: User defined potential function. Must be of type Callable. prior: Prior distribution for parameter transformation and initialization. x: Conditional x value. Provided if using a conditional potential function. Returns: Callable sampling function from MCMCPosterior object. """ # build transformation to unrestricted space for sampling transform = mcmc_transform(prior) # potential_fn must take 1 or 2 required arguments: (theta) or (theta, x) num_args = _num_required_args(potential_fn) assert num_args > 0 and num_args < 3, ( "potential_fn must take 1-2 required arguments" ) is_conditional = num_args == 2 if is_conditional: # you could remove this and require use to set x before calling sample assert x is not None, "x must be provided if potential_fn is conditional" posterior = MCMCPosterior(potential_fn, prior, theta_transform=transform) posterior.set_default_x(x) else: warn( "x has not been provided. Using unconditional potential function.", UserWarning, stacklevel=2, ) # define an unconditional potential function (ignores x) def unconditional_potential_fn(theta, x): return potential_fn(theta) posterior = MCMCPosterior( unconditional_potential_fn, prior, theta_transform=transform, **kwargs ) posterior.set_default_x(torch.zeros(1)) # set default_x to dummy value return posterior