nessai.flows.utils

Various utilities for implementing normalising flows.

Module Contents

Classes

MLP

MLP which can be called with context.

Functions

silu(x)

SiLU (Sigmoid-weighted Linear Unit) activation function.

configure_model(config)

Setup the flow form a configuration dictionary.

reset_weights(module)

Reset parameters of a given module in place.

reset_permutations(module)

Resets permutations and linear transforms for a given module in place.

create_linear_transform(linear_transform, features)

Function for creating linear transforms.

Attributes

logger

nessai.flows.utils.logger
nessai.flows.utils.silu(x)

SiLU (Sigmoid-weighted Linear Unit) activation function.

Also known as swish.

Elfwing et al 2017: https://arxiv.org/abs/1702.03118v3

nessai.flows.utils.configure_model(config)

Setup the flow form a configuration dictionary.

nessai.flows.utils.reset_weights(module)

Reset parameters of a given module in place.

Uses the reset_parameters method from torch.nn.Module

Also checks the following modules from nflows

  • nflows.transforms.normalization.BatchNorm

Parameters
moduletorch.nn.Module

Module to reset

nessai.flows.utils.reset_permutations(module)

Resets permutations and linear transforms for a given module in place.

Resets using the original initialisation method. This needed since they do not have a reset_parameters method.

Parameters
moduletorch.nn.Module

Module to reset

class nessai.flows.utils.MLP(in_shape, out_shape, hidden_sizes, activation=F.relu, activate_output=False)

Bases: nflows.nn.nets.MLP

MLP which can be called with context.

forward(self, inputs, context=None)

Forward method that allows for kwargs such as context.

Parameters
inputstorch.tensor

Inputs to the MLP

contextNone

Conditional inputs, must be None. Only implemented to the function is compatible with other methods.

Raises
RuntimeError

If the context is not None.

nessai.flows.utils.create_linear_transform(linear_transform, features)

Function for creating linear transforms.

Parameters
linear_transform{‘permutation’, ‘lu’, ‘svd’}

Linear transform to use.

featresint

Number of features.