On Data Perturbations#

This article provides an overview of common data optimization problems, using model-dependent criteria, that are relevant for assessing and enhancing robustness in machine learning models. We present a broad picture of how we have designed the rAI-toolbox to solve these problems, and provide pseudo-code that illustrates how optimizations over data perturbations adhere strictly to core PyTorch APIs; i.e., a perturbation model is defined as a torch.nn.Module, and its optimizer as torch.optim.Optimizer.

Model optimization vs. data optimization#

Standard machine learning model-training frameworks are designed to refine the parameters of the machine learning model (i.e., its architecture and weights), while methods for studying the robustness and explainability of the model naturally involve analyses of and optimizations over the data (i.e., inputs to the model and representations extracted by the model). This optimization over the data space increases the complexity of responsible AI workflows over that of the standard setting.

For example, consider the standard optimization objective for training a model, \(f_\theta\), parameterized by \(\theta\):

\[\min\limits_{\theta \in \Theta} \mathbb{E}_{(x,y)\sim D} [\mathcal{L}(f_\theta(x),y)],\]

where \(x\) and \(y\) represent the data input and corresponding output, respectively, sampled from a data distribution, \(D\), and \(\mathcal{L}\) is the loss function to be minimized. Note that here, the data samples are fixed, and the search is done over the model’s weight space.

In contrast, consider the optimization objective for solving for an adversarial example to fool the model into producing an incorrect output, which is a common practice for assessing the robustness of the model:

\[\max\limits_{\delta \in \Delta} \mathcal{L}(f_\theta(x + \delta),y),\]

where a perturbation, \(\delta\), is optimized to maximize loss against the true output, subject to a constraint set, \(\Delta\). Here, the model parameters are held fixed, and the search is conducted over the data space.

A plethora of approaches for solving this objective under different loss configurations and constraint sets have been proposed by the Robust AI research community. One popular approach is to use iterative project gradient descent (PGD) on the negative cross-entropy loss, with an \(L^p\)-ball of radius \(\epsilon\) and \(p=1,2,\) or \(\infty\) as the constraint set.

Solvers for data perturbations#

A variety of other tools exist that implement large libraries of techniques proposed by the research community (such as PGD) in a framework-agnostic API. Their perturbation solvers are often written from scratch and look something like this:

Notional perturbation solver#
def perturbation_solver(
    model: callable,
    data: Tensor,
    target: Tensor,
    lr: float,
    steps: int,
    criterion: callable,
    initialize_fn: callable,
    project_fn: callable,
) -> Tensor:
    # initialize perturbation parameter
    delta = initialize_fn(data)

    for _ in range(steps):
        # perturbation applied manually / in-line
        perturbed_data = data + delta

        # calculate loss
        loss = criterion(model(perturbed_data), target)

        # optimize
        grad = autograd(loss, delta)
        with no_grad():
            delta = delta + lr * grad  # perturbation updated manually / in-line
            delta = project_fn(delta)
    return delta

Note that the code for applying the perturbation and taking an optimization step is embedded within the for loop of the solver. If a user wanted to swap out the optimizer methodology or use a different perturbation model, one would need to write an entirely new solver.

By adhering to PyTorch APIs, the rAI-toolbox frames the process of solving for a perturbation in the standard workflow for training ML models. I.e., we specify perturbation models, which are responsible for initializing, storing, and applying perturbations, and perturbation optimizers, which update the perturbations based on their gradients while also applying normalizations and constraints to the perturbations and their gradients.

rAI-toolbox approach to solving for perturbations#
from torch.nn import Module
from torch.optim import Optimizer

# Implements PyTorch Module API
class PerturbationModel(Module):
   def __init__(self, *args, **kwargs):
      super().__init__()
      # initialize parameters of perturbation model

   def forward(self, x):
      perturbed_data = # use model's parameters to perturb data
      return perturbed_data

# Implements PyTorch Optimizer API
class PerturbationOptimizer(Optimizer):
   def _pre_step_(self, param, **kwds): # e.g., perform gradient-normalization
   def _step_(self, param, **kwds): # perform gradient-based update on parameter
   def _post_step_(self, param, grad): # e.g., project updated parameter into constraint set

   def step(self):
      for param in self.all_params:
         self._pre_step_(param)
         self._step_(param, param.grad)
         self._post_step_(param)

Having framed the perturbation process as a torch.nn.Module, whose parameters (e.g., the perturbation itself) are optimized and constrained via the torch.optim.Optimizer API, we can take any standard trainer, e.g.:

A standard PyTorch trainer#
def standard_trainer(model, data, target, optimizer, steps, criterion):
   for _ in range(steps):
      # calculate loss
      loss = criterion(model(data), target)

      # optimize
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

and solve for the optimal perturbation via:

Solving for perturbations using a standard PyTorch trainer#
from torch.nn import Sequential
from rai_toolbox import freeze

pert_model = PerturbationModel(...)
optim = PerturbationOptimizer(pert_model.parameters(), ...)

ml_model = MyNeuralNetwork(...)

# model(data) -> ml_model(pert_model(data))
model = Sequential(pert_model, freeze(ml_model.eval()))

# solve for perturbations
standard_trainer(model, optimizer=optim, data=..., target=..., steps=..., criterion=...)

# solved perturbations are stored in `pert_model`

We can then use pert_model to apply these optimized perturbations to new data

Peturbing data#
data = # some tensor of data
pert_data = pert_model(data)  # applies optimized peturbation to `data`

The abstractions provided by a perturbation model and a perturbation optimizer yields a natural delegation of functionality, which makes it easy for us to modify the critical implementation details of this problem. E.g., One can modify the optimizer to adjust how the perturbation is constrained, or how its gradient is normalized; the perturbation model controls the random initialization of the perturbation and how the perturbation broadcasts over a batch of data. None of these adjustments require any modification to the process by which we actually solve for the perturbations; i.e., we can continue to use standard_trainer or any gradient-based solver.

ParamTransformingOptimizer, AdditivePerturbation, and gradient_ascent represent concrete implementations of this design; the reader is advised to consult their reference documentation for further insights into the rAI-toolbox’s approach to solving for data perturbations.