rai_toolbox.losses.jensen_shannon_divergence#
- rai_toolbox.losses.jensen_shannon_divergence(*probs, weight=None)[source]#
Computes the Jensen-Shannon divergence [1] between n distributions:
\(JSD(P_1, P_2, ..., P_n)\)
This loss is symmetric and is bounded by \(0 <= JSD(P_1, P_2, ..., P_n) <= \ln(n)\)
- Parameters:
- probsArrayLike, shape-(N, D)
A collection of n probability distributions. Each conveys of batch of N distributions over D categories.
- weightOptional[float]
A scaling factor that will be applied to the consistency loss.
- Returns:
- losstr.Tensor, shape-(,)
The scalar loss computed via the batch-mean.
Notes
The JSD loss is computed for each corresponding n-tuple of distributions and the batch-mean is ultimately returned.
References
Examples
Let’s measure the divergence between three discrete distributions of length-two.
>>> from rai_toolbox.losses import jensen_shannon_divergence >>> P1 = [[0.0, 1.0]] >>> P2 = [[1.0, 0.0]] >>> P3 = [[0.5, 0.5]] >>> jensen_shannon_divergence(P1, P2, P3) tensor(0.4621)
The divergence is symmetric.
>>> jensen_shannon_divergence(P1, P3, P2) tensor(0.4621)