Coverage for /opt/hostedtoolcache/Python/3.10.18/x64/lib/python3.10/site-packages/equine/utils.py: 100%
163 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-29 04:12 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-29 04:12 +0000
1# Copyright 2024, MASSACHUSETTS INSTITUTE OF TECHNOLOGY
2# Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014).
3# SPDX-License-Identifier: MIT
5from typing import Any, Union
7import icontract
8import torch
9from beartype import beartype
10from collections import OrderedDict
11from torchmetrics.classification import (
12 MulticlassAccuracy,
13 MulticlassCalibrationError,
14 MulticlassConfusionMatrix,
15 MulticlassF1Score,
16)
18from .equine import Equine
19from .equine_output import EquineOutput
22@icontract.require(lambda y_hat, y_test: y_hat.size(dim=0) == y_test.size(dim=0))
23@icontract.ensure(lambda result: result >= 0.0)
24@beartype
25def brier_score(y_hat: torch.Tensor, y_test: torch.Tensor) -> float:
26 """
27 Compute the Brier score for a multiclass problem:
28 $$ \\frac{1}{N} \\sum_{i=1}^{N} \\sum_{j=1}^{M} (f_{ij} - o_{ij})^2 , $$
29 where $f_{ij}$ is the predicted probability of class $j$ for inference sample $i$
30 and $o_{ij}$ is the one-hot encoded ground truth label.
32 Parameters
33 ----------
34 y_hat : torch.Tensor
35 Probabilities for each class.
36 y_test : torch.Tensor
37 Integer argument class labels (ground truth).
39 Returns
40 -------
41 float
42 Brier score.
43 """
44 (_, num_classes) = y_hat.size()
45 one_hot_y_test = torch.nn.functional.one_hot(y_test.long(), num_classes=num_classes)
46 bs = torch.mean(torch.sum((y_hat - one_hot_y_test) ** 2, dim=1)).item()
47 return bs
50@icontract.require(lambda y_hat, y_test: y_hat.size(dim=0) == y_test.size(dim=0))
51@icontract.ensure(lambda result: result <= 1.0)
52@beartype
53def brier_skill_score(y_hat: torch.Tensor, y_test: torch.Tensor) -> float:
54 """
55 Compute the Brier skill score as compared to randomly guessing.
57 Parameters
58 ----------
59 y_hat : torch.Tensor
60 Probabilities for each class.
61 y_test : torch.Tensor
62 Integer argument class labels (ground truth).
64 Returns
65 -------
66 float
67 Brier skill score.
68 """
69 (_, num_classes) = y_hat.size()
70 random_guess = (1.0 / num_classes) * torch.ones(y_hat.size())
71 bs0 = brier_score(random_guess, y_test)
72 bs1 = brier_score(y_hat, y_test)
73 bss = 1.0 - bs1 / bs0
74 return bss
77@icontract.require(lambda y_hat, y_test: y_hat.size(dim=0) == y_test.size(dim=0))
78@icontract.ensure(lambda result: (0.0 <= result) and (result <= 1.0))
79@beartype
80def expected_calibration_error(y_hat: torch.Tensor, y_test: torch.Tensor) -> float:
81 """
82 Compute the expected calibration error (ECE) for a multiclass problem.
84 Parameters
85 ----------
86 y_hat : torch.Tensor
87 Probabilities for each class.
88 y_test : torch.Tensor
89 Class label indices (ground truth).
91 Returns
92 -------
93 float
94 Expected calibration error.
95 """
96 (_, num_classes) = y_hat.size()
97 metric = MulticlassCalibrationError(num_classes=num_classes, n_bins=25, norm="l1")
98 ece = metric(y_hat, y_test).item()
99 return ece
102@icontract.require(
103 lambda train_y, selected_labels: len(selected_labels) <= len(train_y)
104)
105@icontract.ensure(
106 lambda result, selected_labels: set(result.keys()).issubset(set(selected_labels))
107)
108@beartype
109def _get_shuffle_idxs_by_class(
110 train_y: torch.Tensor, selected_labels: list
111) -> dict[Any, torch.Tensor]:
112 """
113 Internal helper function to randomly select indices of example classes for a given
114 set of labels.
116 Parameters
117 ----------
118 train_y : torch.Tensor
119 Label data.
120 selected_labels : list
121 list of unique labels found in the label data.
123 Returns
124 -------
125 dict[Any, torch.Tensor]
126 Tensor of indices corresponding to each label.
127 """
128 shuffled_idxs_by_class = OrderedDict()
129 for label in selected_labels:
130 label_idxs = torch.argwhere(train_y == label).squeeze()
131 shuffled_idxs_by_class[label] = label_idxs[torch.randperm(label_idxs.shape[0])]
133 return shuffled_idxs_by_class
136@icontract.require(lambda train_x, train_y: len(train_x) <= len(train_y))
137@icontract.require(
138 lambda selected_labels, train_x: (0 < len(selected_labels))
139 & (len(selected_labels) < len(train_x))
140)
141@icontract.require(
142 lambda support_size, train_x: (0 < support_size) & (support_size < len(train_x))
143)
144@icontract.require(
145 lambda support_size, selected_labels, train_x: support_size * len(selected_labels)
146 <= len(train_x)
147)
148@icontract.require(
149 lambda selected_labels, shuffled_indexes: (
150 (len(shuffled_indexes.keys()) == len(selected_labels))
151 if shuffled_indexes is not None
152 else True
153 )
154)
155@icontract.ensure(
156 lambda result, selected_labels: len(result.keys()) == len(selected_labels)
157)
158@beartype
159def generate_support(
160 train_x: torch.Tensor,
161 train_y: torch.Tensor,
162 support_size: int,
163 selected_labels: list[Any],
164 shuffled_indexes: Union[None, dict[Any, torch.Tensor]] = None,
165) -> OrderedDict[int, torch.Tensor]:
166 """
167 Randomly select `support_size` examples of `way` classes from the examples in
168 `train_x` with corresponding labels in `train_y` and return them as a dictionary.
170 Parameters
171 ----------
172 train_x : torch.Tensor
173 Input training data.
174 train_y : torch.Tensor
175 Corresponding classification labels.
176 support_size : int
177 Number of support examples for each class.
178 selected_labels : list
179 Selected class labels to generate examples from.
180 shuffled_indexes: Union[None, dict[Any, torch.Tensor]], optional
181 Simply use the precomputed indexes if they are available
183 Returns
184 -------
185 OrderedDict[int, torch.Tensor]
186 Ordered dictionary of class labels with corresponding support examples.
187 """
188 labels, counts = torch.unique(train_y, return_counts=True)
189 if shuffled_indexes is None:
190 for label, count in list(zip(labels, counts)):
191 if (label in selected_labels) and (count < support_size):
192 raise ValueError(f"Not enough support examples in class {label}")
193 shuffled_idxs = _get_shuffle_idxs_by_class(train_y, selected_labels)
194 else:
195 shuffled_idxs = shuffled_indexes
197 support = OrderedDict[int, torch.Tensor]()
198 for label in selected_labels:
199 shuffled_x = train_x[shuffled_idxs[label]]
201 assert torch.unique(train_y[shuffled_idxs[label]]).tolist() == [
202 label
203 ], "Not enough support for label " + str(label)
204 selected_support = shuffled_x[:support_size]
205 support[int(label)] = selected_support
207 return support
210@icontract.require(lambda train_x: len(train_x.shape) >= 2)
211@icontract.require(lambda train_y: len(train_y.shape) == 1)
212@icontract.require(lambda support_size: support_size > 1)
213@icontract.require(lambda way: way > 0)
214@icontract.require(lambda episode_size: episode_size > 0)
215@icontract.ensure(lambda result: len(result) == 3)
216@icontract.ensure(lambda result: result[1].shape[0] == result[2].shape[0])
217@icontract.ensure(lambda way, result: len(result[0]) == way)
218@icontract.ensure(
219 lambda support_size, result: all(
220 len(support) == support_size for support in result[0].values()
221 )
222)
223@beartype
224def generate_episode(
225 train_x: torch.Tensor,
226 train_y: torch.Tensor,
227 support_size: int,
228 way: int,
229 episode_size: int,
230) -> tuple[OrderedDict[int, torch.Tensor], torch.Tensor, torch.Tensor]:
231 """
232 Generate a single episode of data for a few-shot learning task.
234 Parameters
235 ----------
236 train_x : torch.Tensor
237 Input training data.
238 train_y : torch.Tensor
239 Corresponding classification labels.
240 support_size : int
241 Number of support examples for each class.
242 way : int
243 Number of classes in the episode.
244 episode_size : int
245 Total number of examples in the episode.
247 Returns
248 -------
249 tuple[dict[Any, torch.Tensor], torch.Tensor, torch.Tensor]
250 tuple of support examples, query examples, and query labels.
251 """
252 labels, counts = torch.unique(train_y, return_counts=True)
253 if way > len(labels):
254 raise ValueError(
255 f"The way (#classes in each episode), {way}, must be <= number of labels, {len(labels)}"
256 )
258 selected_labels = sorted(
259 labels[torch.randperm(labels.shape[0])][:way].tolist()
260 ) # need to be in same order every time
262 for label, count in list(zip(labels, counts)):
263 if (label in selected_labels) and (count < support_size):
264 raise ValueError(f"Not enough support examples in class {label}")
265 shuffled_idxs = _get_shuffle_idxs_by_class(train_y, selected_labels)
267 support = generate_support(
268 train_x, train_y, support_size, selected_labels, shuffled_idxs
269 )
271 examples_per_task = episode_size // way
273 episode_data_list = []
274 episode_label_list = []
275 episode_support = OrderedDict()
276 for episode_label, label in enumerate(selected_labels):
277 shuffled_x = train_x[shuffled_idxs[label]]
278 shuffled_y = torch.Tensor(
279 [episode_label] * len(shuffled_idxs[label])
280 ) # need sequential labels for episode
282 num_remaining_examples = shuffled_x.shape[0] - support_size
283 assert num_remaining_examples > 0, (
284 "Cannot have "
285 + str(num_remaining_examples)
286 + " left with support_size "
287 + str(support_size)
288 + " and shape "
289 + str(shuffled_x.shape)
290 + " from train_x shaped "
291 + str(train_x.shape)
292 )
293 episode_end_idx = support_size + min(num_remaining_examples, examples_per_task)
295 episode_data_list.append(shuffled_x[support_size:episode_end_idx])
296 episode_label_list.append(shuffled_y[support_size:episode_end_idx])
297 episode_support[episode_label] = support[label]
299 episode_x = torch.concat(episode_data_list)
300 episode_y = torch.concat(episode_label_list)
302 return episode_support, episode_x, episode_y.squeeze().to(torch.long)
305@icontract.require(
306 lambda eq_preds, true_y: eq_preds.classes.size(dim=0) == true_y.size(dim=0)
307)
308@beartype
309def generate_model_metrics(
310 eq_preds: EquineOutput, true_y: torch.Tensor
311) -> dict[str, Any]:
312 """
313 Generate various metrics for evaluating a model's performance.
315 Parameters
316 ----------
317 eq_preds : EquineOutput
318 Model predictions.
319 true_y : torch.Tensor
320 True class labels.
322 Returns
323 -------
324 dict[str, Any]
325 Dictionary of model metrics.
326 """
327 pred_y = torch.argmax(eq_preds.classes, dim=1)
328 accuracy = MulticlassAccuracy(num_classes=eq_preds.classes.shape[1])
329 f1_score = MulticlassF1Score(num_classes=eq_preds.classes.shape[1], average="micro")
330 confusion_matrix = MulticlassConfusionMatrix(num_classes=eq_preds.classes.shape[1])
331 metrics = {
332 "accuracy": accuracy(true_y, pred_y),
333 "microF1Score": f1_score(true_y, pred_y),
334 "confusionMatrix": confusion_matrix(true_y, pred_y).tolist(),
335 "brierScore": brier_score(eq_preds.classes, true_y),
336 "brierSkillScore": brier_skill_score(eq_preds.classes, true_y),
337 "expectedCalibrationError": expected_calibration_error(
338 eq_preds.classes, true_y
339 ),
340 }
341 return metrics
344@icontract.require(lambda Y: len(Y.shape) == 1)
345@icontract.ensure(
346 lambda result: all("label" in d and "numExamples" in d for d in result)
347)
348@icontract.ensure(lambda result: all(d["numExamples"] >= 0 for d in result))
349@beartype
350def get_num_examples_per_label(Y: torch.Tensor) -> list[dict[str, Any]]:
351 """
352 Get the number of examples per label in the given tensor.
354 Parameters
355 ----------
356 Y : torch.Tensor
357 Tensor of class labels.
359 Returns
360 -------
361 list[dict[str, Any]]
362 list of dictionaries containing label and number of examples.
363 """
364 tensor_labels, tensor_counts = Y.unique(return_counts=True)
366 examples_per_label = []
367 for i, label in enumerate(tensor_labels):
368 examples_per_label.append(
369 {"label": label.item(), "numExamples": tensor_counts[i].item()}
370 )
372 return examples_per_label
375@icontract.require(lambda train_y: train_y.shape[0] > 0)
376@beartype
377def generate_train_summary(
378 model: Equine, train_y: torch.Tensor, date_trained: str
379) -> dict[str, Any]:
380 """
381 Generate a summary of the training data.
383 Parameters
384 ----------
385 model : Equine
386 Model object.
387 train_y : torch.Tensor
388 Training labels.
389 date_trained : str
390 Date of training.
392 Returns
393 -------
394 dict[str, Any]
395 Dictionary containing training summary.
396 """
397 train_summary = {
398 "numTrainExamples": get_num_examples_per_label(train_y),
399 "dateTrained": date_trained,
400 "modelType": model.__class__.__name__,
401 }
402 return train_summary
405@icontract.require(
406 lambda eq_preds, test_y: test_y.shape[0] == eq_preds.classes.shape[0]
407)
408@beartype
409def generate_model_summary(
410 model: Equine,
411 eq_preds: EquineOutput,
412 test_y: torch.Tensor,
413) -> dict[str, Any]:
414 """
415 Generate a summary of the model's performance.
417 Parameters
418 ----------
419 model : Equine
420 Model object.
421 eq_preds : EquineOutput
422 Model predictions.
423 test_y : torch.Tensor
424 True class labels.
426 Returns
427 -------
428 dict[str, Any]
429 Dictionary containing model summary.
430 """
431 summary = generate_model_metrics(eq_preds, test_y)
432 summary["numTestExamples"] = get_num_examples_per_label(test_y)
433 summary.update(model.train_summary) # union of train_summary and generated metrics
435 return summary
438@icontract.require(lambda cov: cov.shape[-2] == cov.shape[-1])
439def mahalanobis_distance_nosq(x: torch.Tensor, cov: torch.Tensor) -> torch.Tensor:
440 """
441 Compute Mahalanobis distance $x^T C x$ (without square root), assume cov is symmetric positive definite
443 Parameters
444 ----------
445 x : torch.Tensor
446 vectors to compute distances for
447 cov : torch.Tensor
448 covariance matrix, assumes first dimension is number of classes
449 """
450 U, S, _ = torch.linalg.svd(cov)
451 S_inv_sqrt = torch.stack(
452 [torch.diag(torch.sqrt(1.0 / S[i])) for i in range(S.shape[0])], dim=0
453 )
454 prod = torch.matmul(S_inv_sqrt, torch.transpose(U, 1, 2))
455 dist = torch.sum(torch.square(torch.matmul(prod, x)), dim=1)
456 return dist
459@icontract.require(
460 lambda X, Y: X.shape[0] == Y.shape[0],
461 "X and Y must have the same number of samples.",
462)
463@icontract.require(
464 lambda test_size: 0.0 < test_size < 1.0, "test_size must be between 0 and 1."
465)
466@icontract.ensure(
467 lambda result: len(result) == 4, "Function must return four elements."
468)
469@icontract.ensure(
470 lambda X, result: result[0].shape[0] + result[1].shape[0] == X.shape[0],
471 "Total samples must be preserved.",
472)
473@icontract.ensure(
474 lambda Y, result: result[2].shape[0] + result[3].shape[0] == Y.shape[0],
475 "Total labels must be preserved.",
476)
477@icontract.ensure(
478 lambda result: result[0].shape[0] == result[2].shape[0],
479 "Train features and labels must match in size.",
480)
481@icontract.ensure(
482 lambda result: result[1].shape[0] == result[3].shape[0],
483 "Test features and labels must match in size.",
484)
485@beartype
486def stratified_train_test_split(
487 X: torch.Tensor, Y: torch.Tensor, test_size: float
488) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
489 """
490 A pytorch-ified version of sklearn's train_test_split with data stratification
492 Parameters
493 ----------
494 X : torch.Tensor
495 Input features tensor of shape (n_samples, n_features).
496 Y : torch.Tensor
497 Labels tensor of shape (n_samples,).
498 test_size : float
499 Proportion of the dataset to include in the test split (between 0.0 and 1.0).
501 Returns
502 -------
503 train_x : torch.Tensor
504 Training set features.
505 calib_x : torch.Tensor
506 Test set features.
507 train_y : torch.Tensor
508 Training set labels.
509 calib_y : torch.Tensor
510 Test set labels.
511 """
512 unique_classes, class_counts = torch.unique(Y, return_counts=True)
513 test_counts = (class_counts.float() * test_size).round().long()
514 train_indices = []
515 test_indices = []
517 for cls, test_count in zip(unique_classes, test_counts):
518 cls_indices = torch.where(Y == cls)[0]
519 cls_indices = cls_indices[torch.randperm(len(cls_indices))]
520 test_idx = cls_indices[:test_count]
521 train_idx = cls_indices[test_count:]
522 train_indices.append(train_idx)
523 test_indices.append(test_idx)
525 train_indices = torch.cat(train_indices)
526 test_indices = torch.cat(test_indices)
528 train_x = X[train_indices]
529 train_y = Y[train_indices]
530 calib_x = X[test_indices]
531 calib_y = Y[test_indices]
533 return train_x, calib_x, train_y, calib_y