VectorFieldPosteriorParameters#

class VectorFieldPosteriorParameters(max_sampling_batch_size=10000, enable_transform=True, iid_method='auto_gauss', iid_params=None, neural_ode_backend='zuko', neural_ode_kwargs=None)[source]#

Bases: PosteriorParameters

Parameters for initializing VectorFieldPosterior.

Fields:
max_sampling_batch_size: Batchsize of samples being drawn from

the proposal at every iteration.

enable_transform: Whether to transform parameters to unconstrained space

during MAP optimization. When False, an identity transform will be returned for theta_transform. True is not supported yet.

iid_method: Which method to use for computing the score in the iid setting.

We currently support “fnpe”, “gauss”, “auto_gauss”, “jac_gauss”.

iid_params: Parameters for the iid method, for arguments see

IIDScoreFunction.

neural_ode_backend: The backend to use for the neural ODE. Currently,

only “zuko” is supported.

neural_ode_kwargs: Additional keyword arguments for the neural ODE.

Parameters:
  • max_sampling_batch_size (int)

  • enable_transform (bool)

  • iid_method (Literal['fnpe', 'gauss', 'auto_gauss', 'jac_gauss'])

  • iid_params (Dict[str, Any] | None)

  • neural_ode_backend (Literal['zuko'])

  • neural_ode_kwargs (Dict[str, Any] | None)

max_sampling_batch_size: int = 10000#
enable_transform: bool = True#
iid_method: Literal['fnpe', 'gauss', 'auto_gauss', 'jac_gauss'] = 'auto_gauss'#
iid_params: Dict[str, Any] | None = None#
neural_ode_backend: Literal['zuko'] = 'zuko'#
neural_ode_kwargs: Dict[str, Any] | None = None#
validate()[source]#

Validate VectorFieldPosteriorParameters fields.

with_param(**kwargs)#

Create a new instance of the class with updated field values.

Only allows updates to fields defined in the dataclass. Raises an error if any unknown or invalid field names are passed.

Parameters:

**kwargs – Field-value pairs to override in the new instance.

Returns:

A new instance of the same class with updated values.

Raises:

ValueError – If any of the provided keys are not valid dataclass fields.