rai_toolbox.optim.ParamTransformingOptimizer#
- class rai_toolbox.optim.ParamTransformingOptimizer(params=None, InnerOpt=<class 'torch.optim.sgd.SGD'>, *, param_ndim=-1, grad_scale=1.0, grad_bias=0.0, defaults=None, **inner_opt_kwargs)[source]#
An optimizer that performs an in-place transformation to each parameter, both before and after performing the gradient-based update on each parameter via
InnerOptim.step
:_pre_step_transform_(param) param = InnerOptim.step(param, ...) _post_step_transform_(param)
Note that
_pre_step_transform_
and_post_step_transform_
can be used to update a parameter and/or its gradient. Also, this optimizer exposesparam_ndim
as a means of controlling how these transforms broadcast (if at all) over any given tensor.See also
Notes
ParamTransformingOptimizer
mirrors state withInnerOpt
so that theirparam_groups
,defaults
, andstate
are always in sync.ParamTransformingOptimizer
is designed to be combined with other, standard gradient-based optimizers (e.g., Adam) via composition, rather than through inheritance. I.e.,ParamTransformingOptimizer(InnerOpt=<...>)
will apply_pre_step_transform_
on a parameter, and then useInnerOpt.step(...)
to update said parameter, and finally will apply_post_step_transform_
to the parameter.If a closure is supplied to the
step(...)
method, then the_pre_step_transform_
is applied after the closure call and prior to the parameter steps.Methods
_pre_step_transform_
(param, optim_group)Applies an in-place transform on each parameter in the given param group before that parameter has been updated via
InnerOpt.step
._post_step_transform_
(param, optim_group)Applies an in-place transform on each parameter in the given param group after that parameter has been updated via
InnerOpt.step
.project
- __init__(params=None, InnerOpt=<class 'torch.optim.sgd.SGD'>, *, param_ndim=-1, grad_scale=1.0, grad_bias=0.0, defaults=None, **inner_opt_kwargs)[source]#
- Parameters:
- paramsSequence[Tensor] | Iterable[ParamGroup]
Iterable of parameters to optimize or dicts defining parameter groups
- InnerOptType[Optimizer] | Partial[Optimizer], optional (default=`torch.nn.optim.SGD`)
The optimizer that updates the parameters after their gradients have been transformed.
- param_ndimint | None, optional (default=-1)
Determines how a parameter and its gradient are temporarily reshaped prior to being passed to both
_pre_step_transform_
and_post_step_transform_
. By default, the transformation broadcasts over the tensor’s first dimension in a batch-like style.A positive number determines the dimensionality of the tensor that the transformation will act on.
A negative number indicates the ‘offset’ from the dimensionality of the tensor (see “Notes” for examples).
None
means that the transformation will be applied directly to the tensor without any broadcasting.
See “Notes” for more details.
- grad_scalefloat, optional (default=1.0)
Multiplies each gradient in-place after the pre-step transformation is performed. This can be specified per param-group.
- grad_biasfloat, optional (default=0.0)
Added to each gradient in-place after the pre-step transformation is performed. This can be specified per param-group.
- defaultsOptional[Dict[str, Any]]
Specifies default parameters for all parameter groups.
- **inner_opt_kwargsAny
Named arguments used to initialize
InnerOpt
.
Notes
Additional Explanation of `param_ndim`
Consider a parameter of shape
(d0, d1, d2, d4)
.If
param_ndim=0
, then the parameter and its gradient will be temporarily reshaped to a shape-(d0 * d1 * d2 * d3, 1)
so that the transformation will be applied elementwise to the tensor.If
param_ndim=1
(orparam_ndim=-3
), then the parameter and its gradient will be temporarily reshaped to a shape-(d0 * d1 * d2, d3)
so that the transformation will be broadcast over each shape-(d3,)
sub-tensor.If
param_ndim=2
(orparam_ndim=-2
), then the parameter and its gradient will be temporarily reshaped to a shape-(d0 * d1, d2, d3)
so that the transformation will be broadcast over each shape-(d2, d3)
sub-tensor.If
param_ndim=3
(orparam_ndim=-1
), then the parameter and its gradient will be temporarily reshaped to a shape-(d0, d1, d2, d3)
so that the transformation will be broadcast over each shape-(d1, d2, d3)
sub-tensor.If
param_ndim=4
(orparam_ndim=None
), then the parameter and its gradient will be temporarily reshaped to a shape-(1, d0, d1, d2, d3)
so that the transformation will be applied to the shape-(d0, d1, d2, d3)
tensor without broadcasting.Examples
Creating a gradient-transforming optimizer
Let’s create a gradient-transforming optimizer that replaces the gradient of each parameter with the elementwise sign of the gradient (\(\pm 1\)) prior to performing the step of the inner optimizer:
>>> import torch as tr >>> from rai_toolbox.optim import ParamTransformingOptimizer >>> class SignedGradientOptim(ParamTransformingOptimizer): ... ... def _pre_step_transform_(self, param: tr.Tensor, **_kwds) -> None: ... if param.grad is None: ... return ... tr.sign(param.grad, out=param.grad) # operates in-place
Now we’ll use this optimizer – with
torch.optim.AdamW
providing the actual parameter-update functionality – to update the parameter.>>> x = tr.tensor([-10.0, 10.0], requires_grad=True) >>> optim = SignedGradientOptim([x], InnerOpt=tr.optim.AdamW, lr=0.1)
Using
x
in a calculation and compute an associated gradient for it:>>> (10_000 * x).sum().backward()
Updating
x
using our grad-sign + AdamW optimizer:>>> optim.step() >>> x tensor([-10.9000, 8.9000], requires_grad=True)
This was a simple optimizer which did not involve any broadcasting in the gradient transformation; the next example will involve broadcasting.
Controlling the gradient transformation with param_ndim
To understand the role of
param_ndim
let’s design an optimizer that normalizes a parameter’s gradient by its max value – along some user-specified dimension – prior to performing the gradient-based update to its parameter.>>> class MaxNormedGradientOptim(ParamTransformingOptimizer): ... ... def _pre_step_transform_(self, param: tr.Tensor, **_kwds) -> None: ... if param.grad is None: ... return ... ... g = param.grad.flatten(1) # (N, d1, ..., dm) -> (N, d1 * ... * dm) ... max_norms = tr.max(g, dim=1).values ... max_norms = max_norms.view(-1, *([1] * (param.ndim - 1))) # reshape to have dimenionality-m ... param.grad /= tr.clamp(max_norms, 1e-20, None) # clamp to prevent div by 0
Note that we design
_pre_step_transform_
to operate in-place on the gradient and that we treat the gradient as if it has a shape(N, d1, ..., dm)
, where we want to compute the max over each of theN
sub-tensors of shape-(d1, ..., dm)
.Critically, we did not use
param_ndim
at all in this method;ParamTransformingOptimizer
assumes that we designed this method to broadcast in a batch-style, as we did, and it automatically leveragesparam_ndim
to reshape the parameter and its gradient appropriately prior to calling_pre_step_transform_
.Now we will create a shape-
(2, 2)
parameter to see howMaxNormedGradientOptim
can compute the max-norm over various dimensions of the parameter. Let’s print out the transformed gradient when we use each ofparam_ndim
:0
,1
, or2
.>>> x = tr.tensor([[1.0, 2.0], ... [20.0, 10.0]], requires_grad=True) >>> for param_ndim in [0, 1, 2]: ... optim = MaxNormedGradientOptim([x], param_ndim=param_ndim, InnerOpt=tr.optim.SGD, lr=0.0) ... ... loss = (x * x).sum() ... loss.backward() ... optim.step() ... print(f"param_ndim: {param_ndim}, normed grad:\n{x.grad}\n..") ... optim.zero_grad() param_ndim: 0, normed grad: tensor([[1., 1.], [1., 1.]]) .. param_ndim: 1, normed grad: tensor([[0.5000, 1.0000], [1.0000, 0.5000]]) .. param_ndim: 2, normed grad: tensor([[0.0500, 0.1000], [1.0000, 0.5000]])
See that
param_ndim=0
applies the max-norm elementwise, whereasparam_ndim=1
applied the max-norm to each 1D row of the gradient, andparam_ndim=2
applies the max-norm over the entire 2D gradient.Creating a parameter-constraining optimizer
Let’s create an optimizer that clamps each parameter’s values so that they all fall within
[-1, 1]
after performing it’s gradient-based step on the parameter.>>> import torch as tr >>> from rai_toolbox.optim import ParamTransformingOptimizer >>> class ClampedParamOptim(ParamTransformingOptimizer): ... def _post_step_transform_(self, param: tr.Tensor, optim_group: dict) -> None: ... param.clamp_(min=-1.0, max=1.0) # note: clamp occurs in-place
>>> x = tr.tensor([-10., 1.], requires_grad=True) >>> optim = ClampedParamOptim([x], lr=0.1) # InnerOpt=SGD by default >>> x.backward(gradient=tr.tensor([-1., 1.])) >>> optim.step() # parameters updated via SGD.step() and then clamped >>> x tensor([-1.0000, 0.9000], requires_grad=True)
Note that this is a particularly simple function, which acts elementwise on each parameter, and thus does not require us to include
param_ndim
in the optimizer’s param-groups.
- _pre_step_transform_(param, optim_group)[source]#
Applies an in-place transform on each parameter in the given param group before that parameter has been updated via
InnerOpt.step
.This defaults to a no-op.
- Parameters:
- paramtorch.Tensor, shape-(N, d0, …)
The parameter to be modified in-place.
param
andparam.grad
will have been reshaped to have a shape-(N, d0, ...)
where(d0, ...)
containsparam_ndim
entries.- optim_groupDict[str, Any]
The parameter group associated with
param
.
Notes
This transform should always be designed to broadcast over the leading dimension of the tensor being modified. That is, each parameter/gradient should be assumed to have the shape-
(N, d0, ...)
and the transformation should be applied - in-place - to each shape-(d0, ...)
sub-tensor.Prior to calling
_pre_step_transform_
,ParamTransformingOptimizer
will temporarily reshape each parameter and its gradient to have the appropriate shape – in accordance with the value specified forparam_ndim
– such that the shape-(d0, ...)
tensor containsparam_ndim
entries.In the case where
param_ndim=0
, the transformation will be applied to a shape-(T, 1)
tensor, whereT
corresponds to the total number of elements in the tensor.
- _post_step_transform_(param, optim_group)[source]#
Applies an in-place transform on each parameter in the given param group after that parameter has been updated via
InnerOpt.step
.This defaults to a no-op.
- Parameters:
- paramtorch.Tensor, shape-(N, d0, …)
The parameter to be modified in-place.
param
andparam.grad
will have been reshaped to have a shape-(N, d0, ...)
where(d0, ...)
containsparam_ndim
entries.- optim_groupDict[str, Any]
The parameter group associated with
param
.
Notes
This transform should always be designed to broadcast over the leading dimension of the tensor being modified. That is, each parameter/gradient should be assumed to have the shape-(N, d0, …) and the transformation should be applied - in-place - to each shape-
(d0, ...)
sub-tensor.Prior to calling
_post_step_transform_
,ParamTransformingOptimizer
will temporarily reshape each parameter and its gradient to have the appropriate shape – in accordance with the value specified forparam_ndim
– such that the shape-(d0, ...)
tensor containsparam_ndim
entries.In the case where
param_ndim=0
, the transformation will be applied to a shape-(T, 1)
tensor, whereT
corresponds to the total number of elements in the tensor.
- _apply_pre_step_transform_()[source]#
Update each parameter in-place by calling
_pre_step_transform_
on the parameter.This is called automatically by
step()
beforeInnerOpt.step()
has been called.
- _apply_post_step_transform_()[source]#
Update each parameter in-place by calling
_post_step_transform_
on the parameter.This is called automatically by
step()
afterInnerOpt.step()
has been called.
Methods
__init__
([params, InnerOpt, param_ndim, ...])- Parameters:
_pre_step_transform_
(param, optim_group)Applies an in-place transform on each parameter in the given param group before that parameter has been updated via
InnerOpt.step
._post_step_transform_
(param, optim_group)Applies an in-place transform on each parameter in the given param group after that parameter has been updated via
InnerOpt.step
.Update each parameter in-place by calling
_pre_step_transform_
on the parameter.Update each parameter in-place by calling
_post_step_transform_
on the parameter.