rai_toolbox.perturbations.random_restart#

rai_toolbox.perturbations.random_restart(solver, repeats)[source]#

Executes a solver function multiple times, saving out the best perturbations.

Parameters:
solverCallable[…, Tuple[Tensor, Tensor]]

The solver whose execution will be repeated.

repeatsint

The number of times to run solver

Returns:
random_restart_fnCallable[…, Tuple[Tensor, Tensor]]

Wrapped function that will execute solver repeats times.

Examples

Let’s perturb two data points, x1=-1.0 and x2=2.0, to maximize L(δ; x) = |x + δ| w.r.t δ. Our perturbation will randomly initialize δ1 and δ2 and we will re-run the solver three times – retaining the best perturbation of x1 and x2 respectively.

>>> from functools import partial
>>> import torch as tr
>>> from torch.optim import SGD
>>> from rai_toolbox.perturbations.init import uniform_like_l1_n_ball_
>>> from rai_toolbox.perturbations import AdditivePerturbation, gradient_ascent, random_restart
>>>
>>> def verbose_abs_diff(model_out, target):
...     # used to print out loss at each solver step (for purpose of example)
...     out =  (model_out - target).abs()
...     print(out)
...     return out

Configuring a peturbation model to randomly initialize the perturbations.

>>> RNG = tr.Generator().manual_seed(0)
>>> PertModel = partial(AdditivePerturbation, init_fn=uniform_like_l1_n_ball_, generator=RNG)

Creating and running repeating solver.

>>> gradient_ascent_with_restart = random_restart(gradient_ascent, 3)
>>> perturbed_data, losses = gradient_ascent_with_restart(
...    perturbation_model=PertModel,
...    data=[-1.0, 2.0],
...    target=0.0,
...    model=lambda data: data,
...    criterion=verbose_abs_diff,
...    optimizer=SGD,
...    lr=0.1,
...    steps=1,
... )
tensor([0.5037, 2.7682], grad_fn=<AbsBackward0>)
tensor([0.6037, 2.8682])
tensor([0.9115, 2.1320], grad_fn=<AbsBackward0>)
tensor([1.0115, 2.2320])
tensor([0.6926, 2.6341], grad_fn=<AbsBackward0>)
tensor([0.7926, 2.7341])

See that for x1 the highest loss is 1.0115, and for x2 it is 2.8682. This should be reflected losses and perturbed_data that were retained across the restarts.

>>> losses
tensor([1.0115, 2.8682])
>>> perturbed_data
tensor([-1.0115,  2.8682])