Source code for sbi.inference.trainers.npe.npe_c

# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from typing import Callable, Dict, Literal, Optional, Union

import torch
from torch import Tensor, eye, ones
from torch.distributions import Distribution, MultivariateNormal, Uniform
from torch.utils.tensorboard.writer import SummaryWriter

from sbi.inference.posteriors.direct_posterior import DirectPosterior
from sbi.inference.trainers.npe.npe_base import (
    PosteriorEstimatorTrainer,
)
from sbi.neural_nets.estimators.base import (
    ConditionalDensityEstimator,
    ConditionalEstimatorBuilder,
)
from sbi.neural_nets.estimators.mixture_density_estimator import (
    MixtureDensityEstimator,
)
from sbi.neural_nets.estimators.mog import MoG
from sbi.neural_nets.estimators.shape_handling import (
    reshape_to_batch_event,
    reshape_to_sample_batch_event,
)
from sbi.sbi_types import Tracker
from sbi.utils import (
    batched_mixture_mv,
    batched_mixture_vmv,
    check_dist_class,
    clamp_and_warn,
    del_entries,
    repeat_rows,
)
from sbi.utils.torchutils import BoxUniform, assert_all_finite


[docs] class NPE_C(PosteriorEstimatorTrainer): r"""Neural Posterior Estimation algorithm (NPE-C) as in Greenberg et al. (2019) [1]. NPE-C (also known as APT - Automatic Posterior Transformation, aka SNPE-C) trains a neural network over multiple rounds to directly approximate the posterior for a specific observation x_o. In the first round, NPE-C is equivalent to other NPE methods and is fully amortized (direct inference for any new observation). After the first round, NPE-C automatically selects between two loss variants depending on the chosen density estimator: the non-atomic loss (for Mixture of Gaussians) which is stable and avoids leakage, or the atomic loss (for flows) which is more flexible but may suffer from leakage issues. For single-round inference, NPE-A, NPE-B, and NPE-C are equivalent and use plain NLL loss. [1] *Automatic Posterior Transformation for Likelihood-free Inference*, Greenberg et al., ICML 2019, https://arxiv.org/abs/1905.07488. Example: -------- :: import torch from sbi.inference import NPE_C from sbi.utils import BoxUniform # 1. Setup simulator, prior, and observation prior = BoxUniform(low=torch.zeros(3), high=torch.ones(3)) x_o = torch.randn(1, 3) # Observed data def simulator(theta): return theta + torch.randn_like(theta) * 0.1 # 2. Multi-round inference inference = NPE_C(prior=prior) proposal = prior for round_idx in range(5): theta = proposal.sample((100,)) x = simulator(theta) density_estimator = inference.append_simulations(theta, x).train() posterior = inference.build_posterior(density_estimator) proposal = posterior.set_default_x(x_o) # 3. Sample from final posterior samples = posterior.sample((1000,), x=x_o) """ def __init__( self, prior: Optional[Distribution] = None, density_estimator: Union[ Literal["nsf", "maf", "mdn", "made"], ConditionalEstimatorBuilder[ConditionalDensityEstimator], ] = "maf", device: str = "cpu", logging_level: Union[int, str] = "WARNING", summary_writer: Optional[SummaryWriter] = None, tracker: Optional[Tracker] = None, show_progress_bars: bool = True, ): r"""Initialize NPE-C. Args: prior: A probability distribution that expresses prior knowledge about the parameters, e.g. which ranges are meaningful for them. density_estimator: If it is a string, use a pre-configured network of the provided type (one of nsf, maf, mdn, made). Alternatively, a function that builds a custom neural network, which adheres to `ConditionalEstimatorBuilder` protocol can be provided. The function will be called with the first batch of simulations (theta, x), which can thus be used for shape inference and potentially for z-scoring. The density estimator needs to provide the methods `.log_prob` and `.sample()` and must return a `ConditionalDensityEstimator`. device: Training device, e.g., "cpu", "cuda" or "cuda:{0, 1, ...}". logging_level: Minimum severity of messages to log. One of the strings INFO, WARNING, DEBUG, ERROR and CRITICAL. summary_writer: Deprecated alias for the TensorBoard summary writer. Use ``tracker`` instead. tracker: Tracking adapter used to log training metrics. If None, a TensorBoard tracker is used with a default log directory. show_progress_bars: Whether to show a progressbar during training. """ kwargs = del_entries(locals(), entries=("self", "__class__")) super().__init__(**kwargs)
[docs] def train( self, num_atoms: int = 10, training_batch_size: int = 200, learning_rate: float = 5e-4, validation_fraction: float = 0.1, stop_after_epochs: int = 20, max_num_epochs: int = 2**31 - 1, clip_max_norm: Optional[float] = 5.0, calibration_kernel: Optional[Callable] = None, resume_training: bool = False, force_first_round_loss: bool = False, discard_prior_samples: bool = False, use_combined_loss: bool = False, retrain_from_scratch: bool = False, show_train_summary: bool = False, dataloader_kwargs: Optional[Dict] = None, ) -> ConditionalDensityEstimator: r"""Return density estimator that approximates the distribution $p(\theta|x)$. Args: num_atoms: Number of atoms to use for classification. training_batch_size: Training batch size. learning_rate: Learning rate for Adam optimizer. validation_fraction: The fraction of data to use for validation. stop_after_epochs: The number of epochs to wait for improvement on the validation set before terminating training. max_num_epochs: Maximum number of epochs to run. If reached, we stop training even when the validation loss is still decreasing. Otherwise, we train until validation loss increases (see also ``stop_after_epochs``). clip_max_norm: Value at which to clip the total gradient norm in order to prevent exploding gradients. Use None for no clipping. calibration_kernel: A function to calibrate the loss with respect to the simulations ``x``. See Lueckmann, Gonçalves et al., NeurIPS 2017. resume_training: Can be used in case training time is limited, e.g. on a cluster. If ``True``, the split between train and validation set, the optimizer, the number of epochs, and the best validation log-prob will be restored from the last time ``.train()`` was called. force_first_round_loss: If ``True``, train with maximum likelihood, i.e., potentially ignoring the correction for using a proposal distribution different from the prior. discard_prior_samples: Whether to discard samples simulated in round 1, i.e. from the prior. Training may be sped up by ignoring such less targeted samples. use_combined_loss: Whether to train the neural net also on prior samples using maximum likelihood in addition to training it on all samples using atomic loss. The extra MLE loss helps prevent density leaking with bounded priors. retrain_from_scratch: Whether to retrain the conditional density estimator for the posterior from scratch each round. show_train_summary: Whether to print the number of epochs and validation loss and leakage after the training. dataloader_kwargs: Additional or updated kwargs to be passed to the training and validation dataloaders (like, e.g., a collate_fn) Returns: Density estimator that approximates the distribution $p(\theta|x)$. """ if len(self._data_round_index) == 0: raise RuntimeError( "No simulations found. You must call .append_simulations() " "before calling .train()." ) # WARNING: sneaky trick ahead. We proxy the parent's `train` here, # requiring the signature to have `num_atoms`, save it for use below, and # continue. It's sneaky because we are using the object (self) as a namespace # to pass arguments between functions, and that's implicit state management. self._num_atoms = num_atoms self._use_combined_loss = use_combined_loss kwargs = del_entries( locals(), entries=("self", "__class__", "num_atoms", "use_combined_loss"), ) self._round = max(self._data_round_index) if self._round > 0: # Set the proposal to the last proposal that was passed by the user. For # atomic SNPE, it does not matter what the proposal is. For non-atomic # SNPE, we only use the latest data that was passed, i.e. the one from the # last proposal. proposal = self._proposal_roundwise[-1] self.use_non_atomic_loss = ( isinstance(proposal, DirectPosterior) and isinstance(proposal.posterior_estimator, MixtureDensityEstimator) and isinstance(self._neural_net, MixtureDensityEstimator) and check_dist_class( self._prior, class_to_check=(Uniform, MultivariateNormal) )[0] ) algorithm = "non-atomic" if self.use_non_atomic_loss else "atomic" print(f"Using SNPE-C with {algorithm} loss") if self.use_non_atomic_loss: # Take care of z-scoring, pre-compute and store prior terms. self._set_state_for_mog_proposal() return super().train(**kwargs)
def _set_state_for_mog_proposal(self) -> None: """Set state variables that are used at each training step of non-atomic SNPE-C. Three things are computed: 1) Check if z-scoring was requested. We check if the MixtureDensityEstimator has an input transform enabled via the has_input_transform property. 2) Define a (potentially standardized) prior. It's standardized if z-scoring had been requested. 3) Compute (Precision * mean) for the prior. This quantity is used at every training step if the prior is Gaussian. """ # Check if z-scoring is enabled on the MixtureDensityEstimator assert isinstance(self._neural_net, MixtureDensityEstimator) self.z_score_theta = self._neural_net.has_input_transform self._set_maybe_z_scored_prior() if isinstance(self._maybe_z_scored_prior, MultivariateNormal): self.prec_m_prod_prior = torch.mv( self._maybe_z_scored_prior.precision_matrix, # type: ignore self._maybe_z_scored_prior.loc, # type: ignore ) def _set_maybe_z_scored_prior(self) -> None: r"""Compute and store potentially standardized prior (if z-scoring was done). The proposal posterior is: $pp(\theta|x) = 1/Z * q(\theta|x) * prop(\theta) / p(\theta)$ Let's denote z-scored theta by `a`: a = (theta - mean) / std Then pp'(a|x) = 1/Z_2 * q'(a|x) * prop'(a) / p'(a)$ The ' indicates that the evaluation occurs in standardized space. The constant scaling factor has been absorbed into Z_2. From the above equation, we see that we need to evaluate the prior **in standardized space**. We build the standardized prior in this function. The standardize transform that is applied to the samples theta does not use the exact prior mean and std (due to implementation issues). Hence, the z-scored prior will not be exactly have mean=0 and std=1. """ if self.z_score_theta: # Get z-score parameters from the MixtureDensityEstimator # The transform is: z = (theta - shift) / scale # where shift = mean (estimated from samples) and scale = std (estimated) assert isinstance(self._neural_net, MixtureDensityEstimator) shift = self._neural_net._transform_shift scale = self._neural_net._transform_scale # The MixtureDensityEstimator uses: z = (theta - shift) / scale # where shift = mean and scale = std (estimated from training data) estim_prior_mean = shift estim_prior_std = scale # Compute the discrepancy of the true prior mean and std and the mean and # std that was empirically estimated from samples. # N(theta|m,s) = N((theta-m_e)/s_e|(m-m_e)/s_e, s/s_e) # Above: m,s are true prior mean and std. m_e,s_e are estimated prior mean # and std (estimated from samples and used to build standardize transform). almost_zero_mean = (self._prior.mean - estim_prior_mean) / estim_prior_std almost_one_std = torch.sqrt(self._prior.variance) / estim_prior_std if isinstance(self._prior, MultivariateNormal): self._maybe_z_scored_prior = MultivariateNormal( almost_zero_mean, torch.diag(almost_one_std) ) else: range_ = torch.sqrt(almost_one_std * 3.0) self._maybe_z_scored_prior = BoxUniform( almost_zero_mean - range_, almost_zero_mean + range_ ) else: self._maybe_z_scored_prior = self._prior def _log_prob_proposal_posterior( self, theta: Tensor, x: Tensor, masks: Tensor, proposal: DirectPosterior, ) -> Tensor: """Return the log-probability of the proposal posterior. If the proposal is a MoG, the density estimator is a MoG, and the prior is either Gaussian or uniform, we use non-atomic loss. Else, use atomic loss (which suffers from leakage). Args: theta: Batch of parameters θ. x: Batch of data. masks: Mask that is True for prior samples in the batch in order to train them with prior loss. proposal: Proposal distribution. Returns: Log-probability of the proposal posterior. """ if self.use_non_atomic_loss: if not isinstance(self._neural_net, MixtureDensityEstimator): raise ValueError( "The density estimator must be a MixtureDensityEstimator " "for non-atomic loss." ) return self._log_prob_proposal_posterior_mog(theta, x, proposal) else: if not hasattr(self._neural_net, "log_prob"): raise ValueError( "The neural estimator must have a log_prob method, for\ atomic loss. It should at best follow the \ sbi.neural_nets 'DensityEstiamtor' interface." ) return self._log_prob_proposal_posterior_atomic(theta, x, masks) def _log_prob_proposal_posterior_atomic( self, theta: Tensor, x: Tensor, masks: Tensor ): """Return log probability of the proposal posterior for atomic proposals. We have two main options when evaluating the proposal posterior. (1) Generate atoms from the proposal prior. (2) Generate atoms from a more targeted distribution, such as the most recent posterior. If we choose the latter, it is likely beneficial not to do this in the first round, since we would be sampling from a randomly-initialized neural density estimator. Args: theta: Batch of parameters θ. x: Batch of data. masks: Mask that is True for prior samples in the batch in order to train them with prior loss. Returns: Log-probability of the proposal posterior. """ batch_size = theta.shape[0] num_atoms = int( clamp_and_warn("num_atoms", self._num_atoms, min_val=2, max_val=batch_size) ) # Each set of parameter atoms is evaluated using the same x, # so we repeat rows of the data x, e.g. [1, 2] -> [1, 1, 2, 2] repeated_x = repeat_rows(x, num_atoms) # To generate the full set of atoms for a given item in the batch, # we sample without replacement num_atoms - 1 times from the rest # of the theta in the batch. probs = ones(batch_size, batch_size) * (1 - eye(batch_size)) / (batch_size - 1) choices = torch.multinomial(probs, num_samples=num_atoms - 1, replacement=False) contrasting_theta = theta[choices] # We can now create our sets of atoms from the contrasting parameter sets # we have generated. atomic_theta = torch.cat((theta[:, None, :], contrasting_theta), dim=1).reshape( batch_size * num_atoms, -1 ) # Get (batch_size * num_atoms) log prob prior evals. log_prob_prior = self._prior.log_prob(atomic_theta) log_prob_prior = log_prob_prior.reshape(batch_size, num_atoms) assert_all_finite(log_prob_prior, "prior eval") # Evaluate large batch giving (batch_size * num_atoms) log prob posterior evals. atomic_theta = reshape_to_sample_batch_event( atomic_theta, atomic_theta.shape[1:] ) repeated_x = reshape_to_batch_event( repeated_x, self._neural_net.condition_shape ) log_prob_posterior = self._neural_net.log_prob(atomic_theta, repeated_x) assert_all_finite(log_prob_posterior, "posterior eval") log_prob_posterior = log_prob_posterior.reshape(batch_size, num_atoms) # Compute unnormalized proposal posterior. unnormalized_log_prob = log_prob_posterior - log_prob_prior # Normalize proposal posterior across discrete set of atoms. log_prob_proposal_posterior = unnormalized_log_prob[:, 0] - torch.logsumexp( unnormalized_log_prob, dim=-1 ) assert_all_finite(log_prob_proposal_posterior, "proposal posterior eval") # XXX This evaluates the posterior on _all_ prior samples if self._use_combined_loss: theta = reshape_to_sample_batch_event(theta, self._neural_net.input_shape) x = reshape_to_batch_event(x, self._neural_net.condition_shape) log_prob_posterior_non_atomic = self._neural_net.log_prob(theta, x) # squeeze to remove sample dimension, which is always one during the loss # evaluation of `SNPE_C` (because we have one theta vector per x vector). log_prob_posterior_non_atomic = log_prob_posterior_non_atomic.squeeze(dim=0) masks = masks.reshape(-1) log_prob_proposal_posterior = ( masks * log_prob_posterior_non_atomic + log_prob_proposal_posterior ) return log_prob_proposal_posterior def _log_prob_proposal_posterior_mog( self, theta: Tensor, x: Tensor, proposal: DirectPosterior ) -> Tensor: """Return log-probability of the proposal posterior for MoG proposal. For MoG proposals and MoG density estimators, this can be done in closed form and does not require atomic loss (i.e. there will be no leakage issues). Notation: m are mean vectors. prec are precision matrices. cov are covariance matrices. _p at the end indicates that it is the proposal. _d indicates that it is the density estimator. _pp indicates the proposal posterior. All tensors will have shapes (batch_dim, num_components, ...) Args: theta: Batch of parameters θ. x: Batch of data. proposal: Proposal distribution. Returns: Log-probability of the proposal posterior. """ # Get the proposal MoG at the default_x assert isinstance(proposal.posterior_estimator, MixtureDensityEstimator) assert proposal.default_x is not None, "Proposal must have default_x set" mog_p = proposal.posterior_estimator.get_uncorrected_mog(proposal.default_x) norm_logits_p = mog_p.log_weights # Already normalized m_p = mog_p.means prec_p = mog_p.precisions # Get the density estimator MoG at the training data x assert isinstance(self._neural_net, MixtureDensityEstimator) mog_d = self._neural_net.get_uncorrected_mog(x) norm_logits_d = mog_d.log_weights # Already normalized m_d = mog_d.means prec_d = mog_d.precisions # z-score theta if z-scoring was requested. theta = self._maybe_z_score_theta(theta) # Compute the MoG parameters of the proposal posterior. ( logits_pp, m_pp, prec_pp, cov_pp, ) = self._automatic_posterior_transformation( norm_logits_p, m_p, prec_p, norm_logits_d, m_d, prec_d ) # Create MoG for proposal posterior and compute log_prob # We need precision_factors for MoG, compute via Cholesky precf_pp = torch.linalg.cholesky(prec_pp, upper=True) mog_pp = MoG( logits=logits_pp, means=m_pp, precisions=prec_pp, precision_factors=precf_pp, ) # Compute the log_prob of theta under the product. log_prob_proposal_posterior = mog_pp.log_prob(theta) assert_all_finite( log_prob_proposal_posterior, """the evaluation of the MoG proposal posterior. This is likely due to a numerical instability in the training procedure. Please create an issue on Github.""", ) return log_prob_proposal_posterior def _automatic_posterior_transformation( self, logits_p: Tensor, means_p: Tensor, precisions_p: Tensor, logits_d: Tensor, means_d: Tensor, precisions_d: Tensor, ): r"""Returns the MoG parameters of the proposal posterior. The proposal posterior is: $pp(\theta|x) = 1/Z * q(\theta|x) * prop(\theta) / p(\theta)$ In words: proposal posterior = posterior estimate * proposal / prior. If the posterior estimate and the proposal are MoG and the prior is either Gaussian or uniform, we can solve this in closed-form. The is implemented in this function. This function implements Appendix A1 from Greenberg et al. 2019. We have to build L*K components. How do we do this? Example: proposal has two components, density estimator has three components. Let's call the two components of the proposal i,j and the three components of the density estimator x,y,z. We have to multiply every component of the proposal with every component of the density estimator. So, what we do is: 1) for the proposal, build: i,i,i,j,j,j. Done with torch.repeat_interleave() 2) for the density estimator, build: x,y,z,x,y,z. Done with torch.repeat() 3) Multiply them with simple matrix operations. Args: logits_p: Component weight of each Gaussian of the proposal. means_p: Mean of each Gaussian of the proposal. precisions_p: Precision matrix of each Gaussian of the proposal. logits_d: Component weight for each Gaussian of the density estimator. means_d: Mean of each Gaussian of the density estimator. precisions_d: Precision matrix of each Gaussian of the density estimator. Returns: (Component weight, mean, precision matrix, covariance matrix) of each Gaussian of the proposal posterior. Has L*K terms (proposal has L terms, density estimator has K terms). """ precisions_pp, covariances_pp = self._precisions_proposal_posterior( precisions_p, precisions_d ) means_pp = self._means_proposal_posterior( covariances_pp, means_p, precisions_p, means_d, precisions_d ) logits_pp = self._logits_proposal_posterior( means_pp, precisions_pp, covariances_pp, logits_p, means_p, precisions_p, logits_d, means_d, precisions_d, ) return logits_pp, means_pp, precisions_pp, covariances_pp def _precisions_proposal_posterior( self, precisions_p: Tensor, precisions_d: Tensor ): """Return the precisions and covariances of the proposal posterior. Args: precisions_p: Precision matrices of the proposal distribution. precisions_d: Precision matrices of the density estimator. Returns: (Precisions, Covariances) of the proposal posterior. L*K terms. """ num_comps_p = precisions_p.shape[1] num_comps_d = precisions_d.shape[1] precisions_p_rep = precisions_p.repeat_interleave(num_comps_d, dim=1) precisions_d_rep = precisions_d.repeat(1, num_comps_p, 1, 1) precisions_pp = precisions_p_rep + precisions_d_rep if isinstance(self._maybe_z_scored_prior, MultivariateNormal): precisions_pp -= self._maybe_z_scored_prior.precision_matrix covariances_pp = torch.inverse(precisions_pp) return precisions_pp, covariances_pp def _means_proposal_posterior( self, covariances_pp: Tensor, means_p: Tensor, precisions_p: Tensor, means_d: Tensor, precisions_d: Tensor, ): """Return the means of the proposal posterior. means_pp = C_ix * (P_i * m_i + P_x * m_x - P_o * m_o). Args: covariances_pp: Covariance matrices of the proposal posterior. means_p: Means of the proposal distribution. precisions_p: Precision matrices of the proposal distribution. means_d: Means of the density estimator. precisions_d: Precision matrices of the density estimator. Returns: Means of the proposal posterior. L*K terms. """ num_comps_p = precisions_p.shape[1] num_comps_d = precisions_d.shape[1] # First, compute the product P_i * m_i and P_j * m_j prec_m_prod_p = batched_mixture_mv(precisions_p, means_p) prec_m_prod_d = batched_mixture_mv(precisions_d, means_d) # Repeat them to allow for matrix operations: same trick as for the precisions. prec_m_prod_p_rep = prec_m_prod_p.repeat_interleave(num_comps_d, dim=1) prec_m_prod_d_rep = prec_m_prod_d.repeat(1, num_comps_p, 1) # Means = C_ij * (P_i * m_i + P_x * m_x - P_o * m_o). summed_cov_m_prod_rep = prec_m_prod_p_rep + prec_m_prod_d_rep if isinstance(self._maybe_z_scored_prior, MultivariateNormal): summed_cov_m_prod_rep -= self.prec_m_prod_prior means_pp = batched_mixture_mv(covariances_pp, summed_cov_m_prod_rep) return means_pp def _maybe_z_score_theta(self, theta: Tensor) -> Tensor: """Return potentially standardized theta if z-scoring was requested.""" if self.z_score_theta: assert isinstance(self._neural_net, MixtureDensityEstimator) theta = self._neural_net._transform_input(theta) return theta @staticmethod def _logits_proposal_posterior( means_pp: Tensor, precisions_pp: Tensor, covariances_pp: Tensor, logits_p: Tensor, means_p: Tensor, precisions_p: Tensor, logits_d: Tensor, means_d: Tensor, precisions_d: Tensor, ): """Return the component weights (i.e. logits) of the proposal posterior. Args: means_pp: Means of the proposal posterior. precisions_pp: Precision matrices of the proposal posterior. covariances_pp: Covariance matrices of the proposal posterior. logits_p: Component weights (i.e. logits) of the proposal distribution. means_p: Means of the proposal distribution. precisions_p: Precision matrices of the proposal distribution. logits_d: Component weights (i.e. logits) of the density estimator. means_d: Means of the density estimator. precisions_d: Precision matrices of the density estimator. Returns: Component weights of the proposal posterior. L*K terms. """ num_comps_p = precisions_p.shape[1] num_comps_d = precisions_d.shape[1] # Compute log(alpha_i * beta_j) logits_p_rep = logits_p.repeat_interleave(num_comps_d, dim=1) logits_d_rep = logits_d.repeat(1, num_comps_p) logit_factors = logits_p_rep + logits_d_rep # Compute sqrt(det()/(det()*det())) logdet_covariances_pp = torch.logdet(covariances_pp) logdet_covariances_p = -torch.logdet(precisions_p) logdet_covariances_d = -torch.logdet(precisions_d) # Repeat the proposal and density estimator terms such that there are LK terms. # Same trick as has been used above. logdet_covariances_p_rep = logdet_covariances_p.repeat_interleave( num_comps_d, dim=1 ) logdet_covariances_d_rep = logdet_covariances_d.repeat(1, num_comps_p) log_sqrt_det_ratio = 0.5 * ( logdet_covariances_pp - (logdet_covariances_p_rep + logdet_covariances_d_rep) ) # Compute for proposal, density estimator, and proposal posterior: # mu_i.T * P_i * mu_i exponent_p = batched_mixture_vmv(precisions_p, means_p) exponent_d = batched_mixture_vmv(precisions_d, means_d) exponent_pp = batched_mixture_vmv(precisions_pp, means_pp) # Extend proposal and density estimator exponents to get LK terms. exponent_p_rep = exponent_p.repeat_interleave(num_comps_d, dim=1) exponent_d_rep = exponent_d.repeat(1, num_comps_p) exponent = -0.5 * (exponent_p_rep + exponent_d_rep - exponent_pp) logits_pp = logit_factors + log_sqrt_det_ratio + exponent return logits_pp