Source code for rai_toolbox.losses._jensen_shannon_divergence

# Copyright 2023, MASSACHUSETTS INSTITUTE OF TECHNOLOGY
# Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014).
# SPDX-License-Identifier: MIT

from typing import Optional

import torch as tr
import torch.nn.functional as F

from rai_toolbox._typing import ArrayLike


[docs] def jensen_shannon_divergence( *probs: ArrayLike, weight: Optional[float] = None ) -> tr.Tensor: r""" Computes the Jensen-Shannon divergence [1]_ between n distributions: :math:`JSD(P_1, P_2, ..., P_n)` This loss is symmetric and is bounded by :math:`0 <= JSD(P_1, P_2, ..., P_n) <= \ln(n)` Parameters ---------- probs : ArrayLike, shape-(N, D) A collection of n probability distributions. Each conveys of batch of N distributions over D categories. weight : Optional[float] A scaling factor that will be applied to the consistency loss. Returns ------- loss : tr.Tensor, shape-(,) The scalar loss computed via the batch-mean. Notes ----- The JSD loss is computed for each corresponding n-tuple of distributions and the batch-mean is ultimately returned. References ---------- .. [1] https://en.wikipedia.org/wiki/Jensen%E2%80%93Shannon_divergence Examples -------- Let's measure the divergence between three discrete distributions of length-two. >>> from rai_toolbox.losses import jensen_shannon_divergence >>> P1 = [[0.0, 1.0]] >>> P2 = [[1.0, 0.0]] >>> P3 = [[0.5, 0.5]] >>> jensen_shannon_divergence(P1, P2, P3) tensor(0.4621) The divergence is symmetric. >>> jensen_shannon_divergence(P1, P3, P2) tensor(0.4621) """ list_probs = [tr.as_tensor(p) for p in probs] if len(list_probs) < 2 or any( not isinstance(p, tr.Tensor) or p.dim() != 2 for p in list_probs ): raise ValueError( f"*probs must consist of at least two Tensors, and each tensor must have a shape of (N, D). Got {probs}" ) zero = tr.tensor(0.0).type_as(list_probs[0]) # Clamp mixture distribution to avoid exploding KL divergence log_p_mixture = tr.clamp(sum(list_probs, zero) / len(list_probs), 1e-7, 1).log() loss = sum( (F.kl_div(log_p_mixture, p, reduction="batchmean") for p in list_probs), zero ) / len(probs) if weight is not None: loss = loss * weight return loss