Source code for rai_toolbox.perturbations.models

# Copyright 2023, MASSACHUSETTS INSTITUTE OF TECHNOLOGY
# Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014).
# SPDX-License-Identifier: MIT

from abc import abstractmethod
from typing import Any, Callable, Iterator, Optional, Sequence, Union

import torch as tr
from torch import Tensor, nn
from torch.nn.parameter import Parameter
from typing_extensions import Protocol, runtime_checkable


[docs] @runtime_checkable class PerturbationModel(Protocol): """Protocol for Perturbation Models.""" def __init__(self, *args, **kwargs) -> None: # pragma: no cover ...
[docs] @abstractmethod def __call__(self, data: Tensor) -> Tensor: # pragma: no cover """A perturbation model should take data as the input and output the peturbed data. Parameters ---------- data: Tensor The data to perturb. Returns ------- Tensor The perturbed data with the same shape as `data`. """
def parameters( self, recurse: bool = True ) -> Iterator[Parameter]: # pragma: no cover ...
[docs] class AdditivePerturbation(nn.Module, PerturbationModel): r"""Modifies a piece or batch of data by adding a perturbation: :math:`x \rightarrow x+\delta`. Attributes ---------- delta : torch.Tensor :math:`\delta` is the sole trainable parameter of `AdditivePerturbation`. """
[docs] def __init__( self, data_or_shape: Union[Tensor, Sequence[int]], init_fn: Optional[Callable[[Tensor], None]] = None, *, device: Optional[tr.device] = None, dtype: Optional[tr.dtype] = None, delta_ndim: Optional[int] = None, **init_fn_kwargs: Any, ) -> None: """The init function should support data as the argument to initialize perturbations. Parameters ---------- data_or_shape: Union[Tensor, Tuple[int, ...]] Determines the shape of the perturbation. If a tensor is supplied, its dtype and device are mirrored by the initialized perturbation. This parameter can be modified to control whether the perturbation adds elementwise or broadcast-adds over `x`. init_fn: Optional[Callable[[Tensor], None]] Operates in-place on a zero'd perturbation tensor to determine the final initialization of the perturbation. device: Optional[tr.device] If specified, takes precedent over the device associated with `data_or_shape`. dtype: Optional[tr.dtype] = None If specified, takes precedent over the dtype associated with `data_or_shape`. delta_ndim: Optional[int] = None If a positive number, determines the dimensionality of the perturbation. If a negative number, indicates the 'offset' from the dimensionality of `data_or_shape`. E.g., if `data_or_shape` has a shape (N, C, H, W), and if `delta_ndim=-1`, then the perturbation will have shape (C, H, W), and will be applied in a broadcast fashion. **init_fn_kwargs: Any Keyword arguments passed to `init_fn`. Examples -------- **Basic Additive Perturbations** Let's imagine we have a batch of three shape-`(2, 2)` images (our toy data will be all ones) that we want to perturb. We'll randomly initialize a shape-`(3, 2, 2)` tensor of perturbations to apply additively to the shape-`(3, 2, 2)` batch. >>> import torch as tr >>> from rai_toolbox.perturbations import AdditivePerturbation, uniform_like_l1_n_ball_ >>> data = tr.ones(3, 2, 2) We provide a `generator` argument to control the RNG in `~rai_toolbox. perturbations.uniform_like_l1_n_ball_`. >>> pert_model = AdditivePerturbation( ... data_or_shape=data, ... init_fn=uniform_like_l1_n_ball_, ... generator=tr.Generator().manual_seed(0), # controls RNG of init ... ) Accessing the initialized perturbations. >>> pert_model.delta Parameter containing: tensor([[[0.0885, 0.0436], [0.3642, 0.2720]], . [[0.3074, 0.1827], [0.1440, 0.2624]], . [[0.3489, 0.0528], [0.0539, 0.1767]]], requires_grad=True) Applying the perturbations to a batch of data. >>> pert_data = pert_model(data) >>> pert_data tensor([[[1.0885, 1.0436], [1.3642, 1.2720]], . [[1.3074, 1.1827], [1.1440, 1.2624]], . [[1.3489, 1.0528], [1.0539, 1.1767]]], grad_fn=<AddBackward0>) Involving the perturbed data in a computational graph where auto-diff is performed, then gradients are computed for the perturbations. >>> (pert_data ** 2).sum().backward() >>> pert_model.delta.grad tensor([[[2.1770, 2.0871], [2.7285, 2.5439]], . [[2.6148, 2.3653], [2.2880, 2.5247]], . [[2.6978, 2.1056], [2.1078, 2.3534]]]) **Broadcasted ("Universal") Perturbations** Suppose that we want to use a single shape-`(2, 2)` tensor to perturb each datum in a batch. We can create a perturbation model in a similar manner, but specifying `delta_ndim=-1` indicates that our perturbation should have one fewer dimension than our data; whereas our batch has shape-`(N, 2, 2)`, our perturbation model's parameter will have shape-`(2, 2)` >>> pert_model = AdditivePerturbation( ... data_or_shape=data, ... delta_ndim=-1, ... init_fn=uniform_like_l1_n_ball_, ... generator=tr.Generator().manual_seed(1), # controls RNG of init ... ) >>> pert_model.delta Parameter containing: tensor([[0.2793, 0.4783], [0.4031, 0.3316]], requires_grad=True) Perturbing a batch of data now performs broadcast-addition of this tensor over the batch. >>> pert_data = pert_model(data) >>> pert_data tensor([[[1.2793, 1.4783], [1.4031, 1.3316]], . [[1.2793, 1.4783], [1.4031, 1.3316]], . [[1.2793, 1.4783], [1.4031, 1.3316]]], grad_fn=<AddBackward0>) .. important:: Downstream reductions of this broadcast-pertubed data should involve a mean – not a sum – over the batch dimension so that the resulting gradient computed for the perturbation is not scaled by batch-size. >>> (pert_data ** 2).mean().backward() >>> pert_model.delta.grad tensor([[0.6397, 0.7392], [0.7015, 0.6658]]) Similarly, when using a `~rai_toolbox.optim.ParamTransformingOptimizer` to optimize this broadcasted perturbation, we should specify `param_ndim=None` to ensure that the parameter transformations are not broadcasted over our perturbation tensor and/or its gradient, as it has no batch dimension. """ super().__init__() _init_kwargs = {} if isinstance(data_or_shape, tr.Tensor): shape = data_or_shape.shape _init_kwargs.update( { "dtype": data_or_shape.dtype, "device": data_or_shape.device, "layout": data_or_shape.layout, } ) else: shape = tuple(data_or_shape) if device is not None: _init_kwargs["device"] = device if dtype is not None: _init_kwargs["dtype"] = dtype if delta_ndim is not None: offset = len(shape) - delta_ndim if delta_ndim >= 0 else abs(delta_ndim) shape = shape[offset:] self.delta = Parameter(tr.zeros(shape, **_init_kwargs)) del _init_kwargs if init_fn is not None: init_fn(self.delta, **init_fn_kwargs) elif init_fn_kwargs: raise TypeError( f"No `init_fn` was specified, but the keyword arguments " f"{init_fn_kwargs} were provided." )
def forward(self, data: Tensor) -> Tensor: """Add perturbation to data. Parameters ---------- data: Tensor The data to perturb. Returns ------- Tensor The perturbed data """ return data + self.delta