# Copyright 2023, MASSACHUSETTS INSTITUTE OF TECHNOLOGY
# Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014).
# SPDX-License-Identifier: MIT
import functools
from typing import Any, Callable, List, Optional, Tuple, Union
import torch as tr
from torch import Tensor
from torch.nn import CrossEntropyLoss, Module
from torch.optim import Optimizer as _TorchOptim
from rai_toolbox import negate
from rai_toolbox._typing import (
    ArrayLike,
    InstantiatesTo,
    Optimizer,
    OptimizerType,
    Partial,
    instantiates_to,
)
from rai_toolbox._utils import value_check
from rai_toolbox._utils.stateful import evaluating, frozen
from rai_toolbox.perturbations import AdditivePerturbation, PerturbationModel
[docs]
def gradient_ascent(
    *,
    model: Callable[[Tensor], Tensor],
    data: ArrayLike,
    target: ArrayLike,
    optimizer: Union[Optimizer, OptimizerType, Partial[Optimizer]],
    steps: int,
    perturbation_model: Union[
        PerturbationModel, InstantiatesTo[PerturbationModel]
    ] = AdditivePerturbation,
    targeted: bool = False,
    use_best: bool = False,
    criterion: Optional[Callable[[Tensor, Tensor], Tensor]] = None,
    reduction_fn: Callable[[Tensor], Tensor] = tr.sum,
    **optim_kwargs: Any,
) -> Tuple[Tensor, Tensor]:
    """Solve for a set of perturbations for a given set of data and a model,
    and then apply those perturbations to the data.
    This performs, for `steps` iterations, the following optimization::
       optim = optimizer(perturbation_model.parameters)
       pert_data = perturbation_model(data)
       loss = criterion(model(pert_data), target)
       loss = (1 if targeted else -1) * loss  # default: targeted=False
       optim.step()
    Note that, by default, this perturbs the data away from `target` (i.e., this performs
    gradient *ascent*), given a standard loss function that seeks to minimize the
    difference between the model's output and the target. See `targeted` to toggle this
    behavior.
    Parameters
    ----------
    model : Callable[[Tensor], Tensor]
        Differentiable function that processes the (perturbed) data prior to computing
        the loss.
        If `model` is a `torch.nn.Module`, then its weights will be frozen and it will
        be set to eval mode during the perturbation-solve phase.
    data : ArrayLike, shape-(N, ...)
        The input data to perturb.
    target : ArrayLike, shape-(N, ...)
        If `targeted==False` (default), then this is the target to perturb away from.
        If `targeted==True`, then this is the target to perturb toward.
    optimizer : Optimizer | Type[Optimizer] | Partial[Optimizer]
        The optimizer to use for updating the perturbation model.
        If `optimizer` is uninstantiated, it will be instantiated as
        `optimizer(perturbation_model.parameters(), **optim_kwargs)`
    steps : int
        Number of projected gradient steps.
    perturbation_model : PerturbationModel | Type[PerturbationModel], optional (default=AdditivePerturbation)
        A `torch.nn.Module` whose parameters are updated by the solver. Its forward-pass
        applies the perturbation to the data. Default is
        `AdditivePerturbation`, which simply adds the perturbation to the data.
        The perturbation model should not modify the data in-place.
        If `perturbation_model` is a type, then it will be instantiated as
        `perturbation_model(data)`.
    criterion : Optional[Callable[[Tensor, Tensor], Tensor]]
        The criterion to use for calculating the loss **per-datum**.  I.e., for a
        shape-(N, ...) batch of data, `criterion` should return a shape-(N,) tensor of
        loss values – one for each datum in the batch.
        If `None`, then `CrossEntropyLoss(reduction=None)` is used.
    targeted : bool (default: False)
        If `True`, then perturb towards the defined `target`, otherwise move away from
        `target`.
        Note: Default (`targeted=False`) implements gradient *ascent*.
        To perform gradient *descent*, set `targeted=True`.
    use_best : bool (default: True)
        Whether to only report the best perturbation over all steps.
        Note: Requires criterion to output a loss per sample, e.g., set
        `reduction="none"`.
    reduction_fn : Callable[[Tensor], Tensor], optional (default=torch.sum)
        Used to reduce the shape-(N,) per-datum loss to a scalar. This should be
        set to `torch.mean` when solving for a "universal" perturbation.
    **optim_kwargs : Any
       Keyword arguments passed to `optimizer` when it is instatiated.
    Returns
    -------
    xadv, losses : tuple[Tensor, Tensor], shape-(N, ...), shape-(N, ...)
        The perturbed data, if `use_best==True` then this is the best perturbation
        based on the loss across all steps.
        The loss for each perturbed data point, if `use_best==True` then this is the
        best loss across all steps.
    Notes
    -----
    `model` is automatically set to eval-mode and its parameters are set to
    `requires_grad=False` within the context of this function.
    Examples
    --------
    Let's perturb two data points, `x1=-1.0` and `x2=2.0`, to maximize
    `L(δ; x) = |x + δ|` w.r.t `δ`. We will use five standard gradient steps, using a
    learning  rate of 0.1. The default perturbation model is simply additive:
    `x -> x + δ`.
    This solver is refining `δ1` and `δ2`, whose initial values are 0 by default, to
    maximize `L(x) = |x|` for `x1` and `x2`, respectively. Thus we should find that our
    solved perturbations modify our data as: `x-1.0 -> -1.5` and `2.0 -> 2.5`,
    respectively.
    >>> from rai_toolbox.perturbations import gradient_ascent
    >>> from torch.optim import SGD
    >>> identity_model = lambda data: data
    >>> abs_diff = lambda model_out, target: (model_out - target).abs()
    >>> perturbed_data, losses = gradient_ascent(
    ...    data=[-1.0, 2.0],
    ...    target=0.0,
    ...    model=identity_model,
    ...    criterion=abs_diff,
    ...    optimizer=SGD,
    ...    lr=0.1,
    ...    steps=5,
    ... )
    >>> perturbed_data
    tensor([-1.5000,  2.5000])
    We can instead specify `targeted=True` and perform gradient *descent*.
    Here, the perturbations we solve for should modify our data as:
    -1.0 -> -0.5 and 2.0 -> 1.5, respectively.
    >>> perturbed_data, losses = gradient_ascent(
    ...    data=[-1.0, 2.0],
    ...    target=0.0,
    ...    model=identity_model,
    ...    criterion=abs_diff,
    ...    optimizer=SGD,
    ...    lr=0.1,
    ...    steps=5,
    ...    targeted=True,
    ... )
    >>> perturbed_data
    tensor([-0.5000,  1.5000])
    **Accessing the perturbations**
    To gain direct access to the solved perturbations, we can provide our own
    perturbation model to the solver. Let's solve the same optimization problem, but
    provide our own instance of `AdditivePerturbation`
    >>> from rai_toolbox.perturbations import AdditivePerturbation
    >>> pert_model = AdditivePerturbation(data_or_shape=(2,))
    >>> perturbed_data, losses = gradient_ascent(
    ...    perturbation_model=pert_model,
    ...    data=[-1.0, 2.0],
    ...    target=0.0,
    ...    model=identity_model,
    ...    criterion=abs_diff,
    ...    optimizer=SGD,
    ...    lr=0.1,
    ...    steps=5,
    ... )
    >>> perturbed_data
    tensor([-1.5000,  2.5000])
    Now we can access the values that were solved for `δ1` and `δ2`.
    >>> pert_model.delta
    Parameter containing:
    tensor([-0.5000,  0.5000], requires_grad=True)
    """
    data = tr.as_tensor(data)
    target = tr.as_tensor(target)
    if not data.is_leaf:
        data = data.detach()
    if not target.is_leaf:
        target = target.detach()
    # Initialize
    best_loss = None
    best_x = None
    if criterion is None:
        criterion = CrossEntropyLoss(reduction="none")
    if not targeted:
        # maximize the objective function
        criterion = negate(criterion)
    if instantiates_to(perturbation_model, PerturbationModel):
        pmodel = perturbation_model(data)
    else:
        if not isinstance(perturbation_model, PerturbationModel):
            raise TypeError(
                f"`perturbation_model` must be satisfy the `PerturbationModel`"
                f" protocol, got: {perturbation_model}"
            )
        pmodel = perturbation_model
    if instantiates_to(optimizer, _TorchOptim):
        optim = optimizer(pmodel.parameters(), **optim_kwargs)
    else:
        if not isinstance(optimizer, _TorchOptim):
            raise TypeError(
                f"`optimizer` must be an instance of Optimizer or must instantiate to "
                f"Optimizer; got: {optimizer}"
            )
        if instantiates_to(perturbation_model, PerturbationModel):
            raise TypeError(
                "An initialized optimizer can only be passed to the solver in "
                "combination with an initialized perturbation model."
            )
        if optim_kwargs:
            raise TypeError(
                "**optim_kwargs were specified, but the provided `optimizer` has "
                "already been instaniated"
            )
        optim = optimizer
    to_freeze: List[Any] = [data, target]
    if isinstance(model, Module):
        to_freeze.append(model)
    # don't pass non nn.Module to frozen/eval
    packed_model = (model,) if isinstance(model, Module) else ()
    # Projected Gradient Descent
    with frozen(*to_freeze), evaluating(*packed_model), tr.enable_grad():
        for _ in range(steps):
            # Calculate the gradient of loss
            xadv = pmodel(data)
            logits = model(xadv)
            losses = criterion(logits, target)
            loss = reduction_fn(losses)
            # Update the perturbation
            optim.zero_grad(set_to_none=True)
            loss.backward()
            optim.step()
            if use_best:
                if (
                    (losses.ndim == 0 and data.ndim > 0)
                    or losses.ndim > data.ndim
                    or losses.shape != data.shape[: losses.ndim]
                ):
                    raise ValueError(
                        f"`use_best=True` but `criterion` does not output a per-datum-loss. "
                        f"I.e. `criterion` returned a tensor of shape-{tuple(losses.shape)} for a "
                        f"batch of shape-{tuple(data.shape)}. Expected a tensor of "
                        f"shape-{(len(data),)} or greater."
                    )
                best_loss, best_x = _replace_best(losses, best_loss, xadv, best_x)
        # free up memory
        optim.zero_grad(set_to_none=True)
        # Final evaluation
        with tr.no_grad():
            xadv = pmodel(data)
            logits = model(xadv)
            losses = criterion(logits, target)
            if use_best:
                # we negate the loss when `targeted=True` so min-loss is always best
                losses, xadv = _replace_best(losses, best_loss, xadv, best_x, min=True)
    if not targeted:
        # The returned loss is always relative to the original criterion.
        # E.g. an adversarial perturbation with a "low" negated loss will yield
        # a high loss.
        losses = -1 * losses
    return xadv.detach(), losses.detach() 
[docs]
def random_restart(
    solver: Callable[..., Tuple[Tensor, Tensor]],
    repeats: int,
) -> Callable[..., Tuple[Tensor, Tensor]]:
    """Executes a solver function multiple times, saving out the best perturbations.
    Parameters
    ----------
    solver : Callable[..., Tuple[Tensor, Tensor]]
        The solver whose execution will be repeated.
    repeats : int
        The number of times to run `solver`
    Returns
    -------
    random_restart_fn : Callable[..., Tuple[Tensor, Tensor]]
        Wrapped function that will execute `solver` `repeats` times.
    Examples
    --------
    Let's perturb two data points, `x1=-1.0` and `x2=2.0`, to maximize
    `L(δ; x) = |x + δ|` w.r.t `δ`. Our perturbation will randomly initialize `δ1` and
    `δ2` and we will re-run the solver three times – retaining the best perturbation of
    `x1` and `x2` respectively.
    >>> from functools import partial
    >>> import torch as tr
    >>> from torch.optim import SGD
    >>> from rai_toolbox.perturbations.init import uniform_like_l1_n_ball_
    >>> from rai_toolbox.perturbations import AdditivePerturbation, gradient_ascent, random_restart
    >>>
    >>> def verbose_abs_diff(model_out, target):
    ...     # used to print out loss at each solver step (for purpose of example)
    ...     out =  (model_out - target).abs()
    ...     print(out)
    ...     return out
    Configuring a peturbation model to randomly initialize the perturbations.
    >>> RNG = tr.Generator().manual_seed(0)
    >>> PertModel = partial(AdditivePerturbation, init_fn=uniform_like_l1_n_ball_, generator=RNG)
    Creating and running repeating solver.
    >>> gradient_ascent_with_restart = random_restart(gradient_ascent, 3)
    >>> perturbed_data, losses = gradient_ascent_with_restart(
    ...    perturbation_model=PertModel,
    ...    data=[-1.0, 2.0],
    ...    target=0.0,
    ...    model=lambda data: data,
    ...    criterion=verbose_abs_diff,
    ...    optimizer=SGD,
    ...    lr=0.1,
    ...    steps=1,
    ... )
    tensor([0.5037, 2.7682], grad_fn=<AbsBackward0>)
    tensor([0.6037, 2.8682])
    tensor([0.9115, 2.1320], grad_fn=<AbsBackward0>)
    tensor([1.0115, 2.2320])
    tensor([0.6926, 2.6341], grad_fn=<AbsBackward0>)
    tensor([0.7926, 2.7341])
    See that for `x1` the highest loss is `1.0115`, and for `x2` it is `2.8682`. This
    should be reflected `losses` and `perturbed_data` that were retained across the
    restarts.
    >>> losses
    tensor([1.0115, 2.8682])
    >>> perturbed_data
    tensor([-1.0115,  2.8682])
    """
    value_check("repeats", repeats, min_=1, incl_min=True)
    @functools.wraps(solver)
    def random_restart_fn(*args, **kwargs) -> Tuple[Tensor, Tensor]:
        targeted = kwargs.get("targeted", False)
        best_x = None
        best_loss = None
        for _ in range(repeats):
            # run the attack
            xadv, losses = solver(*args, **kwargs)
            # Save best loss for each data point
            best_loss, best_x = _replace_best(
                loss=losses,
                best_loss=best_loss,
                data=xadv,
                best_data=best_x,
                min=targeted,
            )
        assert best_x is not None
        assert best_loss is not None
        return best_x, best_loss
    return random_restart_fn 
# A function that updates the best loss and best input
def _replace_best(
    loss: Tensor,
    best_loss: Optional[Tensor],
    data: Tensor,
    best_data: Optional[Tensor],
    min: bool = True,
) -> Tuple[Tensor, Tensor]:
    """Returns the data with the smallest (or largest) loss
    Parameters
    ----------
    loss : Tensor, shape-(N, ...)
        N: batch size
    best_loss : Optional[Tensor], shape-(N, ...)
        N: batch size
    data : Tensor, shape-(N, ...)
        N: batch size
    best_data : Optional[Tensor], shape-(N, ...)
        N: batch size
    min : bool (default: True)
        Whether best is minimum (True) or maximum (False)
    Returns
    -------
    best_loss, best_data : Tuple[Tensor, Tensor]
    """
    if best_loss is None:
        best_data = data
        best_loss = loss
    else:
        assert best_data is not None
        if min:
            replace = loss < best_loss
        else:
            replace = loss > best_loss
        best_data[replace] = data[replace]
        best_loss[replace] = loss[replace]
    return best_loss, best_data