Source code for rai_toolbox.losses._jensen_shannon_divergence
# Copyright 2023, MASSACHUSETTS INSTITUTE OF TECHNOLOGY
# Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014).
# SPDX-License-Identifier: MIT
from typing import Optional
import torch as tr
import torch.nn.functional as F
from rai_toolbox._typing import ArrayLike
[docs]
def jensen_shannon_divergence(
    *probs: ArrayLike, weight: Optional[float] = None
) -> tr.Tensor:
    r"""
    Computes the Jensen-Shannon divergence [1]_ between n distributions:
    :math:`JSD(P_1, P_2, ..., P_n)`
    This loss is symmetric and is bounded by :math:`0 <= JSD(P_1, P_2, ..., P_n) <= \ln(n)`
    Parameters
    ----------
    probs : ArrayLike, shape-(N, D)
        A collection of n probability distributions. Each conveys of batch of N
        distributions over D categories.
    weight : Optional[float]
        A scaling factor that will be applied to the consistency loss.
    Returns
    -------
    loss : tr.Tensor, shape-(,)
        The scalar loss computed via the batch-mean.
    Notes
    -----
    The JSD loss is computed for each corresponding n-tuple of distributions and
    the batch-mean is ultimately returned.
    References
    ----------
    .. [1] https://en.wikipedia.org/wiki/Jensen%E2%80%93Shannon_divergence
    Examples
    --------
    Let's measure the divergence between three discrete distributions of length-two.
    >>> from rai_toolbox.losses import jensen_shannon_divergence
    >>> P1 = [[0.0, 1.0]]
    >>> P2 = [[1.0, 0.0]]
    >>> P3 = [[0.5, 0.5]]
    >>> jensen_shannon_divergence(P1, P2, P3)
    tensor(0.4621)
    The divergence is symmetric.
    >>> jensen_shannon_divergence(P1, P3, P2)
    tensor(0.4621)
    """
    list_probs = [tr.as_tensor(p) for p in probs]
    if len(list_probs) < 2 or any(
        not isinstance(p, tr.Tensor) or p.dim() != 2 for p in list_probs
    ):
        raise ValueError(
            f"*probs must consist of at least two Tensors, and each tensor must have a shape of (N, D). Got {probs}"
        )
    zero = tr.tensor(0.0).type_as(list_probs[0])
    # Clamp mixture distribution to avoid exploding KL divergence
    log_p_mixture = tr.clamp(sum(list_probs, zero) / len(list_probs), 1e-7, 1).log()
    loss = sum(
        (F.kl_div(log_p_mixture, p, reduction="batchmean") for p in list_probs), zero
    ) / len(probs)
    if weight is not None:
        loss = loss * weight
    return loss