rai_toolbox.augmentations.augmix.augment_and_mix#
- rai_toolbox.augmentations.augmix.augment_and_mix(datum, *, process_fn, augmentations, num_aug_chains=3, aug_chain_depth=(1, 4), beta_params=(1.0, 1.0), dirichlet_params=1.0, augmentation_choice_probs=None)[source]#
Augments
datum
using a mixture of randomly-composed augmentations via a method called “AugMix” [1].- Parameters:
- datumT1
The datum to be augmixed.
- num_aug_chainsint
The number of independent “augmentation chains” that are used to produce the augmixed result.
- aug_chain_depthUnion[int, Tuple[int, int]]
Determines that number of augmentations that are randomly sampled (with replacement) and composed to form each augmentation chain.
If a tuple of values are provided, these are used as the lower (inclusive) and upper (exclusive) bounds on a uniform integer-valued distribution from which the depth will be sampled for each augmentation chain.
E.g.,
mixture_depth=2
means that each augmentation chain will consist of two (randomly sampled) augmentations composed together.aug_chain_depth=(1, 4)
means depth of any given augmentation chain is uniformly sampled from [1, 4).
- process_fnCallable[[T1], T2]
The preprocessing function applied to the both clean and augmixed datums before they are combined. The return type of
process_fn
determines the return type ofaugment_and_mix
.- augmentationsSequence[Callable[[T1], T1]]
The collection of datum augmentations that is sampled from (with replacement) to form each augmentation chain.
- beta_paramsTuple[float, float]
The Beta distribution parameters to draw
m
, which weights that convex combination:(1 - m) * img_process_fn(datum) + m * img_process_fn(augment(datum))
If a single value is specified, it is used as both parameters for the distribution.
- dirichlet_paramsUnion[float, Sequence[float]]
The Dirichlet distribution parameters used to weight the
mixture_width
number of augmentation chains. If a sequence is provided, its length must matchnum_aug_chains
.- augmentation_choice_probsOptional[Sequence[float]]
The probabilities associated with sampling each respective entry in
augmentations
. If not specified, a uniform distribution over all entries ofaugmentation
.
- Returns:
- augmixed: T2
The augmixed datum.
Notes
The following depicts AugMix with N augmentation chains. Each
augchain(...)
consists of composed augmentations, where the composition depth is determined bymixture_depth
:(1 - m) * process_fn(img) + m * (w1 * (process_fn ∘ augchain1)(img) + ... + wN * (process_fn ∘ augchainN)(img))
with
m ~ Beta
[w1, …, wN] ~ Dirichlet
Random values are drawn via NumPy’s global random number generator. Thus
numpy.random.seed
must be set in order to obtain reproducible results. Note that, until PyTorch 1.9.0, there was an issue with using NumPy’s global RNG in conjunction with DataLoaders that used multiple workers, where identical seeds were being used across workers and the same seed was being set at the outset of each epoch.References
[1]Hendrycks, Dan, et al. “Augmix: A simple data processing method to improve robustness and uncertainty.” arXiv preprint arXiv:1912.02781 (2019).