Plotting functionality#

Here we will have a look at the different options for finetuning pairplots and marginal_plots.

Lets first draw some samples from the posterior used in a previous tutorial.

import torch
from toy_posterior_for_07_cc import ExamplePosterior

from sbi.analysis import pairplot

posterior = ExamplePosterior()
posterior_samples = posterior.sample((100,))

We will start with the default plot and gradually make it prettier

_ = pairplot(
    posterior_samples,
)
../_images/d0a56b1304f4c27e64274ee907500fc9302c1701d4d8d563d11801c94bad380e.png

Customisation#

The pairplots are split into three regions, the diagonal (diag) and the upper and lower off-diagonal regions(upper and lower). We can pass separate arguments (e.g. hist, kde, scatter) for each region, as well as corresponding style keywords in a dictionary (by using e.g. upper_kwargs). For overall figure stylisation one can use fig_kwargs.

To get a closer look at the potential options, have a look at the following dataclasses.

  • FigOptions dataclass for figure stylisation.

  • ContourOffDiagOptions,HistOffDiagOptions, KdeOffDiagOptions, PlotOffDiagOptions, ScatterOffDiagOptions dataclasses for styling the upper and lower off-diagonal regions.

  • HistDiagOptions, KdeDiagOptions, ScatterDiagOptions for styling the diagonal region.

You can find the dataclasses in analysis/plotting_classes.py.

As illustrated below, we can directly use any matplotlib keywords (such as cmap for images) by passing them in the mpl_kwargs entry of upper_kwargs or diag_kwargs.

Migration Note#

Previously you would pass nested dictionaries to diag_kwargs, upper_kwargs, lower_kwargs, and fig_kwargs arguments. This is still supported for backward compatability, but we recommend using the dataclasses listed above for clarity and autocompletion.

Let’s now make a scatter plot for the upper diagonal, a histogram for the diagonal, and pass the respective dataclasses for both.

from sbi.analysis.plotting_classes import HistDiagOptions, ScatterOffDiagOptions

_ = pairplot(
    posterior_samples,
    limits=[[-3, 3] * 3],
    figsize=(5, 5),
    diag="hist",
    upper="scatter",
    diag_kwargs=HistDiagOptions(
        mpl_kwargs={
            "color": 'tab:blue',
            "histtype": "bar",
            "bins": 10,
            "edgecolor": 'white',
            "linewidth": 1,
            "alpha": 0.6,
            "fill": True,
        }
    ),
    upper_kwargs=ScatterOffDiagOptions(mpl_kwargs={"color": 'tab:blue', "s": 20, "alpha": 0.8}),
    labels=[r"$\theta_1$", r"$\theta_2$", r"$\theta_3$"],
)
../_images/b3d89f0261d33240eeae81845cf3d637c8ee288edb81812596ca4b2d4dce8f76.png

Compare two sets of samples#

By passing a list of samples, we can plot two sets of samples on top of each other.

# draw two different subsets of samples to plot
posterior_samples1 = posterior.sample((20,))
posterior_samples2 = posterior.sample((20,))

_ = pairplot(
    [posterior_samples1, posterior_samples2],
    limits=[[-3, 3] * 3],
    figsize=(5, 5),
    diag=["hist", "hist"],
    upper=["scatter", "scatter"],
    diag_kwargs=HistDiagOptions(
        mpl_kwargs={
            "bins": 10,
            "edgecolor": 'white',
            "linewidth": 1,
            "alpha": 0.6,
            "histtype": "bar",
            "fill": True,
        }
    ),
    upper_kwargs=ScatterOffDiagOptions(mpl_kwargs={"s": 50, "alpha": 0.8}),
    labels=[r"$\theta_1$", r"$\theta_2$", r"$\theta_3$"],
)
../_images/02ad434ec2da70202b912a02071332bc6b9d0d45611436e4f0af0f06ef21844c.png

Multi-layered plots#

We can use the same functionality to make a multi-layered plot using the same set of samples, e.g. a kernel-density estimate on top of a scatter plot.

from sbi.analysis.plotting_classes import FigOptions

_ = pairplot(
    [posterior_samples, posterior_samples],
    limits=[[-3, 3] * 3],
    figsize=(5, 5),
    diag=["hist", None],
    upper=["scatter", "contour"],
    diag_kwargs=HistDiagOptions(
        mpl_kwargs= {
            "bins": 10,
            "color": 'tab:blue',
            "edgecolor": 'white',
            "linewidth": 1,
            "alpha": 0.6,
            "histtype": "bar",
            "fill": True,
        },
    ),
    upper_kwargs=[
        ScatterOffDiagOptions(
            mpl_kwargs={"color": 'tab:blue', "s": 20, "alpha": 0.8},
        ),
        ScatterOffDiagOptions(mpl_kwargs={"cmap": 'Blues_r', "alpha": 0.8, "colors": None}),
    ],
    labels=[r"$\theta_1$", r"$\theta_2$", r"$\theta_3$"],
    fig_kwargs=FigOptions(despine=dict(offset=0)),
)
../_images/5e1c056cdf4a989b9773362a1e9ec551e386df5184b70ed03b5d9a4eec378de9.png

Lower diagonal#

We can add something in the lower off-diagonal as well.

from sbi.analysis.plotting_classes import KdeOffDiagOptions

_ = pairplot(
    [posterior_samples, posterior_samples],
    limits=[[-3, 3] * 3],
    figsize=(5, 5),
    diag=["hist", None],
    upper=["scatter", "contour"],
    lower=["kde", None],
    diag_kwargs=HistDiagOptions(
        mpl_kwargs={
            "bins": 10,
            "color": 'tab:blue',
            "edgecolor": 'white',
            "linewidth": 1,
            "alpha": 0.6,
            "histtype": "bar",
            "fill": True,
        }
    ),
    upper_kwargs=[
        ScatterOffDiagOptions(mpl_kwargs={"color": 'tab:blue', "s": 20, "alpha": 0.8}),
        ScatterOffDiagOptions(mpl_kwargs={"cmap": 'Blues_r', "alpha": 0.8, "colors": None}),
    ],
    lower_kwargs=KdeOffDiagOptions(mpl_kwargs={"cmap": "Blues_r"}),
    labels=[r"$\theta_1$", r"$\theta_2$", r"$\theta_3$"],
)
../_images/5c386f0aa22abcc3cc328c4ee92114a8e14dacee8dfe549e75b222c4f6f61650.png

Adding observed data#

We can also add points, e.g., our observed data \(\theta_o\) to the plot.

# fake observed data:

theta_o = torch.ones(1, 3)

_ = pairplot(
    [posterior_samples, posterior_samples],
    limits=[[-3, 3] * 3],
    figsize=(5, 5),
    diag=["hist", None],
    upper=["scatter", "contour"],
    diag_kwargs=HistDiagOptions(
        mpl_kwargs={
            "bins": 10,
            "color": 'tab:blue',
            "edgecolor": 'white',
            "linewidth": 1,
            "alpha": 0.6,
            "histtype": "bar",
            "fill": True,
        }
    ),
    upper_kwargs=[
        ScatterOffDiagOptions(mpl_kwargs={"color": 'tab:blue', "s": 20, "alpha": 0.8}),
        ScatterOffDiagOptions(mpl_kwargs={"cmap": 'Blues_r', "alpha": 0.8, "colors": None}),
    ],
    labels=[r"$\theta_1$", r"$\theta_2$", r"$\theta_3$"],
    points=theta_o,
    fig_kwargs=FigOptions(
        points_labels=[r"$\theta_o$"],
        legend=True,
        points_colors=["purple"],
        points_offdiag=dict(marker="+", markersize=20),
        despine=dict(offset=0),
    ),
)
../_images/6707d5a554c951505881bb81fc4945a089ef47f34616c3ccd8db51d031a54c1f.png

Subsetting the plot#

For high-dimensional posteriors, we might only want to visualise a subset of the marginals. This can be done by passing a list of entries to plot to the subset argument of the pairplot function.

_ = pairplot(
    [posterior_samples, posterior_samples],
    limits=[[-3, 3] * 3],
    figsize=(5, 5),
    subset=[0, 2],
    diag=["hist", None],
    upper=["scatter", "contour"],
    diag_kwargs=HistDiagOptions(
        mpl_kwargs={
            "bins": 10,
            "color": 'tab:blue',
            "edgecolor": 'white',
            "linewidth": 1,
            "alpha": 0.6,
            "histtype": "bar",
            "fill": True,
        }
    ),
    upper_kwargs=[
        ScatterOffDiagOptions(mpl_kwargs={"color": 'tab:blue', "s": 20, "alpha": 0.8}),
        ScatterOffDiagOptions(mpl_kwargs={"cmap": 'Blues_r', "alpha": 0.8, "colors": None}),
    ],
    labels=[r"$\theta_1$", r"$\theta_2$", r"$\theta_3$"],
    points=theta_o,
    fig_kwargs=FigOptions(
        points_labels=[r"$\theta_o$"],
        legend=True,
        points_colors=["purple"],
        points_offdiag=dict(marker="+", markersize=20),
        despine=dict(offset=0),
    ),
)
../_images/cbba2b137ce7e280961454e3f833497bc9ddb283c03a714ee566a0501a40aea3.png

Plot just the marginals#

1D Marginals can also be visualised using the marginal_plot function

from sbi.analysis import marginal_plot

# plot posterior samples
_ = marginal_plot(
    [posterior_samples, posterior_samples],
    limits=[[-3, 3] * 3],
    subset=[0, 1],
    diag=["hist", None],
    diag_kwargs=HistDiagOptions(
        mpl_kwargs={
            "bins": 10,
            "color": 'tab:blue',
            "edgecolor": 'white',
            "linewidth": 1,
            "alpha": 0.6,
            "histtype": "bar",
            "fill": True,
        }
    ),
    labels=[r"$\theta_1$", r"$\theta_2$", r"$\theta_3$"],
    points=[torch.ones(1, 3)],
    figsize=(4, 2),
    fig_kwargs=FigOptions(
        points_labels=[r"$\theta_o$"],
        legend=True,
        points_colors=["purple"],
        points_offdiag=dict(marker="+", markersize=20),
        despine=dict(offset=0),
    ),
)
../_images/95085b193209e51517e3cbd444c169c03a7e57b05e2027a6127179312f8b40fd.png