# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
"""Various PyTorch utility functions."""
import os
import warnings
from typing import (
Any,
Callable,
Dict,
Generator,
List,
Optional,
OrderedDict,
Tuple,
Union,
)
import numpy as np
import torch
from torch import Tensor, float32
from torch.distributions import Independent, Uniform
from torch.nn import Module
from sbi.sbi_types import Array, OneOrMore
from sbi.utils.typechecks import is_nonnegative_int, is_positive_int
def process_device(device: Union[str, torch.device]) -> str:
"""Set and return the default device to cpu or gpu (cuda, mps).
Args:
device: target torch device
Returns:
device: processed string, e.g., "cuda" is mapped to "cuda:0".
"""
if device == "cpu":
return "cpu"
else:
# If user just passes 'gpu', search for CUDA or MPS.
if device == "gpu":
# check whether either pytorch cuda or mps is available
if torch.cuda.is_available():
current_gpu_index = torch.cuda.current_device()
device = f"cuda:{current_gpu_index}"
check_device(device)
torch.cuda.set_device(device)
elif torch.backends.mps.is_available():
device = "mps:0"
# MPS support is not implemented for a number of operations.
# use CPU as fallback.
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
# MPS framework does not support double precision.
torch.set_default_dtype(torch.float32)
check_device(device)
else:
raise RuntimeError(
"Neither CUDA nor MPS is available. "
"Please make sure to install a version of PyTorch that supports "
"CUDA or MPS."
)
# Else, check whether the custom device is valid.
else:
if isinstance(device, torch.device):
device = str(device)
check_device(device)
return device
def gpu_available() -> bool:
"""Check whether GPU is available."""
return torch.cuda.is_available() or torch.backends.mps.is_available()
def check_device(device: Union[str, torch.device]) -> None:
"""Check whether the device is valid.
Args:
device: target torch device
"""
try:
torch.randn(1, device=device)
except (RuntimeError, AssertionError) as exc:
raise RuntimeError(
f"""Could not instantiate torch.randn(1, device={device}). Make sure
the device is set up properly and that you are passing the
corresponding device string. It should be something like 'cuda',
'cuda:0', or 'mps'. Error message: {exc}."""
) from exc
def check_if_prior_on_device(
device: Union[str, torch.device], prior: Optional[Any] = None
) -> None:
"""Try to sample from the prior, and check that the returned data is on the correct
trainin device. If the prior is `None`, simplys pass.
Args:
device: target torch training device
prior: any simulator outputing torch `Tensor`
"""
if prior is None:
pass
else:
prior_device = prior.sample((1,)).device
training_device = torch.zeros(1, device=device).device
assert prior_device == training_device, (
f"Prior device '{prior_device}' must match training device "
f"'{training_device}'. When training on GPU make sure to "
"pass a prior initialized on the GPU as well, e.g., "
"use `.to(device)` for sbi priors or "
"prior = torch.distributions.Normal"
"(torch.zeros(2, device='cuda'), scale=1.0)`, or ."
)
def infer_module_device(module: torch.nn.Module, fallback: str) -> str:
"""Infer device from module parameters or buffers, falling back to `fallback`.
Args:
module: The module to inspect.
fallback: Device string returned (with a warning) if the module has no
parameters or buffers.
Returns:
Device string, e.g. ``"cpu"`` or ``"cuda:0"``.
"""
try:
return str(next(module.parameters()).device)
except StopIteration:
try:
return str(next(module.buffers()).device)
except StopIteration:
warnings.warn(
f"{type(module).__name__} has no parameters/buffers; "
f"falling back to device='{fallback}'.",
stacklevel=2,
)
return fallback
def tile(x: Tensor, n: int) -> Tensor:
"""Tiles a tensor `x` by repeating it `n` times along a new leading dimension.
Args:
x: Input tensor to tile.
n: Number of times to tile the tensor.
Returns:
Tiled tensor.
"""
if not is_positive_int(n):
raise TypeError("Argument `n` must be a positive integer.")
x_ = x.reshape(-1)
x_ = x_.repeat(n)
x_ = x_.reshape(n, -1)
x_ = x_.transpose(1, 0)
x_ = x_.reshape(-1)
return x_
def sum_except_batch(x: Tensor, num_batch_dims: int = 1) -> Tensor:
"""Sums all elements of `x` except for the first `num_batch_dims` dimensions.
Args:
x: Input tensor.
num_batch_dims: Number of batch dimensions to keep. Defaults to 1.
Returns:
Tensor with all non-batch dimensions summed.
"""
if not is_nonnegative_int(num_batch_dims):
raise TypeError("Number of batch dimensions must be a non-negative integer.")
reduce_dims = list(range(num_batch_dims, x.ndimension()))
return torch.sum(x, dim=reduce_dims)
def split_leading_dim(x: Tensor, shape: List[int]) -> Tensor:
"""Reshapes the leading dim of `x` to have the given shape.
Args:
x: Input tensor.
shape: Desired shape for the leading dimension.
Returns:
Tensor with reshaped leading dimension.
"""
new_shape = torch.Size(shape) + x.shape[1:]
return torch.reshape(x, new_shape)
def merge_leading_dims(x: Tensor, num_dims: int) -> Tensor:
"""Reshapes the tensor `x` such that the first `num_dims` dimensions are merged
to one.
Args:
x: Input tensor.
num_dims: Number of leading dimensions to merge.
Returns:
Tensor with first `num_dims` dimensions merged into one.
"""
if not is_positive_int(num_dims):
raise TypeError("Number of leading dims must be a positive integer.")
if num_dims > x.dim():
raise ValueError(
"Number of leading dims can't be greater than total number of dims."
)
new_shape = torch.Size([-1]) + x.shape[num_dims:]
return torch.reshape(x, new_shape)
def repeat_rows(x: Tensor, num_reps: int) -> Tensor:
"""Each row of tensor `x` is repeated `num_reps` times along leading dimension.
Args:
x: Input tensor.
num_reps: Number of times to repeat each row.
Returns:
Tensor with each row repeated `num_reps` times.
"""
if not is_positive_int(num_reps):
raise TypeError("Number of repetitions must be a positive integer.")
shape = x.shape
x = x.unsqueeze(1)
x = x.expand(shape[0], num_reps, *shape[1:])
return merge_leading_dims(x, num_dims=2)
def tensor2numpy(x: Tensor) -> np.ndarray:
"""Converts a PyTorch tensor to a NumPy array.
Detaches the tensor from the computation graph and moves it to CPU
before converting.
Args:
x: Input tensor.
Returns:
NumPy array with the same data as `x`.
"""
return x.detach().cpu().numpy()
def logabsdet(x: Tensor) -> Tensor:
"""Returns the log absolute determinant of square matrix x.
Args:
x: Square matrix tensor.
Returns:
Scalar tensor containing the log absolute determinant.
"""
# Note: torch.logdet() only works for positive determinant.
_, res = torch.slogdet(x)
return res
def random_orthogonal(size: int) -> Tensor:
"""Returns a random orthogonal matrix as a 2-dim tensor of shape [size, size].
Uses the QR decomposition of a random Gaussian matrix to generate
a uniformly distributed orthogonal matrix.
Args:
size: Dimension of the square orthogonal matrix.
Returns:
Random orthogonal matrix of shape (size, size).
"""
# Use the QR decomposition of a random Gaussian matrix.
x = torch.randn(size, size)
q, _ = torch.linalg.qr(x)
return q
def get_num_parameters(model: Module) -> int:
"""Returns the number of trainable parameters in a model of type nets.Module.
Args:
model: PyTorch module containing trainable parameters.
Returns:
Total number of trainable parameters.
"""
num_parameters = 0
for parameter in model.parameters():
num_parameters += torch.numel(parameter)
return num_parameters
def create_alternating_binary_mask(features: int, even: bool = True) -> Tensor:
"""Creates a binary mask of a given dimension which alternates its masking.
Args:
features: Dimension of mask.
even: If True, even indices are assigned 1s and odd indices 0s.
If False, vice versa. Defaults to True.
Returns:
Alternating binary mask of type torch.Tensor.
"""
mask = torch.zeros(features).byte()
start = 0 if even else 1
mask[start::2] += 1
return mask
def create_mid_split_binary_mask(features: int) -> Tensor:
"""Creates a binary mask of a given dimension which splits its masking
at the midpoint
Args:
features: Dimension of mask.
Returns:
Binary mask split at midpoint of type torch.Tensor.
"""
mask = torch.zeros(features).byte()
midpoint = features // 2 if features % 2 == 0 else features // 2 + 1
mask[:midpoint] += 1
return mask
def create_random_binary_mask(features: int) -> Tensor:
"""Creates a random binary mask of a given dimension with half of its entries
randomly set to 1s.
Args:
features: Dimension of mask.
Returns:
Binary mask with half of its entries set to 1s, of type torch.Tensor.
"""
mask = torch.zeros(features).byte()
weights = torch.ones(features).float()
num_samples = features // 2 if features % 2 == 0 else features // 2 + 1
indices = torch.multinomial(
input=weights, num_samples=num_samples, replacement=False
)
mask[indices] += 1
return mask
def searchsorted(bin_locations: Tensor, inputs: Tensor, eps: float = 1e-6) -> Tensor:
"""Finds the indices of the bins to which each input value belongs.
Args:
bin_locations: Tensor of bin edges.
inputs: Tensor of values to search for in the bins.
eps: Small value added to the last bin edge to ensure correct boundary
behavior. Defaults to 1e-6.
Returns:
Tensor of bin indices for each input value.
"""
bin_locations[..., -1] += eps
return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
def cbrt(x: Tensor) -> Tensor:
"""Cube root. Equivalent to torch.pow(x, 1/3), but numerically stable.
Args:
x: Input tensor.
Returns:
Element-wise cube root of `x`.
"""
return torch.sign(x) * torch.exp(torch.log(torch.abs(x)) / 3.0)
def get_temperature(max_value: float, bound: float = 1 - 1e-3) -> Tensor:
"""Returns the temperature such that sigmoid(temperature * max_value) = bound.
If the computed temperature is greater than 1, returns 1.
Args:
max_value: Maximum value of the dataset.
bound: Target value for sigmoid(temperature * max_value).
Defaults to 1 - 1e-3.
Returns:
Scalar tensor containing the temperature value, capped at 1.
"""
max_value_t = torch.Tensor([max_value])
bound_t = torch.Tensor([bound])
temperature = torch.min(
-(1 / max_value_t) * (torch.log1p(-bound_t) - torch.log(bound_t)),
torch.ones(1),
)
return temperature
def gaussian_kde_log_eval(samples: Tensor, query: Tensor) -> Tensor:
"""Evaluates the log probability of query points under a Gaussian KDE.
Fits a Gaussian kernel density estimator to `samples` using
Silverman's rule of thumb for bandwidth selection, then evaluates
the log probability at each point in `query`.
Args:
samples: Tensor of shape (N, D) used to fit the KDE.
query: Tensor of shape (..., D) of points to evaluate.
Returns:
Tensor of log probabilities for each query point.
"""
N, D = samples.shape[0], samples.shape[-1]
std = N ** (-1 / (D + 4))
precision = (1 / (std**2)) * torch.eye(D)
a = query - samples
b = a @ precision
c = -0.5 * torch.sum(a * b, dim=-1)
d = -np.log(N) - (D / 2) * np.log(2 * np.pi) - D * np.log(std)
c += d
return torch.logsumexp(c, dim=-1)
def ensure_theta_batched(theta: Tensor) -> Tensor:
r"""
Return parameter set theta that has a batch dimension, i.e. has shape
(1, shape_of_single_theta)
Args:
theta: parameters $\theta$, of shape (n) or (1,n)
Returns:
Batched parameter set $\theta$
"""
# => ensure theta has shape (1, dim_parameter)
if theta.ndim == 1:
theta = theta.unsqueeze(0)
return theta
def ensure_x_batched(x: Tensor) -> Tensor:
"""
Return simulation output x that has a batch dimension, i.e. has shape
(1, shape_of_single_x).
Args:
x: simulation output of shape (n) or (1,n).
Returns:
Batched simulation output x.
"""
# ensure x has shape (1, shape_of_single_x). If shape[0] > 1, we assume that
# the batch-dimension is missing, even though ndim might be >1 (e.g. for images)
if x.shape[0] > 1 or x.ndim == 1:
x = x.unsqueeze(0)
return x
def atleast_2d_many(*arys: Array) -> OneOrMore[Tensor]:
"""Return tensors with at least dimension 2.
Tensors or arrays of dimension 0 or 1 will get additional dimension(s) prepended.
Returns:
Tensor or list of tensors all with dimension >= 2.
"""
if len(arys) == 1:
arr = arys[0]
if isinstance(arr, np.ndarray):
arr = torch.from_numpy(arr)
return atleast_2d(arr)
else:
return [atleast_2d_many(arr) for arr in arys] # type: ignore
def atleast_2d(t: Tensor) -> Tensor:
return t if t.ndim >= 2 else t.reshape(1, -1)
def maybe_add_batch_dim_to_size(s: torch.Size) -> torch.Size:
"""
Take a torch.Size and add a batch dimension to it if dimensionality of size is 1.
(N) -> (1,N)
(1,N) -> (1,N)
(N,M) -> (N,M)
(1,N,M) -> (1,N,M)
Args:
s: Input size, possibly without batch dimension.
Returns: Batch size.
"""
return s if len(s) >= 2 else torch.Size([1]) + s # type: ignore
def atleast_2d_float32_tensor(arr: Union[Tensor, np.ndarray]) -> Tensor:
return atleast_2d(torch.as_tensor(arr, dtype=float32))
def batched_first_of_batch(t: Tensor) -> Tensor:
"""
Takes in a tensor of shape (N, M) and outputs tensor of shape (1,M).
"""
return t[:1]
def assert_all_finite(quantity: Tensor, description: str = "tensor") -> None:
"""Raise if tensor quantity contains any NaN or Inf element."""
msg = f"NaN/Inf present in {description}."
if not torch.isfinite(quantity).all():
raise ValueError(msg)
def assert_not_nan_or_plus_inf(quantity: Tensor, description: str = "tensor") -> None:
"""Raise if tensor quantity contains any NaN or +Inf element."""
msg = f"NaN/ +Inf present in {description}."
assert not (torch.isposinf(quantity).any()) and not (torch.isnan(quantity).any()), (
msg
)
def _base_recursor(
obj: object,
parent: Optional[object] = None,
key: Optional[str] = None,
check: Callable[..., bool] = lambda obj: False,
action: Callable[..., object] = lambda obj: obj,
):
"""Recursive function that traverses objects (e.g. Distributions) and applies
an action to any encountered object that passes the check.
Used e.g. to move all tensors within a distribution to a given device.
Args:
obj: An object which serves as root of the traversal.
parent: The previously traversed object.
key: The name of the previously traversed object.
check: A function that inputs an object and outputs a boolean.
If the check evaluates to True, then ``action`` is applied.
action: A function that specifies an operation on an object and returns
a modified version.
"""
if isinstance(obj, Module) and check(obj):
action(obj)
elif isinstance(obj, (Dict, OrderedDict)):
for k, o in obj.items():
if check(o):
obj[k] = action(o)
else:
_base_recursor(o, parent=obj, key=k, check=check, action=action)
elif isinstance(obj, type):
# Skip class/type objects to avoid modifying immutable C extension types
# (e.g. torch.LongTensor) which fail on Python 3.13+.
return
elif hasattr(obj, "__dict__"):
for k, o in obj.__dict__.items():
if check(o):
setattr(obj, k, action(o))
else:
_base_recursor(o, parent=obj, key=k, check=check, action=action)
elif isinstance(obj, (List, Tuple, Generator)):
new_obj = []
for o in obj:
if check(o):
new_obj.append(action(o))
else:
_base_recursor(o, check=check, action=action)
new_obj.append(o)
if parent is not None and key is not None:
setattr(parent, key, type(obj)(new_obj)) # type: ignore
else:
return
def move_all_tensor_to_device(obj: object, device: Union[str, torch.device]) -> None:
"""Recursively move all tensors and modules within an object to a device.
Traverses the object's attributes, dictionaries, lists, and tuples,
moving any encountered ``Tensor`` or ``Module`` to the specified device.
Note:
Leaf tensors with ``requires_grad=True`` cannot be moved in-place.
A ``ValueError`` is raised if such a tensor is on the wrong device.
Args:
obj: The root object to traverse.
device: The target device.
"""
def check(o: object) -> bool:
return isinstance(o, (Tensor, Module))
def action(o: object) -> object:
if isinstance(o, Tensor) and o.requires_grad and o.is_leaf:
# Moving leaf tensors inplace is hard. Cant call .to as this would create a
# copy and thus results in non-leaf tensors.
if str(o.device) != str(device):
raise ValueError(
f"Cannot move leaf tensor with requires_grad=True from "
f"{o.device} to {device}. Please initialize it on the "
f"correct device."
)
else:
return o
else:
return o.to(device) # type: ignore
with torch.no_grad():
_base_recursor(obj, check=check, action=action)