rai_toolbox.perturbations.AdditivePerturbation#

class rai_toolbox.perturbations.AdditivePerturbation(data_or_shape, init_fn=None, *, device=None, dtype=None, delta_ndim=None, **init_fn_kwargs)[source]#

Modifies a piece or batch of data by adding a perturbation: \(x \rightarrow x+\delta\).

Attributes:
deltatorch.Tensor

\(\delta\) is the sole trainable parameter of AdditivePerturbation.

__init__(data_or_shape, init_fn=None, *, device=None, dtype=None, delta_ndim=None, **init_fn_kwargs)[source]#

The init function should support data as the argument to initialize perturbations.

Parameters:
data_or_shape: Union[Tensor, Tuple[int, …]]

Determines the shape of the perturbation. If a tensor is supplied, its dtype and device are mirrored by the initialized perturbation.

This parameter can be modified to control whether the perturbation adds elementwise or broadcast-adds over x.

init_fn: Optional[Callable[[Tensor], None]]

Operates in-place on a zero’d perturbation tensor to determine the final initialization of the perturbation.

device: Optional[tr.device]

If specified, takes precedent over the device associated with data_or_shape.

dtype: Optional[tr.dtype] = None

If specified, takes precedent over the dtype associated with data_or_shape.

delta_ndim: Optional[int] = None

If a positive number, determines the dimensionality of the perturbation. If a negative number, indicates the ‘offset’ from the dimensionality of data_or_shape. E.g., if data_or_shape has a shape (N, C, H, W), and if delta_ndim=-1, then the perturbation will have shape (C, H, W), and will be applied in a broadcast fashion.

**init_fn_kwargs: Any

Keyword arguments passed to init_fn.

Examples

Basic Additive Perturbations

Let’s imagine we have a batch of three shape-(2, 2) images (our toy data will be all ones) that we want to perturb. We’ll randomly initialize a shape-(3, 2, 2) tensor of perturbations to apply additively to the shape-(3, 2, 2) batch.

>>> import torch as tr
>>> from rai_toolbox.perturbations import AdditivePerturbation, uniform_like_l1_n_ball_
>>> data = tr.ones(3, 2, 2)

We provide a generator argument to control the RNG in uniform_like_l1_n_ball_.

>>> pert_model = AdditivePerturbation(
...     data_or_shape=data,
...     init_fn=uniform_like_l1_n_ball_,
...     generator=tr.Generator().manual_seed(0),  # controls RNG of init
... )

Accessing the initialized perturbations.

>>> pert_model.delta
Parameter containing:
tensor([[[0.0885, 0.0436],
        [0.3642, 0.2720]],
        .
        [[0.3074, 0.1827],
        [0.1440, 0.2624]],
        .
        [[0.3489, 0.0528],
        [0.0539, 0.1767]]], requires_grad=True)

Applying the perturbations to a batch of data.

>>> pert_data = pert_model(data)
>>> pert_data
tensor([[[1.0885, 1.0436],
        [1.3642, 1.2720]],
        .
        [[1.3074, 1.1827],
        [1.1440, 1.2624]],
        .
        [[1.3489, 1.0528],
        [1.0539, 1.1767]]], grad_fn=<AddBackward0>)

Involving the perturbed data in a computational graph where auto-diff is performed, then gradients are computed for the perturbations.

>>> (pert_data ** 2).sum().backward()
>>> pert_model.delta.grad
tensor([[[2.1770, 2.0871],
        [2.7285, 2.5439]],
        .
        [[2.6148, 2.3653],
        [2.2880, 2.5247]],
        .
        [[2.6978, 2.1056],
        [2.1078, 2.3534]]])

Broadcasted (“Universal”) Perturbations

Suppose that we want to use a single shape-(2, 2) tensor to perturb each datum in a batch. We can create a perturbation model in a similar manner, but specifying delta_ndim=-1 indicates that our perturbation should have one fewer dimension than our data; whereas our batch has shape-(N, 2, 2), our perturbation model’s parameter will have shape-(2, 2)

>>> pert_model = AdditivePerturbation(
...     data_or_shape=data,
...     delta_ndim=-1,
...     init_fn=uniform_like_l1_n_ball_,
...     generator=tr.Generator().manual_seed(1),  # controls RNG of init
... )
>>> pert_model.delta
Parameter containing:
tensor([[0.2793, 0.4783],
        [0.4031, 0.3316]], requires_grad=True)

Perturbing a batch of data now performs broadcast-addition of this tensor over the batch.

>>> pert_data = pert_model(data)
>>> pert_data
tensor([[[1.2793, 1.4783],
        [1.4031, 1.3316]],
        .
        [[1.2793, 1.4783],
        [1.4031, 1.3316]],
        .
        [[1.2793, 1.4783],
        [1.4031, 1.3316]]], grad_fn=<AddBackward0>)

Important

Downstream reductions of this broadcast-pertubed data should involve a mean – not a sum – over the batch dimension so that the resulting gradient computed for the perturbation is not scaled by batch-size.

>>> (pert_data ** 2).mean().backward()
>>> pert_model.delta.grad
tensor([[0.6397, 0.7392],
        [0.7015, 0.6658]])

Similarly, when using a ParamTransformingOptimizer to optimize this broadcasted perturbation, we should specify param_ndim=None to ensure that the parameter transformations are not broadcasted over our perturbation tensor and/or its gradient, as it has no batch dimension.

Methods

__init__(data_or_shape[, init_fn, device, ...])

The init function should support data as the argument to initialize perturbations.