Source code for rai_toolbox.augmentations.fourier._implementations

# Copyright 2023, MASSACHUSETTS INSTITUTE OF TECHNOLOGY
# Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014).
# SPDX-License-Identifier: MIT
# type: ignore
import random
from collections import defaultdict
from typing import (
    Any,
    Callable,
    Collection,
    DefaultDict,
    Dict,
    Iterable,
    List,
    NamedTuple,
    Optional,
    Sequence,
    Tuple,
    Type,
    TypeVar,
    Union,
)

import numpy as np
import torch as tr
from torch.nn.functional import softmax
from torchmetrics import Metric

from rai_toolbox import evaluating
from rai_toolbox._utils import get_device

from ._fourier_basis import generate_fourier_bases

T = TypeVar("T", np.ndarray, tr.Tensor)


def perturb_batch(
    *,
    batch: T,
    basis: np.ndarray,
    basis_norm: float,
    rand_flip_per_channel: bool,
) -> T:
    """
    Given a single Fourier basis array, perturbs a batch of images by applying
        color_channel += rand_sign * norm * normed_basis

    to each color channel of each image.

    Parameters
    ----------
    batch: np.ndarray, shape-(N, C, H, W)
        N: batch size, C: number of color channels.

    basis: np.ndarray, shape-(H, W)
        Assumed to already be normalized

    basis_norm: float

    rand_flip_per_channel: bool
        If True, the std-lib `random` module is used to draw a random sign associated
        with each of the to-be-perturbed color channels.

    Returns
    -------
    perturbed_batch: np.ndarray, shape-(N, C, H, W)
    """
    basis = basis * basis_norm

    if not rand_flip_per_channel:
        if isinstance(batch, tr.Tensor):
            basis = tr.tensor(basis, dtype=batch.dtype, device=batch.device)
        return batch + basis

    _channel_flips = np.array(
        [random.randrange(-1, 2, 2) for i in range(batch.shape[0] * batch.shape[1])],
        dtype="float32",
    )

    out = np.multiply.outer(_channel_flips, basis).reshape(batch.shape)

    if isinstance(batch, tr.Tensor):
        out = tr.tensor(out, dtype=batch.dtype, device=batch.device)

    out += batch
    return out


def normalize(
    imgs: np.ndarray,
    *,
    mean: Union[float, Sequence[float]],
    std: Union[float, Sequence[float]],
    inplace: bool = False,
):
    """
    Returns (imgs - mean) / std

    Parameters
    ----------
    imgs: np.ndarray, shape-(N, C, H, W)
    mean : float | shape-(C,)
    std : float | shape-(C,)
    inplace: bool, optional (default=False)

    Returns
    -------
    normalized_imgs: np.ndarray, shape-(N, C, H, W)
    """
    mean_arr = np.atleast_1d(mean)[None, :, None, None].astype(imgs.dtype)
    std_arr = np.atleast_1d(std)[None, :, None, None].astype(imgs.dtype)

    assert imgs.ndim == 4
    if not inplace:
        return (imgs - mean_arr) / std_arr
    else:
        imgs -= mean_arr
        imgs /= std_arr
        return imgs


class HeatMapEntry(NamedTuple):
    pos: Tuple[int, int]
    sym_pos: Tuple[int, int]
    result: Any


[docs] def create_heatmaps( dataloader: Collection[Tuple[tr.Tensor, tr.Tensor]], image_height_width: Tuple[int, int], *, model: tr.nn.Module, metrics: Dict[str, Type[Metric]], basis_norm: float, rand_flip_per_channel: bool, post_pert_batch_transform: Optional[Callable[[tr.Tensor], tr.Tensor]] = None, device: Optional[Union[tr.device, str, int]] = None, row_col_coords: Optional[Iterable[Tuple[int, int]]] = None, factor_2pi_phase_shift: float = 0, ) -> Dict[str, List[HeatMapEntry]]: from rai_toolbox._utils.tqdm import tqdm _outer_total = ( None if row_col_coords is not None else int(np.prod(image_height_width)) // 2 ) with evaluating(model), tr.no_grad(): if device is not None: device = tr.device(device) model.to(device=device) else: device = get_device(model) if post_pert_batch_transform is None: def post_pert_batch_transform(x): return x results: DefaultDict[str, Dict[Tuple[int, int], HeatMapEntry]] = defaultdict( dict ) for batch, targets in tqdm(dataloader, desc="batch"): batch = batch.pin_memory() targets = targets.to(device=device) for basis in tqdm( generate_fourier_bases( *image_height_width, dtype="float32", row_col_coords=row_col_coords, factor_2pi_phase_shift=factor_2pi_phase_shift, ), total=_outer_total, desc="fourier-grid", leave=False, ): p_batch = batch.to(device=device) p_batch = perturb_batch( batch=p_batch, basis=basis.basis, basis_norm=basis_norm, rand_flip_per_channel=rand_flip_per_channel, ) p_batch = post_pert_batch_transform(p_batch) probs = softmax(model(p_batch), dim=1).cpu() targets = targets.cpu() for name, M in metrics.items(): if basis.position not in results[name]: results[name][basis.position] = HeatMapEntry( basis.position, sym_pos=basis.sym_position, result=M(), ) results[name][basis.position].result.update( preds=probs, target=targets ) out: Dict[str, List[HeatMapEntry]] = {} for metric_name, r in results.items(): out[metric_name] = [ HeatMapEntry(p, ps, metric.compute()) for _, (p, ps, metric) in r.items() ] return out