rai_toolbox.augmentations.augmix.Fork#
- class rai_toolbox.augmentations.augmix.Fork(*forked_transforms)[source]#
Forks an input into an arbitrary number of transform-chains. This can be useful for doing consistency-loss workflows.
- Parameters:
- *forked_transforms: Callable[[Any], Any]
One transform for each fork to create.
Examples
>>> from rai_toolbox.augmentations.augmix import Fork
Here are some trivial examples:
>>> two_fork = Fork(lambda x: x, lambda x: 2 * x) >>> two_fork(2) (2, 4)
>>> three_fork = Fork(lambda x: x, lambda x: 2 * x, lambda x: 0 * x) >>> three_fork(-1.0) (-1.0, -2.0, -0.0)
Here is a simplified version of the triple-processing used by the AugMix paper’s consistency loss. It anticipates a PIL image and produces a triplet.
>>> from torchvision.transforms import ToTensor, RandomHorizontalFlip, RandomVerticalFlip >>> from rai_toolbox.augmentations.augmix import AugMix >>> augmix = AugMix( ... augmentations=[RandomHorizontalFlip(), RandomVerticalFlip()], ... process_fn=ToTensor(), ... ) >>> Fork(augmix, augmix, ToTensor()) Fork( - ToTensor() -> x --> - AugMix(...) -> - AugMix(...) -> )
- __init__(*forked_transforms)[source]#
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Methods
__init__
(*forked_transforms)Initializes internal Module state, shared by both nn.Module and ScriptModule.