Source code for rai_toolbox.augmentations.augmix.transforms
# Copyright 2023, MASSACHUSETTS INSTITUTE OF TECHNOLOGY# Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014).# SPDX-License-Identifier: MITfromnumbersimportIntegralfromtypingimportAny,Callable,Optional,Sequence,Tuple,TypeVar,Unionimporttorchastrfromrai_toolbox._typingimportArrayLikefrom._implementationimportaugment_and_mix_T1=TypeVar("_T1")_T2=TypeVar("_T2",bound=ArrayLike)__all__=["AugMix","Fork","augment_and_mix"]def_check_non_neg_int(item,name):ifnotisinstance(item,(int,Integral))oritem<0:raiseValueError(f"`{name}`` must be a non-negative integer. Got {item}")classAugMix(tr.nn.Module):__doc__=augment_and_mix.__doc__def__init__(self,process_fn:Callable[[_T1],_T2],augmentations:Sequence[Callable[[_T1],_T1]],*,num_aug_chains:int=3,aug_chain_depth:Union[int,Tuple[int,int]]=(1,4),beta_params:Union[float,Tuple[float,float]]=(1.0,1.0),dirichlet_params:Union[float,Sequence[float]]=1.0,augmentation_choice_probs:Optional[Sequence[float]]=None,):super().__init__()_check_non_neg_int(num_aug_chains,"num_aug_chains")ifisinstance(aug_chain_depth,Sequence):assertlen(aug_chain_depth)==2,aug_chain_depth_check_non_neg_int(aug_chain_depth[0],name="aug_chain_depth[0]")_check_non_neg_int(aug_chain_depth[1],name="aug_chain_depth[1]")assertaug_chain_depth[0]<=aug_chain_depth[1],aug_chain_depthifisinstance(dirichlet_params,Sequence):assertlen(dirichlet_params)==num_aug_chainsself.process_fn=process_fnself.augmentations=augmentationsself.num_aug_chains=num_aug_chainsself.aug_chain_depth=aug_chain_depthself.beta_params=beta_paramsself.dirichlet_params=dirichlet_paramsself.augmentations=augmentationsself.augmentation_choice_probs=augmentation_choice_probsifaugmentation_choice_probsisnotNoneandlen(augmentation_choice_probs)!=len(augmentations):raiseValueError(f"`len(sample_probabilities)` ({len(augmentation_choice_probs)}) must match `len(augmentations)` ({len(augmentations)})")defforward(self,datum):returnaugment_and_mix(datum=datum,process_fn=self.process_fn,augmentations=self.augmentations,num_aug_chains=self.num_aug_chains,aug_chain_depth=self.aug_chain_depth,beta_params=self.beta_params,dirichlet_params=self.dirichlet_params,augmentation_choice_probs=self.augmentation_choice_probs,)def__repr__(self)->str:format_string=self.__class__.__name__+"(\naugmentations="fortinself.augmentations:format_string+="\n"format_string+=f" {t},"format_string+=f"\nprocess_fn={self.process_fn},"format_string+=f"\nnum_aug_chains={self.num_aug_chains},"format_string+=f"\naug_chain_depth={self.aug_chain_depth},"format_string+=f"\nbeta_params={self.beta_params},"format_string+=f"\ndirichlet_params={self.dirichlet_params},"format_string+="\n)"returnformat_stringdef_flat_repr(x)->str:out=f"{x}".splitlines()iflen(out)==1:returnout[0]else:returnout[0]+"...)"
[docs]classFork(tr.nn.Module):""" Forks an input into an arbitrary number of transform-chains. This can be useful for doing consistency-loss workflows. Parameters ---------- *forked_transforms: Callable[[Any], Any] One transform for each fork to create. Examples -------- >>> from rai_toolbox.augmentations.augmix import Fork Here are some trivial examples: >>> two_fork = Fork(lambda x: x, lambda x: 2 * x) >>> two_fork(2) (2, 4) >>> three_fork = Fork(lambda x: x, lambda x: 2 * x, lambda x: 0 * x) >>> three_fork(-1.0) (-1.0, -2.0, -0.0) Here is a simplified version of the triple-processing used by the AugMix paper's consistency loss. It anticipates a PIL image and produces a triplet. >>> from torchvision.transforms import ToTensor, RandomHorizontalFlip, RandomVerticalFlip >>> from rai_toolbox.augmentations.augmix import AugMix >>> augmix = AugMix( ... augmentations=[RandomHorizontalFlip(), RandomVerticalFlip()], ... process_fn=ToTensor(), ... ) >>> Fork(augmix, augmix, ToTensor()) Fork( - ToTensor() -> x --> - AugMix(...) -> - AugMix(...) -> ) """
[docs]def__init__(self,*forked_transforms:Callable[[Any],Any]):super().__init__()ifnotforked_transforms:raiseValueError("At least one transform must be passed")ifnotall(callable(t)fortinforked_transforms):raiseTypeError(f"All forked transforms must be callable, got: {forked_transforms}")self.forked_transforms=forked_transforms