rai_toolbox.optim.FrankWolfe#

class rai_toolbox.optim.FrankWolfe(params, *, lr=2.0, lmo_scaling_factor=1.0, use_default_lr_schedule=True, div_by_zero_eps=1.1754943508222875e-38)[source]#

Implements the Frank-Wolfe minimization algorithm [1].

wk+1=(1lr)wk+lrsk

where sk is the linear minimization oracle (LMO).

It is critical to note that this optimizer assumes that the grad attribute of each parameter has been modified so as to store the negative of the LMO for that parameter, and not the gradient itself.

References

__init__(params, *, lr=2.0, lmo_scaling_factor=1.0, use_default_lr_schedule=True, div_by_zero_eps=1.1754943508222875e-38)[source]#
Parameters:
paramsIterable

Iterable of tensor parameters to optimize or dicts defining parameter groups.

lrfloat, optional (default=2.0)

Indicates the weight with which the LMO contributes to the parameter update. See use_default_lr_schedule for additional details. If use_default_lr_schedule=False then lr must be be in the domain [0, 1].

lmo_scaling_factorfloat, optional (default=1.0)

A scaling factor applied to sk prior to each step.

use_default_lr_schedulebool, optional (default=True)

If True, then the per-parameter “learning rate” is scaled by lr^=lr/(lr+k) where k is the update index for that parameter.

div_by_zero_epsfloat, optional (default=`torch.finfo(torch.float32).tiny`)

Prevents div-by-zero error in learning rate schedule.

Methods

__init__(params, *[, lr, ...])

Parameters: