rai_toolbox.optim.L2ProjectedOptim#
- class rai_toolbox.optim.L2ProjectedOptim(params, InnerOpt=<class 'torch.optim.sgd.SGD'>, *, epsilon=<required parameter>, param_ndim=-1, grad_scale=1.0, grad_bias=0.0, defaults=None, div_by_zero_eps=1.1754943508222875e-38, **inner_opt_kwargs)[source]#
A gradient-tranforming optimizer that constrains the updated parameters to lie within an \(\epsilon\)-sized ball in \(L^2\) space centered on the origin.
A step with this optimizer normalizes the gradient by its \(L^2\)-norm prior to using
InnerOp.step
to update the corresponding parameter. Each parameter is then projected into the constraint set.The transformation/projection is applied to the gradient/parameter in accordance with
param_ndim
.- __init__(params, InnerOpt=<class 'torch.optim.sgd.SGD'>, *, epsilon=<required parameter>, param_ndim=-1, grad_scale=1.0, grad_bias=0.0, defaults=None, div_by_zero_eps=1.1754943508222875e-38, **inner_opt_kwargs)[source]#
- Parameters:
- paramsSequence[Tensor] | Iterable[Mapping[str, Any]]
Iterable of parameters or dicts defining parameter groups.
- InnerOptType[Optimizer] | Partial[Optimizer], optional (default=`torch.nn.optim.SGD`)
The optimizer that updates the parameters after their gradients have been transformed.
- epsilonfloat
Specifies the size of the L2-space ball that all parameters will be projected into, post optimization step.
- param_ndimUnion[int, None], optional (default=-1)
Determines how a parameter and its gradient is temporarily reshaped prior to being passed to both
_pre_step_transform_
and_post_step_transform_
. By default,the transformation broadcasts over the tensor’s first dimension in a batch-like style. This can be specified per param-groupA positive number determines the dimensionality of the tensor that the transformation will act on.
A negative number indicates the ‘offset’ from the dimensionality of the tensor (see “Notes” for examples).
None
means that the transformation will be applied directly to the tensor without any broadcasting.
See
ParamTransformingOptimizer
for more details and examples.- grad_scalefloat, optional (default=1.0)
Multiplies each gradient in-place after the in-place transformation is performed. This can be specified per param-group.
- grad_biasfloat, optional (default=0.0)
Added to each gradient in-place after the in-place transformation is performed. This can be specified per param-group.
- defaultsOptional[Dict[str, Any]]
Specifies default parameters for all parameter groups.
- div_by_zero_epsfloat, optional (default=`torch.finfo(torch.float32).tiny`)
A lower bound used to clamp the normalization factor to prevent div-by-zero.
- **inner_opt_kwargsAny
Named arguments used to initialize
InnerOpt
.
Examples
Let’s create an optimizer that normalizes all parameter gradients using their \(L^2\)-norm, and then updates the parameters with a standard SGD-step with a learning rate of
1.0
. After the step, each parameter will be projected into a \(L^2\)-ball of radius0.8
.>>> import torch as tr >>> from rai_toolbox.optim import L2ProjectedOptim
Creating a parameter for our optimizer to update, and our optimizer. We want the norm to be computed over the entire gradient tensor – without broadcasting – so we specify
param_ndim=None
. This also controls the projection behavior.>>> x = tr.tensor([-1.0, 1.0], requires_grad=True) >>> optim = L2ProjectedOptim([x], param_ndim=None, InnerOpt=tr.optim.SGD, lr=1.0, epsilon=0.8)
Performing a simple calculation with
x
and performing backprop to create a gradient.>>> (tr.tensor([2.0, 2.0]) * x).sum().backward() >>> x.grad # the un-normed gradient tensor([2., 2.])
Performing a step with our optimizer transforms the gradient in-place, updates the parameter using
SGD([x], lr=1.0).step()
, and then projects the parameter into the constraint set.>>> optim.step() >>> x.grad # the normalized gradient tensor([0.7071, 0.7071]) >>> x # the updated parameter tensor([-0.7885, 0.1353], requires_grad=True) >>> x.norm(p=2).item() # `x` lies on the L2-ball of radius 0.8 0.800000011920929
Methods
__init__
(params[, InnerOpt, epsilon, ...])- Parameters: