Coverage for /opt/hostedtoolcache/Python/3.10.18/x64/lib/python3.10/site-packages/equine/utils.py: 100%

163 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-29 04:12 +0000

1# Copyright 2024, MASSACHUSETTS INSTITUTE OF TECHNOLOGY 

2# Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014). 

3# SPDX-License-Identifier: MIT 

4 

5from typing import Any, Union 

6 

7import icontract 

8import torch 

9from beartype import beartype 

10from collections import OrderedDict 

11from torchmetrics.classification import ( 

12 MulticlassAccuracy, 

13 MulticlassCalibrationError, 

14 MulticlassConfusionMatrix, 

15 MulticlassF1Score, 

16) 

17 

18from .equine import Equine 

19from .equine_output import EquineOutput 

20 

21 

22@icontract.require(lambda y_hat, y_test: y_hat.size(dim=0) == y_test.size(dim=0)) 

23@icontract.ensure(lambda result: result >= 0.0) 

24@beartype 

25def brier_score(y_hat: torch.Tensor, y_test: torch.Tensor) -> float: 

26 """ 

27 Compute the Brier score for a multiclass problem: 

28 $$ \\frac{1}{N} \\sum_{i=1}^{N} \\sum_{j=1}^{M} (f_{ij} - o_{ij})^2 , $$ 

29 where $f_{ij}$ is the predicted probability of class $j$ for inference sample $i$ 

30 and $o_{ij}$ is the one-hot encoded ground truth label. 

31 

32 Parameters 

33 ---------- 

34 y_hat : torch.Tensor 

35 Probabilities for each class. 

36 y_test : torch.Tensor 

37 Integer argument class labels (ground truth). 

38 

39 Returns 

40 ------- 

41 float 

42 Brier score. 

43 """ 

44 (_, num_classes) = y_hat.size() 

45 one_hot_y_test = torch.nn.functional.one_hot(y_test.long(), num_classes=num_classes) 

46 bs = torch.mean(torch.sum((y_hat - one_hot_y_test) ** 2, dim=1)).item() 

47 return bs 

48 

49 

50@icontract.require(lambda y_hat, y_test: y_hat.size(dim=0) == y_test.size(dim=0)) 

51@icontract.ensure(lambda result: result <= 1.0) 

52@beartype 

53def brier_skill_score(y_hat: torch.Tensor, y_test: torch.Tensor) -> float: 

54 """ 

55 Compute the Brier skill score as compared to randomly guessing. 

56 

57 Parameters 

58 ---------- 

59 y_hat : torch.Tensor 

60 Probabilities for each class. 

61 y_test : torch.Tensor 

62 Integer argument class labels (ground truth). 

63 

64 Returns 

65 ------- 

66 float 

67 Brier skill score. 

68 """ 

69 (_, num_classes) = y_hat.size() 

70 random_guess = (1.0 / num_classes) * torch.ones(y_hat.size()) 

71 bs0 = brier_score(random_guess, y_test) 

72 bs1 = brier_score(y_hat, y_test) 

73 bss = 1.0 - bs1 / bs0 

74 return bss 

75 

76 

77@icontract.require(lambda y_hat, y_test: y_hat.size(dim=0) == y_test.size(dim=0)) 

78@icontract.ensure(lambda result: (0.0 <= result) and (result <= 1.0)) 

79@beartype 

80def expected_calibration_error(y_hat: torch.Tensor, y_test: torch.Tensor) -> float: 

81 """ 

82 Compute the expected calibration error (ECE) for a multiclass problem. 

83 

84 Parameters 

85 ---------- 

86 y_hat : torch.Tensor 

87 Probabilities for each class. 

88 y_test : torch.Tensor 

89 Class label indices (ground truth). 

90 

91 Returns 

92 ------- 

93 float 

94 Expected calibration error. 

95 """ 

96 (_, num_classes) = y_hat.size() 

97 metric = MulticlassCalibrationError(num_classes=num_classes, n_bins=25, norm="l1") 

98 ece = metric(y_hat, y_test).item() 

99 return ece 

100 

101 

102@icontract.require( 

103 lambda train_y, selected_labels: len(selected_labels) <= len(train_y) 

104) 

105@icontract.ensure( 

106 lambda result, selected_labels: set(result.keys()).issubset(set(selected_labels)) 

107) 

108@beartype 

109def _get_shuffle_idxs_by_class( 

110 train_y: torch.Tensor, selected_labels: list 

111) -> dict[Any, torch.Tensor]: 

112 """ 

113 Internal helper function to randomly select indices of example classes for a given 

114 set of labels. 

115 

116 Parameters 

117 ---------- 

118 train_y : torch.Tensor 

119 Label data. 

120 selected_labels : list 

121 list of unique labels found in the label data. 

122 

123 Returns 

124 ------- 

125 dict[Any, torch.Tensor] 

126 Tensor of indices corresponding to each label. 

127 """ 

128 shuffled_idxs_by_class = OrderedDict() 

129 for label in selected_labels: 

130 label_idxs = torch.argwhere(train_y == label).squeeze() 

131 shuffled_idxs_by_class[label] = label_idxs[torch.randperm(label_idxs.shape[0])] 

132 

133 return shuffled_idxs_by_class 

134 

135 

136@icontract.require(lambda train_x, train_y: len(train_x) <= len(train_y)) 

137@icontract.require( 

138 lambda selected_labels, train_x: (0 < len(selected_labels)) 

139 & (len(selected_labels) < len(train_x)) 

140) 

141@icontract.require( 

142 lambda support_size, train_x: (0 < support_size) & (support_size < len(train_x)) 

143) 

144@icontract.require( 

145 lambda support_size, selected_labels, train_x: support_size * len(selected_labels) 

146 <= len(train_x) 

147) 

148@icontract.require( 

149 lambda selected_labels, shuffled_indexes: ( 

150 (len(shuffled_indexes.keys()) == len(selected_labels)) 

151 if shuffled_indexes is not None 

152 else True 

153 ) 

154) 

155@icontract.ensure( 

156 lambda result, selected_labels: len(result.keys()) == len(selected_labels) 

157) 

158@beartype 

159def generate_support( 

160 train_x: torch.Tensor, 

161 train_y: torch.Tensor, 

162 support_size: int, 

163 selected_labels: list[Any], 

164 shuffled_indexes: Union[None, dict[Any, torch.Tensor]] = None, 

165) -> OrderedDict[int, torch.Tensor]: 

166 """ 

167 Randomly select `support_size` examples of `way` classes from the examples in 

168 `train_x` with corresponding labels in `train_y` and return them as a dictionary. 

169 

170 Parameters 

171 ---------- 

172 train_x : torch.Tensor 

173 Input training data. 

174 train_y : torch.Tensor 

175 Corresponding classification labels. 

176 support_size : int 

177 Number of support examples for each class. 

178 selected_labels : list 

179 Selected class labels to generate examples from. 

180 shuffled_indexes: Union[None, dict[Any, torch.Tensor]], optional 

181 Simply use the precomputed indexes if they are available 

182 

183 Returns 

184 ------- 

185 OrderedDict[int, torch.Tensor] 

186 Ordered dictionary of class labels with corresponding support examples. 

187 """ 

188 labels, counts = torch.unique(train_y, return_counts=True) 

189 if shuffled_indexes is None: 

190 for label, count in list(zip(labels, counts)): 

191 if (label in selected_labels) and (count < support_size): 

192 raise ValueError(f"Not enough support examples in class {label}") 

193 shuffled_idxs = _get_shuffle_idxs_by_class(train_y, selected_labels) 

194 else: 

195 shuffled_idxs = shuffled_indexes 

196 

197 support = OrderedDict[int, torch.Tensor]() 

198 for label in selected_labels: 

199 shuffled_x = train_x[shuffled_idxs[label]] 

200 

201 assert torch.unique(train_y[shuffled_idxs[label]]).tolist() == [ 

202 label 

203 ], "Not enough support for label " + str(label) 

204 selected_support = shuffled_x[:support_size] 

205 support[int(label)] = selected_support 

206 

207 return support 

208 

209 

210@icontract.require(lambda train_x: len(train_x.shape) >= 2) 

211@icontract.require(lambda train_y: len(train_y.shape) == 1) 

212@icontract.require(lambda support_size: support_size > 1) 

213@icontract.require(lambda way: way > 0) 

214@icontract.require(lambda episode_size: episode_size > 0) 

215@icontract.ensure(lambda result: len(result) == 3) 

216@icontract.ensure(lambda result: result[1].shape[0] == result[2].shape[0]) 

217@icontract.ensure(lambda way, result: len(result[0]) == way) 

218@icontract.ensure( 

219 lambda support_size, result: all( 

220 len(support) == support_size for support in result[0].values() 

221 ) 

222) 

223@beartype 

224def generate_episode( 

225 train_x: torch.Tensor, 

226 train_y: torch.Tensor, 

227 support_size: int, 

228 way: int, 

229 episode_size: int, 

230) -> tuple[OrderedDict[int, torch.Tensor], torch.Tensor, torch.Tensor]: 

231 """ 

232 Generate a single episode of data for a few-shot learning task. 

233 

234 Parameters 

235 ---------- 

236 train_x : torch.Tensor 

237 Input training data. 

238 train_y : torch.Tensor 

239 Corresponding classification labels. 

240 support_size : int 

241 Number of support examples for each class. 

242 way : int 

243 Number of classes in the episode. 

244 episode_size : int 

245 Total number of examples in the episode. 

246 

247 Returns 

248 ------- 

249 tuple[dict[Any, torch.Tensor], torch.Tensor, torch.Tensor] 

250 tuple of support examples, query examples, and query labels. 

251 """ 

252 labels, counts = torch.unique(train_y, return_counts=True) 

253 if way > len(labels): 

254 raise ValueError( 

255 f"The way (#classes in each episode), {way}, must be <= number of labels, {len(labels)}" 

256 ) 

257 

258 selected_labels = sorted( 

259 labels[torch.randperm(labels.shape[0])][:way].tolist() 

260 ) # need to be in same order every time 

261 

262 for label, count in list(zip(labels, counts)): 

263 if (label in selected_labels) and (count < support_size): 

264 raise ValueError(f"Not enough support examples in class {label}") 

265 shuffled_idxs = _get_shuffle_idxs_by_class(train_y, selected_labels) 

266 

267 support = generate_support( 

268 train_x, train_y, support_size, selected_labels, shuffled_idxs 

269 ) 

270 

271 examples_per_task = episode_size // way 

272 

273 episode_data_list = [] 

274 episode_label_list = [] 

275 episode_support = OrderedDict() 

276 for episode_label, label in enumerate(selected_labels): 

277 shuffled_x = train_x[shuffled_idxs[label]] 

278 shuffled_y = torch.Tensor( 

279 [episode_label] * len(shuffled_idxs[label]) 

280 ) # need sequential labels for episode 

281 

282 num_remaining_examples = shuffled_x.shape[0] - support_size 

283 assert num_remaining_examples > 0, ( 

284 "Cannot have " 

285 + str(num_remaining_examples) 

286 + " left with support_size " 

287 + str(support_size) 

288 + " and shape " 

289 + str(shuffled_x.shape) 

290 + " from train_x shaped " 

291 + str(train_x.shape) 

292 ) 

293 episode_end_idx = support_size + min(num_remaining_examples, examples_per_task) 

294 

295 episode_data_list.append(shuffled_x[support_size:episode_end_idx]) 

296 episode_label_list.append(shuffled_y[support_size:episode_end_idx]) 

297 episode_support[episode_label] = support[label] 

298 

299 episode_x = torch.concat(episode_data_list) 

300 episode_y = torch.concat(episode_label_list) 

301 

302 return episode_support, episode_x, episode_y.squeeze().to(torch.long) 

303 

304 

305@icontract.require( 

306 lambda eq_preds, true_y: eq_preds.classes.size(dim=0) == true_y.size(dim=0) 

307) 

308@beartype 

309def generate_model_metrics( 

310 eq_preds: EquineOutput, true_y: torch.Tensor 

311) -> dict[str, Any]: 

312 """ 

313 Generate various metrics for evaluating a model's performance. 

314 

315 Parameters 

316 ---------- 

317 eq_preds : EquineOutput 

318 Model predictions. 

319 true_y : torch.Tensor 

320 True class labels. 

321 

322 Returns 

323 ------- 

324 dict[str, Any] 

325 Dictionary of model metrics. 

326 """ 

327 pred_y = torch.argmax(eq_preds.classes, dim=1) 

328 accuracy = MulticlassAccuracy(num_classes=eq_preds.classes.shape[1]) 

329 f1_score = MulticlassF1Score(num_classes=eq_preds.classes.shape[1], average="micro") 

330 confusion_matrix = MulticlassConfusionMatrix(num_classes=eq_preds.classes.shape[1]) 

331 metrics = { 

332 "accuracy": accuracy(true_y, pred_y), 

333 "microF1Score": f1_score(true_y, pred_y), 

334 "confusionMatrix": confusion_matrix(true_y, pred_y).tolist(), 

335 "brierScore": brier_score(eq_preds.classes, true_y), 

336 "brierSkillScore": brier_skill_score(eq_preds.classes, true_y), 

337 "expectedCalibrationError": expected_calibration_error( 

338 eq_preds.classes, true_y 

339 ), 

340 } 

341 return metrics 

342 

343 

344@icontract.require(lambda Y: len(Y.shape) == 1) 

345@icontract.ensure( 

346 lambda result: all("label" in d and "numExamples" in d for d in result) 

347) 

348@icontract.ensure(lambda result: all(d["numExamples"] >= 0 for d in result)) 

349@beartype 

350def get_num_examples_per_label(Y: torch.Tensor) -> list[dict[str, Any]]: 

351 """ 

352 Get the number of examples per label in the given tensor. 

353 

354 Parameters 

355 ---------- 

356 Y : torch.Tensor 

357 Tensor of class labels. 

358 

359 Returns 

360 ------- 

361 list[dict[str, Any]] 

362 list of dictionaries containing label and number of examples. 

363 """ 

364 tensor_labels, tensor_counts = Y.unique(return_counts=True) 

365 

366 examples_per_label = [] 

367 for i, label in enumerate(tensor_labels): 

368 examples_per_label.append( 

369 {"label": label.item(), "numExamples": tensor_counts[i].item()} 

370 ) 

371 

372 return examples_per_label 

373 

374 

375@icontract.require(lambda train_y: train_y.shape[0] > 0) 

376@beartype 

377def generate_train_summary( 

378 model: Equine, train_y: torch.Tensor, date_trained: str 

379) -> dict[str, Any]: 

380 """ 

381 Generate a summary of the training data. 

382 

383 Parameters 

384 ---------- 

385 model : Equine 

386 Model object. 

387 train_y : torch.Tensor 

388 Training labels. 

389 date_trained : str 

390 Date of training. 

391 

392 Returns 

393 ------- 

394 dict[str, Any] 

395 Dictionary containing training summary. 

396 """ 

397 train_summary = { 

398 "numTrainExamples": get_num_examples_per_label(train_y), 

399 "dateTrained": date_trained, 

400 "modelType": model.__class__.__name__, 

401 } 

402 return train_summary 

403 

404 

405@icontract.require( 

406 lambda eq_preds, test_y: test_y.shape[0] == eq_preds.classes.shape[0] 

407) 

408@beartype 

409def generate_model_summary( 

410 model: Equine, 

411 eq_preds: EquineOutput, 

412 test_y: torch.Tensor, 

413) -> dict[str, Any]: 

414 """ 

415 Generate a summary of the model's performance. 

416 

417 Parameters 

418 ---------- 

419 model : Equine 

420 Model object. 

421 eq_preds : EquineOutput 

422 Model predictions. 

423 test_y : torch.Tensor 

424 True class labels. 

425 

426 Returns 

427 ------- 

428 dict[str, Any] 

429 Dictionary containing model summary. 

430 """ 

431 summary = generate_model_metrics(eq_preds, test_y) 

432 summary["numTestExamples"] = get_num_examples_per_label(test_y) 

433 summary.update(model.train_summary) # union of train_summary and generated metrics 

434 

435 return summary 

436 

437 

438@icontract.require(lambda cov: cov.shape[-2] == cov.shape[-1]) 

439def mahalanobis_distance_nosq(x: torch.Tensor, cov: torch.Tensor) -> torch.Tensor: 

440 """ 

441 Compute Mahalanobis distance $x^T C x$ (without square root), assume cov is symmetric positive definite 

442 

443 Parameters 

444 ---------- 

445 x : torch.Tensor 

446 vectors to compute distances for 

447 cov : torch.Tensor 

448 covariance matrix, assumes first dimension is number of classes 

449 """ 

450 U, S, _ = torch.linalg.svd(cov) 

451 S_inv_sqrt = torch.stack( 

452 [torch.diag(torch.sqrt(1.0 / S[i])) for i in range(S.shape[0])], dim=0 

453 ) 

454 prod = torch.matmul(S_inv_sqrt, torch.transpose(U, 1, 2)) 

455 dist = torch.sum(torch.square(torch.matmul(prod, x)), dim=1) 

456 return dist 

457 

458 

459@icontract.require( 

460 lambda X, Y: X.shape[0] == Y.shape[0], 

461 "X and Y must have the same number of samples.", 

462) 

463@icontract.require( 

464 lambda test_size: 0.0 < test_size < 1.0, "test_size must be between 0 and 1." 

465) 

466@icontract.ensure( 

467 lambda result: len(result) == 4, "Function must return four elements." 

468) 

469@icontract.ensure( 

470 lambda X, result: result[0].shape[0] + result[1].shape[0] == X.shape[0], 

471 "Total samples must be preserved.", 

472) 

473@icontract.ensure( 

474 lambda Y, result: result[2].shape[0] + result[3].shape[0] == Y.shape[0], 

475 "Total labels must be preserved.", 

476) 

477@icontract.ensure( 

478 lambda result: result[0].shape[0] == result[2].shape[0], 

479 "Train features and labels must match in size.", 

480) 

481@icontract.ensure( 

482 lambda result: result[1].shape[0] == result[3].shape[0], 

483 "Test features and labels must match in size.", 

484) 

485@beartype 

486def stratified_train_test_split( 

487 X: torch.Tensor, Y: torch.Tensor, test_size: float 

488) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 

489 """ 

490 A pytorch-ified version of sklearn's train_test_split with data stratification 

491 

492 Parameters 

493 ---------- 

494 X : torch.Tensor 

495 Input features tensor of shape (n_samples, n_features). 

496 Y : torch.Tensor 

497 Labels tensor of shape (n_samples,). 

498 test_size : float 

499 Proportion of the dataset to include in the test split (between 0.0 and 1.0). 

500 

501 Returns 

502 ------- 

503 train_x : torch.Tensor 

504 Training set features. 

505 calib_x : torch.Tensor 

506 Test set features. 

507 train_y : torch.Tensor 

508 Training set labels. 

509 calib_y : torch.Tensor 

510 Test set labels. 

511 """ 

512 unique_classes, class_counts = torch.unique(Y, return_counts=True) 

513 test_counts = (class_counts.float() * test_size).round().long() 

514 train_indices = [] 

515 test_indices = [] 

516 

517 for cls, test_count in zip(unique_classes, test_counts): 

518 cls_indices = torch.where(Y == cls)[0] 

519 cls_indices = cls_indices[torch.randperm(len(cls_indices))] 

520 test_idx = cls_indices[:test_count] 

521 train_idx = cls_indices[test_count:] 

522 train_indices.append(train_idx) 

523 test_indices.append(test_idx) 

524 

525 train_indices = torch.cat(train_indices) 

526 test_indices = torch.cat(test_indices) 

527 

528 train_x = X[train_indices] 

529 train_y = Y[train_indices] 

530 calib_x = X[test_indices] 

531 calib_y = Y[test_indices] 

532 

533 return train_x, calib_x, train_y, calib_y