# Copyright 2023, MASSACHUSETTS INSTITUTE OF TECHNOLOGY
# Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014).
# SPDX-License-Identifier: MIT
from abc import abstractmethod
from typing import Any, Callable, Iterator, Optional, Sequence, Union
import torch as tr
from torch import Tensor, nn
from torch.nn.parameter import Parameter
from typing_extensions import Protocol, runtime_checkable
[docs]
@runtime_checkable
class PerturbationModel(Protocol):
"""Protocol for Perturbation Models."""
def __init__(self, *args, **kwargs) -> None: # pragma: no cover
...
[docs]
@abstractmethod
def __call__(self, data: Tensor) -> Tensor: # pragma: no cover
"""A perturbation model should take data as the input and output the
peturbed data.
Parameters
----------
data: Tensor
The data to perturb.
Returns
-------
Tensor
The perturbed data with the same shape as `data`.
"""
def parameters(
self, recurse: bool = True
) -> Iterator[Parameter]: # pragma: no cover
...
[docs]
class AdditivePerturbation(nn.Module, PerturbationModel):
r"""Modifies a piece or batch of data by adding a perturbation: :math:`x \rightarrow x+\delta`.
Attributes
----------
delta : torch.Tensor
:math:`\delta` is the sole trainable parameter of `AdditivePerturbation`.
"""
[docs]
def __init__(
self,
data_or_shape: Union[Tensor, Sequence[int]],
init_fn: Optional[Callable[[Tensor], None]] = None,
*,
device: Optional[tr.device] = None,
dtype: Optional[tr.dtype] = None,
delta_ndim: Optional[int] = None,
**init_fn_kwargs: Any,
) -> None:
"""The init function should support data as the argument to initialize
perturbations.
Parameters
----------
data_or_shape: Union[Tensor, Tuple[int, ...]]
Determines the shape of the perturbation. If a tensor is supplied, its dtype
and device are mirrored by the initialized perturbation.
This parameter can be modified to control whether the perturbation adds
elementwise or broadcast-adds over `x`.
init_fn: Optional[Callable[[Tensor], None]]
Operates in-place on a zero'd perturbation tensor to determine the
final initialization of the perturbation.
device: Optional[tr.device]
If specified, takes precedent over the device associated with `data_or_shape`.
dtype: Optional[tr.dtype] = None
If specified, takes precedent over the dtype associated with `data_or_shape`.
delta_ndim: Optional[int] = None
If a positive number, determines the dimensionality of the perturbation.
If a negative number, indicates the 'offset' from the dimensionality of
`data_or_shape`. E.g., if `data_or_shape` has a shape (N, C, H, W), and
if `delta_ndim=-1`, then the perturbation will have shape (C, H, W),
and will be applied in a broadcast fashion.
**init_fn_kwargs: Any
Keyword arguments passed to `init_fn`.
Examples
--------
**Basic Additive Perturbations**
Let's imagine we have a batch of three shape-`(2, 2)` images (our toy data will
be all ones) that we want to perturb. We'll randomly initialize a shape-`(3, 2,
2)` tensor of perturbations to apply additively to the shape-`(3, 2, 2)` batch.
>>> import torch as tr
>>> from rai_toolbox.perturbations import AdditivePerturbation, uniform_like_l1_n_ball_
>>> data = tr.ones(3, 2, 2)
We provide a `generator` argument to control the RNG in `~rai_toolbox.
perturbations.uniform_like_l1_n_ball_`.
>>> pert_model = AdditivePerturbation(
... data_or_shape=data,
... init_fn=uniform_like_l1_n_ball_,
... generator=tr.Generator().manual_seed(0), # controls RNG of init
... )
Accessing the initialized perturbations.
>>> pert_model.delta
Parameter containing:
tensor([[[0.0885, 0.0436],
[0.3642, 0.2720]],
.
[[0.3074, 0.1827],
[0.1440, 0.2624]],
.
[[0.3489, 0.0528],
[0.0539, 0.1767]]], requires_grad=True)
Applying the perturbations to a batch of data.
>>> pert_data = pert_model(data)
>>> pert_data
tensor([[[1.0885, 1.0436],
[1.3642, 1.2720]],
.
[[1.3074, 1.1827],
[1.1440, 1.2624]],
.
[[1.3489, 1.0528],
[1.0539, 1.1767]]], grad_fn=<AddBackward0>)
Involving the perturbed data in a computational graph where auto-diff is
performed, then gradients are computed for the perturbations.
>>> (pert_data ** 2).sum().backward()
>>> pert_model.delta.grad
tensor([[[2.1770, 2.0871],
[2.7285, 2.5439]],
.
[[2.6148, 2.3653],
[2.2880, 2.5247]],
.
[[2.6978, 2.1056],
[2.1078, 2.3534]]])
**Broadcasted ("Universal") Perturbations**
Suppose that we want to use a single shape-`(2, 2)` tensor to perturb each datum
in a batch. We can create a perturbation model in a similar manner, but
specifying `delta_ndim=-1` indicates that our perturbation should have one
fewer dimension than our data; whereas our batch has shape-`(N, 2, 2)`, our
perturbation model's parameter will have shape-`(2, 2)`
>>> pert_model = AdditivePerturbation(
... data_or_shape=data,
... delta_ndim=-1,
... init_fn=uniform_like_l1_n_ball_,
... generator=tr.Generator().manual_seed(1), # controls RNG of init
... )
>>> pert_model.delta
Parameter containing:
tensor([[0.2793, 0.4783],
[0.4031, 0.3316]], requires_grad=True)
Perturbing a batch of data now performs broadcast-addition of this tensor over
the batch.
>>> pert_data = pert_model(data)
>>> pert_data
tensor([[[1.2793, 1.4783],
[1.4031, 1.3316]],
.
[[1.2793, 1.4783],
[1.4031, 1.3316]],
.
[[1.2793, 1.4783],
[1.4031, 1.3316]]], grad_fn=<AddBackward0>)
.. important::
Downstream reductions of this broadcast-pertubed data should involve
a mean – not a sum – over the batch dimension so that the resulting gradient
computed for the perturbation is not scaled by batch-size.
>>> (pert_data ** 2).mean().backward()
>>> pert_model.delta.grad
tensor([[0.6397, 0.7392],
[0.7015, 0.6658]])
Similarly, when using a `~rai_toolbox.optim.ParamTransformingOptimizer` to
optimize this broadcasted perturbation, we should specify `param_ndim=None`
to ensure that the parameter transformations are not broadcasted over our
perturbation tensor and/or its gradient, as it has no batch dimension.
"""
super().__init__()
_init_kwargs = {}
if isinstance(data_or_shape, tr.Tensor):
shape = data_or_shape.shape
_init_kwargs.update(
{
"dtype": data_or_shape.dtype,
"device": data_or_shape.device,
"layout": data_or_shape.layout,
}
)
else:
shape = tuple(data_or_shape)
if device is not None:
_init_kwargs["device"] = device
if dtype is not None:
_init_kwargs["dtype"] = dtype
if delta_ndim is not None:
offset = len(shape) - delta_ndim if delta_ndim >= 0 else abs(delta_ndim)
shape = shape[offset:]
self.delta = Parameter(tr.zeros(shape, **_init_kwargs))
del _init_kwargs
if init_fn is not None:
init_fn(self.delta, **init_fn_kwargs)
elif init_fn_kwargs:
raise TypeError(
f"No `init_fn` was specified, but the keyword arguments "
f"{init_fn_kwargs} were provided."
)
def forward(self, data: Tensor) -> Tensor:
"""Add perturbation to data.
Parameters
----------
data: Tensor
The data to perturb.
Returns
-------
Tensor
The perturbed data
"""
return data + self.delta