Source code for rai_toolbox.perturbations.solvers

# Copyright 2023, MASSACHUSETTS INSTITUTE OF TECHNOLOGY
# Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014).
# SPDX-License-Identifier: MIT

import functools
from typing import Any, Callable, List, Optional, Tuple, Union

import torch as tr
from torch import Tensor
from torch.nn import CrossEntropyLoss, Module
from torch.optim import Optimizer as _TorchOptim

from rai_toolbox import negate
from rai_toolbox._typing import (
    ArrayLike,
    InstantiatesTo,
    Optimizer,
    OptimizerType,
    Partial,
    instantiates_to,
)
from rai_toolbox._utils import value_check
from rai_toolbox._utils.stateful import evaluating, frozen
from rai_toolbox.perturbations import AdditivePerturbation, PerturbationModel


[docs] def gradient_ascent( *, model: Callable[[Tensor], Tensor], data: ArrayLike, target: ArrayLike, optimizer: Union[Optimizer, OptimizerType, Partial[Optimizer]], steps: int, perturbation_model: Union[ PerturbationModel, InstantiatesTo[PerturbationModel] ] = AdditivePerturbation, targeted: bool = False, use_best: bool = False, criterion: Optional[Callable[[Tensor, Tensor], Tensor]] = None, reduction_fn: Callable[[Tensor], Tensor] = tr.sum, **optim_kwargs: Any, ) -> Tuple[Tensor, Tensor]: """Solve for a set of perturbations for a given set of data and a model, and then apply those perturbations to the data. This performs, for `steps` iterations, the following optimization:: optim = optimizer(perturbation_model.parameters) pert_data = perturbation_model(data) loss = criterion(model(pert_data), target) loss = (1 if targeted else -1) * loss # default: targeted=False optim.step() Note that, by default, this perturbs the data away from `target` (i.e., this performs gradient *ascent*), given a standard loss function that seeks to minimize the difference between the model's output and the target. See `targeted` to toggle this behavior. Parameters ---------- model : Callable[[Tensor], Tensor] Differentiable function that processes the (perturbed) data prior to computing the loss. If `model` is a `torch.nn.Module`, then its weights will be frozen and it will be set to eval mode during the perturbation-solve phase. data : ArrayLike, shape-(N, ...) The input data to perturb. target : ArrayLike, shape-(N, ...) If `targeted==False` (default), then this is the target to perturb away from. If `targeted==True`, then this is the target to perturb toward. optimizer : Optimizer | Type[Optimizer] | Partial[Optimizer] The optimizer to use for updating the perturbation model. If `optimizer` is uninstantiated, it will be instantiated as `optimizer(perturbation_model.parameters(), **optim_kwargs)` steps : int Number of projected gradient steps. perturbation_model : PerturbationModel | Type[PerturbationModel], optional (default=AdditivePerturbation) A `torch.nn.Module` whose parameters are updated by the solver. Its forward-pass applies the perturbation to the data. Default is `AdditivePerturbation`, which simply adds the perturbation to the data. The perturbation model should not modify the data in-place. If `perturbation_model` is a type, then it will be instantiated as `perturbation_model(data)`. criterion : Optional[Callable[[Tensor, Tensor], Tensor]] The criterion to use for calculating the loss **per-datum**. I.e., for a shape-(N, ...) batch of data, `criterion` should return a shape-(N,) tensor of loss values – one for each datum in the batch. If `None`, then `CrossEntropyLoss(reduction=None)` is used. targeted : bool (default: False) If `True`, then perturb towards the defined `target`, otherwise move away from `target`. Note: Default (`targeted=False`) implements gradient *ascent*. To perform gradient *descent*, set `targeted=True`. use_best : bool (default: True) Whether to only report the best perturbation over all steps. Note: Requires criterion to output a loss per sample, e.g., set `reduction="none"`. reduction_fn : Callable[[Tensor], Tensor], optional (default=torch.sum) Used to reduce the shape-(N,) per-datum loss to a scalar. This should be set to `torch.mean` when solving for a "universal" perturbation. **optim_kwargs : Any Keyword arguments passed to `optimizer` when it is instatiated. Returns ------- xadv, losses : tuple[Tensor, Tensor], shape-(N, ...), shape-(N, ...) The perturbed data, if `use_best==True` then this is the best perturbation based on the loss across all steps. The loss for each perturbed data point, if `use_best==True` then this is the best loss across all steps. Notes ----- `model` is automatically set to eval-mode and its parameters are set to `requires_grad=False` within the context of this function. Examples -------- Let's perturb two data points, `x1=-1.0` and `x2=2.0`, to maximize `L(δ; x) = |x + δ|` w.r.t `δ`. We will use five standard gradient steps, using a learning rate of 0.1. The default perturbation model is simply additive: `x -> x + δ`. This solver is refining `δ1` and `δ2`, whose initial values are 0 by default, to maximize `L(x) = |x|` for `x1` and `x2`, respectively. Thus we should find that our solved perturbations modify our data as: `x-1.0 -> -1.5` and `2.0 -> 2.5`, respectively. >>> from rai_toolbox.perturbations import gradient_ascent >>> from torch.optim import SGD >>> identity_model = lambda data: data >>> abs_diff = lambda model_out, target: (model_out - target).abs() >>> perturbed_data, losses = gradient_ascent( ... data=[-1.0, 2.0], ... target=0.0, ... model=identity_model, ... criterion=abs_diff, ... optimizer=SGD, ... lr=0.1, ... steps=5, ... ) >>> perturbed_data tensor([-1.5000, 2.5000]) We can instead specify `targeted=True` and perform gradient *descent*. Here, the perturbations we solve for should modify our data as: -1.0 -> -0.5 and 2.0 -> 1.5, respectively. >>> perturbed_data, losses = gradient_ascent( ... data=[-1.0, 2.0], ... target=0.0, ... model=identity_model, ... criterion=abs_diff, ... optimizer=SGD, ... lr=0.1, ... steps=5, ... targeted=True, ... ) >>> perturbed_data tensor([-0.5000, 1.5000]) **Accessing the perturbations** To gain direct access to the solved perturbations, we can provide our own perturbation model to the solver. Let's solve the same optimization problem, but provide our own instance of `AdditivePerturbation` >>> from rai_toolbox.perturbations import AdditivePerturbation >>> pert_model = AdditivePerturbation(data_or_shape=(2,)) >>> perturbed_data, losses = gradient_ascent( ... perturbation_model=pert_model, ... data=[-1.0, 2.0], ... target=0.0, ... model=identity_model, ... criterion=abs_diff, ... optimizer=SGD, ... lr=0.1, ... steps=5, ... ) >>> perturbed_data tensor([-1.5000, 2.5000]) Now we can access the values that were solved for `δ1` and `δ2`. >>> pert_model.delta Parameter containing: tensor([-0.5000, 0.5000], requires_grad=True) """ data = tr.as_tensor(data) target = tr.as_tensor(target) if not data.is_leaf: data = data.detach() if not target.is_leaf: target = target.detach() # Initialize best_loss = None best_x = None if criterion is None: criterion = CrossEntropyLoss(reduction="none") if not targeted: # maximize the objective function criterion = negate(criterion) if instantiates_to(perturbation_model, PerturbationModel): pmodel = perturbation_model(data) else: if not isinstance(perturbation_model, PerturbationModel): raise TypeError( f"`perturbation_model` must be satisfy the `PerturbationModel`" f" protocol, got: {perturbation_model}" ) pmodel = perturbation_model if instantiates_to(optimizer, _TorchOptim): optim = optimizer(pmodel.parameters(), **optim_kwargs) else: if not isinstance(optimizer, _TorchOptim): raise TypeError( f"`optimizer` must be an instance of Optimizer or must instantiate to " f"Optimizer; got: {optimizer}" ) if instantiates_to(perturbation_model, PerturbationModel): raise TypeError( "An initialized optimizer can only be passed to the solver in " "combination with an initialized perturbation model." ) if optim_kwargs: raise TypeError( "**optim_kwargs were specified, but the provided `optimizer` has " "already been instaniated" ) optim = optimizer to_freeze: List[Any] = [data, target] if isinstance(model, Module): to_freeze.append(model) # don't pass non nn.Module to frozen/eval packed_model = (model,) if isinstance(model, Module) else () # Projected Gradient Descent with frozen(*to_freeze), evaluating(*packed_model), tr.enable_grad(): for _ in range(steps): # Calculate the gradient of loss xadv = pmodel(data) logits = model(xadv) losses = criterion(logits, target) loss = reduction_fn(losses) # Update the perturbation optim.zero_grad(set_to_none=True) loss.backward() optim.step() if use_best: if ( (losses.ndim == 0 and data.ndim > 0) or losses.ndim > data.ndim or losses.shape != data.shape[: losses.ndim] ): raise ValueError( f"`use_best=True` but `criterion` does not output a per-datum-loss. " f"I.e. `criterion` returned a tensor of shape-{tuple(losses.shape)} for a " f"batch of shape-{tuple(data.shape)}. Expected a tensor of " f"shape-{(len(data),)} or greater." ) best_loss, best_x = _replace_best(losses, best_loss, xadv, best_x) # free up memory optim.zero_grad(set_to_none=True) # Final evaluation with tr.no_grad(): xadv = pmodel(data) logits = model(xadv) losses = criterion(logits, target) if use_best: # we negate the loss when `targeted=True` so min-loss is always best losses, xadv = _replace_best(losses, best_loss, xadv, best_x, min=True) if not targeted: # The returned loss is always relative to the original criterion. # E.g. an adversarial perturbation with a "low" negated loss will yield # a high loss. losses = -1 * losses return xadv.detach(), losses.detach()
[docs] def random_restart( solver: Callable[..., Tuple[Tensor, Tensor]], repeats: int, ) -> Callable[..., Tuple[Tensor, Tensor]]: """Executes a solver function multiple times, saving out the best perturbations. Parameters ---------- solver : Callable[..., Tuple[Tensor, Tensor]] The solver whose execution will be repeated. repeats : int The number of times to run `solver` Returns ------- random_restart_fn : Callable[..., Tuple[Tensor, Tensor]] Wrapped function that will execute `solver` `repeats` times. Examples -------- Let's perturb two data points, `x1=-1.0` and `x2=2.0`, to maximize `L(δ; x) = |x + δ|` w.r.t `δ`. Our perturbation will randomly initialize `δ1` and `δ2` and we will re-run the solver three times – retaining the best perturbation of `x1` and `x2` respectively. >>> from functools import partial >>> import torch as tr >>> from torch.optim import SGD >>> from rai_toolbox.perturbations.init import uniform_like_l1_n_ball_ >>> from rai_toolbox.perturbations import AdditivePerturbation, gradient_ascent, random_restart >>> >>> def verbose_abs_diff(model_out, target): ... # used to print out loss at each solver step (for purpose of example) ... out = (model_out - target).abs() ... print(out) ... return out Configuring a peturbation model to randomly initialize the perturbations. >>> RNG = tr.Generator().manual_seed(0) >>> PertModel = partial(AdditivePerturbation, init_fn=uniform_like_l1_n_ball_, generator=RNG) Creating and running repeating solver. >>> gradient_ascent_with_restart = random_restart(gradient_ascent, 3) >>> perturbed_data, losses = gradient_ascent_with_restart( ... perturbation_model=PertModel, ... data=[-1.0, 2.0], ... target=0.0, ... model=lambda data: data, ... criterion=verbose_abs_diff, ... optimizer=SGD, ... lr=0.1, ... steps=1, ... ) tensor([0.5037, 2.7682], grad_fn=<AbsBackward0>) tensor([0.6037, 2.8682]) tensor([0.9115, 2.1320], grad_fn=<AbsBackward0>) tensor([1.0115, 2.2320]) tensor([0.6926, 2.6341], grad_fn=<AbsBackward0>) tensor([0.7926, 2.7341]) See that for `x1` the highest loss is `1.0115`, and for `x2` it is `2.8682`. This should be reflected `losses` and `perturbed_data` that were retained across the restarts. >>> losses tensor([1.0115, 2.8682]) >>> perturbed_data tensor([-1.0115, 2.8682]) """ value_check("repeats", repeats, min_=1, incl_min=True) @functools.wraps(solver) def random_restart_fn(*args, **kwargs) -> Tuple[Tensor, Tensor]: targeted = kwargs.get("targeted", False) best_x = None best_loss = None for _ in range(repeats): # run the attack xadv, losses = solver(*args, **kwargs) # Save best loss for each data point best_loss, best_x = _replace_best( loss=losses, best_loss=best_loss, data=xadv, best_data=best_x, min=targeted, ) assert best_x is not None assert best_loss is not None return best_x, best_loss return random_restart_fn
# A function that updates the best loss and best input def _replace_best( loss: Tensor, best_loss: Optional[Tensor], data: Tensor, best_data: Optional[Tensor], min: bool = True, ) -> Tuple[Tensor, Tensor]: """Returns the data with the smallest (or largest) loss Parameters ---------- loss : Tensor, shape-(N, ...) N: batch size best_loss : Optional[Tensor], shape-(N, ...) N: batch size data : Tensor, shape-(N, ...) N: batch size best_data : Optional[Tensor], shape-(N, ...) N: batch size min : bool (default: True) Whether best is minimum (True) or maximum (False) Returns ------- best_loss, best_data : Tuple[Tensor, Tensor] """ if best_loss is None: best_data = data best_loss = loss else: assert best_data is not None if min: replace = loss < best_loss else: replace = loss > best_loss best_data[replace] = data[replace] best_loss[replace] = loss[replace] return best_loss, best_data