Source code for rai_toolbox.optim.optimizer
# Copyright 2023, MASSACHUSETTS INSTITUTE OF TECHNOLOGY
# Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014).
# SPDX-License-Identifier: MIT
import inspect
from abc import ABCMeta
from typing import Any, Callable, Dict, List, Optional, TypeVar, Union, cast, overload
import torch
from torch import Tensor
from torch.optim import SGD, Optimizer
from typing_extensions import TypedDict
from rai_toolbox._typing import (
InstantiatesTo,
Optimizer as Opt,
OptimizerType,
OptimParams,
Partial,
instantiates_to,
)
from rai_toolbox._utils import validate_param_ndim as _validate_param_ndim
_T = TypeVar("_T", bound=Optional[Union[Tensor, float]])
__all__ = ["ParamTransformingOptimizer", "ChainedParamTransformingOptimizer"]
REQUIRED: Any = inspect.signature(SGD).parameters["lr"].default
class DatumParamGroup(TypedDict):
params: List[Tensor]
param_ndim: Optional[int]
grad_scale: float
grad_bias: float
def _shares_memory(x: Tensor, y: Tensor) -> bool:
return x.storage().data_ptr() == y.storage().data_ptr()
def _reshape_to_batch(x: Tensor, param_ndim: Optional[int]) -> Tensor:
"""Reshapes to a shape-`(N, d1, ..., dm)`, where `(d1, ..., dm)` has `param_ndim`
dimensions. Dimensions will be added or consolidated to achieve this."""
if param_ndim is None:
param_ndim = x.ndim
if param_ndim < 0:
param_ndim += x.ndim
# `1 + param_ndim` is the required dimensionality
# for a shape-(N, d0, d1, ...) tensor, where (d0, d1, ...)
# is the shape of the param_ndim-dimension tensor.
#
# We compute `ndim_delta` to determine if we need to add
# or consolidate dimensions to create the shape-(N, d0, d1, ...)
# tensor.
ndim_delta = (1 + param_ndim) - x.ndim
if ndim_delta > 0:
# E.g.:
# p.shape: (d0, )
# desired shape: (N=1, d0)
x = x[ndim_delta * (None,)]
elif ndim_delta < 0:
# E.g.:
# p.shape: (d0, d1, d2, d3)
# desired shape: (N=d0*d1, d2, d3)
x = x.view(-1, *x.shape[x.ndim - param_ndim :])
if x.ndim < 2: # make at least 2D
# (N,) -> (N, 1)
x = x.view(*x.shape, *(1,) * (2 - x.ndim))
return x
def _to_batch(p: Tensor, param_ndim: Optional[int]) -> Tensor:
"""
Returns a view of `p`, reshaped as shape-(N, d0, ...) where (d0, ...)
has `param_ndim` entries.
See Parameters for further description
Parameters
----------
p : Tensor
param_ndim: Optional[int]
Determines the shape of the resulting parameter
- A positive number determines the dimensionality of the tensor that the transformation will act on.
- A negative number indicates the 'offset' from the dimensionality of the tensor.
- `None` means that the transformation will be applied to the tensor without any broadcasting.
Returns
-------
reshaped_p: Tensor
Examples
--------
>>> import torch as tr
>>> x = tr.rand((3, 5, 2))
>>> _to_batch(x, 0).shape
torch.Size([30, 1])
>>> _to_batch(x, 1).shape
torch.Size([15, 2])
>>> _to_batch(x, 2).shape
torch.Size([3, 5, 2])
>>> _to_batch(x, None).shape
torch.Size([1, 3, 5, 2])
>>> _to_batch(x, -1).shape
torch.Size([3, 5, 2])
>>> _to_batch(x, -2).shape
torch.Size([15, 2])
>>> _to_batch(x, -3).shape
torch.Size([30, 1])
"""
# atleast_2d needed for case where p was scalar
vp = _reshape_to_batch(p, param_ndim=param_ndim)
if p.grad is not None:
vp.grad = _reshape_to_batch(p.grad, param_ndim=param_ndim)
# vp (vp.grad) must be a view of p (p.grad). There is
# not a simple way to assert this.
# our views must be size-preserving
assert torch.numel(vp) == torch.numel(p)
return vp
[docs]
class ParamTransformingOptimizer(Optimizer, metaclass=ABCMeta):
r"""An optimizer that performs an in-place transformation to
each parameter, both before and after performing the gradient-based update on each
parameter via `InnerOptim.step`::
_pre_step_transform_(param)
param = InnerOptim.step(param, ...)
_post_step_transform_(param)
Note that `_pre_step_transform_` and `_post_step_transform_` can be used to update
a parameter and/or its gradient. Also, this optimizer exposes `param_ndim` as a
means of controlling how these transforms broadcast (if at all) over any given
tensor.
Notes
-----
`ParamTransformingOptimizer` mirrors state with `InnerOpt` so that their
`param_groups`, `defaults`, and `state` are always in sync.
`ParamTransformingOptimizer` is designed to be combined with other,
standard gradient-based optimizers (e.g., Adam) via composition, rather than
through inheritance. I.e., `ParamTransformingOptimizer(InnerOpt=<...>)` will apply
`_pre_step_transform_` on a parameter, and then use `InnerOpt.step(...)` to update
said parameter, and finally will apply `_post_step_transform_` to the parameter.
If a closure is supplied to the `.step(...)` method, then the `_pre_step_transform_`
is applied after the closure call and prior to the parameter steps.
Methods
-------
_pre_step_transform_
_post_step_transform_
project
See Also
--------
ChainedParamTransformingOptimizer
"""
param_groups: List[DatumParamGroup]
[docs]
def __init__(
self,
params: Optional[OptimParams] = None,
InnerOpt: Union[Opt, Partial[Opt], OptimizerType] = SGD,
*,
param_ndim: Union[int, None] = -1,
grad_scale: float = 1.0,
grad_bias: float = 0.0,
defaults: Optional[Dict[str, Any]] = None,
**inner_opt_kwargs,
) -> None:
r"""
Parameters
----------
params : Sequence[Tensor] | Iterable[ParamGroup]
Iterable of parameters to optimize or dicts defining parameter groups
InnerOpt : Type[Optimizer] | Partial[Optimizer], optional (default=`torch.nn.optim.SGD`)
The optimizer that updates the parameters after their gradients have
been transformed.
param_ndim : int | None, optional (default=-1)
Determines how a parameter and its gradient are temporarily reshaped prior
to being passed to both `_pre_step_transform_` and `_post_step_transform_`.
By default, the transformation broadcasts over the tensor's first dimension
in a batch-like style.
- A positive number determines the dimensionality of the tensor that the transformation will act on.
- A negative number indicates the 'offset' from the dimensionality of the tensor (see "Notes" for examples).
- `None` means that the transformation will be applied directly to the tensor without any broadcasting.
See "Notes" for more details.
grad_scale : float, optional (default=1.0)
Multiplies each gradient in-place after the pre-step transformation is
performed. This can be specified per param-group.
grad_bias : float, optional (default=0.0)
Added to each gradient in-place after the pre-step transformation is
performed. This can be specified per param-group.
defaults : Optional[Dict[str, Any]]
Specifies default parameters for all parameter groups.
**inner_opt_kwargs : Any
Named arguments used to initialize `InnerOpt`.
Notes
-----
.. _param-ndim-add:
**Additional Explanation of `param_ndim`**
Consider a parameter of shape `(d0, d1, d2, d4)`.
If `param_ndim=0`, then the parameter and its gradient will be temporarily
reshaped to a shape-`(d0 * d1 * d2 * d3, 1)` so that the transformation will be
applied elementwise to the tensor.
If `param_ndim=1` (or `param_ndim=-3`), then the parameter and its gradient
will be temporarily reshaped to a shape-`(d0 * d1 * d2, d3)` so that the
transformation will be broadcast over each shape-`(d3,)` sub-tensor.
If `param_ndim=2` (or `param_ndim=-2`), then the parameter and its gradient
will be temporarily reshaped to a shape-`(d0 * d1, d2, d3)` so that the
transformation will be broadcast over each shape-`(d2, d3)` sub-tensor.
If `param_ndim=3` (or `param_ndim=-1`), then the parameter and its gradient
will be temporarily reshaped to a shape-`(d0, d1, d2, d3)` so that the
transformation will be broadcast over each shape-`(d1, d2, d3)` sub-tensor.
If `param_ndim=4` (or `param_ndim=None`), then the parameter and its gradient
will be temporarily reshaped to a shape-`(1, d0, d1, d2, d3)` so that the
transformation will be applied to the shape-`(d0, d1, d2, d3)` tensor without
broadcasting.
Examples
--------
**Creating a gradient-transforming optimizer**
Let's create a gradient-transforming optimizer that replaces the gradient
of each parameter with the elementwise sign of the gradient (:math:`\pm 1`)
prior to performing the step of the inner optimizer:
>>> import torch as tr
>>> from rai_toolbox.optim import ParamTransformingOptimizer
>>> class SignedGradientOptim(ParamTransformingOptimizer):
...
... def _pre_step_transform_(self, param: tr.Tensor, **_kwds) -> None:
... if param.grad is None:
... return
... tr.sign(param.grad, out=param.grad) # operates in-place
Now we'll use this optimizer – with `torch.optim.AdamW` providing the actual
parameter-update functionality – to update the parameter.
>>> x = tr.tensor([-10.0, 10.0], requires_grad=True)
>>> optim = SignedGradientOptim([x], InnerOpt=tr.optim.AdamW, lr=0.1)
Using `x` in a calculation and compute an associated gradient for it:
>>> (10_000 * x).sum().backward()
Updating `x` using our grad-sign + AdamW optimizer:
>>> optim.step()
>>> x
tensor([-10.9000, 8.9000], requires_grad=True)
This was a simple optimizer which did not involve any broadcasting in the
gradient transformation; the next example will involve broadcasting.
**Controlling the gradient transformation with param_ndim**
To understand the role of `param_ndim` let's design an optimizer that
normalizes a parameter's gradient by its max value – along some user-specified
dimension – prior to performing the gradient-based update to its parameter.
>>> class MaxNormedGradientOptim(ParamTransformingOptimizer):
...
... def _pre_step_transform_(self, param: tr.Tensor, **_kwds) -> None:
... if param.grad is None:
... return
...
... g = param.grad.flatten(1) # (N, d1, ..., dm) -> (N, d1 * ... * dm)
... max_norms = tr.max(g, dim=1).values
... max_norms = max_norms.view(-1, *([1] * (param.ndim - 1))) # reshape to have dimenionality-m
... param.grad /= tr.clamp(max_norms, 1e-20, None) # clamp to prevent div by 0
Note that we design `_pre_step_transform_` to operate in-place on the gradient
and that we treat the gradient as if it has a shape `(N, d1, ..., dm)`, where we
want to compute the max over each of the `N` sub-tensors of
shape-`(d1, ..., dm)`.
Critically, we did not use `param_ndim` at all in this method;
`ParamTransformingOptimizer` assumes that we designed this method to broadcast
in a batch-style, as we did, and it automatically leverages `param_ndim`
to reshape the parameter and its gradient appropriately prior to calling
`_pre_step_transform_`.
Now we will create a shape-`(2, 2)` parameter to see how `MaxNormedGradientOptim`
can compute the max-norm over various dimensions of the parameter. Let's print
out the transformed gradient when we use each of `param_ndim`: `0`, `1`, or `2`.
>>> x = tr.tensor([[1.0, 2.0],
... [20.0, 10.0]], requires_grad=True)
>>> for param_ndim in [0, 1, 2]:
... optim = MaxNormedGradientOptim([x], param_ndim=param_ndim, InnerOpt=tr.optim.SGD, lr=0.0)
...
... loss = (x * x).sum()
... loss.backward()
... optim.step()
... print(f"param_ndim: {param_ndim}, normed grad:\n{x.grad}\n..")
... optim.zero_grad()
param_ndim: 0, normed grad:
tensor([[1., 1.],
[1., 1.]])
..
param_ndim: 1, normed grad:
tensor([[0.5000, 1.0000],
[1.0000, 0.5000]])
..
param_ndim: 2, normed grad:
tensor([[0.0500, 0.1000],
[1.0000, 0.5000]])
See that `param_ndim=0` applies the max-norm elementwise, whereas `param_ndim=1`
applied the max-norm to each 1D row of the gradient, and `param_ndim=2` applies
the max-norm over the entire 2D gradient.
**Creating a parameter-constraining optimizer**
Let's create an optimizer that clamps each parameter's values so that
they all fall within `[-1, 1]` after performing it's gradient-based step on the parameter.
>>> import torch as tr
>>> from rai_toolbox.optim import ParamTransformingOptimizer
>>> class ClampedParamOptim(ParamTransformingOptimizer):
... def _post_step_transform_(self, param: tr.Tensor, optim_group: dict) -> None:
... param.clamp_(min=-1.0, max=1.0) # note: clamp occurs in-place
>>> x = tr.tensor([-10., 1.], requires_grad=True)
>>> optim = ClampedParamOptim([x], lr=0.1) # InnerOpt=SGD by default
>>> x.backward(gradient=tr.tensor([-1., 1.]))
>>> optim.step() # parameters updated via SGD.step() and then clamped
>>> x
tensor([-1.0000, 0.9000], requires_grad=True)
Note that this is a particularly simple function, which acts elementwise on
each parameter, and thus does not require us to include `param_ndim` in the
optimizer's param-groups.
"""
if defaults is None:
defaults = {}
defaults.setdefault("param_ndim", param_ndim)
defaults.setdefault("grad_scale", grad_scale)
defaults.setdefault("grad_bias", grad_bias)
if instantiates_to(InnerOpt, Optimizer):
if params is None:
raise TypeError(
"`params` cannot be `None` when `InnerOpt` is an un-instantiated "
"optimizer type."
)
super().__init__(params, defaults) # type: ignore
self.inner_opt = InnerOpt(self.param_groups, **inner_opt_kwargs) # type: ignore
elif isinstance(InnerOpt, Optimizer):
self.inner_opt = InnerOpt
super().__init__(self.inner_opt.param_groups, defaults)
else:
raise TypeError(
f"`InnerOpt` must be an Optimizer type or instance, got: {InnerOpt}"
)
# ensure inner-opt's defaults include those of `self`
self.inner_opt.defaults.update(
**{
k: v
for k, v in self.inner_opt.defaults.items()
if k not in self.defaults
},
**self.defaults,
)
# state of `self` must mirror that of inner-opt
self.__setstate__(self.inner_opt.__getstate__()) # type: ignore
for group in self.param_groups:
param_ndim = group["param_ndim"]
if param_ndim is not None and not isinstance(param_ndim, int):
raise TypeError(
f"`param_ndim` must be an int or None, got: {param_ndim}"
)
if not isinstance(group["grad_scale"], (float, int)):
raise TypeError(
f"grad_scale must be a float, got {group['grad_scale']}"
)
if not isinstance(group["grad_bias"], (float, int)):
raise TypeError(f"grad_bias must be a float, got {group['grad_bias']}")
for p in group["params"]:
p: Tensor
_validate_param_ndim(param_ndim=param_ndim, p=p)
def state_dict(self) -> dict:
return self.inner_opt.state_dict()
def __setstate__(self, state: dict):
self.inner_opt.__setstate__(state)
super().__setstate__(self.inner_opt.__getstate__()) # type: ignore
def __getstate__(self) -> dict:
return self.inner_opt.__getstate__() # type: ignore
def __repr__(self) -> str:
return super().__repr__().replace("(", f"[{type(self.inner_opt).__name__}](", 1)
[docs]
def _pre_step_transform_(
self, param: Tensor, optim_group: DatumParamGroup
) -> None: # pragma: no cover
"""Applies an in-place transform on each parameter in the given param group
**before** that parameter has been updated via `InnerOpt.step`.
This defaults to a no-op.
Parameters
----------
param : torch.Tensor, shape-(N, d0, ...)
The parameter to be modified in-place.
`param` and `param.grad` will have been reshaped to have a
shape-`(N, d0, ...)` where `(d0, ...)` contains `param_ndim` entries.
optim_group : Dict[str, Any]
The parameter group associated with `param`.
Notes
-----
This transform should *always* be designed to broadcast over the leading
dimension of the tensor being modified. That is, each parameter/gradient should
be assumed to have the shape-`(N, d0, ...)` and the transformation should be
applied - in-place - to each shape-`(d0, ...)` sub-tensor.
Prior to calling `_pre_step_transform_`, `ParamTransformingOptimizer`
will temporarily reshape each parameter and its gradient to have the appropriate
shape – in accordance with the value specified for `param_ndim` – such that
the shape-`(d0, ...)` tensor contains `param_ndim` entries.
In the case where `param_ndim=0`, the transformation will be applied to a
shape-`(T, 1)` tensor, where `T` corresponds to the total number of elements
in the tensor."""
del param
del optim_group
return None
[docs]
def _post_step_transform_(
self, param: Tensor, optim_group: DatumParamGroup
) -> None: # pragma: no cover
"""Applies an in-place transform on each parameter in the given param group
**after** that parameter has been updated via `InnerOpt.step`.
This defaults to a no-op.
Parameters
----------
param : torch.Tensor, shape-(N, d0, ...)
The parameter to be modified in-place.
`param` and `param.grad` will have been reshaped to have a
shape-`(N, d0, ...)` where `(d0, ...)` contains `param_ndim` entries.
optim_group : Dict[str, Any]
The parameter group associated with `param`.
Notes
-----
This transform should *always* be designed to broadcast over the leading
dimension of the tensor being modified. That is, each parameter/gradient should
be assumed to have the shape-(N, d0, ...) and the transformation should be
applied - in-place - to each shape-`(d0, ...)` sub-tensor.
Prior to calling `_post_step_transform_`, `ParamTransformingOptimizer`
will temporarily reshape each parameter and its gradient to have the appropriate
shape – in accordance with the value specified for `param_ndim` – such that
the shape-`(d0, ...)` tensor contains `param_ndim` entries.
In the case where `param_ndim=0`, the transformation will be applied to a
shape-`(T, 1)` tensor, where `T` corresponds to the total number of elements
in the tensor.
"""
del param
del optim_group
return None
[docs]
@torch.no_grad()
def _apply_post_step_transform_(self) -> None:
"""Update each parameter in-place by calling `_post_step_transform_` on the
parameter.
This is called automatically by `.step()` after `InnerOpt.step()` has been
called."""
for group in self.param_groups:
param_ndim = group["param_ndim"]
for p in group["params"]:
p = _to_batch(p, param_ndim)
self._post_step_transform_(param=p, optim_group=group)
[docs]
@torch.no_grad()
def _apply_pre_step_transform_(self):
"""Update each parameter in-place by calling `_pre_step_transform_` on the
parameter.
This is called automatically by `.step()` before `InnerOpt.step()` has been
called."""
for group in self.param_groups:
for p in group["params"]:
p: Tensor
orig_p = p
if p.grad is None:
continue
assert orig_p.grad is not None
p = _to_batch(p, group["param_ndim"])
assert p.grad is not None
self._pre_step_transform_(p, optim_group=group)
if group["grad_scale"] != 1.0:
p.grad *= group["grad_scale"]
if group["grad_bias"] != 0.0:
p.grad += group["grad_bias"]
if p.grad is None or not _shares_memory(orig_p.grad, p.grad):
raise ValueError(
f"`{type(self).__name__}._pre_step_transform_` did "
" not modify the gradient of the parameter in-place."
" \nNote that setting `p.grad` directly replaces the"
" tensor, rather than writing to the tensor."
)
@torch.no_grad()
def _create_closure(self, closure: Callable[[], _T]) -> Callable[[], Optional[_T]]:
def new_closure():
with torch.enable_grad():
loss = closure()
self._apply_pre_step_transform_()
return loss
return new_closure
@overload
def step(self, closure: Callable[[], _T]) -> _T: # pragma: no cover
...
@overload
def step(self) -> None: # pragma: no cover
...
@overload
def step(
self, closure: Optional[Callable[[], _T]] = None
) -> Optional[_T]: # pragma: no cover
...
@torch.no_grad()
def step(self, closure=None):
if closure is not None:
closure = self._create_closure(closure)
loss = self.inner_opt.step(closure) # type: ignore
else:
self._apply_pre_step_transform_()
self.inner_opt.step()
loss = None
self._apply_post_step_transform_()
loss = cast(Optional[Union[float, Tensor]], loss)
return loss
[docs]
class ChainedParamTransformingOptimizer(ParamTransformingOptimizer):
"""Chains together an arbitrary number of parameter-transforming optimizers,
composing their pre- and post-step transformation functions to modify the parameters
(and their gradients) in-place. `InnerOpt.step()` applies the gradient-based update
to each parameter.
I.e., passing `Opt1, Opt2, ..., OptN` to `ChainedParamTransformingOptimizer` will
update a parameter using: `OptN.fn_(...(Opt2.fn_(Opt1.fn_(param)))`,
where `fn_` is a shorthand for `_pre_step_transform_` / `_post_step_transform_`.
Notes
-----
`ChainedParamTransformingOptimizer` mirrors state with `InnerOpt`, and with all of
the user-specified chained gradient-trasnformers, so that their `param_groups`,
`defaults`, and `state` are always in sync.
See Also
--------
ParamTransformingOptimizer
"""
[docs]
def __init__(
self,
*transforming_optimizers: InstantiatesTo[ParamTransformingOptimizer],
params: Optional[OptimParams] = None,
InnerOpt: Union[Opt, Partial[Opt], OptimizerType] = SGD,
param_ndim: Union[int, None] = -1,
grad_scale: float = 1,
grad_bias: float = 0,
defaults: Optional[Dict[str, Any]] = None,
**inner_opt_kwargs,
) -> None:
r"""
Parameters
----------
*transforming_optimizers: InstantiatesTo[ParamTransformingOptimizer],
An arbitrary number of parameter-transforming optimizers, whose
`_pre_step_transform_` and `_post_step_transform_` methods, respectively,
will be composed from left to right –
`Opt1, Opt2, ..., OptN -> fN_(...f2_(f1_(grad)))` – to modify a parameter
prior to / after being updated by `InnerOpt.step`
params : Optional[Sequence[Tensor] | Iterable[ParamGroup]]
Iterable of parameters to optimize or dicts defining parameter groups
InnerOpt : Type[Optimizer] | Partial[Optimizer], optional (default=`torch.nn.optim.SGD`)
The optimizer that updates the parameters after `_pre_step_transform_` has
been applied to each of them.
param_ndim : int | None, optional (default=-1)
Determines how a parameter and its gradient is temporarily reshaped prior
to being passed to both `_pre_step_transform_` and `_post_step_transform_`.
By default, the transformation broadcasts over the tensor's first dimension
in a batch-like style.
- A positive number determines the dimensionality of the tensor that the transformation will act on.
- A negative number indicates the 'offset' from the dimensionality of the tensor (see "Notes" for examples).
- `None` means that the transformation will be applied directly to the tensor without any broadcasting.
See `ParamTransformingOptimizer` for more details and examples.
grad_scale : float, optional (default=1.0)
Multiplies each gradient in-place after the in-place transformation is
performed. This can be specified per param-group.
grad_bias : float, optional (default=0.0)
Added to each gradient in-place after the in-place transformation is
performed. This can be specified per param-group.
defaults : Optional[Dict[str, Any]]
Specifies default parameters for all parameter groups.
**inner_opt_kwargs : Any
Named arguments used to initialize `InnerOpt`.
Examples
--------
**Basic Example**
Let's chain together two gradient-transforming optimizers supplied by rAI-toolbox:
`TopQGradientOptimizer` and `ClampedGradientOptimizer`
>>> from rai_toolbox.optim import (
... ChainedParamTransformingOptimizer,
... ClampedGradientOptimizer,
... TopQGradientOptimizer,
... )
>>> import torch as tr
>>> from functools import partial
>>> x1 = tr.ones(3, requires_grad=True) # shape-(3,)
Our optimizer will retain only the top-33rd percentile elements in the gradient:
the smallest elements will be zero'd. Then the resulting gradient will be
clamped so that its largest possible entry is `2.8`. Finally, the standard `SGD`
optimizer will be used, with `lr=1.0`, to update the parameter(s) using the
transformed gradients.
We specify `TopQGradientOptimizer` and then `ClampedGradientOptimizer`; the
transformations are applied in order from left to right. Providing per-optimizer
defaults is achieved most naturally using :py:func:`functools.partial`.
>>> optim = ChainedParamTransformingOptimizer(
... partial(TopQGradientOptimizer, q=0.33),
... partial(ClampedGradientOptimizer, clamp_max=2.8),
... params=[x1],
... lr=1.0,
... param_ndim=None, # we don't want any broadcasting to occur
... )
ClampedGradientOptimizer ○ TopQGradientOptimizer [SGD](
Parameter Group 0
clamp_max: 2.8
clamp_min: None
dampening: 0
dq: 0.0
grad_bias: 0
grad_scale: 1
lr: 1.0
maximize: False
momentum: 0
nesterov: False
param_ndim: None
q: 0.33
weight_decay: 0
)
Let's verify that `optim` transforms our gradients as-expected.
>>> (tr.tensor([1.0, 2.0, 3.0]) * x1).sum().backward()
>>> optim.step()
>>> x1.grad # element-0 should be zero'd by top-q; element-2 should be clamped to 2.8
tensor([0.0000, 2.0000, 2.8000])
See that `SGD([x1], lr=1.0).step()` is used to update our parameters; this can
be controlled via the `InnerOpt` argument.
>>> x1
tensor([ 1.0000, -1.0000, -1.8000], requires_grad=True)
**Adding Parameter Groups**
Our chained gradient-transforming optimizers mirror their states with `optim`
and `SGD`, thus we can add parameter groups and the group's settings will be
applied to our chain as-expected.
Let's add a 2D parameter, where we want to apply the top-q sparsification
row-wise (via `param_ndim=1`), and retain only 64th-percentile gradient
elements.
>>> x2 = tr.ones(2, 3, requires_grad=True) # shape-(2, 3)
>>> optim.add_param_group(dict(params=x2, param_ndim=1, q=0.64))
>>> optim
ClampedGradientOptimizer ○ TopQGradientOptim [SGD](
Parameter Group 0
clamp_max: 2.8
clamp_min: None
dampening: 0
dq: 0.0
grad_bias: 0
grad_scale: 1
lr: 1.0
maximize: False
momentum: 0
nesterov: False
param_ndim: None
q: 0.33
weight_decay: 0
Parameter Group 1
clamp_max: 2.8
clamp_min: None
dampening: 0
dq: 0.0
grad_bias: 0
grad_scale: 1
lr: 1.0
maximize: False
momentum: 0
nesterov: False
param_ndim: 1
q: 0.64
>>> optim.zero_grad()
>>> (tr.tensor([1.0, 2.0, 3.0]) * (x1 + x2)).sum().backward()
>>> optim.step()
>>> x1.grad
tensor([0.0000, 2.8000, 2.8000])
>>> x2.grad
tensor([[0.0000, 0.0000, 2.8000],
[0.0000, 0.0000, 2.8000]])
"""
self._chain = ()
super().__init__(
params,
InnerOpt,
param_ndim=param_ndim,
grad_scale=grad_scale,
grad_bias=grad_bias,
defaults=defaults,
**inner_opt_kwargs,
)
for _opt in transforming_optimizers:
if not instantiates_to(_opt, ParamTransformingOptimizer):
raise TypeError(
f"*transforming_optimizers must contain "
f"`Type[ParamTransformingOptimizer]`, got: {transforming_optimizers}"
)
self._chain = tuple(
opt(params, InnerOpt=self.inner_opt, defaults=self.defaults)
for opt in transforming_optimizers
)
def _pre_step_transform_(self, param: Tensor, optim_group: DatumParamGroup) -> None:
# [f1, f2, f3] -> f3(f2(f1(param)))
for opt in self._chain:
opt._pre_step_transform_(param=param, optim_group=optim_group)
def _post_step_transform_(
self, param: Tensor, optim_group: DatumParamGroup
) -> None:
# [f1, f2, f3] -> f3(f2(f1(param)))
for opt in self._chain:
opt._post_step_transform_(param=param, optim_group=optim_group)
def __setstate__(self, state: dict):
# synchornize state between `self`, members of `self._chain`,
# and `self.inner_opt`
self.inner_opt.__setstate__(state)
state = self.inner_opt.__getstate__() # type: ignore
for c in self._chain:
c.__setstate__(state)
super().__setstate__(state)
def __repr__(self) -> str:
return (
super()
.__repr__()
.replace(
type(self).__name__,
" ○ ".join(type(c).__name__ for c in self._chain[::-1]),
)
)