rai_toolbox.optim.FrankWolfe#
- class rai_toolbox.optim.FrankWolfe(params, *, lr=2.0, lmo_scaling_factor=1.0, use_default_lr_schedule=True, div_by_zero_eps=1.1754943508222875e-38)[source]#
Implements the Frank-Wolfe minimization algorithm [1].
where
is the linear minimization oracle (LMO).It is critical to note that this optimizer assumes that the
grad
attribute of each parameter has been modified so as to store the negative of the LMO for that parameter, and not the gradient itself.References
- __init__(params, *, lr=2.0, lmo_scaling_factor=1.0, use_default_lr_schedule=True, div_by_zero_eps=1.1754943508222875e-38)[source]#
- Parameters:
- paramsIterable
Iterable of tensor parameters to optimize or dicts defining parameter groups.
- lrfloat, optional (default=2.0)
Indicates the weight with which the LMO contributes to the parameter update. See
use_default_lr_schedule
for additional details. Ifuse_default_lr_schedule=False
thenlr
must be be in the domain [0, 1].- lmo_scaling_factorfloat, optional (default=1.0)
A scaling factor applied to
prior to each step.- use_default_lr_schedulebool, optional (default=True)
If
True
, then the per-parameter “learning rate” is scaled by where k is the update index for that parameter.- div_by_zero_epsfloat, optional (default=`torch.finfo(torch.float32).tiny`)
Prevents div-by-zero error in learning rate schedule.
Methods
__init__
(params, *[, lr, ...])- Parameters: