rai_toolbox.freeze#

rai_toolbox.freeze(*items)[source]#

‘Freezes’ collections of tensors by setting requires_grad=False. Returns a callable that, when called, restores the state of the tensors.

Parameters:
*items: tr.Tensor | tr.nn.Module | tr.optim.Optimizer | Iterable[tr.Tensor] | Iterable[Dict[str, Iterable[tr.Tensor]]]

Tensors, modules, optimizers, or param-groups. All tensors/parameters must be leaf tensors [1] .

Returns:
unfreezeCallable[[], None]

Can be called without any input to restore the states of the frozen tensors.

Notes

‘Unfreezing’ the tensors restores their original states faithfully.

References

Examples

>>> import torch as tr
>>> from rai_toolbox.utils import freeze

Basic behavior

>>> x = tr.tensor(1.0, requires_grad=True)
>>> unfreeze = freeze(x)
>>> x.requires_grad
False
>>> unfreeze()
>>> x.requires_grad
True

Freezing a module

>>> from torch.nn import Linear
>>> m = Linear(2, 3)
>>> m.weight.requires_grad, m.bias.requires_grad
(True, True)
>>> unfreeze = freeze(m)
>>> m.weight.requires_grad, m.bias.requires_grad
(False, False)
>>> unfreeze()
>>> m.weight.requires_grad, m.bias.requires_grad
(True, True)