Source code for rai_toolbox.augmentations.augmix.transforms

# Copyright 2023, MASSACHUSETTS INSTITUTE OF TECHNOLOGY
# Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014).
# SPDX-License-Identifier: MIT

from numbers import Integral
from typing import Any, Callable, Optional, Sequence, Tuple, TypeVar, Union

import torch as tr

from rai_toolbox._typing import ArrayLike

from ._implementation import augment_and_mix

_T1 = TypeVar("_T1")
_T2 = TypeVar("_T2", bound=ArrayLike)

__all__ = ["AugMix", "Fork", "augment_and_mix"]


def _check_non_neg_int(item, name):
    if not isinstance(item, (int, Integral)) or item < 0:
        raise ValueError(f"`{name}`` must be a non-negative integer. Got {item}")


class AugMix(tr.nn.Module):
    __doc__ = augment_and_mix.__doc__

    def __init__(
        self,
        process_fn: Callable[[_T1], _T2],
        augmentations: Sequence[Callable[[_T1], _T1]],
        *,
        num_aug_chains: int = 3,
        aug_chain_depth: Union[int, Tuple[int, int]] = (1, 4),
        beta_params: Union[float, Tuple[float, float]] = (1.0, 1.0),
        dirichlet_params: Union[float, Sequence[float]] = 1.0,
        augmentation_choice_probs: Optional[Sequence[float]] = None,
    ):
        super().__init__()
        _check_non_neg_int(num_aug_chains, "num_aug_chains")

        if isinstance(aug_chain_depth, Sequence):
            assert len(aug_chain_depth) == 2, aug_chain_depth
            _check_non_neg_int(aug_chain_depth[0], name="aug_chain_depth[0]")
            _check_non_neg_int(aug_chain_depth[1], name="aug_chain_depth[1]")
            assert aug_chain_depth[0] <= aug_chain_depth[1], aug_chain_depth

        if isinstance(dirichlet_params, Sequence):
            assert len(dirichlet_params) == num_aug_chains

        self.process_fn = process_fn
        self.augmentations = augmentations
        self.num_aug_chains = num_aug_chains
        self.aug_chain_depth = aug_chain_depth
        self.beta_params = beta_params
        self.dirichlet_params = dirichlet_params
        self.augmentations = augmentations
        self.augmentation_choice_probs = augmentation_choice_probs

        if augmentation_choice_probs is not None and len(
            augmentation_choice_probs
        ) != len(augmentations):
            raise ValueError(
                f"`len(sample_probabilities)` ({len(augmentation_choice_probs)}) must match `len(augmentations)` ({len(augmentations)})"
            )

    def forward(self, datum):
        return augment_and_mix(
            datum=datum,
            process_fn=self.process_fn,
            augmentations=self.augmentations,
            num_aug_chains=self.num_aug_chains,
            aug_chain_depth=self.aug_chain_depth,
            beta_params=self.beta_params,
            dirichlet_params=self.dirichlet_params,
            augmentation_choice_probs=self.augmentation_choice_probs,
        )

    def __repr__(self) -> str:
        format_string = self.__class__.__name__ + "(\naugmentations="
        for t in self.augmentations:
            format_string += "\n"
            format_string += f"    {t},"
        format_string += f"\nprocess_fn={self.process_fn},"
        format_string += f"\nnum_aug_chains={self.num_aug_chains},"
        format_string += f"\naug_chain_depth={self.aug_chain_depth},"
        format_string += f"\nbeta_params={self.beta_params},"
        format_string += f"\ndirichlet_params={self.dirichlet_params},"
        format_string += "\n)"

        return format_string


def _flat_repr(x) -> str:
    out = f"{x}".splitlines()
    if len(out) == 1:
        return out[0]
    else:
        return out[0] + "...)"


[docs] class Fork(tr.nn.Module): """ Forks an input into an arbitrary number of transform-chains. This can be useful for doing consistency-loss workflows. Parameters ---------- *forked_transforms: Callable[[Any], Any] One transform for each fork to create. Examples -------- >>> from rai_toolbox.augmentations.augmix import Fork Here are some trivial examples: >>> two_fork = Fork(lambda x: x, lambda x: 2 * x) >>> two_fork(2) (2, 4) >>> three_fork = Fork(lambda x: x, lambda x: 2 * x, lambda x: 0 * x) >>> three_fork(-1.0) (-1.0, -2.0, -0.0) Here is a simplified version of the triple-processing used by the AugMix paper's consistency loss. It anticipates a PIL image and produces a triplet. >>> from torchvision.transforms import ToTensor, RandomHorizontalFlip, RandomVerticalFlip >>> from rai_toolbox.augmentations.augmix import AugMix >>> augmix = AugMix( ... augmentations=[RandomHorizontalFlip(), RandomVerticalFlip()], ... process_fn=ToTensor(), ... ) >>> Fork(augmix, augmix, ToTensor()) Fork( - ToTensor() -> x --> - AugMix(...) -> - AugMix(...) -> ) """
[docs] def __init__(self, *forked_transforms: Callable[[Any], Any]): super().__init__() if not forked_transforms: raise ValueError("At least one transform must be passed") if not all(callable(t) for t in forked_transforms): raise TypeError( f"All forked transforms must be callable, got: {forked_transforms}" ) self.forked_transforms = forked_transforms
def forward(self, x) -> Tuple[Any, ...]: return tuple(f(x) for f in self.forked_transforms) def __repr__(self) -> str: out = "Fork(\n" num_forks = len(self.forked_transforms) for n, f in enumerate(self.forked_transforms): if num_forks // 2 == n: if num_forks % 2 == 1: out += f"x --> - {_flat_repr(f)} ->\n" else: out += "x -->\n" out += f" - {_flat_repr(f)} ->\n" else: out += f" - {_flat_repr(f)} ->\n" return out + ")"