VIPosteriorParameters#
- class VIPosteriorParameters(q='maf', vi_method='rKL', num_transforms=5, hidden_features=50, z_score_theta='independent', z_score_x='independent')[source]#
Bases:
PosteriorParametersParameters for VIPosterior, supporting both single-x and amortized VI.
- Fields:
- q: Variational distribution. Either a string specifying the flow type
[maf, nsf, naf, unaf, nice, sospf, gf, gaussian, gaussian_diag], a Distribution, a VIPosterior object, or a Callable builder function. For amortized VI, only string flow types are supported. If q is already a VIPosterior, arguments are copied from it (relevant for multi-round training). Note: For 1D problems, prefer “gf” (mixture of Gaussians) or “gaussian” as autoregressive flows may be unstable.
- vi_method: Variational method for fitting q to the posterior. Options:
[rKL, fKL, IW, alpha]. Some are “mode seeking” (rKL, alpha > 1) and some are “mass covering” (fKL, IW, alpha < 1). Currently only used for single-x VI; amortized VI uses ELBO (rKL).
- num_transforms: Number of transforms in the normalizing flow. Used for
both single-x VI (via set_q/train) and amortized VI (via train_amortized).
- hidden_features: Hidden layer size in the flow networks. Used for both
single-x VI and amortized VI.
- z_score_theta: Method for z-scoring θ (the parameters being sampled).
One of “none”, “independent”, “structured”. Used for both single-x VI and amortized VI. Use “structured” for parameters with correlations.
- z_score_x: Method for z-scoring x (the conditioning observation).
One of “none”, “independent”, “structured”. Only used for amortized VI (train_amortized). Use “structured” for structured data like images.
Note
For custom distributions that lack parameters() and modules() methods, pass these via VIPosterior.set_q(q, parameters=…, modules=…) instead.
- Parameters:
q (Literal['maf', 'nsf', 'naf', 'unaf', 'nice', 'sospf', 'gf', 'gaussian', 'gaussian_diag'] | ~torch.distributions.distribution.Distribution | ~sbi.inference.posteriors.vi_posterior.VIPosterior | ~typing.Callable)
vi_method (Literal['rKL', 'fKL', 'IW', 'alpha'])
num_transforms (int)
hidden_features (int)
z_score_theta (Literal['none', 'independent', 'structured'])
z_score_x (Literal['none', 'independent', 'structured'])
- q: Literal['maf', 'nsf', 'naf', 'unaf', 'nice', 'sospf', 'gf', 'gaussian', 'gaussian_diag'] | Distribution | VIPosterior | Callable = 'maf'#
- with_param(**kwargs)#
Create a new instance of the class with updated field values.
Only allows updates to fields defined in the dataclass. Raises an error if any unknown or invalid field names are passed.
- Parameters:
**kwargs – Field-value pairs to override in the new instance.
- Returns:
A new instance of the same class with updated values.
- Raises:
ValueError – If any of the provided keys are not valid dataclass fields.