rai_toolbox.perturbations.random_restart#
- rai_toolbox.perturbations.random_restart(solver, repeats)[source]#
Executes a solver function multiple times, saving out the best perturbations.
- Parameters:
- solverCallable[…, Tuple[Tensor, Tensor]]
The solver whose execution will be repeated.
- repeatsint
The number of times to run
solver
- Returns:
- random_restart_fnCallable[…, Tuple[Tensor, Tensor]]
Wrapped function that will execute
solver
repeats
times.
Examples
Let’s perturb two data points,
x1=-1.0
andx2=2.0
, to maximizeL(δ; x) = |x + δ|
w.r.tδ
. Our perturbation will randomly initializeδ1
andδ2
and we will re-run the solver three times – retaining the best perturbation ofx1
andx2
respectively.>>> from functools import partial >>> import torch as tr >>> from torch.optim import SGD >>> from rai_toolbox.perturbations.init import uniform_like_l1_n_ball_ >>> from rai_toolbox.perturbations import AdditivePerturbation, gradient_ascent, random_restart >>> >>> def verbose_abs_diff(model_out, target): ... # used to print out loss at each solver step (for purpose of example) ... out = (model_out - target).abs() ... print(out) ... return out
Configuring a peturbation model to randomly initialize the perturbations.
>>> RNG = tr.Generator().manual_seed(0) >>> PertModel = partial(AdditivePerturbation, init_fn=uniform_like_l1_n_ball_, generator=RNG)
Creating and running repeating solver.
>>> gradient_ascent_with_restart = random_restart(gradient_ascent, 3)
>>> perturbed_data, losses = gradient_ascent_with_restart( ... perturbation_model=PertModel, ... data=[-1.0, 2.0], ... target=0.0, ... model=lambda data: data, ... criterion=verbose_abs_diff, ... optimizer=SGD, ... lr=0.1, ... steps=1, ... ) tensor([0.5037, 2.7682], grad_fn=<AbsBackward0>) tensor([0.6037, 2.8682]) tensor([0.9115, 2.1320], grad_fn=<AbsBackward0>) tensor([1.0115, 2.2320]) tensor([0.6926, 2.6341], grad_fn=<AbsBackward0>) tensor([0.7926, 2.7341])
See that for
x1
the highest loss is1.0115
, and forx2
it is2.8682
. This should be reflectedlosses
andperturbed_data
that were retained across the restarts.>>> losses tensor([1.0115, 2.8682]) >>> perturbed_data tensor([-1.0115, 2.8682])