rai_toolbox.losses.jensen_shannon_divergence#

rai_toolbox.losses.jensen_shannon_divergence(*probs, weight=None)[source]#

Computes the Jensen-Shannon divergence [1] between n distributions:

\(JSD(P_1, P_2, ..., P_n)\)

This loss is symmetric and is bounded by \(0 <= JSD(P_1, P_2, ..., P_n) <= \ln(n)\)

Parameters:
probsArrayLike, shape-(N, D)

A collection of n probability distributions. Each conveys of batch of N distributions over D categories.

weightOptional[float]

A scaling factor that will be applied to the consistency loss.

Returns:
losstr.Tensor, shape-(,)

The scalar loss computed via the batch-mean.

Notes

The JSD loss is computed for each corresponding n-tuple of distributions and the batch-mean is ultimately returned.

References

Examples

Let’s measure the divergence between three discrete distributions of length-two.

>>> from rai_toolbox.losses import jensen_shannon_divergence
>>> P1 = [[0.0, 1.0]]
>>> P2 = [[1.0, 0.0]]
>>> P3 = [[0.5, 0.5]]
>>> jensen_shannon_divergence(P1, P2, P3)
tensor(0.4621)

The divergence is symmetric.

>>> jensen_shannon_divergence(P1, P3, P2)
tensor(0.4621)