Source code for rai_toolbox.augmentations.augmix._implementation

# Copyright 2023, MASSACHUSETTS INSTITUTE OF TECHNOLOGY
# Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014).
# SPDX-License-Identifier: MIT
# type: ignore
"""
An implementation of the AugMix augmentation method specified in:

    Hendrycks, Dan, et al. "Augmix: A simple data processing method to improve
    robustness and uncertainty." arXiv preprint arXiv:1912.02781 (2019).

with reference implementation from https://github.com/google-research/augmix
"""

from numbers import Real
from typing import Callable, Optional, Sequence, Tuple, TypeVar, Union

import numpy as np

from rai_toolbox._typing import ArrayLike

_T1 = TypeVar("_T1")
_T2 = TypeVar("_T2", bound=ArrayLike)

C1 = Callable[[_T1], _T2]
C2 = Sequence[Callable[[_T1], _T1]]


[docs] def augment_and_mix( datum: _T1, *, process_fn: Callable[[_T1], _T2], augmentations: Sequence[Callable[[_T1], _T1]], num_aug_chains: int = 3, aug_chain_depth: Union[int, Tuple[int, int]] = (1, 4), beta_params: Union[float, Tuple[float, float]] = (1.0, 1.0), dirichlet_params: Union[float, Sequence[float]] = 1.0, augmentation_choice_probs: Optional[Sequence[float]] = None, ) -> _T2: """ Augments `datum` using a mixture of randomly-composed augmentations via a method called "AugMix" [1]_. Parameters ---------- datum : T1 The datum to be augmixed. num_aug_chains : int The number of independent "augmentation chains" that are used to produce the augmixed result. aug_chain_depth : Union[int, Tuple[int, int]] Determines that number of augmentations that are randomly sampled (with replacement) and composed to form each augmentation chain. If a tuple of values are provided, these are used as the lower (inclusive) and upper (exclusive) bounds on a uniform integer-valued distribution from which the depth will be sampled for each augmentation chain. E.g., - `mixture_depth=2` means that each augmentation chain will consist of two (randomly sampled) augmentations composed together. - `aug_chain_depth=(1, 4)` means depth of any given augmentation chain is uniformly sampled from [1, 4). process_fn : Callable[[T1], T2] The preprocessing function applied to the both clean and augmixed datums before they are combined. The return type of `process_fn` determines the return type of `augment_and_mix`. augmentations : Sequence[Callable[[T1], T1]] The collection of datum augmentations that is sampled from (with replacement) to form each augmentation chain. beta_params : Tuple[float, float] The Beta distribution parameters to draw `m`, which weights that convex combination:: (1 - m) * img_process_fn(datum) + m * img_process_fn(augment(datum)) If a single value is specified, it is used as both parameters for the distribution. dirichlet_params : Union[float, Sequence[float]] The Dirichlet distribution parameters used to weight the `mixture_width` number of augmentation chains. If a sequence is provided, its length must match `num_aug_chains`. augmentation_choice_probs : Optional[Sequence[float]] The probabilities associated with sampling each respective entry in `augmentations`. If not specified, a uniform distribution over all entries of `augmentation`. Returns ------- augmixed: T2 The augmixed datum. Notes ----- The following depicts AugMix with N augmentation chains. Each `augchain(...)` consists of composed augmentations, where the composition depth is determined by `mixture_depth`:: (1 - m) * process_fn(img) + m * (w1 * (process_fn ∘ augchain1)(img) + ... + wN * (process_fn ∘ augchainN)(img)) with - m ~ Beta - [w1, ..., wN] ~ Dirichlet Random values are drawn via NumPy's global random number generator. Thus `numpy.random.seed` must be set in order to obtain reproducible results. Note that, until PyTorch 1.9.0, there was an issue with using NumPy's global RNG in conjunction with DataLoaders that used multiple workers, where identical seeds were being used across workers and the same seed was being set at the outset of each epoch. https://github.com/pytorch/pytorch/pull/56488 References ---------- .. [1] Hendrycks, Dan, et al. "Augmix: A simple data processing method to improve robustness and uncertainty." arXiv preprint arXiv:1912.02781 (2019). """ # See: https://numpy.org/doc/stable/reference/random/index.html?#module-numpy.random # for details about new rng methods in numpy if isinstance(beta_params, Real): beta_params = (float(beta_params), float(beta_params)) if isinstance(dirichlet_params, (Real, float)): dirichlet_params = [float(dirichlet_params)] * num_aug_chains elif len(dirichlet_params) != num_aug_chains: raise ValueError( f"The number of specified dirichlet parameters ({len(dirichlet_params)}) " f"does not match the mixture width ({num_aug_chains})" ) if num_aug_chains: mix_weight = float(np.random.beta(*beta_params)) else: # no augmentations occur mix_weight = 0.0 if augmentation_choice_probs is not None: augmentation_choice_probs = np.array(augmentation_choice_probs, dtype=float) augmentation_choice_probs /= augmentation_choice_probs.sum() width_weights = np.asarray(np.random.dirichlet(dirichlet_params), dtype=np.float32) # mixture_depths[i]: number of augmentations applied to branch-i # of mixture graph if isinstance(aug_chain_depth, Sequence): mixture_depths: np.ndarray = np.random.randint( *aug_chain_depth, size=num_aug_chains ) else: mixture_depths = np.full(width_weights.shape, aug_chain_depth) out = (1.0 - mix_weight) * process_fn(datum) for depth, weight in zip(mixture_depths, width_weights): # augmentations must not mutate original datum augmented_datum = datum for op in np.random.choice( augmentations, size=depth, replace=True, p=augmentation_choice_probs ): op: Callable[[_T1], _T1] augmented_datum = op(augmented_datum) out += (mix_weight * weight) * process_fn(augmented_datum) out: _T2 return out