Source code for rai_toolbox._utils.stateful

# Copyright 2023, MASSACHUSETTS INSTITUTE OF TECHNOLOGY
# Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014).
# SPDX-License-Identifier: MIT

from abc import ABCMeta, abstractmethod
from functools import wraps
from typing import Callable, Dict, Iterable, TypeVar, Union, cast
from weakref import WeakSet

import torch as tr

from .itertools import flatten_params

T = TypeVar("T", bound=Callable)
NoneType = type(None)


[docs] def freeze( *items: Union[ tr.Tensor, tr.nn.Module, tr.optim.Optimizer, Iterable[tr.Tensor], Iterable[Dict[str, Iterable[tr.Tensor]]], ] ) -> Callable[[], None]: """'Freezes' collections of tensors by setting `requires_grad=False`. Returns a callable that, when called, restores the state of the tensors. Parameters ---------- *items: tr.Tensor | tr.nn.Module | tr.optim.Optimizer | Iterable[tr.Tensor] | Iterable[Dict[str, Iterable[tr.Tensor]]] Tensors, modules, optimizers, or param-groups. All tensors/parameters must be leaf tensors [1]_ . Returns ------- unfreeze : Callable[[], None] Can be called without any input to restore the states of the frozen tensors. Notes ----- 'Unfreezing' the tensors restores their original states faithfully. References ---------- .. [1] https://pytorch.org/docs/stable/generated/torch.Tensor.is_leaf.html Examples -------- >>> import torch as tr >>> from rai_toolbox.utils import freeze Basic behavior >>> x = tr.tensor(1.0, requires_grad=True) >>> unfreeze = freeze(x) >>> x.requires_grad False >>> unfreeze() >>> x.requires_grad True Freezing a module >>> from torch.nn import Linear >>> m = Linear(2, 3) >>> m.weight.requires_grad, m.bias.requires_grad (True, True) >>> unfreeze = freeze(m) >>> m.weight.requires_grad, m.bias.requires_grad (False, False) >>> unfreeze() >>> m.weight.requires_grad, m.bias.requires_grad (True, True) """ seen = {True: WeakSet(), False: WeakSet()} for item in items: if isinstance(item, tr.nn.Module): item = item.parameters() elif isinstance(item, tr.optim.Optimizer): item = item.param_groups for param in flatten_params(item): seen[param.requires_grad].add(param) for param in seen[True]: param.requires_grad_(False) def restore_state(): for item in (True, False): for p in seen[item]: p.requires_grad_(item) return restore_state
class ContextDecorator(metaclass=ABCMeta): @abstractmethod def __enter__(self): # pragma: no cover raise NotImplementedError() @abstractmethod def __exit__(self, type, value, traceback): # pragma: no cover raise NotImplementedError() def __call__(self, func: T) -> T: @wraps(func) def wrapper(*args, **kwargs): with self: return func(*args, **kwargs) return cast(T, wrapper)
[docs] class frozen(ContextDecorator): """A context manager/decorator for 'freezing' collections of tensors; i.e. `requires_grad` is set to `False` for the tensors during the context."""
[docs] def __init__( self, *items: Union[ tr.Tensor, tr.nn.Module, tr.optim.Optimizer, Iterable[tr.Tensor], Iterable[Dict[str, Iterable[tr.Tensor]]], ], ) -> None: """ Parameters ---------- *items: tr.Tensor | tr.nn.Module | tr.optim.Optimizer | Iterable[tr.Tensor] | Iterable[Dict[str, Iterable[tr.Tensor]]] Tensors, modules, optimizers, or param-groups to be frozen. All tensors/ parameters must be leaf tensors [1]_ . References ---------- .. [1] https://pytorch.org/docs/stable/generated/torch.Tensor.is_leaf.html Examples -------- >>> import torch as tr >>> from rai_toolbox.utils._implementations import frozen Demonstrating `frozen` as a context manager. >>> x = tr.tensor(1.0, requires_grad=True) >>> with frozen(x): ... print(x.requires_grad) False >>> x.requires_grad True Demonstrating `frozen` as a decorator. >>> x = tr.tensor(1.0, requires_grad=True) >>> @frozen(x) ... def f(): ... print("hello world") ... return x.requires_grad >>> x.requires_grad # x isn't frozen until f is called True >>> f() hello world False >>> x.requires_grad True """ self._items = items
def __enter__(self) -> None: self._unfreeze = freeze(*self._items) def __exit__(self, type, value, traceback) -> None: self._unfreeze()
[docs] class evaluating(ContextDecorator): """A context manager / decorator that temporarily places one or more modules in eval mode during the context."""
[docs] def __init__(self, *modules: tr.nn.Module) -> None: """ Parameters ---------- *modules: tr.nn.Module Notes ----- A module's state is restored faithfully; e.g., a module that was already in eval mode will not be placed in train mode upon leaving the `evaluating` context. Examples -------- >>> from torch.nn import Linear >>> from rai_toolbox import evaluating Using `evaluating` as a context manager. >>> module = Linear(1, 1) >>> module.training True >>> with evaluating(module): ... print(module.training) False >>> module.training True Using `evaluating` as a decorator. >>> def f(): ... print("hello world") ... return module.training >>> f = evaluating(module)(f) >>> module.training True >>> f() hello world False >>> module.training True """ self._states: Dict[bool, WeakSet[tr.nn.Module]] = { True: WeakSet(), False: WeakSet(), } self._states[True].update(m for m in modules if m.training) self._states[False].update(m for m in modules if not m.training)
def __enter__(self) -> None: for train_status in self._states: for m in self._states[train_status]: m.eval() def __exit__(self, type, value, traceback) -> None: for train_status in self._states: for module in self._states[train_status]: module.train(train_status)