rai_toolbox.frozen#

class rai_toolbox.frozen(*items)[source]#

A context manager/decorator for ‘freezing’ collections of tensors; i.e. requires_grad is set to False for the tensors during the context.

__init__(*items)[source]#
Parameters:
*items: tr.Tensor | tr.nn.Module | tr.optim.Optimizer | Iterable[tr.Tensor] | Iterable[Dict[str, Iterable[tr.Tensor]]]

Tensors, modules, optimizers, or param-groups to be frozen. All tensors/ parameters must be leaf tensors [1] .

References

Examples

>>> import torch as tr
>>> from rai_toolbox.utils._implementations import frozen

Demonstrating frozen as a context manager.

>>> x = tr.tensor(1.0, requires_grad=True)
>>> with frozen(x):
...    print(x.requires_grad)
False
>>> x.requires_grad
True

Demonstrating frozen as a decorator.

>>> x = tr.tensor(1.0, requires_grad=True)
>>> @frozen(x)
... def f():
...     print("hello world")
...     return x.requires_grad
>>> x.requires_grad # x isn't frozen until f is called
True
>>> f()
hello world
False
>>> x.requires_grad
True

Methods

__init__(*items)

Parameters: