rai_toolbox.freeze#
- rai_toolbox.freeze(*items)[source]#
‘Freezes’ collections of tensors by setting
requires_grad=False
. Returns a callable that, when called, restores the state of the tensors.- Parameters:
- *items: tr.Tensor | tr.nn.Module | tr.optim.Optimizer | Iterable[tr.Tensor] | Iterable[Dict[str, Iterable[tr.Tensor]]]
Tensors, modules, optimizers, or param-groups. All tensors/parameters must be leaf tensors [1] .
- Returns:
- unfreezeCallable[[], None]
Can be called without any input to restore the states of the frozen tensors.
Notes
‘Unfreezing’ the tensors restores their original states faithfully.
References
Examples
>>> import torch as tr >>> from rai_toolbox.utils import freeze
Basic behavior
>>> x = tr.tensor(1.0, requires_grad=True) >>> unfreeze = freeze(x) >>> x.requires_grad False >>> unfreeze() >>> x.requires_grad True
Freezing a module
>>> from torch.nn import Linear >>> m = Linear(2, 3) >>> m.weight.requires_grad, m.bias.requires_grad (True, True) >>> unfreeze = freeze(m) >>> m.weight.requires_grad, m.bias.requires_grad (False, False) >>> unfreeze() >>> m.weight.requires_grad, m.bias.requires_grad (True, True)