# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
"""
Implementation taken from Lemos et al, 'Sampling-Based Accuracy Testing of
Posterior Estimators for General Inference' https://arxiv.org/abs/2302.03026
The TARP diagnostic is a global diagnostic which can be used to check a
trained posterior against a set of true values of theta.
"""
import warnings
from typing import Callable, Optional, Tuple
import torch
from scipy.stats import kstest
from torch import Tensor
from sbi.inference.posteriors.base_posterior import NeuralPosterior
from sbi.utils.diagnostics_utils import (
get_posterior_samples_on_batch,
remove_nans_and_infs_in_x,
)
from sbi.utils.metrics import l2
[docs]
def run_tarp(
thetas: Tensor,
xs: Tensor,
posterior: NeuralPosterior,
references: Optional[Tensor] = None,
num_posterior_samples: int = 1000,
num_workers: int = 1,
show_progress_bar: bool = True,
distance: Callable = l2,
num_bins: Optional[int] = None,
z_score_theta: bool = True,
use_batched_sampling: bool = True,
) -> Tuple[Tensor, Tensor]:
"""
Estimates coverage of samples given true values thetas with the TARP method.
Reference: `Lemos, Coogan et al 2023 <https://arxiv.org/abs/2302.03026>`_
The TARP diagnostic is a global diagnostic which can be used to check a
trained posterior against a set of true values of theta.
Args:
thetas: ground-truth parameters for tarp, simulated from the prior.
xs: observed data for tarp, simulated from thetas.
posterior: a posterior obtained from sbi.
num_posterior_samples: number of approximate posterior samples used for ranking.
num_workers: number of CPU cores to use for running inference in parallel.
show_progress_bar: whether to display a progress over sbc runs.
distance: the distance metric to use when computing the distance.
Should be a callable function that accepts two tensors and
computes the distance between them, e.g. given two tensors
of shape ``(batch, 3)`` and ``(batch,3)``, this function should
return ``(batch,1)`` distance values.
Possible values: ``sbi.utils.metrics.l1`` or
``sbi.utils.metrics.l2``. ``l2`` is the default.
num_bins: number of bins to use for the credibility values.
If ``None``, then ``num_tarp_samples // 10`` bins are used, which targets
at least 10 samples per bin (requires ``num_tarp_samples >= 100``).
z_score_theta : whether to normalize parameters before coverage test.
use_batched_sampling: whether to use batched sampling for posterior samples.
Returns:
ecp: Expected coverage probability (``ecp``), see equation 4 of the paper
alpha: credibility values, see equation 2 of the paper
"""
thetas, xs = remove_nans_and_infs_in_x(thetas, xs)
num_tarp_samples, dim_theta = thetas.shape
if num_tarp_samples < 100:
warnings.warn(
"Number of TARP samples should be on the order of 100s to give reliable "
"results.",
stacklevel=2,
)
posterior_samples = get_posterior_samples_on_batch(
xs,
posterior,
(num_posterior_samples,),
num_workers,
show_progress_bar=show_progress_bar,
use_batched_sampling=use_batched_sampling,
)
assert posterior_samples.shape == (
num_posterior_samples,
num_tarp_samples,
dim_theta,
), f"Wrong posterior samples shape for TARP: {posterior_samples.shape}"
# Sample reference points uniformly if not provided
if references is None:
references = get_tarp_references(thetas)
return _run_tarp(
posterior_samples, thetas, references, distance, num_bins, z_score_theta
)
def _run_tarp(
posterior_samples: Tensor,
thetas: Tensor,
references: Tensor,
distance: Callable = l2,
num_bins: Optional[int] = None,
z_score_theta: bool = False,
) -> Tuple[Tensor, Tensor]:
"""
Estimates coverage of samples given true values theta with the TARP method.
Reference: `Lemos, Coogan et al 2023 <https://arxiv.org/abs/2302.03026>`_
The TARP diagnostic is a global diagnostic which can be used to check a
trained posterior against a set of true values of theta.
Args:
posterior_samples: The predicted parameter samples to compute the coverage of,
these samples are expected to have shape ``(num_samples,
num_tarp_samples, num_dims)``. These are obtained by sampling a trained
posterior `num_samples` times. Multiple (posterior) samples for one
observation are encouraged.
theta: The true parameter value theta. Theta is expected to
have shape ``(num_tarp_samples, num_dims)``.
references: the reference points to use for the coverage regions, with
shape ``(1, num_tarp_samples, num_dims)``, or ``None``.
If ``None``, then reference points are chosen randomly from
the unit hypercube over the parameter space given by theta.
In other words, reference samples are drawn from the
following ``Uniform(low=theta.min(dim=-1),high=theta.max(dim=-1))``.
distance: the distance metric to use when computing the distance.
Should be a callable function that accepts two tensors and
computes the distance between them, e.g. given two tensors
of shape ``(batch, 3)`` and ``(batch,3)``, this function should
return ``(batch,1)`` distance values.
Possible values: ``sbi.utils.metrics.l1`` or
``sbi.utils.metrics.l2``. ``l2`` is the default.
num_bins: number of bins to use for the credibility values.
If ``None``, then ``num_tarp_samples // 10`` bins are used.
z_score_theta : whether to normalize parameters before coverage test.
Returns:
ecp: Expected coverage probability (``ecp``), see equation 4 of the paper
alpha: grid of credibility values, see equation 2 of the paper
"""
num_posterior_samples, num_tarp_samples, _ = posterior_samples.shape
input_device = posterior_samples.device
assert references.shape == thetas.shape, (
"references must have the same shape as thetas"
)
if num_bins is None:
num_bins = num_tarp_samples // 10
if z_score_theta:
lo = thetas.min(dim=0, keepdim=True).values # min over batch
hi = thetas.max(dim=0, keepdim=True).values # max over batch
posterior_samples = (posterior_samples - lo) / (hi - lo + 1e-10)
thetas = (thetas - lo) / (hi - lo + 1e-10)
references = (references - lo) / (hi - lo + 1e-10)
# distances between references and samples
sample_dists = distance(references, posterior_samples)
# distances between references and true values
theta_dists = distance(references, thetas)
# compute coverage, f in algorithm 2
coverage_values = (
torch.sum(sample_dists < theta_dists, dim=0) / num_posterior_samples
)
# enforce execution on the CPU due to
# https://github.com/pytorch/pytorch/issues/69519
hist, alpha_grid = torch.histogram(
coverage_values.cpu(), density=True, bins=num_bins
)
# return all tensors to input_device to keep contract valid
hist, alpha_grid = hist.to(input_device), alpha_grid.to(input_device)
# calculate empirical CDF via cumsum and normalize
ecp = torch.cumsum(hist, dim=0) / hist.sum()
# add 0 to the beginning of the ecp curve to match the alpha grid
ecp = torch.cat([torch.zeros((1,), device=input_device), ecp])
return ecp, alpha_grid
def get_tarp_references(thetas: Tensor) -> Tensor:
"""Returns reference points for the TARP diagnostic, sampled from a uniform."""
# obtain min/max per dimension of theta
lo = thetas.min(dim=0).values # min for each theta dimension
hi = thetas.max(dim=0).values # max for each theta dimension
refpdf = torch.distributions.Uniform(low=lo, high=hi)
# sample one reference point for each entry in theta
return refpdf.sample(torch.Size([thetas.shape[0]]))
[docs]
def check_tarp(
ecp: Tensor,
alpha: Tensor,
) -> Tuple[float, float]:
r"""check the obtained TARP credibitlity levels and
expected coverage probabilities. This will help to uncover underdispersed,
well covering or overdispersed posteriors.
Args:
ecp: expected coverage probabilities computed with the TARP method,
i.e. first output of ``run_tarp``.
alpha: credibility levels $\alpha$, i.e. second output of ``run_tarp``.
Returns:
atc: area to curve, the difference between the ecp and alpha curve for
alpha values larger than 0.5. This number should be close to ``0``.
Values larger than ``0`` indicated overdispersed distributions (i.e.
the estimated posterior is too wide). Values smaller than ``0``
indicate underdispersed distributions (i.e. the estimated posterior
is too narrow). Note, this property of the ecp curve can also
indicate if the posterior is biased, see figure 2 of the paper for
details (https://arxiv.org/abs/2302.03026).
ks prob: p-value for a two sample Kolmogorov-Smirnov test. The null
hypothesis of this test is that the two distributions (ecp and
alpha) are identical, i.e. are produced by one common CDF. If they
were, the p-value should be close to ``1``. Commonly, people reject
the null if p-value is below 0.05!
"""
# get the index of the middle of the alpha grid
midindex = alpha.shape[0] // 2
dalpha = alpha[1] - alpha[0]
atc = ((ecp[midindex:] - alpha[midindex:]) * dalpha).sum().item()
# Kolmogorov-Smirnov test between ecp and alpha
kstest_pvals: float = kstest(ecp.numpy(), alpha.numpy())[1] # type: ignore
return atc, kstest_pvals