# Copyright 2023, MASSACHUSETTS INSTITUTE OF TECHNOLOGY
# Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014).
# SPDX-License-Identifier: MIT
from numbers import Real
from typing import Any, Iterable, Mapping, Optional, Tuple, TypeVar, Union, cast
import torch as tr
T = TypeVar("T", bound=Any)
class Unsatisfiable(AssertionError):  # pragma: no cover
    pass
def get_device(obj: Union[tr.nn.Module, tr.Tensor]) -> tr.device:
    if isinstance(obj, tr.nn.Module):
        for p in obj.parameters():
            return p.device
        return tr.device("cpu")
    elif isinstance(obj, tr.Tensor):
        return obj.device
    else:  # pragma: no cover
        raise TypeError(f"Expected torch.nn.Module or torch.Tensor, got {obj}")
def _safe_name(x: Any) -> str:
    return getattr(x, "__name__", str(x))
def value_check(
    name: str,
    value: T,
    *,
    type_: Union[type, Tuple[type, ...]] = Real,
    min_: Optional[Union[int, float]] = None,
    max_: Optional[Union[int, float]] = None,
    incl_min: bool = True,
    incl_max: bool = True,
    optional: bool = False,
    lower_name: str = "",
    upper_name: str = "",
) -> T:
    """
    For internal use only.
    Used to check the type of `value`. Numerical types can also be bound-checked.
    Examples
    --------
    >>> value_check("x", 1, type_=str)
    TypeError: `x` must be of type(s) `str`, got 1 (type: int)
    >>> value_check("x", 1, min_=20)
    ValueError: `x` must satisfy 20 <= x  Got: 1
    >>> value_check("x", 1, min_=1, incl_min=False)
    ValueError: `x` must satisfy 1 < x  Got: 1
    >>> value_check("x", 1, min_=1, incl_min=True) # ok
    1
    >>> value_check("x", 0.0, min_=-10, max_=10)  # ok
    0.0
    Raises
    ------
    TypeError, ValueError"""
    # check internal params
    assert isinstance(name, str), name
    assert min_ is None or isinstance(min_, (int, float)), min_
    assert max_ is None or isinstance(max_, (int, float)), max_
    assert isinstance(incl_min, bool), incl_min
    assert isinstance(incl_max, bool), incl_max
    if optional and value is None:
        return value
    if not isinstance(value, type_):
        raise TypeError(
            f"`{name}` must be {'None or' if optional else ''}of type(s) "
            f"`{_safe_name(type_)}`, got {value} (type: {_safe_name(type(value))})"
        )
    if min_ is not None and max_ is not None:
        if incl_max and incl_min:
            if not (min_ <= max_):
                raise Unsatisfiable(f"{min_} <= {max_}")
        elif not min_ < max_:
            raise Unsatisfiable(f"{min_} < {max_}")
    min_satisfied = (
        (min_ <= value if incl_min else min_ < value) if min_ is not None else True
    )
    max_satisfied = (
        (value <= max_ if incl_max else value < max_) if max_ is not None else True
    )
    if not min_satisfied or not max_satisfied:
        lsymb = "<=" if incl_min else "<"
        rsymb = "<=" if incl_max else "<"
        err_msg = f"`{name}` must satisfy"
        if min_ is not None:
            if lower_name:  # pragma: no cover
                min_ = f"{lower_name}(= {min_})"  # type: ignore
            err_msg += f" {min_} {lsymb}"
        err_msg += f" {name}"
        if max_ is not None:
            if upper_name:
                max_ = f"{upper_name}(= {max_})"  # type: ignore
            err_msg += f" {rsymb} {max_}"
        err_msg += f"  Got: {value}"
        raise ValueError(err_msg)
    return cast(T, value)
def check_param_group_value(
    name: str,
    param_groups: Iterable[Mapping[str, Any]],
    *,
    type_: Union[type, Tuple[type, ...]] = Real,
    min_: Optional[Union[int, float]] = None,
    max_: Optional[Union[int, float]] = None,
    incl_min: bool = True,
    incl_max: bool = True,
    optional: bool = False,
) -> None:
    for group in param_groups:
        value_check(
            name,
            group[name],
            type_=type_,
            max_=max_,
            min_=min_,
            incl_min=incl_min,
            incl_max=incl_max,
            optional=optional,
        )
def _reshape_to_batch(x: tr.Tensor, param_ndim: Optional[int]) -> tr.Tensor:
    """Reshapes to a shape-`(N, d1, ..., dm)`, where `(d1, ..., dm)` has `param_ndim`
    dimensions. Dimensions will be added or consolidated to achieve this."""
    if param_ndim is None:
        param_ndim = x.ndim
    if param_ndim < 0:
        param_ndim += x.ndim
    # `1 + param_ndim` is the required dimensionality
    # for a shape-(N, d0, d1, ...) tensor, where (d0, d1, ...)
    # is the shape of the param_ndim-dimension tensor.
    #
    # We compute `ndim_delta` to determine if we need to add
    # or consolidate dimensions to create the shape-(N, d0, d1, ...)
    # tensor.
    ndim_delta = (1 + param_ndim) - x.ndim
    if ndim_delta > 0:
        # E.g.:
        #   p.shape: (d0, )
        #   desired shape: (N=1, d0)
        x = x[ndim_delta * (None,)]
    elif ndim_delta < 0:
        # E.g.:
        #   p.shape: (d0, d1, d2, d3)
        #   desired shape: (N=d0*d1, d2, d3)
        x = x.view(-1, *x.shape[x.ndim - param_ndim :])
    if x.ndim < 2:  # make at least 2D
        # (N,) -> (N, 1)
        x = x.view(*x.shape, *(1,) * (2 - x.ndim))
    return x
[docs]
def to_batch(p: tr.Tensor, param_ndim: Optional[int]) -> tr.Tensor:
    """
    Returns a view of `p`, reshaped as shape-(N, d0, ...) where (d0, ...)
    has `param_ndim` entries.
    See Parameters for further description
    Parameters
    ----------
    p : Tensor
    param_ndim: Optional[int]
        Determines the shape of the resulting parameter
        - A positive number determines the dimensionality of the tensor that the transformation will act on.
        - A negative number indicates the 'offset' from the dimensionality of the tensor.
        - `None` means that the transformation will be applied to the tensor without any broadcasting. This is equivalent to `param_ndim=p.ndim`
    Returns
    -------
    reshaped_p: Tensor, shape-(N, d0, ...)
        Where
        - (d0, ...) is of length `param_ndim` for `param_ndim > 0`
        - (d0, ...) is (1,) for `param_ndim == 0`
        - (d0, ...) is of length `p.ndim - |param_ndim|` for `param_ndim < 0`
    Examples
    --------
    >>> import torch as tr
    >>> x = tr.rand((3, 5, 2))
    >>> to_batch(x, param_ndim=0).shape
    torch.Size([30, 1])
    >>> to_batch(x, param_ndim=1).shape
    torch.Size([15, 2])
    >>> to_batch(x, param_ndim=2).shape
    torch.Size([3, 5, 2])
    >>> to_batch(x, param_ndim=3).shape
    torch.Size([1, 3, 5, 2])
    >>> to_batch(x, param_ndim=None).shape  # same as `param_ndim=x.ndim`
    torch.Size([1, 3, 5, 2])
    >>> to_batch(x, param_ndim=-1).shape
    torch.Size([3, 5, 2])
    >>> to_batch(x, param_ndim=-2).shape
    torch.Size([15, 2])
    >>> to_batch(x, param_ndim=-3).shape
    torch.Size([30, 1])
    """
    # atleast_2d needed for case where p was scalar
    vp = _reshape_to_batch(p, param_ndim=param_ndim)
    if p.grad is not None:
        vp.grad = _reshape_to_batch(p.grad, param_ndim=param_ndim)
    # vp (vp.grad) must be a view of p (p.grad). There is
    # not a simple way to assert this.
    # our views must be size-preserving
    assert tr.numel(vp) == tr.numel(p)
    return vp 
def validate_param_ndim(param_ndim: Optional[int], p: tr.Tensor) -> None:
    if param_ndim is not None and p.ndim < abs(param_ndim):
        raise ValueError(
            f"`param_ndim={param_ndim}` specified for parameter "
            f"with ndim={p.ndim} is not valid. `abs(param_ndim) <= "
            f"ndim` must hold."
        )