rai_toolbox.optim.ChainedParamTransformingOptimizer#

class rai_toolbox.optim.ChainedParamTransformingOptimizer(*transforming_optimizers, params=None, InnerOpt=<class 'torch.optim.sgd.SGD'>, param_ndim=-1, grad_scale=1, grad_bias=0, defaults=None, **inner_opt_kwargs)[source]#

Chains together an arbitrary number of parameter-transforming optimizers, composing their pre- and post-step transformation functions to modify the parameters (and their gradients) in-place. InnerOpt.step() applies the gradient-based update to each parameter.

I.e., passing Opt1, Opt2, ..., OptN to ChainedParamTransformingOptimizer will update a parameter using: OptN.fn_(...(Opt2.fn_(Opt1.fn_(param))), where fn_ is a shorthand for _pre_step_transform_ / _post_step_transform_.

Notes

ChainedParamTransformingOptimizer mirrors state with InnerOpt, and with all of the user-specified chained gradient-trasnformers, so that their param_groups, defaults, and state are always in sync.

__init__(*transforming_optimizers, params=None, InnerOpt=<class 'torch.optim.sgd.SGD'>, param_ndim=-1, grad_scale=1, grad_bias=0, defaults=None, **inner_opt_kwargs)[source]#
Parameters:
*transforming_optimizers: InstantiatesTo[ParamTransformingOptimizer],

An arbitrary number of parameter-transforming optimizers, whose _pre_step_transform_ and _post_step_transform_ methods, respectively, will be composed from left to right – Opt1, Opt2, ..., OptN -> fN_(...f2_(f1_(grad))) – to modify a parameter prior to / after being updated by InnerOpt.step

paramsOptional[Sequence[Tensor] | Iterable[ParamGroup]]

Iterable of parameters to optimize or dicts defining parameter groups

InnerOptType[Optimizer] | Partial[Optimizer], optional (default=`torch.nn.optim.SGD`)

The optimizer that updates the parameters after _pre_step_transform_ has been applied to each of them.

param_ndimint | None, optional (default=-1)

Determines how a parameter and its gradient is temporarily reshaped prior to being passed to both _pre_step_transform_ and _post_step_transform_. By default, the transformation broadcasts over the tensor’s first dimension in a batch-like style.

  • A positive number determines the dimensionality of the tensor that the transformation will act on.

  • A negative number indicates the ‘offset’ from the dimensionality of the tensor (see “Notes” for examples).

  • None means that the transformation will be applied directly to the tensor without any broadcasting.

See ParamTransformingOptimizer for more details and examples.

grad_scalefloat, optional (default=1.0)

Multiplies each gradient in-place after the in-place transformation is performed. This can be specified per param-group.

grad_biasfloat, optional (default=0.0)

Added to each gradient in-place after the in-place transformation is performed. This can be specified per param-group.

defaultsOptional[Dict[str, Any]]

Specifies default parameters for all parameter groups.

**inner_opt_kwargsAny

Named arguments used to initialize InnerOpt.

Examples

Basic Example

Let’s chain together two gradient-transforming optimizers supplied by rAI-toolbox: TopQGradientOptimizer and ClampedGradientOptimizer

>>> from rai_toolbox.optim import (
... ChainedParamTransformingOptimizer,
... ClampedGradientOptimizer,
... TopQGradientOptimizer,
... )
>>> import torch as tr
>>> from functools import partial
>>> x1 = tr.ones(3, requires_grad=True)  # shape-(3,)

Our optimizer will retain only the top-33rd percentile elements in the gradient: the smallest elements will be zero’d. Then the resulting gradient will be clamped so that its largest possible entry is 2.8. Finally, the standard SGD optimizer will be used, with lr=1.0, to update the parameter(s) using the transformed gradients.

We specify TopQGradientOptimizer and then ClampedGradientOptimizer; the transformations are applied in order from left to right. Providing per-optimizer defaults is achieved most naturally using functools.partial().

>>> optim = ChainedParamTransformingOptimizer(
...     partial(TopQGradientOptimizer, q=0.33),
...     partial(ClampedGradientOptimizer, clamp_max=2.8),
...     params=[x1],
...     lr=1.0,
...     param_ndim=None, # we don't want any broadcasting to occur
... )
ClampedGradientOptimizer ○ TopQGradientOptimizer [SGD](
Parameter Group 0
    clamp_max: 2.8
    clamp_min: None
    dampening: 0
    dq: 0.0
    grad_bias: 0
    grad_scale: 1
    lr: 1.0
    maximize: False
    momentum: 0
    nesterov: False
    param_ndim: None
    q: 0.33
    weight_decay: 0
)

Let’s verify that optim transforms our gradients as-expected.

>>> (tr.tensor([1.0, 2.0, 3.0]) * x1).sum().backward()
>>> optim.step()
>>> x1.grad  # element-0 should be zero'd by top-q; element-2 should be clamped to 2.8
tensor([0.0000, 2.0000, 2.8000])

See that SGD([x1], lr=1.0).step() is used to update our parameters; this can be controlled via the InnerOpt argument.

>>> x1
tensor([ 1.0000, -1.0000, -1.8000], requires_grad=True)

Adding Parameter Groups

Our chained gradient-transforming optimizers mirror their states with optim and SGD, thus we can add parameter groups and the group’s settings will be applied to our chain as-expected.

Let’s add a 2D parameter, where we want to apply the top-q sparsification row-wise (via param_ndim=1), and retain only 64th-percentile gradient elements.

>>> x2 = tr.ones(2, 3, requires_grad=True)  # shape-(2, 3)
>>> optim.add_param_group(dict(params=x2, param_ndim=1, q=0.64))
>>> optim
ClampedGradientOptimizer ○ TopQGradientOptim [SGD](
Parameter Group 0
    clamp_max: 2.8
    clamp_min: None
    dampening: 0
    dq: 0.0
    grad_bias: 0
    grad_scale: 1
    lr: 1.0
    maximize: False
    momentum: 0
    nesterov: False
    param_ndim: None
    q: 0.33
    weight_decay: 0
Parameter Group 1
    clamp_max: 2.8
    clamp_min: None
    dampening: 0
    dq: 0.0
    grad_bias: 0
    grad_scale: 1
    lr: 1.0
    maximize: False
    momentum: 0
    nesterov: False
    param_ndim: 1
    q: 0.64
>>> optim.zero_grad()
>>> (tr.tensor([1.0, 2.0, 3.0]) * (x1 + x2)).sum().backward()
>>> optim.step()
>>> x1.grad
tensor([0.0000, 2.8000, 2.8000])
>>> x2.grad
tensor([[0.0000, 0.0000, 2.8000],
    [0.0000, 0.0000, 2.8000]])

Methods

__init__(*transforming_optimizers[, params, ...])

Parameters: