# Copyright 2023, MASSACHUSETTS INSTITUTE OF TECHNOLOGY
# Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014).
# SPDX-License-Identifier: MIT
from abc import ABCMeta, abstractmethod
from functools import wraps
from typing import Callable, Dict, Iterable, TypeVar, Union, cast
from weakref import WeakSet
import torch as tr
from .itertools import flatten_params
T = TypeVar("T", bound=Callable)
NoneType = type(None)
[docs]
def freeze(
    *items: Union[
        tr.Tensor,
        tr.nn.Module,
        tr.optim.Optimizer,
        Iterable[tr.Tensor],
        Iterable[Dict[str, Iterable[tr.Tensor]]],
    ]
) -> Callable[[], None]:
    """'Freezes' collections of tensors by setting `requires_grad=False`.
    Returns a callable that, when called, restores the state of the tensors.
    Parameters
    ----------
    *items: tr.Tensor | tr.nn.Module | tr.optim.Optimizer | Iterable[tr.Tensor] | Iterable[Dict[str, Iterable[tr.Tensor]]]
        Tensors, modules, optimizers, or param-groups. All tensors/parameters must
        be leaf tensors [1]_ .
    Returns
    -------
    unfreeze : Callable[[], None]
        Can be called without any input to restore the states of the frozen tensors.
    Notes
    -----
    'Unfreezing' the tensors restores their original states faithfully.
    References
    ----------
    .. [1]  https://pytorch.org/docs/stable/generated/torch.Tensor.is_leaf.html
    Examples
    --------
    >>> import torch as tr
    >>> from rai_toolbox.utils import freeze
    Basic behavior
    >>> x = tr.tensor(1.0, requires_grad=True)
    >>> unfreeze = freeze(x)
    >>> x.requires_grad
    False
    >>> unfreeze()
    >>> x.requires_grad
    True
    Freezing a module
    >>> from torch.nn import Linear
    >>> m = Linear(2, 3)
    >>> m.weight.requires_grad, m.bias.requires_grad
    (True, True)
    >>> unfreeze = freeze(m)
    >>> m.weight.requires_grad, m.bias.requires_grad
    (False, False)
    >>> unfreeze()
    >>> m.weight.requires_grad, m.bias.requires_grad
    (True, True)
    """
    seen = {True: WeakSet(), False: WeakSet()}
    for item in items:
        if isinstance(item, tr.nn.Module):
            item = item.parameters()
        elif isinstance(item, tr.optim.Optimizer):
            item = item.param_groups
        for param in flatten_params(item):
            seen[param.requires_grad].add(param)
    for param in seen[True]:
        param.requires_grad_(False)
    def restore_state():
        for item in (True, False):
            for p in seen[item]:
                p.requires_grad_(item)
    return restore_state 
class ContextDecorator(metaclass=ABCMeta):
    @abstractmethod
    def __enter__(self):  # pragma: no cover
        raise NotImplementedError()
    @abstractmethod
    def __exit__(self, type, value, traceback):  # pragma: no cover
        raise NotImplementedError()
    def __call__(self, func: T) -> T:
        @wraps(func)
        def wrapper(*args, **kwargs):
            with self:
                return func(*args, **kwargs)
        return cast(T, wrapper)
[docs]
class frozen(ContextDecorator):
    """A context manager/decorator for 'freezing' collections of tensors; i.e.
    `requires_grad` is set to `False` for the tensors during the context."""
[docs]
    def __init__(
        self,
        *items: Union[
            tr.Tensor,
            tr.nn.Module,
            tr.optim.Optimizer,
            Iterable[tr.Tensor],
            Iterable[Dict[str, Iterable[tr.Tensor]]],
        ],
    ) -> None:
        """
        Parameters
        ----------
        *items: tr.Tensor | tr.nn.Module | tr.optim.Optimizer | Iterable[tr.Tensor] | Iterable[Dict[str, Iterable[tr.Tensor]]]
            Tensors, modules, optimizers, or param-groups to be frozen. All tensors/
            parameters must be leaf tensors [1]_ .
        References
        ----------
        .. [1]  https://pytorch.org/docs/stable/generated/torch.Tensor.is_leaf.html
        Examples
        --------
        >>> import torch as tr
        >>> from rai_toolbox.utils._implementations import frozen
        Demonstrating `frozen` as a context manager.
        >>> x = tr.tensor(1.0, requires_grad=True)
        >>> with frozen(x):
        ...    print(x.requires_grad)
        False
        >>> x.requires_grad
        True
        Demonstrating `frozen` as a decorator.
        >>> x = tr.tensor(1.0, requires_grad=True)
        >>> @frozen(x)
        ... def f():
        ...     print("hello world")
        ...     return x.requires_grad
        >>> x.requires_grad # x isn't frozen until f is called
        True
        >>> f()
        hello world
        False
        >>> x.requires_grad
        True
        """
        self._items = items 
    def __enter__(self) -> None:
        self._unfreeze = freeze(*self._items)
    def __exit__(self, type, value, traceback) -> None:
        self._unfreeze() 
[docs]
class evaluating(ContextDecorator):
    """A context manager / decorator that temporarily places one
    or more modules in eval mode during the context."""
[docs]
    def __init__(self, *modules: tr.nn.Module) -> None:
        """
        Parameters
        ----------
        *modules: tr.nn.Module
        Notes
        -----
        A module's state is restored faithfully; e.g., a module that
        was already in eval mode will not be placed in train mode upon
        leaving the `evaluating` context.
        Examples
        --------
        >>> from torch.nn import Linear
        >>> from rai_toolbox import evaluating
        Using `evaluating` as a context manager.
        >>> module = Linear(1, 1)
        >>> module.training
        True
        >>> with evaluating(module):
        ...     print(module.training)
        False
        >>> module.training
        True
        Using `evaluating` as a decorator.
        >>> def f():
        ...     print("hello world")
        ...     return module.training
        >>> f = evaluating(module)(f)
        >>> module.training
        True
        >>> f()
        hello world
        False
        >>> module.training
        True
        """
        self._states: Dict[bool, WeakSet[tr.nn.Module]] = {
            True: WeakSet(),
            False: WeakSet(),
        }
        self._states[True].update(m for m in modules if m.training)
        self._states[False].update(m for m in modules if not m.training) 
    def __enter__(self) -> None:
        for train_status in self._states:
            for m in self._states[train_status]:
                m.eval()
    def __exit__(self, type, value, traceback) -> None:
        for train_status in self._states:
            for module in self._states[train_status]:
                module.train(train_status)