General Utilities#

PyTorch Utilities#

We provide general, PyTorch-centric utilities that help to reduce boilerplate code and to ensure correctness when using the rAI-toolbox.

to_batch(p, param_ndim)

Returns a view of p, reshaped as shape-(N, d0, ...) where (d0, ...) has param_ndim entries.

evaluating(*modules)

A context manager / decorator that temporarily places one or more modules in eval mode during the context.

freeze(*items)

'Freezes' collections of tensors by setting requires_grad=False.

frozen(*items)

A context manager/decorator for 'freezing' collections of tensors; i.e. requires_grad is set to False for the tensors during the context.

negate(func)

A wrapper that negates (applies the - operator) to the function's output.