# Copyright 2023, MASSACHUSETTS INSTITUTE OF TECHNOLOGY
# Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014).
# SPDX-License-Identifier: MIT
# type: ignore
import random
from collections import defaultdict
from typing import (
Any,
Callable,
Collection,
DefaultDict,
Dict,
Iterable,
List,
NamedTuple,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
)
import numpy as np
import torch as tr
from torch.nn.functional import softmax
from torchmetrics import Metric
from rai_toolbox import evaluating
from rai_toolbox._utils import get_device
from ._fourier_basis import generate_fourier_bases
T = TypeVar("T", np.ndarray, tr.Tensor)
def perturb_batch(
*,
batch: T,
basis: np.ndarray,
basis_norm: float,
rand_flip_per_channel: bool,
) -> T:
"""
Given a single Fourier basis array, perturbs a batch of images by applying
color_channel += rand_sign * norm * normed_basis
to each color channel of each image.
Parameters
----------
batch: np.ndarray, shape-(N, C, H, W)
N: batch size, C: number of color channels.
basis: np.ndarray, shape-(H, W)
Assumed to already be normalized
basis_norm: float
rand_flip_per_channel: bool
If True, the std-lib `random` module is used to draw a random sign associated
with each of the to-be-perturbed color channels.
Returns
-------
perturbed_batch: np.ndarray, shape-(N, C, H, W)
"""
basis = basis * basis_norm
if not rand_flip_per_channel:
if isinstance(batch, tr.Tensor):
basis = tr.tensor(basis, dtype=batch.dtype, device=batch.device)
return batch + basis
_channel_flips = np.array(
[random.randrange(-1, 2, 2) for i in range(batch.shape[0] * batch.shape[1])],
dtype="float32",
)
out = np.multiply.outer(_channel_flips, basis).reshape(batch.shape)
if isinstance(batch, tr.Tensor):
out = tr.tensor(out, dtype=batch.dtype, device=batch.device)
out += batch
return out
def normalize(
imgs: np.ndarray,
*,
mean: Union[float, Sequence[float]],
std: Union[float, Sequence[float]],
inplace: bool = False,
):
"""
Returns (imgs - mean) / std
Parameters
----------
imgs: np.ndarray, shape-(N, C, H, W)
mean : float | shape-(C,)
std : float | shape-(C,)
inplace: bool, optional (default=False)
Returns
-------
normalized_imgs: np.ndarray, shape-(N, C, H, W)
"""
mean_arr = np.atleast_1d(mean)[None, :, None, None].astype(imgs.dtype)
std_arr = np.atleast_1d(std)[None, :, None, None].astype(imgs.dtype)
assert imgs.ndim == 4
if not inplace:
return (imgs - mean_arr) / std_arr
else:
imgs -= mean_arr
imgs /= std_arr
return imgs
class HeatMapEntry(NamedTuple):
pos: Tuple[int, int]
sym_pos: Tuple[int, int]
result: Any
[docs]
def create_heatmaps(
dataloader: Collection[Tuple[tr.Tensor, tr.Tensor]],
image_height_width: Tuple[int, int],
*,
model: tr.nn.Module,
metrics: Dict[str, Type[Metric]],
basis_norm: float,
rand_flip_per_channel: bool,
post_pert_batch_transform: Optional[Callable[[tr.Tensor], tr.Tensor]] = None,
device: Optional[Union[tr.device, str, int]] = None,
row_col_coords: Optional[Iterable[Tuple[int, int]]] = None,
factor_2pi_phase_shift: float = 0,
) -> Dict[str, List[HeatMapEntry]]:
from rai_toolbox._utils.tqdm import tqdm
_outer_total = (
None if row_col_coords is not None else int(np.prod(image_height_width)) // 2
)
with evaluating(model), tr.no_grad():
if device is not None:
device = tr.device(device)
model.to(device=device)
else:
device = get_device(model)
if post_pert_batch_transform is None:
def post_pert_batch_transform(x):
return x
results: DefaultDict[str, Dict[Tuple[int, int], HeatMapEntry]] = defaultdict(
dict
)
for batch, targets in tqdm(dataloader, desc="batch"):
batch = batch.pin_memory()
targets = targets.to(device=device)
for basis in tqdm(
generate_fourier_bases(
*image_height_width,
dtype="float32",
row_col_coords=row_col_coords,
factor_2pi_phase_shift=factor_2pi_phase_shift,
),
total=_outer_total,
desc="fourier-grid",
leave=False,
):
p_batch = batch.to(device=device)
p_batch = perturb_batch(
batch=p_batch,
basis=basis.basis,
basis_norm=basis_norm,
rand_flip_per_channel=rand_flip_per_channel,
)
p_batch = post_pert_batch_transform(p_batch)
probs = softmax(model(p_batch), dim=1).cpu()
targets = targets.cpu()
for name, M in metrics.items():
if basis.position not in results[name]:
results[name][basis.position] = HeatMapEntry(
basis.position,
sym_pos=basis.sym_position,
result=M(),
)
results[name][basis.position].result.update(
preds=probs, target=targets
)
out: Dict[str, List[HeatMapEntry]] = {}
for metric_name, r in results.items():
out[metric_name] = [
HeatMapEntry(p, ps, metric.compute())
for _, (p, ps, metric) in r.items()
]
return out