# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
import torch
from torch import Tensor
from torch.distributions import Distribution
from zuko.distributions import NormalizingFlow
from sbi.inference.potentials.base_potential import BasePotential
from sbi.inference.potentials.vector_field_adaptor import (
get_guidance_method,
get_iid_method,
)
from sbi.neural_nets.estimators import ConditionalVectorFieldEstimator
from sbi.neural_nets.estimators.shape_handling import (
reshape_to_batch_event,
reshape_to_sample_batch_event,
)
from sbi.samplers.ode_solvers import build_neural_ode
from sbi.sbi_types import TorchTransform
from sbi.utils.sbiutils import mcmc_transform, within_support
from sbi.utils.torchutils import ensure_theta_batched
class VectorFieldBasedPotential(BasePotential):
def __init__(
self,
vector_field_estimator: ConditionalVectorFieldEstimator,
prior: Optional[Distribution], # type: ignore
x_o: Optional[Tensor] = None,
device: Union[str, torch.device] = "cpu",
iid_method: Literal["fnpe", "gauss", "auto_gauss", "jac_gauss"] = "auto_gauss",
iid_params: Optional[Dict[str, Any]] = None,
neural_ode_backend: Literal["zuko"] = "zuko",
neural_ode_kwargs: Optional[Dict[str, Any]] = None,
):
r"""
Potential class for vector field estimators. Implements the potential function
via the probability flow ODE and the gradient via the score estimator. If
the vector field estimator does not define the score (SCORE_DEFINED = False),
the gradient is not available and an error is raised.
Note that the potential function is not defined for the iid setting yet.
Args:
vector_field_estimator: The neural network modelling the vector field.
prior: The prior distribution.
x_o: The observed data at which to evaluate the posterior.
device: The device on which to evaluate the potential.
iid_method: Which method to use for computing the score in the iid setting.
We currently support "fnpe", "gauss", "auto_gauss", "jac_gauss".
iid_params: Parameters for the iid method, for arguments see
`IIDScoreFunction`.
neural_ode_backend: The backend to use for the neural ODE. Currently,
only "zuko" is supported.
neural_ode_kwargs: Additional keyword arguments for the neural ODE.
"""
self.vector_field_estimator = vector_field_estimator
self.vector_field_estimator.eval()
self.iid_method = iid_method
self.iid_params = iid_params
neural_ode_kwargs = neural_ode_kwargs or {}
self.neural_ode = build_neural_ode(
self.vector_field_estimator.ode_fn,
self.vector_field_estimator.net,
self.vector_field_estimator.mean_base,
self.vector_field_estimator.std_base,
self.vector_field_estimator.condition_shape,
backend=neural_ode_backend,
t_min=self.vector_field_estimator.t_min,
t_max=self.vector_field_estimator.t_max,
**neural_ode_kwargs,
)
super().__init__(prior, x_o, device=device)
def to(self, device: Union[str, torch.device]) -> None:
"""
Moves score_estimator, prior and x_o to the given device.
It also sets the device attribute to the given device.
Args:
device: Device to move the score_estimator, prior and x_o to.
"""
self.device = device
self.vector_field_estimator.to(device)
if self.prior:
self.prior.to(device) # type: ignore
if self._x_o is not None:
self._x_o = self._x_o.to(device)
def set_x(
self,
x_o: Optional[Tensor],
x_is_iid: Optional[bool] = False,
iid_method: Optional[str] = None,
iid_params: Optional[Dict[str, Any]] = None,
guidance_method: Optional[str] = None,
guidance_params: Optional[Dict[str, Any]] = None,
**ode_kwargs,
):
"""
Set the observed data and whether it is IID.
Rebuilds the continuous normalizing flow if the observed data is set.
Args:
x_o: The observed data.
x_is_iid: Whether the observed data is IID (if batch_dim>1).
iid_method: Which method to use for computing the score in the iid setting.
We currently support "fnpe", "gauss", "auto_gauss", "jac_gauss".
iid_params: Parameters for the iid method, for arguments see
`IIDScoreFunction`.
ode_kwargs: Additional keyword arguments for the neural ODE.
"""
super().set_x(x_o, x_is_iid)
self.iid_method = iid_method or self.iid_method
self.iid_params = iid_params
self.guidance_method = guidance_method
self.guidance_params = guidance_params
if not x_is_iid and (self._x_o is not None):
self.flow = self.rebuild_flow(**ode_kwargs)
elif self._x_o is not None:
self.flows = self.rebuild_flows_for_batch(**ode_kwargs)
def __call__(
self,
theta: Tensor,
track_gradients: bool = False,
) -> Tensor:
"""
Return the potential (posterior log prob) via probability flow ODE.
Args:
theta: The parameters at which to evaluate the potential.
track_gradients: Whether to track gradients. Default is False.
Returns:
The potential function, i.e., the log probability of the posterior.
"""
if self.guidance_method is not None:
raise NotImplementedError(
"Potential evaluation for guidance is not supported yet."
)
theta = ensure_theta_batched(torch.as_tensor(theta))
theta_density_estimator = reshape_to_sample_batch_event(
theta, theta.shape[1:], leading_is_sample=True
)
self.vector_field_estimator.eval()
with torch.set_grad_enabled(track_gradients):
if self.x_is_iid:
assert self.prior is not None, (
"Prior is required for evaluating log_prob with iid observations."
)
assert self.flows is not None, (
"Flows for each iid x are required for evaluating log_prob."
)
num_iid = self.x_o.shape[0] # number of iid samples
iid_posteriors_prob = torch.sum(
torch.stack(
[
flow.log_prob(theta_density_estimator).squeeze(-1)
for flow in self.flows
],
dim=0,
),
dim=0,
)
# Apply the adjustment for iid observations i.e. we have to subtract
# (num_iid-1) times the log prior.
log_probs = iid_posteriors_prob - (num_iid - 1) * self.prior.log_prob(
theta_density_estimator
).squeeze(-1)
else:
log_probs = self.flow.log_prob(theta_density_estimator).squeeze(-1)
# Force probability to be zero outside prior support.
in_prior_support = within_support(self.prior, theta)
masked_log_prob = torch.where(
in_prior_support,
log_probs,
torch.tensor(float("-inf"), dtype=torch.float32, device=self.device),
)
return masked_log_prob
def gradient(
self,
theta: Tensor,
time: Optional[Tensor] = None,
track_gradients: bool = False,
) -> Tensor:
r"""Returns the potential function gradient for score-based methods.
Args:
theta: The parameters at which to evaluate the potential gradient.
time: The diffusion time. If None, then `t_min` of the
self.vector_field_estimator is used
(i.e. we evaluate the gradient of the actual data distribution).
track_gradients: Whether to track gradients. Default is False.
Returns:
The gradient of the potential function.
Raises:
ValueError: If the score is not defined for this vector field estimator.
"""
if not self.vector_field_estimator.SCORE_DEFINED:
raise ValueError(
"Gradient is not available since the score"
"is not defined for this vector field estimator."
)
device = theta.device
if time is None:
time = torch.tensor([self.vector_field_estimator.t_min])
assert time is not None
if self._x_o is None:
raise ValueError(
"No observed data x_o is available. Please reinitialize"
"the potential or manually set self._x_o."
)
if self.guidance_method is not None:
score_wrapper, config_cls = get_guidance_method(self.guidance_method)
config_params = config_cls(**(self.guidance_params or {}))
# Note to make this cross compatible with IID we need make this
# wrapper more like a proper estimator.
vf_estimator = score_wrapper(
self.vector_field_estimator,
self.prior,
config=config_params,
device=device,
)
else:
vf_estimator = self.vector_field_estimator
with torch.set_grad_enabled(track_gradients):
if not self.x_is_iid or self._x_o.shape[0] == 1:
score = vf_estimator.score(
input=theta, condition=self.x_o, t=time.to(device)
)
else:
assert self.prior is not None, "Prior is required for iid methods."
iid_method = get_iid_method(self.iid_method)
score_fn_iid = iid_method(
vf_estimator,
self.prior,
device=device,
**(self.iid_params or {}),
)
score = score_fn_iid(theta, self.x_o, time)
return score
def rebuild_flow(self, **kwargs) -> NormalizingFlow:
"""
Rebuilds the continuous normalizing flow. This is used when
a new default x is set, or to evaluate the log probs at higher precision.
"""
if self._x_o is None:
raise ValueError(
"No observed data x_o is available. Please reinitialize"
"the potential or manually set self._x_o."
)
x_density_estimator = reshape_to_batch_event(
self.x_o, event_shape=self.vector_field_estimator.condition_shape
)
flow = self.neural_ode(x_density_estimator, **kwargs)
return flow
def rebuild_flows_for_batch(self, **kwargs) -> List[NormalizingFlow]:
"""
Rebuilds the continuous normalizing flows for each iid in x_o. This is used when
a new default x_o is set, or to evaluate the log probs at higher precision.
"""
if self._x_o is None:
raise ValueError(
"No observed data x_o is available. Please reinitialize "
"the potential or manually set self._x_o."
)
flows = []
for i in range(self._x_o.shape[0]):
iid_x = self._x_o[i]
x_density_estimator = reshape_to_batch_event(
iid_x, event_shape=self.vector_field_estimator.condition_shape
)
flow = self.neural_ode(x_density_estimator, **kwargs)
flows.append(flow)
return flows
[docs]
def vector_field_estimator_based_potential(
vector_field_estimator: ConditionalVectorFieldEstimator,
prior: Optional[Distribution],
x_o: Optional[Tensor],
enable_transform: bool = True,
**kwargs,
) -> Tuple[VectorFieldBasedPotential, TorchTransform]:
r"""Returns the potential function gradient for vector field estimators.
Args:
vector_field_estimator: The neural network modelling the vector field.
prior: The prior distribution.
x_o: The observed data at which to evaluate the vector field.
enable_transform: Whether to enable transforms. Not supported yet.
**kwargs: Additional keyword arguments passed to
`VectorFieldBasedPotential`.
Returns:
The potential function and a transformation that maps
to unconstrained space.
"""
device = str(next(vector_field_estimator.parameters()).device)
potential_fn = VectorFieldBasedPotential(
vector_field_estimator, prior, x_o, device=device, **kwargs
)
if prior is not None:
theta_transform = mcmc_transform(
prior, device=device, enable_transform=enable_transform
)
else:
theta_transform = torch.distributions.transforms.identity_transform
return potential_fn, theta_transform
class DifferentiablePotentialFunction(torch.autograd.Function):
"""
A wrapper of `VectorFieldBasedPotential` with a custom autograd function
to compute the gradient of log_prob with respect to theta. Instead of
backpropagating through the continuous normalizing flow, we use the gradient
of the score estimator.
"""
@staticmethod
def forward(ctx, input, call_function, gradient_function):
"""
Computes the potential normally.
"""
# Save the methods as callables
ctx.call_function = call_function
ctx.gradient_function = gradient_function
ctx.save_for_backward(input)
# Perform the forward computation
output = call_function(input)
return output
@staticmethod
def backward(ctx, grad_output):
(input,) = ctx.saved_tensors
grad = ctx.gradient_function(input)
# Match dims
while len(grad_output.shape) < len(grad.shape):
grad_output = grad_output.unsqueeze(-1)
grad_input = grad_output * grad
return grad_input, None, None
class CallableDifferentiablePotentialFunction:
"""
This class handles the forward and backward functions from the potential function
that can be passed to DifferentiablePotentialFunction, as torch.autograd.Function
only supports static methods, and so it can't be given the potential class directly.
"""
def __init__(self, vector_field_based_potential: VectorFieldBasedPotential):
self.vector_field_based_potential = vector_field_based_potential
def __call__(self, input):
return DifferentiablePotentialFunction.apply(
input,
self.vector_field_based_potential.__call__,
self.vector_field_based_potential.gradient,
)