rai_toolbox.perturbations.AdditivePerturbation#
- class rai_toolbox.perturbations.AdditivePerturbation(data_or_shape, init_fn=None, *, device=None, dtype=None, delta_ndim=None, **init_fn_kwargs)[source]#
Modifies a piece or batch of data by adding a perturbation: \(x \rightarrow x+\delta\).
- Attributes:
- deltatorch.Tensor
\(\delta\) is the sole trainable parameter of
AdditivePerturbation
.
- __init__(data_or_shape, init_fn=None, *, device=None, dtype=None, delta_ndim=None, **init_fn_kwargs)[source]#
The init function should support data as the argument to initialize perturbations.
- Parameters:
- data_or_shape: Union[Tensor, Tuple[int, …]]
Determines the shape of the perturbation. If a tensor is supplied, its dtype and device are mirrored by the initialized perturbation.
This parameter can be modified to control whether the perturbation adds elementwise or broadcast-adds over
x
.- init_fn: Optional[Callable[[Tensor], None]]
Operates in-place on a zero’d perturbation tensor to determine the final initialization of the perturbation.
- device: Optional[tr.device]
If specified, takes precedent over the device associated with
data_or_shape
.- dtype: Optional[tr.dtype] = None
If specified, takes precedent over the dtype associated with
data_or_shape
.- delta_ndim: Optional[int] = None
If a positive number, determines the dimensionality of the perturbation. If a negative number, indicates the ‘offset’ from the dimensionality of
data_or_shape
. E.g., ifdata_or_shape
has a shape (N, C, H, W), and ifdelta_ndim=-1
, then the perturbation will have shape (C, H, W), and will be applied in a broadcast fashion.- **init_fn_kwargs: Any
Keyword arguments passed to
init_fn
.
Examples
Basic Additive Perturbations
Let’s imagine we have a batch of three shape-
(2, 2)
images (our toy data will be all ones) that we want to perturb. We’ll randomly initialize a shape-(3, 2, 2)
tensor of perturbations to apply additively to the shape-(3, 2, 2)
batch.>>> import torch as tr >>> from rai_toolbox.perturbations import AdditivePerturbation, uniform_like_l1_n_ball_ >>> data = tr.ones(3, 2, 2)
We provide a
generator
argument to control the RNG inuniform_like_l1_n_ball_
.>>> pert_model = AdditivePerturbation( ... data_or_shape=data, ... init_fn=uniform_like_l1_n_ball_, ... generator=tr.Generator().manual_seed(0), # controls RNG of init ... )
Accessing the initialized perturbations.
>>> pert_model.delta Parameter containing: tensor([[[0.0885, 0.0436], [0.3642, 0.2720]], . [[0.3074, 0.1827], [0.1440, 0.2624]], . [[0.3489, 0.0528], [0.0539, 0.1767]]], requires_grad=True)
Applying the perturbations to a batch of data.
>>> pert_data = pert_model(data) >>> pert_data tensor([[[1.0885, 1.0436], [1.3642, 1.2720]], . [[1.3074, 1.1827], [1.1440, 1.2624]], . [[1.3489, 1.0528], [1.0539, 1.1767]]], grad_fn=<AddBackward0>)
Involving the perturbed data in a computational graph where auto-diff is performed, then gradients are computed for the perturbations.
>>> (pert_data ** 2).sum().backward() >>> pert_model.delta.grad tensor([[[2.1770, 2.0871], [2.7285, 2.5439]], . [[2.6148, 2.3653], [2.2880, 2.5247]], . [[2.6978, 2.1056], [2.1078, 2.3534]]])
Broadcasted (“Universal”) Perturbations
Suppose that we want to use a single shape-
(2, 2)
tensor to perturb each datum in a batch. We can create a perturbation model in a similar manner, but specifyingdelta_ndim=-1
indicates that our perturbation should have one fewer dimension than our data; whereas our batch has shape-(N, 2, 2)
, our perturbation model’s parameter will have shape-(2, 2)
>>> pert_model = AdditivePerturbation( ... data_or_shape=data, ... delta_ndim=-1, ... init_fn=uniform_like_l1_n_ball_, ... generator=tr.Generator().manual_seed(1), # controls RNG of init ... ) >>> pert_model.delta Parameter containing: tensor([[0.2793, 0.4783], [0.4031, 0.3316]], requires_grad=True)
Perturbing a batch of data now performs broadcast-addition of this tensor over the batch.
>>> pert_data = pert_model(data) >>> pert_data tensor([[[1.2793, 1.4783], [1.4031, 1.3316]], . [[1.2793, 1.4783], [1.4031, 1.3316]], . [[1.2793, 1.4783], [1.4031, 1.3316]]], grad_fn=<AddBackward0>)
Important
Downstream reductions of this broadcast-pertubed data should involve a mean – not a sum – over the batch dimension so that the resulting gradient computed for the perturbation is not scaled by batch-size.
>>> (pert_data ** 2).mean().backward() >>> pert_model.delta.grad tensor([[0.6397, 0.7392], [0.7015, 0.6658]])
Similarly, when using a
ParamTransformingOptimizer
to optimize this broadcasted perturbation, we should specifyparam_ndim=None
to ensure that the parameter transformations are not broadcasted over our perturbation tensor and/or its gradient, as it has no batch dimension.
Methods
__init__
(data_or_shape[, init_fn, device, ...])The init function should support data as the argument to initialize perturbations.