rai_toolbox.augmentations.augmix.Fork#

class rai_toolbox.augmentations.augmix.Fork(*forked_transforms)[source]#

Forks an input into an arbitrary number of transform-chains. This can be useful for doing consistency-loss workflows.

Parameters:
*forked_transforms: Callable[[Any], Any]

One transform for each fork to create.

Examples

>>> from rai_toolbox.augmentations.augmix import Fork

Here are some trivial examples:

>>> two_fork = Fork(lambda x: x, lambda x: 2 * x)
>>> two_fork(2)
(2, 4)
>>> three_fork = Fork(lambda x: x, lambda x: 2 * x, lambda x: 0 * x)
>>> three_fork(-1.0)
(-1.0, -2.0, -0.0)

Here is a simplified version of the triple-processing used by the AugMix paper’s consistency loss. It anticipates a PIL image and produces a triplet.

>>> from torchvision.transforms import ToTensor, RandomHorizontalFlip, RandomVerticalFlip
>>> from rai_toolbox.augmentations.augmix import AugMix
>>> augmix = AugMix(
...     augmentations=[RandomHorizontalFlip(), RandomVerticalFlip()],
...     process_fn=ToTensor(),
... )
>>> Fork(augmix, augmix, ToTensor())
Fork(
      - ToTensor() ->
x --> - AugMix(...) ->
      - AugMix(...) ->
)
__init__(*forked_transforms)[source]#

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Methods

__init__(*forked_transforms)

Initializes internal Module state, shared by both nn.Module and ScriptModule.