Source code for sbi.utils.potentialutils

# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from typing import Callable, Dict, Union

import numpy as np
import torch
import torch.distributions.transforms as torch_tf
from torch import Tensor

from sbi.utils.torchutils import ensure_theta_batched


[docs] def transformed_potential( theta: Union[Tensor, np.ndarray], potential_fn: Callable, theta_transform: torch_tf.Transform, device: str, track_gradients: bool = False, ) -> Tensor: """Return potential after a transformation by adding the log-abs-determinant. In addition, this function takes care of moving the parameters to the correct device. Args: theta: Parameters $\theta$ in transformed space. potential_fn: Potential function. theta_transform: Transformation applied before evaluating the `potential_fn` device: The device to which to move the parameters before evaluation. track_gradients: Whether to track the gradients of the `potential_fn` evaluation. """ # Device is the same for net and prior. transformed_theta = ensure_theta_batched( torch.as_tensor(theta, dtype=torch.float32) ).to(device) # Transform `theta` from transformed (i.e. unconstrained) to untransformed # space. theta = theta_transform.inv(transformed_theta) # type: ignore log_abs_det = theta_transform.log_abs_det_jacobian(theta, transformed_theta) posterior_potential = potential_fn(theta, track_gradients=track_gradients) posterior_potential_transformed = posterior_potential.to(device) - log_abs_det.to( device ) return posterior_potential_transformed
def pyro_potential_wrapper(theta: Dict[str, Tensor], potential: Callable) -> Callable: r"""Evaluate pyro-based `theta` under the negative `potential`. Args: theta: Parameters $\theta$. The tensor's shape will be (1, shape_of_single_theta) if running a single chain or just (shape_of_single_theta) for multiple chains. potential: Potential which to evaluate. Returns: The negative potential $-[\log r(x_o, \theta) + \log p(\theta)]$. """ theta_tensor = next(iter(theta.values())) # Note the minus to match the pyro potential function requirements. return -potential(theta_tensor)