sbi.analysis.conditional_pairplot

sbi.analysis.conditional_pairplot#

conditional_pairplot(density, condition, limits, points=None, subset=None, resolution=50, figsize=(10, 10), labels=None, ticks=None, fig=None, axes=None, **kwargs)[source]#

Plot conditional distribution given all other parameters.

The conditionals can be interpreted as slices through the density at a location given by condition.

For example: Say we have a 3D density with parameters \(\theta_0\), \(\theta_1\), \(\theta_2\) and a condition \(c\) passed by the user in the condition argument. For the plot of \(\theta_0\) on the diagonal, this will plot the conditional \(p(\theta_0 | \theta_1=c[1], \theta_2=c[2])\). For the upper diagonal of \(\theta_1\) and \(\theta_2\), it will plot \(p(\theta_1, \theta_2 | \theta_0=c[0])\). All other diagonals and upper-diagonals are built in the corresponding way.

Parameters:
  • density (Any) – Probability density with a log_prob() method.

  • condition (Tensor) – Condition that all but the one/two regarded parameters are fixed to. The condition should be of shape (1, dim_theta), i.e. it could e.g. be a sample from the posterior distribution.

  • limits (List | Tensor) – Limits in between which each parameter will be evaluated.

  • points (List[ndarray] | List[Tensor] | ndarray | Tensor | None) – Additional points to scatter.

  • subset (List[int] | None) – List containing the dimensions to plot. E.g. subset=[1,3] will plot plot only the 1st and 3rd dimension but will discard the 0th and 2nd (and, if they exist, the 4th, 5th and so on)

  • resolution (int) – Resolution of the grid at which we evaluate the pdf.

  • figsize (Tuple) – Size of the entire figure.

  • labels (List[str] | None) – List of strings specifying the names of the parameters.

  • ticks (List | Tensor | None) – Position of the ticks.

  • points_colors – Colors of the points.

  • fig – matplotlib figure to plot on.

  • axes – matplotlib axes corresponding to fig.

  • **kwargs – Additional arguments to adjust the plot, e.g., samples_colors, points_colors and many more, see the source code in _get_default_opts() in sbi.analysis.plot for details.

Returns: figure and axis of posterior distribution plot