rai_toolbox.augmentations.augmix.augment_and_mix#

rai_toolbox.augmentations.augmix.augment_and_mix(datum, *, process_fn, augmentations, num_aug_chains=3, aug_chain_depth=(1, 4), beta_params=(1.0, 1.0), dirichlet_params=1.0, augmentation_choice_probs=None)[source]#

Augments datum using a mixture of randomly-composed augmentations via a method called “AugMix” [1].

Parameters:
datumT1

The datum to be augmixed.

num_aug_chainsint

The number of independent “augmentation chains” that are used to produce the augmixed result.

aug_chain_depthUnion[int, Tuple[int, int]]

Determines that number of augmentations that are randomly sampled (with replacement) and composed to form each augmentation chain.

If a tuple of values are provided, these are used as the lower (inclusive) and upper (exclusive) bounds on a uniform integer-valued distribution from which the depth will be sampled for each augmentation chain.

E.g.,

  • mixture_depth=2 means that each augmentation chain will consist of two (randomly sampled) augmentations composed together.

  • aug_chain_depth=(1, 4) means depth of any given augmentation chain is uniformly sampled from [1, 4).

process_fnCallable[[T1], T2]

The preprocessing function applied to the both clean and augmixed datums before they are combined. The return type of process_fn determines the return type of augment_and_mix.

augmentationsSequence[Callable[[T1], T1]]

The collection of datum augmentations that is sampled from (with replacement) to form each augmentation chain.

beta_paramsTuple[float, float]

The Beta distribution parameters to draw m, which weights that convex combination:

(1 - m) * img_process_fn(datum) + m * img_process_fn(augment(datum))

If a single value is specified, it is used as both parameters for the distribution.

dirichlet_paramsUnion[float, Sequence[float]]

The Dirichlet distribution parameters used to weight the mixture_width number of augmentation chains. If a sequence is provided, its length must match num_aug_chains.

augmentation_choice_probsOptional[Sequence[float]]

The probabilities associated with sampling each respective entry in augmentations. If not specified, a uniform distribution over all entries of augmentation.

Returns:
augmixed: T2

The augmixed datum.

Notes

The following depicts AugMix with N augmentation chains. Each augchain(...) consists of composed augmentations, where the composition depth is determined by mixture_depth:

(1 - m) * process_fn(img) + m * (w1 * (process_fn ∘ augchain1)(img) + ... + wN * (process_fn ∘ augchainN)(img))

with

  • m ~ Beta

  • [w1, …, wN] ~ Dirichlet

Random values are drawn via NumPy’s global random number generator. Thus numpy.random.seed must be set in order to obtain reproducible results. Note that, until PyTorch 1.9.0, there was an issue with using NumPy’s global RNG in conjunction with DataLoaders that used multiple workers, where identical seeds were being used across workers and the same seed was being set at the outset of each epoch.

pytorch/pytorch#56488

References

[1]

Hendrycks, Dan, et al. “Augmix: A simple data processing method to improve robustness and uncertainty.” arXiv preprint arXiv:1912.02781 (2019).