Source code for rai_toolbox.augmentations.fourier._implementations
# Copyright 2023, MASSACHUSETTS INSTITUTE OF TECHNOLOGY# Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014).# SPDX-License-Identifier: MIT# type: ignoreimportrandomfromcollectionsimportdefaultdictfromtypingimport(Any,Callable,Collection,DefaultDict,Dict,Iterable,List,NamedTuple,Optional,Sequence,Tuple,Type,TypeVar,Union,)importnumpyasnpimporttorchastrfromtorch.nn.functionalimportsoftmaxfromtorchmetricsimportMetricfromrai_toolboximportevaluatingfromrai_toolbox._utilsimportget_devicefrom._fourier_basisimportgenerate_fourier_basesT=TypeVar("T",np.ndarray,tr.Tensor)defperturb_batch(*,batch:T,basis:np.ndarray,basis_norm:float,rand_flip_per_channel:bool,)->T:""" Given a single Fourier basis array, perturbs a batch of images by applying color_channel += rand_sign * norm * normed_basis to each color channel of each image. Parameters ---------- batch: np.ndarray, shape-(N, C, H, W) N: batch size, C: number of color channels. basis: np.ndarray, shape-(H, W) Assumed to already be normalized basis_norm: float rand_flip_per_channel: bool If True, the std-lib `random` module is used to draw a random sign associated with each of the to-be-perturbed color channels. Returns ------- perturbed_batch: np.ndarray, shape-(N, C, H, W) """basis=basis*basis_normifnotrand_flip_per_channel:ifisinstance(batch,tr.Tensor):basis=tr.tensor(basis,dtype=batch.dtype,device=batch.device)returnbatch+basis_channel_flips=np.array([random.randrange(-1,2,2)foriinrange(batch.shape[0]*batch.shape[1])],dtype="float32",)out=np.multiply.outer(_channel_flips,basis).reshape(batch.shape)ifisinstance(batch,tr.Tensor):out=tr.tensor(out,dtype=batch.dtype,device=batch.device)out+=batchreturnoutdefnormalize(imgs:np.ndarray,*,mean:Union[float,Sequence[float]],std:Union[float,Sequence[float]],inplace:bool=False,):""" Returns (imgs - mean) / std Parameters ---------- imgs: np.ndarray, shape-(N, C, H, W) mean : float | shape-(C,) std : float | shape-(C,) inplace: bool, optional (default=False) Returns ------- normalized_imgs: np.ndarray, shape-(N, C, H, W) """mean_arr=np.atleast_1d(mean)[None,:,None,None].astype(imgs.dtype)std_arr=np.atleast_1d(std)[None,:,None,None].astype(imgs.dtype)assertimgs.ndim==4ifnotinplace:return(imgs-mean_arr)/std_arrelse:imgs-=mean_arrimgs/=std_arrreturnimgsclassHeatMapEntry(NamedTuple):pos:Tuple[int,int]sym_pos:Tuple[int,int]result:Any