Source code for sbi.analysis.conditional_density

# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from typing import Any, Callable, List, Optional, Tuple, Union
from warnings import warn

import torch
import torch.distributions.transforms as torch_tf
from torch import Tensor
from torch.distributions import Distribution

from sbi.inference.potentials.base_potential import BasePotential
from sbi.neural_nets.estimators.mixture_density_estimator import (
    MixtureDensityEstimator,
)
from sbi.neural_nets.estimators.mog import MoG
from sbi.sbi_types import Shape, TorchTransform
from sbi.utils.conditional_density_utils import (
    ConditionedPotential,
    RestrictedPriorForConditional,
    RestrictedTransformForConditional,
    compute_corrcoeff,
    condition_mog,
    extract_and_transform_mog,
)
from sbi.utils.torchutils import atleast_2d_float32_tensor, ensure_theta_batched


def eval_conditional_density(
    density: Any,
    condition: Tensor,
    limits: Tensor,
    dim1: int,
    dim2: int,
    resolution: int = 50,
    eps_margins1: Union[Tensor, float] = 1e-32,
    eps_margins2: Union[Tensor, float] = 1e-32,
    return_raw_log_prob: bool = False,
) -> Tensor:
    r"""Return the unnormalized conditional along `dim1, dim2` given `condition`.

    We compute the unnormalized conditional by evaluating the joint distribution:
        $p(x1 | x2) = p(x1, x2) / p(x2) \propto p(x1, x2)$

    The joint distribution is evaluated on an evenly spaced grid defined by the
    `limits`.

    Args:
        density: Probability density function with `.log_prob()` method.
        condition: Parameter set that all dimensions other than dim1 and dim2 will be
            fixed to. Should be of shape (1, dim_theta), i.e. it could e.g. be
            a sample from the posterior distribution. The entries at `dim1` and `dim2`
            will be ignored.
        limits: Bounds within which to evaluate the density. Shape (dim_theta, 2).
        dim1: First dimension along which to evaluate the conditional.
        dim2: Second dimension along which to evaluate the conditional.
        resolution: Resolution of the grid along which the conditional density is
            evaluated.
        eps_margins1: We will evaluate the posterior along `dim1` at
            `limits[0]+eps_margins` until `limits[1]-eps_margins`. This avoids
            evaluations potentially exactly at the prior bounds.
        eps_margins2: We will evaluate the posterior along `dim2` at
            `limits[0]+eps_margins` until `limits[1]-eps_margins`. This avoids
            evaluations potentially exactly at the prior bounds.
        return_raw_log_prob: If `True`, return the log-probability evaluated on the
            grid. If `False`, return the probability, scaled down by the maximum value
            on the grid for numerical stability (i.e. exp(log_prob - max_log_prob)).

    Returns: Conditional probabilities. If `dim1 == dim2`, this will have shape
        (resolution). If `dim1 != dim2`, it will have shape (resolution, resolution).
    """

    condition = ensure_theta_batched(condition)

    theta_grid_dim1 = torch.linspace(
        float(limits[dim1, 0] + eps_margins1),
        float(limits[dim1, 1] - eps_margins1),
        resolution,
        device=condition.device,
    )
    theta_grid_dim2 = torch.linspace(
        float(limits[dim2, 0] + eps_margins2),
        float(limits[dim2, 1] - eps_margins2),
        resolution,
        device=condition.device,
    )

    if dim1 == dim2:
        repeated_condition = condition.repeat(resolution, 1)
        repeated_condition[:, dim1] = theta_grid_dim1

        log_probs_on_grid = density.log_prob(repeated_condition)
    else:
        repeated_condition = condition.repeat(resolution**2, 1)
        repeated_condition[:, dim1] = theta_grid_dim1.repeat(resolution)
        repeated_condition[:, dim2] = torch.repeat_interleave(
            theta_grid_dim2, resolution
        )

        log_probs_on_grid = density.log_prob(repeated_condition)
        log_probs_on_grid = torch.reshape(log_probs_on_grid, (resolution, resolution))

    if return_raw_log_prob:
        return log_probs_on_grid
    else:
        # Subtract maximum for numerical stability
        return torch.exp(log_probs_on_grid - torch.max(log_probs_on_grid))


[docs] def conditional_corrcoeff( density: Any, limits: Tensor, condition: Tensor, subset: Optional[List[int]] = None, resolution: int = 50, ) -> Tensor: r"""Returns the conditional correlation matrix of a distribution. To compute the conditional distribution, we condition all but two parameters to values from `condition`, and then compute the Pearson correlation coefficient $\rho$ between the remaining two parameters under the distribution `density`. We do so for any pair of parameters specified in `subset`, thus creating a matrix containing conditional correlations between any pair of parameters. If `condition` is a batch of conditions, this function computes the conditional correlation matrix for each one of them and returns the mean. Args: density: Probability density function with `.log_prob()` function. limits: Limits within which to evaluate the `density`. condition: Values to condition the `density` on. If a batch of conditions is passed, we compute the conditional correlation matrix for each of them and return the average conditional correlation matrix. subset: Evaluate the conditional distribution only on a subset of dimensions. If `None` this function uses all dimensions. resolution: Number of grid points on which the conditional distribution is evaluated. A higher value increases the accuracy of the estimated correlation but also increases the computational cost. Returns: Average conditional correlation matrix of shape either `(num_dim, num_dim)` or `(len(subset), len(subset))` if `subset` was specified. """ device = density._device if hasattr(density, "_device") else "cpu" subset_ = subset if subset is not None else range(condition.shape[1]) correlation_matrices = [] for cond in condition: correlation_matrices.append( torch.stack([ compute_corrcoeff( eval_conditional_density( density, cond.to(device), limits.to(device), dim1=dim1, dim2=dim2, resolution=resolution, ), limits[[dim1, dim2]].to(device), ) for dim1 in subset_ for dim2 in subset_ if dim1 < dim2 ]) ) average_correlations = torch.mean(torch.stack(correlation_matrices), dim=0) # `average_correlations` is still a vector containing the upper triangular entries. # Below, assemble them into a matrix: av_correlation_matrix = torch.zeros((len(subset_), len(subset_)), device=device) triu_indices = torch.triu_indices( row=len(subset_), col=len(subset_), offset=1, device=device ) av_correlation_matrix[triu_indices[0], triu_indices[1]] = average_correlations # Make the matrix symmetric by copying upper diagonal to lower diagonal. av_correlation_matrix = torch.triu(av_correlation_matrix) + torch.tril( av_correlation_matrix.T ) av_correlation_matrix.fill_diagonal_(1.0) return av_correlation_matrix
class ConditionedMDN: def __init__( self, mdn: MixtureDensityEstimator, x_o: Tensor, condition: Tensor, dims_to_sample: List[int], ) -> None: r"""Class that can sample and evaluate a conditional mixture-of-gaussians. Args: mdn: MixtureDensityEstimator that models $p(\theta|x)$. x_o: The datapoint at which the `net` is evaluated. condition: Parameter set that all dimensions not specified in `dims_to_sample` will be fixed to. Should contain dim_theta elements, i.e. it could e.g. be a sample from the posterior distribution. The entries at all `dims_to_sample` will be ignored. dims_to_sample: Which dimensions to sample from. The dimensions not specified in `dims_to_sample` will be fixed to values given in `condition`. """ condition = atleast_2d_float32_tensor(condition) logits, means, precfs, _ = extract_and_transform_mog(estimator=mdn, context=x_o) cond_logits, cond_means, cond_precfs, _ = condition_mog( condition, dims_to_sample, logits, means, precfs ) cond_prec = cond_precfs.transpose(3, 2) @ cond_precfs # Store the conditioned MoG for sampling and evaluation self._mog = MoG( logits=cond_logits, means=cond_means, precisions=cond_prec, precision_factors=cond_precfs, ) def sample(self, sample_shape: Shape = torch.Size()) -> Tensor: """Sample from the conditioned MoG. Args: sample_shape: Shape prefix for samples. Returns: Samples, shape (*sample_shape, dim) where dim is the number of free dimensions (those in dims_to_sample). """ # MoG.sample returns (*sample_shape, batch_size, dim) # Since this is a single conditioned distribution, batch_size=1 # We squeeze out the batch dimension for convenience samples = self._mog.sample(torch.Size(sample_shape)) # Squeeze batch dimension (which is always 1 for ConditionedMDN) samples = samples.squeeze(-2) return samples.detach() def log_prob(self, theta: Tensor) -> Tensor: """Evaluate log probability of theta under the conditioned MoG. Args: theta: Parameters to evaluate, shape (dim,) or (batch_size, dim). Returns: Log probabilities, shape () or (batch_size,). """ # Ensure theta has batch dimension if theta.dim() == 1: theta = theta.unsqueeze(0) # MoG.log_prob handles broadcasting correctly: # If self._mog has batch_size=1 and theta has batch_size=N, # it broadcasts the MoG parameters to match theta's batch dimension. return self._mog.log_prob(theta) def conditonal_potential( potential_fn: BasePotential, theta_transform: TorchTransform, prior: Distribution, condition: Tensor, dims_to_sample: List[int], ) -> Tuple[Callable, torch_tf.Transform, Any]: """ Only for backwards compatibility. The name of this function was renamed until v0.19.0. (notice the missing `i` in the name). """ warn( "The misspelled function `conditonal_potential` will be removed in a future " "release of sbi. Please use `conditional_potential` (spelled correctly).", stacklevel=2, ) return conditional_potential( potential_fn, theta_transform, prior, condition, dims_to_sample )
[docs] def conditional_potential( potential_fn: BasePotential, theta_transform: TorchTransform, prior: Distribution, condition: Tensor, dims_to_sample: List[int], ) -> Tuple[Callable, torch_tf.Transform, Any]: r"""Returns potential function that can be used to sample the conditional potential. It also returns a transform and a prior to be used to sample the conditional potential. The conditional potential is $p(\theta_i | \theta_j, x_o) \propto p(\theta | x_o)$ but is a function only of $\theta_i$. Args: potential_fn: The potential function to be conditioned. theta_transform: The parameter transformation that should be reduced (by ignoring dimensions not contained in `dims_to_sample`). prior: The prior distribution that should be reduced (by ignoring dimensions not contained in `dims_to_sample`). condition: Parameter set that all dimensions not specified in `dims_to_sample` will be fixed to. Should contain dim_theta elements, i.e. it could e.g. be a sample from the posterior distribution. The entries at all `dims_to_sample` will be ignored. dims_to_sample: Which dimensions to sample from. The dimensions not specified in `dims_to_sample` will be fixed to values given in `condition`. Returns: A conditioned potential function, conditioned parameter transformation, and a marginalised prior. """ restricted_tf = RestrictedTransformForConditional( theta_transform, condition, dims_to_sample ) condition = atleast_2d_float32_tensor(condition) conditioned_potential_fn = ConditionedPotential( potential_fn, condition, dims_to_sample, # type: ignore ) restricted_prior = RestrictedPriorForConditional(prior, dims_to_sample) return conditioned_potential_fn, restricted_tf, restricted_prior