
class rai_toolbox.optim.ParamTransformingOptimizer(params=None, InnerOpt=<class 'torch.optim.sgd.SGD'>, *, param_ndim=-1, grad_scale=1.0, grad_bias=0.0, defaults=None, **inner_opt_kwargs)[source]#

An optimizer that performs an in-place transformation to each parameter, both before and after performing the gradient-based update on each parameter via InnerOptim.step:

param = InnerOptim.step(param, ...)

Note that _pre_step_transform_ and _post_step_transform_ can be used to update a parameter and/or its gradient. Also, this optimizer exposes param_ndim as a means of controlling how these transforms broadcast (if at all) over any given tensor.


ParamTransformingOptimizer mirrors state with InnerOpt so that their param_groups, defaults, and state are always in sync.

ParamTransformingOptimizer is designed to be combined with other, standard gradient-based optimizers (e.g., Adam) via composition, rather than through inheritance. I.e., ParamTransformingOptimizer(InnerOpt=<...>) will apply _pre_step_transform_ on a parameter, and then use InnerOpt.step(...) to update said parameter, and finally will apply _post_step_transform_ to the parameter.

If a closure is supplied to the step(...) method, then the _pre_step_transform_ is applied after the closure call and prior to the parameter steps.


_pre_step_transform_(param, optim_group)

Applies an in-place transform on each parameter in the given param group before that parameter has been updated via InnerOpt.step.

_post_step_transform_(param, optim_group)

Applies an in-place transform on each parameter in the given param group after that parameter has been updated via InnerOpt.step.


__init__(params=None, InnerOpt=<class 'torch.optim.sgd.SGD'>, *, param_ndim=-1, grad_scale=1.0, grad_bias=0.0, defaults=None, **inner_opt_kwargs)[source]#
paramsSequence[Tensor] | Iterable[ParamGroup]

Iterable of parameters to optimize or dicts defining parameter groups

InnerOptType[Optimizer] | Partial[Optimizer], optional (default=`torch.nn.optim.SGD`)

The optimizer that updates the parameters after their gradients have been transformed.

param_ndimint | None, optional (default=-1)

Determines how a parameter and its gradient are temporarily reshaped prior to being passed to both _pre_step_transform_ and _post_step_transform_. By default, the transformation broadcasts over the tensor’s first dimension in a batch-like style.

  • A positive number determines the dimensionality of the tensor that the transformation will act on.

  • A negative number indicates the ‘offset’ from the dimensionality of the tensor (see “Notes” for examples).

  • None means that the transformation will be applied directly to the tensor without any broadcasting.

See “Notes” for more details.

grad_scalefloat, optional (default=1.0)

Multiplies each gradient in-place after the pre-step transformation is performed. This can be specified per param-group.

grad_biasfloat, optional (default=0.0)

Added to each gradient in-place after the pre-step transformation is performed. This can be specified per param-group.

defaultsOptional[Dict[str, Any]]

Specifies default parameters for all parameter groups.


Named arguments used to initialize InnerOpt.


Additional Explanation of `param_ndim`

Consider a parameter of shape (d0, d1, d2, d4).

If param_ndim=0, then the parameter and its gradient will be temporarily reshaped to a shape-(d0 * d1 * d2 * d3, 1) so that the transformation will be applied elementwise to the tensor.

If param_ndim=1 (or param_ndim=-3), then the parameter and its gradient will be temporarily reshaped to a shape-(d0 * d1 * d2, d3) so that the transformation will be broadcast over each shape-(d3,) sub-tensor.

If param_ndim=2 (or param_ndim=-2), then the parameter and its gradient will be temporarily reshaped to a shape-(d0 * d1, d2, d3) so that the transformation will be broadcast over each shape-(d2, d3) sub-tensor.

If param_ndim=3 (or param_ndim=-1), then the parameter and its gradient will be temporarily reshaped to a shape-(d0, d1, d2, d3) so that the transformation will be broadcast over each shape-(d1, d2, d3) sub-tensor.

If param_ndim=4 (or param_ndim=None), then the parameter and its gradient will be temporarily reshaped to a shape-(1, d0, d1, d2, d3) so that the transformation will be applied to the shape-(d0, d1, d2, d3) tensor without broadcasting.


Creating a gradient-transforming optimizer

Let’s create a gradient-transforming optimizer that replaces the gradient of each parameter with the elementwise sign of the gradient (\(\pm 1\)) prior to performing the step of the inner optimizer:

>>> import torch as tr
>>> from rai_toolbox.optim import ParamTransformingOptimizer
>>> class SignedGradientOptim(ParamTransformingOptimizer):
...     def _pre_step_transform_(self, param: tr.Tensor, **_kwds) -> None:
...         if param.grad is None:
...             return
...         tr.sign(param.grad, out=param.grad)  # operates in-place

Now we’ll use this optimizer – with torch.optim.AdamW providing the actual parameter-update functionality – to update the parameter.

>>> x = tr.tensor([-10.0, 10.0], requires_grad=True)
>>> optim = SignedGradientOptim([x], InnerOpt=tr.optim.AdamW, lr=0.1)

Using x in a calculation and compute an associated gradient for it:

>>> (10_000 * x).sum().backward()

Updating x using our grad-sign + AdamW optimizer:

>>> optim.step()
>>> x
tensor([-10.9000,   8.9000], requires_grad=True)

This was a simple optimizer which did not involve any broadcasting in the gradient transformation; the next example will involve broadcasting.

Controlling the gradient transformation with param_ndim

To understand the role of param_ndim let’s design an optimizer that normalizes a parameter’s gradient by its max value – along some user-specified dimension – prior to performing the gradient-based update to its parameter.

>>> class MaxNormedGradientOptim(ParamTransformingOptimizer):
...     def _pre_step_transform_(self, param: tr.Tensor, **_kwds) -> None:
...         if param.grad is None:
...             return
...         g = param.grad.flatten(1) # (N, d1, ..., dm) -> (N, d1 * ... * dm)
...         max_norms = tr.max(g, dim=1).values
...         max_norms = max_norms.view(-1, *([1] * (param.ndim - 1)))  # reshape to have dimenionality-m
...         param.grad /= tr.clamp(max_norms, 1e-20, None)  # clamp to prevent div by 0

Note that we design _pre_step_transform_ to operate in-place on the gradient and that we treat the gradient as if it has a shape (N, d1, ..., dm), where we want to compute the max over each of the N sub-tensors of shape-(d1, ..., dm).

Critically, we did not use param_ndim at all in this method; ParamTransformingOptimizer assumes that we designed this method to broadcast in a batch-style, as we did, and it automatically leverages param_ndim to reshape the parameter and its gradient appropriately prior to calling _pre_step_transform_.

Now we will create a shape-(2, 2) parameter to see how MaxNormedGradientOptim can compute the max-norm over various dimensions of the parameter. Let’s print out the transformed gradient when we use each of param_ndim: 0, 1, or 2.

>>> x = tr.tensor([[1.0, 2.0],
...                [20.0, 10.0]], requires_grad=True)
>>> for param_ndim in [0, 1, 2]:
...     optim = MaxNormedGradientOptim([x], param_ndim=param_ndim, InnerOpt=tr.optim.SGD, lr=0.0)
...     loss = (x * x).sum()
...     loss.backward()
...     optim.step()
...     print(f"param_ndim: {param_ndim}, normed grad:\n{x.grad}\n..")
...     optim.zero_grad()
param_ndim: 0, normed grad:
tensor([[1., 1.],
        [1., 1.]])
param_ndim: 1, normed grad:
tensor([[0.5000, 1.0000],
        [1.0000, 0.5000]])
param_ndim: 2, normed grad:
tensor([[0.0500, 0.1000],
        [1.0000, 0.5000]])

See that param_ndim=0 applies the max-norm elementwise, whereas param_ndim=1 applied the max-norm to each 1D row of the gradient, and param_ndim=2 applies the max-norm over the entire 2D gradient.

Creating a parameter-constraining optimizer

Let’s create an optimizer that clamps each parameter’s values so that they all fall within [-1, 1] after performing it’s gradient-based step on the parameter.

>>> import torch as tr
>>> from rai_toolbox.optim import ParamTransformingOptimizer
>>> class ClampedParamOptim(ParamTransformingOptimizer):
...     def _post_step_transform_(self, param: tr.Tensor, optim_group: dict) -> None:
...         param.clamp_(min=-1.0, max=1.0)  # note: clamp occurs in-place
>>> x = tr.tensor([-10., 1.], requires_grad=True)
>>> optim = ClampedParamOptim([x], lr=0.1)  # InnerOpt=SGD by default
>>> x.backward(gradient=tr.tensor([-1., 1.]))
>>> optim.step()  # parameters updated via SGD.step() and then clamped
>>> x
tensor([-1.0000,  0.9000], requires_grad=True)

Note that this is a particularly simple function, which acts elementwise on each parameter, and thus does not require us to include param_ndim in the optimizer’s param-groups.

_pre_step_transform_(param, optim_group)[source]#

Applies an in-place transform on each parameter in the given param group before that parameter has been updated via InnerOpt.step.

This defaults to a no-op.

paramtorch.Tensor, shape-(N, d0, …)

The parameter to be modified in-place.

param and param.grad will have been reshaped to have a shape-(N, d0, ...) where (d0, ...) contains param_ndim entries.

optim_groupDict[str, Any]

The parameter group associated with param.


This transform should always be designed to broadcast over the leading dimension of the tensor being modified. That is, each parameter/gradient should be assumed to have the shape-(N, d0, ...) and the transformation should be applied - in-place - to each shape-(d0, ...) sub-tensor.

Prior to calling _pre_step_transform_, ParamTransformingOptimizer will temporarily reshape each parameter and its gradient to have the appropriate shape – in accordance with the value specified for param_ndim – such that the shape-(d0, ...) tensor contains param_ndim entries.

In the case where param_ndim=0, the transformation will be applied to a shape-(T, 1) tensor, where T corresponds to the total number of elements in the tensor.

_post_step_transform_(param, optim_group)[source]#

Applies an in-place transform on each parameter in the given param group after that parameter has been updated via InnerOpt.step.

This defaults to a no-op.

paramtorch.Tensor, shape-(N, d0, …)

The parameter to be modified in-place.

param and param.grad will have been reshaped to have a shape-(N, d0, ...) where (d0, ...) contains param_ndim entries.

optim_groupDict[str, Any]

The parameter group associated with param.


This transform should always be designed to broadcast over the leading dimension of the tensor being modified. That is, each parameter/gradient should be assumed to have the shape-(N, d0, …) and the transformation should be applied - in-place - to each shape-(d0, ...) sub-tensor.

Prior to calling _post_step_transform_, ParamTransformingOptimizer will temporarily reshape each parameter and its gradient to have the appropriate shape – in accordance with the value specified for param_ndim – such that the shape-(d0, ...) tensor contains param_ndim entries.

In the case where param_ndim=0, the transformation will be applied to a shape-(T, 1) tensor, where T corresponds to the total number of elements in the tensor.


Update each parameter in-place by calling _pre_step_transform_ on the parameter.

This is called automatically by step() before InnerOpt.step() has been called.


Update each parameter in-place by calling _post_step_transform_ on the parameter.

This is called automatically by step() after InnerOpt.step() has been called.


__init__([params, InnerOpt, param_ndim, ...])


_pre_step_transform_(param, optim_group)

Applies an in-place transform on each parameter in the given param group before that parameter has been updated via InnerOpt.step.

_post_step_transform_(param, optim_group)

Applies an in-place transform on each parameter in the given param group after that parameter has been updated via InnerOpt.step.


Update each parameter in-place by calling _pre_step_transform_ on the parameter.


Update each parameter in-place by calling _post_step_transform_ on the parameter.