rai_toolbox.frozen#
- class rai_toolbox.frozen(*items)[source]#
A context manager/decorator for ‘freezing’ collections of tensors; i.e.
requires_grad
is set toFalse
for the tensors during the context.- __init__(*items)[source]#
- Parameters:
- *items: tr.Tensor | tr.nn.Module | tr.optim.Optimizer | Iterable[tr.Tensor] | Iterable[Dict[str, Iterable[tr.Tensor]]]
Tensors, modules, optimizers, or param-groups to be frozen. All tensors/ parameters must be leaf tensors [1] .
References
Examples
>>> import torch as tr >>> from rai_toolbox.utils._implementations import frozen
Demonstrating
frozen
as a context manager.>>> x = tr.tensor(1.0, requires_grad=True) >>> with frozen(x): ... print(x.requires_grad) False >>> x.requires_grad True
Demonstrating
frozen
as a decorator.>>> x = tr.tensor(1.0, requires_grad=True) >>> @frozen(x) ... def f(): ... print("hello world") ... return x.requires_grad >>> x.requires_grad # x isn't frozen until f is called True >>> f() hello world False >>> x.requires_grad True
Methods
__init__
(*items)- Parameters: