rai_toolbox.optim.ParamTransformingOptimizer#
- class rai_toolbox.optim.ParamTransformingOptimizer(params=None, InnerOpt=<class 'torch.optim.sgd.SGD'>, *, param_ndim=-1, grad_scale=1.0, grad_bias=0.0, defaults=None, **inner_opt_kwargs)[source]#
 An optimizer that performs an in-place transformation to each parameter, both before and after performing the gradient-based update on each parameter via
InnerOptim.step:_pre_step_transform_(param) param = InnerOptim.step(param, ...) _post_step_transform_(param)
Note that
_pre_step_transform_and_post_step_transform_can be used to update a parameter and/or its gradient. Also, this optimizer exposesparam_ndimas a means of controlling how these transforms broadcast (if at all) over any given tensor.See also
Notes
ParamTransformingOptimizermirrors state withInnerOptso that theirparam_groups,defaults, andstateare always in sync.ParamTransformingOptimizeris designed to be combined with other, standard gradient-based optimizers (e.g., Adam) via composition, rather than through inheritance. I.e.,ParamTransformingOptimizer(InnerOpt=<...>)will apply_pre_step_transform_on a parameter, and then useInnerOpt.step(...)to update said parameter, and finally will apply_post_step_transform_to the parameter.If a closure is supplied to the
step(...)method, then the_pre_step_transform_is applied after the closure call and prior to the parameter steps.Methods
_pre_step_transform_(param, optim_group)Applies an in-place transform on each parameter in the given param group before that parameter has been updated via
InnerOpt.step._post_step_transform_(param, optim_group)Applies an in-place transform on each parameter in the given param group after that parameter has been updated via
InnerOpt.step.project
- __init__(params=None, InnerOpt=<class 'torch.optim.sgd.SGD'>, *, param_ndim=-1, grad_scale=1.0, grad_bias=0.0, defaults=None, **inner_opt_kwargs)[source]#
 - Parameters:
 - paramsSequence[Tensor] | Iterable[ParamGroup]
 Iterable of parameters to optimize or dicts defining parameter groups
- InnerOptType[Optimizer] | Partial[Optimizer], optional (default=`torch.nn.optim.SGD`)
 The optimizer that updates the parameters after their gradients have been transformed.
- param_ndimint | None, optional (default=-1)
 Determines how a parameter and its gradient are temporarily reshaped prior to being passed to both
_pre_step_transform_and_post_step_transform_. By default, the transformation broadcasts over the tensor’s first dimension in a batch-like style.A positive number determines the dimensionality of the tensor that the transformation will act on.
A negative number indicates the ‘offset’ from the dimensionality of the tensor (see “Notes” for examples).
Nonemeans that the transformation will be applied directly to the tensor without any broadcasting.
See “Notes” for more details.
- grad_scalefloat, optional (default=1.0)
 Multiplies each gradient in-place after the pre-step transformation is performed. This can be specified per param-group.
- grad_biasfloat, optional (default=0.0)
 Added to each gradient in-place after the pre-step transformation is performed. This can be specified per param-group.
- defaultsOptional[Dict[str, Any]]
 Specifies default parameters for all parameter groups.
- **inner_opt_kwargsAny
 Named arguments used to initialize
InnerOpt.
Notes
Additional Explanation of `param_ndim`
Consider a parameter of shape
(d0, d1, d2, d4).If
param_ndim=0, then the parameter and its gradient will be temporarily reshaped to a shape-(d0 * d1 * d2 * d3, 1)so that the transformation will be applied elementwise to the tensor.If
param_ndim=1(orparam_ndim=-3), then the parameter and its gradient will be temporarily reshaped to a shape-(d0 * d1 * d2, d3)so that the transformation will be broadcast over each shape-(d3,)sub-tensor.If
param_ndim=2(orparam_ndim=-2), then the parameter and its gradient will be temporarily reshaped to a shape-(d0 * d1, d2, d3)so that the transformation will be broadcast over each shape-(d2, d3)sub-tensor.If
param_ndim=3(orparam_ndim=-1), then the parameter and its gradient will be temporarily reshaped to a shape-(d0, d1, d2, d3)so that the transformation will be broadcast over each shape-(d1, d2, d3)sub-tensor.If
param_ndim=4(orparam_ndim=None), then the parameter and its gradient will be temporarily reshaped to a shape-(1, d0, d1, d2, d3)so that the transformation will be applied to the shape-(d0, d1, d2, d3)tensor without broadcasting.Examples
Creating a gradient-transforming optimizer
Let’s create a gradient-transforming optimizer that replaces the gradient of each parameter with the elementwise sign of the gradient (\(\pm 1\)) prior to performing the step of the inner optimizer:
>>> import torch as tr >>> from rai_toolbox.optim import ParamTransformingOptimizer >>> class SignedGradientOptim(ParamTransformingOptimizer): ... ... def _pre_step_transform_(self, param: tr.Tensor, **_kwds) -> None: ... if param.grad is None: ... return ... tr.sign(param.grad, out=param.grad) # operates in-place
Now we’ll use this optimizer – with
torch.optim.AdamWproviding the actual parameter-update functionality – to update the parameter.>>> x = tr.tensor([-10.0, 10.0], requires_grad=True) >>> optim = SignedGradientOptim([x], InnerOpt=tr.optim.AdamW, lr=0.1)
Using
xin a calculation and compute an associated gradient for it:>>> (10_000 * x).sum().backward()
Updating
xusing our grad-sign + AdamW optimizer:>>> optim.step() >>> x tensor([-10.9000, 8.9000], requires_grad=True)
This was a simple optimizer which did not involve any broadcasting in the gradient transformation; the next example will involve broadcasting.
Controlling the gradient transformation with param_ndim
To understand the role of
param_ndimlet’s design an optimizer that normalizes a parameter’s gradient by its max value – along some user-specified dimension – prior to performing the gradient-based update to its parameter.>>> class MaxNormedGradientOptim(ParamTransformingOptimizer): ... ... def _pre_step_transform_(self, param: tr.Tensor, **_kwds) -> None: ... if param.grad is None: ... return ... ... g = param.grad.flatten(1) # (N, d1, ..., dm) -> (N, d1 * ... * dm) ... max_norms = tr.max(g, dim=1).values ... max_norms = max_norms.view(-1, *([1] * (param.ndim - 1))) # reshape to have dimenionality-m ... param.grad /= tr.clamp(max_norms, 1e-20, None) # clamp to prevent div by 0
Note that we design
_pre_step_transform_to operate in-place on the gradient and that we treat the gradient as if it has a shape(N, d1, ..., dm), where we want to compute the max over each of theNsub-tensors of shape-(d1, ..., dm).Critically, we did not use
param_ndimat all in this method;ParamTransformingOptimizerassumes that we designed this method to broadcast in a batch-style, as we did, and it automatically leveragesparam_ndimto reshape the parameter and its gradient appropriately prior to calling_pre_step_transform_.Now we will create a shape-
(2, 2)parameter to see howMaxNormedGradientOptimcan compute the max-norm over various dimensions of the parameter. Let’s print out the transformed gradient when we use each ofparam_ndim:0,1, or2.>>> x = tr.tensor([[1.0, 2.0], ... [20.0, 10.0]], requires_grad=True) >>> for param_ndim in [0, 1, 2]: ... optim = MaxNormedGradientOptim([x], param_ndim=param_ndim, InnerOpt=tr.optim.SGD, lr=0.0) ... ... loss = (x * x).sum() ... loss.backward() ... optim.step() ... print(f"param_ndim: {param_ndim}, normed grad:\n{x.grad}\n..") ... optim.zero_grad() param_ndim: 0, normed grad: tensor([[1., 1.], [1., 1.]]) .. param_ndim: 1, normed grad: tensor([[0.5000, 1.0000], [1.0000, 0.5000]]) .. param_ndim: 2, normed grad: tensor([[0.0500, 0.1000], [1.0000, 0.5000]])
See that
param_ndim=0applies the max-norm elementwise, whereasparam_ndim=1applied the max-norm to each 1D row of the gradient, andparam_ndim=2applies the max-norm over the entire 2D gradient.Creating a parameter-constraining optimizer
Let’s create an optimizer that clamps each parameter’s values so that they all fall within
[-1, 1]after performing it’s gradient-based step on the parameter.>>> import torch as tr >>> from rai_toolbox.optim import ParamTransformingOptimizer >>> class ClampedParamOptim(ParamTransformingOptimizer): ... def _post_step_transform_(self, param: tr.Tensor, optim_group: dict) -> None: ... param.clamp_(min=-1.0, max=1.0) # note: clamp occurs in-place
>>> x = tr.tensor([-10., 1.], requires_grad=True) >>> optim = ClampedParamOptim([x], lr=0.1) # InnerOpt=SGD by default >>> x.backward(gradient=tr.tensor([-1., 1.])) >>> optim.step() # parameters updated via SGD.step() and then clamped >>> x tensor([-1.0000, 0.9000], requires_grad=True)
Note that this is a particularly simple function, which acts elementwise on each parameter, and thus does not require us to include
param_ndimin the optimizer’s param-groups.
- _pre_step_transform_(param, optim_group)[source]#
 Applies an in-place transform on each parameter in the given param group before that parameter has been updated via
InnerOpt.step.This defaults to a no-op.
- Parameters:
 - paramtorch.Tensor, shape-(N, d0, …)
 The parameter to be modified in-place.
paramandparam.gradwill have been reshaped to have a shape-(N, d0, ...)where(d0, ...)containsparam_ndimentries.- optim_groupDict[str, Any]
 The parameter group associated with
param.
Notes
This transform should always be designed to broadcast over the leading dimension of the tensor being modified. That is, each parameter/gradient should be assumed to have the shape-
(N, d0, ...)and the transformation should be applied - in-place - to each shape-(d0, ...)sub-tensor.Prior to calling
_pre_step_transform_,ParamTransformingOptimizerwill temporarily reshape each parameter and its gradient to have the appropriate shape – in accordance with the value specified forparam_ndim– such that the shape-(d0, ...)tensor containsparam_ndimentries.In the case where
param_ndim=0, the transformation will be applied to a shape-(T, 1)tensor, whereTcorresponds to the total number of elements in the tensor.
- _post_step_transform_(param, optim_group)[source]#
 Applies an in-place transform on each parameter in the given param group after that parameter has been updated via
InnerOpt.step.This defaults to a no-op.
- Parameters:
 - paramtorch.Tensor, shape-(N, d0, …)
 The parameter to be modified in-place.
paramandparam.gradwill have been reshaped to have a shape-(N, d0, ...)where(d0, ...)containsparam_ndimentries.- optim_groupDict[str, Any]
 The parameter group associated with
param.
Notes
This transform should always be designed to broadcast over the leading dimension of the tensor being modified. That is, each parameter/gradient should be assumed to have the shape-(N, d0, …) and the transformation should be applied - in-place - to each shape-
(d0, ...)sub-tensor.Prior to calling
_post_step_transform_,ParamTransformingOptimizerwill temporarily reshape each parameter and its gradient to have the appropriate shape – in accordance with the value specified forparam_ndim– such that the shape-(d0, ...)tensor containsparam_ndimentries.In the case where
param_ndim=0, the transformation will be applied to a shape-(T, 1)tensor, whereTcorresponds to the total number of elements in the tensor.
- _apply_pre_step_transform_()[source]#
 Update each parameter in-place by calling
_pre_step_transform_on the parameter.This is called automatically by
step()beforeInnerOpt.step()has been called.
- _apply_post_step_transform_()[source]#
 Update each parameter in-place by calling
_post_step_transform_on the parameter.This is called automatically by
step()afterInnerOpt.step()has been called.
Methods
__init__([params, InnerOpt, param_ndim, ...])- Parameters:
 
_pre_step_transform_(param, optim_group)Applies an in-place transform on each parameter in the given param group before that parameter has been updated via
InnerOpt.step._post_step_transform_(param, optim_group)Applies an in-place transform on each parameter in the given param group after that parameter has been updated via
InnerOpt.step.Update each parameter in-place by calling
_pre_step_transform_on the parameter.Update each parameter in-place by calling
_post_step_transform_on the parameter.