sbi.neural_nets.posterior_flow_nn

Navigation

sbi.neural_nets.posterior_flow_nn#

posterior_flow_nn(model='mlp', z_score_theta='independent', z_score_x='independent', hidden_features=100, num_layers=5, embedding_net=Identity(), time_emb_type='sinusoidal', t_embedding_dim=32, gaussian_baseline=False, **kwargs)[source]#

Build util function that builds a FlowMatchingEstimator object for flow-based posteriors.

Parameters:
  • model (Literal['mlp', 'ada_mlp', 'transformer', 'transformer_cross_attn'] | ~sbi.utils.vector_field_utils.VectorFieldNet) –

    Type of regression network. One of: - ‘mlp’: Fully connected feed-forward network. - ‘ada_mlp’: Fully connected feed-forward with adaptive

    layer normalization for conditioning.

    • ’transformer’: Transformer network.

    • ’transformer_cross_attention’: Transformer with cross-attention.

    • nn.Module: Custom network

    Defaults to ‘mlp’.

  • z_score_theta (Literal['independent', 'structured', 'transform_to_unconstrained', 'none'] | None) – Whether to z-score theta for time-dependent normalization. This enables time-dependent z-scoring which helps FMPE learn when theta is far from N(0,1). Defaults to ‘independent’.

  • z_score_x (Literal['independent', 'structured', 'transform_to_unconstrained', 'none'] | None) – Whether to z-score observations (x) before passing to the embedding network. Defaults to ‘independent’.

  • hidden_features (int) – Number of hidden units per layer. Defaults to 100.

  • num_layers (int) – Number of hidden layers. Defaults to 5.

  • embedding_net (Module) – Embedding network for x (conditioning variable). Defaults to nn.Identity().

  • time_emb_type (Literal['sinusoidal', 'fourier']) – Type of time embedding. Defaults to ‘sinusoidal’.

  • t_embedding_dim (int) – Embedding dimension of diffusion time. Defaults to 32.

  • gaussian_baseline (bool) – If True, use analytical Gaussian baseline velocity derived from Bayes’ rule. The network then only learns the residual. Defaults to False.

  • **kwargs (Any) – Additional estimator / network arguments. Valid keys are defined by FlowEstimatorConfig; unknown keys raise TypeError.

Returns:

Constructor function for FMPE.

Return type:

Callable