Coverage for /opt/hostedtoolcache/Python/3.10.18/x64/lib/python3.10/site-packages/equine/equine_protonet.py: 96%

298 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-29 04:12 +0000

1# Copyright 2024, MASSACHUSETTS INSTITUTE OF TECHNOLOGY 

2# Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014). 

3# SPDX-License-Identifier: MIT 

4from __future__ import annotations 

5 

6from typing import Any, Optional 

7 

8import icontract 

9import io 

10import numpy as np 

11import torch 

12import warnings 

13from beartype import beartype 

14from collections import OrderedDict 

15from collections.abc import Callable 

16from datetime import datetime 

17from enum import Enum 

18from scipy.stats import gaussian_kde 

19from torch.utils.data import TensorDataset 

20from tqdm import tqdm 

21 

22from .equine import Equine, EquineOutput 

23from .utils import ( 

24 generate_episode, 

25 generate_support, 

26 generate_train_summary, 

27 mahalanobis_distance_nosq, 

28 stratified_train_test_split, 

29) 

30 

31 

32##################################### 

33class CovType(Enum): 

34 """ 

35 Enum class for covariance types used in EQUINE. 

36 """ 

37 

38 UNIT = "unit" 

39 DIAGONAL = "diag" 

40 FULL = "full" 

41 

42 

43PRED_COV_TYPE = CovType.DIAGONAL 

44OOD_COV_TYPE = CovType.DIAGONAL 

45DEFAULT_EPSILON = 1e-5 

46COV_REG_TYPE = "epsilon" 

47 

48 

49############################################### 

50 

51 

52@beartype 

53class Protonet(torch.nn.Module): 

54 """ 

55 Private class that implements a prototypical neural network for use in EQUINE. 

56 """ 

57 

58 def __init__( 

59 self, 

60 embedding_model: torch.nn.Module, 

61 emb_out_dim: int, 

62 cov_type: CovType, 

63 cov_reg_type: str, 

64 epsilon: float, 

65 device: str = "cpu", 

66 ) -> None: 

67 """ 

68 Protonet class constructor. 

69 

70 Parameters 

71 ---------- 

72 embedding_model : torch.nn.Module 

73 The PyTorch embedding model to generate logits with. 

74 emb_out_dim : int 

75 Dimension size of given embedding model's output. 

76 cov_type : CovType 

77 Type of covariance to use when computing distances [unit, diag, full]. 

78 cov_reg_type : str 

79 Type of regularization to use when generating the covariance matrix [epsilon, shared]. 

80 epsilon : float 

81 Epsilon value to use for covariance regularization. 

82 device : str, optional 

83 The device to train the protonet model on (defaults to cpu). 

84 """ 

85 super().__init__() 

86 self.embedding_model = embedding_model 

87 self.cov_type = cov_type 

88 self.cov_reg_type = cov_reg_type 

89 self.epsilon = epsilon 

90 self.emb_out_dim = emb_out_dim 

91 self.to(device) 

92 self.device = device 

93 

94 self.support: OrderedDict[int, torch.Tensor] = OrderedDict() 

95 self.support_embeddings: OrderedDict[int, torch.Tensor] = OrderedDict() 

96 self.model_head: torch.nn.Module = self.create_model_head(emb_out_dim) 

97 self.model_head.to(device) 

98 

99 def create_model_head(self, emb_out_dim: int) -> torch.nn.Linear: 

100 """ 

101 Method for adding a PyTorch layer on top of the given embedding model. This layer 

102 is intended to offer extra degrees of freedom for distance learning in the embedding space. 

103 

104 Parameters 

105 ---------- 

106 emb_out_dim : int 

107 Dimension size of the embedding model output. 

108 

109 Returns 

110 ------- 

111 torch.nn.Linear 

112 The created PyTorch model layer. 

113 """ 

114 return torch.nn.Linear(emb_out_dim, emb_out_dim) 

115 

116 def compute_embeddings(self, X: torch.Tensor) -> torch.Tensor: 

117 """ 

118 Method for calculating model embeddings using both the given embedding model and the added model head. 

119 

120 Parameters 

121 ---------- 

122 X : torch.Tensor 

123 Input tensor to compute embeddings on. 

124 

125 Returns 

126 ------- 

127 torch.Tensor 

128 Fully computed embedding tensors for the given X tensor. 

129 """ 

130 model_embeddings = self.embedding_model(X.to(self.device)) 

131 head_embeddings = self.model_head(model_embeddings) 

132 return head_embeddings 

133 

134 @icontract.require(lambda self: len(self.support_embeddings) > 0) 

135 def compute_prototypes(self) -> torch.Tensor: 

136 """ 

137 Method for computing class prototypes based on given support examples. 

138 ``Prototypes'' in this context are the means of the support embeddings for each class. 

139 

140 Returns 

141 ------- 

142 torch.Tensor 

143 Tensors of prototypes for each of the given classes in the support. 

144 """ 

145 # Compute prototype for each class 

146 proto_list = [] 

147 for label in self.support_embeddings: # look at doing functorch 

148 class_prototype = torch.mean(self.support_embeddings[label], dim=0) 

149 proto_list.append(class_prototype) 

150 

151 prototypes = torch.stack(proto_list) 

152 

153 return prototypes 

154 

155 @icontract.require(lambda self: len(self.support_embeddings) > 0) 

156 def compute_covariance(self, cov_type: CovType) -> torch.Tensor: 

157 """ 

158 Method for generating the (regularized) support example covariance matrix(es) used for calculating distances. 

159 Note that this method is only called once per episode, and the resulting tensor is used for all queries. 

160 

161 Parameters 

162 ---------- 

163 cov_type : CovType 

164 Type of covariance to use [unit, diag, full]. 

165 

166 Returns 

167 ------- 

168 torch.Tensor 

169 Tensor containing the generated regularized covariance matrix. 

170 """ 

171 class_cov_dict = OrderedDict().fromkeys( 

172 self.support_embeddings.keys(), torch.Tensor() 

173 ) 

174 for label in self.support_embeddings.keys(): 

175 class_covariance = self.compute_covariance_by_type( 

176 cov_type, self.support_embeddings[label] 

177 ) 

178 class_cov_dict[label] = class_covariance 

179 

180 reg_covariance_dict = self.regularize_covariance( 

181 class_cov_dict, cov_type, self.cov_reg_type 

182 ) 

183 reg_covariance = torch.stack(list(reg_covariance_dict.values())) 

184 

185 return reg_covariance # TODO try putting everything on GPU with .to() and see if faster 

186 

187 def compute_covariance_by_type( 

188 self, cov_type: CovType, embedding: torch.Tensor 

189 ) -> torch.Tensor: 

190 """ 

191 Select the appropriate covariance matrix type based on cov_type. 

192 

193 Parameters 

194 ---------- 

195 cov_type : str 

196 Type of covariance to use. Options are ['unit', 'diag', 'full']. 

197 embedding : torch.Tensor 

198 Embedding tensor to use when generating the covariance matrix. 

199 

200 Returns 

201 ------- 

202 torch.Tensor 

203 Tensor containing the requested covariance matrix. 

204 """ 

205 if cov_type == CovType.FULL: 

206 class_covariance = torch.cov(embedding.T) 

207 elif cov_type == CovType.DIAGONAL: 

208 class_covariance = torch.var(embedding, dim=0) 

209 elif cov_type == CovType.UNIT: 

210 class_covariance = torch.ones(self.emb_out_dim) 

211 else: 

212 raise ValueError 

213 

214 return class_covariance 

215 

216 def regularize_covariance( 

217 self, 

218 class_cov_dict: OrderedDict[int, torch.Tensor], 

219 cov_type: CovType, 

220 cov_reg_type: str, 

221 ) -> OrderedDict[int, torch.Tensor]: 

222 """ 

223 Method to add regularization to each class covariance matrix based on the selected regularization type. 

224 

225 Parameters 

226 ---------- 

227 class_cov_dict : OrderedDict[int, torch.Tensor] 

228 A dictionary containing each class and the corresponding covariance matrix. 

229 cov_type : CovType 

230 Type of covariance to use [unit, diag, full]. 

231 

232 Returns 

233 ------- 

234 dict[float, torch.Tensor] 

235 Dictionary containing the regularized class covariance matrices. 

236 """ 

237 

238 if cov_type == CovType.FULL: 

239 regularization = torch.diag(self.epsilon * torch.ones(self.emb_out_dim)).to( 

240 self.device 

241 ) 

242 elif cov_type == CovType.DIAGONAL: 

243 regularization = self.epsilon * torch.ones(self.emb_out_dim).to(self.device) 

244 elif cov_type == CovType.UNIT: 244 ↛ 247line 244 didn't jump to line 247 because the condition on line 244 was always true

245 regularization = torch.zeros(self.emb_out_dim).to(self.device) 

246 

247 if cov_reg_type == "shared": 

248 if cov_type != CovType.FULL and cov_type != CovType.DIAGONAL: 248 ↛ 249line 248 didn't jump to line 249 because the condition on line 248 was never true

249 for label in self.support_embeddings: 

250 class_cov_dict[label] = class_cov_dict[label] + regularization 

251 warnings.warn( 

252 "Covariance type UNIT is incompatible with shared regularization, \ 

253 reverting to epsilon regularization" 

254 ) 

255 return class_cov_dict 

256 

257 shared_covariance = self.compute_shared_covariance(class_cov_dict, cov_type) 

258 

259 for label in self.support_embeddings: 

260 num_class_support = self.support_embeddings[label].shape[0] 

261 lamb = num_class_support / (num_class_support + 1) 

262 

263 class_cov_dict[label] = ( 

264 lamb * class_cov_dict[label] 

265 + (1 - lamb) * shared_covariance 

266 + regularization 

267 ) 

268 

269 elif cov_reg_type == "epsilon": 269 ↛ 275line 269 didn't jump to line 275 because the condition on line 269 was always true

270 for label in class_cov_dict.keys(): 

271 class_cov_dict[label] = ( 

272 class_cov_dict[label].to(self.device) + regularization 

273 ) 

274 

275 return class_cov_dict 

276 

277 def compute_shared_covariance( 

278 self, class_cov_dict: OrderedDict[int, torch.Tensor], cov_type: CovType 

279 ) -> torch.Tensor: 

280 """ 

281 Method to calculate a shared covariance matrix. 

282 

283 The shared covariance matrix is calculated as the weighted average of the class covariance matrices, 

284 where the weights are the number of support examples for each class. This is useful when the number of 

285 support examples for each class is small. 

286 

287 Parameters 

288 ---------- 

289 class_cov_dict : OrderedDict[int, torch.Tensor] 

290 A dictionary containing each class and the corresponding covariance matrix. 

291 cov_type : CovType 

292 Type of covariance to use [unit, diag, full]. 

293 

294 Returns 

295 ------- 

296 torch.Tensor 

297 Tensor containing the shared covariance matrix. 

298 """ 

299 total_support = sum([x.shape[0] for x in class_cov_dict.values()]) 

300 

301 if cov_type == CovType.FULL: 301 ↛ 302line 301 didn't jump to line 302 because the condition on line 301 was never true

302 shared_covariance = torch.zeros((self.emb_out_dim, self.emb_out_dim)) 

303 elif cov_type == CovType.DIAGONAL: 

304 shared_covariance = torch.zeros(self.emb_out_dim) 

305 else: 

306 raise ValueError( 

307 "Shared covariance can only be used with FULL or DIAGONAL (not UNIT) covariance types" 

308 ) 

309 

310 for label in class_cov_dict: 

311 num_class_support = class_cov_dict[label].shape[0] 

312 shared_covariance = ( 

313 shared_covariance + (num_class_support - 1) * class_cov_dict[label] 

314 ) # undo N-1 div from cov 

315 

316 shared_covariance = shared_covariance / ( 

317 total_support - 1 

318 ) # redo N-1 div for shared cov 

319 

320 return shared_covariance 

321 

322 @icontract.require(lambda X_embed, mu: X_embed.shape[-1] == mu.shape[-1]) 

323 @icontract.ensure(lambda result: torch.all(result >= 0)) 

324 def compute_distance( 

325 self, X_embed: torch.Tensor, mu: torch.Tensor, cov: torch.Tensor 

326 ) -> torch.Tensor: 

327 """ 

328 Method to compute the distances to class prototypes for the given embeddings. 

329 

330 Parameters 

331 ---------- 

332 X_embed : torch.Tensor 

333 The embeddings of the query examples. 

334 mu : torch.Tensor 

335 The class prototypes (means of the support embeddings). 

336 cov : torch.Tensor 

337 The support covariance matrix. 

338 

339 Returns 

340 ------- 

341 torch.Tensor 

342 The calculated distances from each of the class prototypes for the given embeddings. 

343 """ 

344 _queries = torch.unsqueeze(X_embed, 1) # examples x 1 x dimension 

345 diff = torch.sub(mu, _queries) 

346 

347 if len(cov.shape) == 2: # (diagonal covariance) 

348 # examples x classes x dimension 

349 sq_diff = diff**2 

350 div = torch.div(sq_diff.to(self.device), cov.to(self.device)) 

351 dist = torch.nan_to_num(div) 

352 dist = torch.sum(dist, dim=2) # examples x classes 

353 dist = dist.squeeze(dim=1) 

354 dist = torch.sqrt(dist + self.epsilon) # examples x classes 

355 else: # len(cov.shape) == 3: (full covariance) 

356 diff = diff.permute(1, 2, 0) # classes x dimension x examples 

357 dist = mahalanobis_distance_nosq(diff, cov) 

358 dist = torch.sqrt(dist.permute(1, 0) + self.epsilon) # examples x classes 

359 dist = dist.squeeze(dim=1) 

360 return dist 

361 

362 def compute_classes(self, distances: torch.Tensor) -> torch.Tensor: 

363 """ 

364 Method to compute predicted classes from distances via a softmax function. 

365 

366 Parameters 

367 ---------- 

368 distances : torch.Tensor 

369 The distances of embeddings to class prototypes. 

370 

371 Returns 

372 ------- 

373 torch.Tensor 

374 Tensor of class predictions. 

375 """ 

376 softmax = torch.nn.functional.softmax(torch.neg(distances), dim=-1) 

377 return softmax 

378 

379 def forward(self, X: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: 

380 """ 

381 Protonet forward function, generates class probability predictions and distances from prototypes. 

382 

383 Parameters 

384 ---------- 

385 X : torch.Tensor 

386 Input tensor of queries for generating predictions. 

387 

388 Returns 

389 ------- 

390 tuple[torch.Tensor, torch.Tensor] 

391 tuple containing class probability predictions, and class distances from prototypes. 

392 """ 

393 if len(self.support) == 0 or len(self.support_embeddings) == 0: 

394 raise ValueError( 

395 "No support examples found. Protonet Model requires model support to \ 

396 be set with the 'update_support()' method before calling forward." 

397 ) 

398 

399 X_embed = self.compute_embeddings(X) 

400 if X_embed.shape == torch.Size([self.emb_out_dim]): 

401 X_embed = X_embed.unsqueeze(dim=0) # handle single examples 

402 distances = self.compute_distance(X_embed, self.prototypes, self.covariance) 

403 classes = self.compute_classes(distances) 

404 

405 return classes, distances 

406 

407 def update_support(self, support: OrderedDict[int, torch.Tensor]) -> None: 

408 """ 

409 Method to update the support examples, and all the calculations that rely on them. 

410 

411 Parameters 

412 ---------- 

413 support : OrderedDict 

414 Ordered dict containing class labels and their associated support examples. 

415 """ 

416 self.support = support # TODO torch.nn.ParameterDict(support) 

417 

418 support_embs = OrderedDict().fromkeys(support.keys(), torch.Tensor()) 

419 for label in support: 

420 support_embs[label] = self.compute_embeddings(support[label]) 

421 

422 self.support_embeddings = ( 

423 support_embs # TODO torch.nn.ParameterDict(support_embs) 

424 ) 

425 

426 self.prototypes: torch.Tensor = self.compute_prototypes() 

427 

428 if self.training is False: 

429 self.compute_global_moments() 

430 self.covariance: torch.Tensor = self.compute_covariance( 

431 cov_type=PRED_COV_TYPE 

432 ) 

433 else: 

434 self.covariance: torch.Tensor = self.compute_covariance( 

435 cov_type=self.cov_type 

436 ) 

437 

438 @icontract.require(lambda self: len(self.support_embeddings) > 0) 

439 def compute_global_moments(self) -> None: 

440 """Method to calculate the global moments of the support embeddings for use in OOD score generation""" 

441 embeddings = torch.cat(list(self.support_embeddings.values())) 

442 self.global_covariance = torch.unsqueeze( 

443 self.compute_covariance_by_type(OOD_COV_TYPE, embeddings), dim=0 

444 ) 

445 global_reg_input = OrderedDict().fromkeys([0], torch.Tensor()) 

446 global_reg_input[0] = self.global_covariance 

447 self.global_covariance: torch.Tensor = self.regularize_covariance( 

448 global_reg_input, OOD_COV_TYPE, "epsilon" 

449 )[0] 

450 self.global_mean: torch.Tensor = torch.mean(embeddings, dim=0) 

451 

452 

453############################################### 

454@beartype 

455class EquineProtonet(Equine): 

456 """ 

457 A class representing an EQUINE model that utilizes protonets and (optionally) relative Mahalanobis distances 

458 to generate OOD and model confidence scores. This wraps any pytorch embedding neural network 

459 and provides the `forward`, `predict`, `save`, and `load` methods required by Equine. 

460 """ 

461 

462 def __init__( 

463 self, 

464 embedding_model: torch.nn.Module, 

465 emb_out_dim: int, 

466 cov_type: CovType = CovType.UNIT, 

467 relative_mahal: bool = True, 

468 use_temperature: bool = False, 

469 init_temperature: float = 1.0, 

470 device: str = "cpu", 

471 feature_names: Optional[list[str]] = None, 

472 label_names: Optional[list[str]] = None, 

473 ) -> None: 

474 """ 

475 EquineProtonet class constructor 

476 

477 Parameters 

478 ---------- 

479 embedding_model : torch.nn.Module 

480 Neural Network feature embedding model. 

481 emb_out_dim : int 

482 The number of output features from the embedding model. 

483 cov_type : CovType, optional 

484 The type of covariance to use when training the protonet [UNIT, DIAG, FULL], by default CovType.UNIT. 

485 relative_mahal : bool, optional 

486 Use relative mahalanobis distance for OOD calculations. If false, uses standard mahalanobis distance instead, by default True. 

487 use_temperature : bool, optional 

488 Whether to use temperature scaling after training, by default False. 

489 init_temperature : float, optional 

490 What to use as the initial temperature (1.0 has no effect), by default 1.0. 

491 device : str, optional 

492 The device to train the equine model on (defaults to cpu). 

493 feature_names : list[str], optional 

494 List of strings of the names of the tabular features (ex ["duration", "fiat_mean", ...]) 

495 label_names : list[str], optional 

496 List of strings of the names of the labels (ex ["streaming", "voip", ...]) 

497 """ 

498 super().__init__( 

499 embedding_model, 

500 device=device, 

501 feature_names=feature_names, 

502 label_names=label_names, 

503 ) 

504 self.cov_type = cov_type 

505 self.cov_reg_type = COV_REG_TYPE 

506 self.relative_mahal = relative_mahal 

507 self.emb_out_dim = emb_out_dim 

508 self.epsilon = DEFAULT_EPSILON 

509 self.outlier_score_kde: OrderedDict[int, gaussian_kde] = OrderedDict() 

510 self.model_summary: dict[str, Any] = dict() 

511 self.use_temperature = use_temperature 

512 self.init_temperature = init_temperature 

513 self.register_buffer( 

514 "temperature", torch.Tensor(self.init_temperature * torch.ones(1)) 

515 ) 

516 

517 self.model: torch.nn.Module = Protonet( 

518 embedding_model, 

519 self.emb_out_dim, 

520 self.cov_type, 

521 self.cov_reg_type, 

522 self.epsilon, 

523 device=device, 

524 ) 

525 

526 def forward(self, X: torch.Tensor) -> torch.Tensor: 

527 """ 

528 Generates logits for classification based on the input tensor. 

529 

530 Parameters 

531 ---------- 

532 X : torch.Tensor 

533 The input tensor for generating predictions. 

534 

535 Returns 

536 ------- 

537 torch.Tensor 

538 The output class predictions. 

539 """ 

540 preds, _ = self.model(X) 

541 return preds 

542 

543 @icontract.require(lambda calib_frac: calib_frac > 0 and calib_frac < 1) 

544 def train_model( 

545 self, 

546 dataset: TensorDataset, 

547 num_episodes: int, 

548 calib_frac: float = 0.2, 

549 support_size: int = 25, 

550 way: int = 3, 

551 episode_size: int = 100, 

552 loss_fn: Callable = torch.nn.functional.cross_entropy, 

553 opt_class: Callable = torch.optim.Adam, 

554 num_calibration_epochs: int = 2, 

555 calibration_lr: float = 0.01, 

556 ) -> dict[str, Any]: 

557 """ 

558 Train or fine-tune an EquineProtonet model. 

559 

560 Parameters 

561 ---------- 

562 dataset : TensorDataset 

563 Input pytorch TensorDataset of training data for model. 

564 num_episodes : int 

565 The desired number of episodes to use for training. 

566 calib_frac : float, optional 

567 Fraction of given training data to reserve for model calibration, by default 0.2. 

568 support_size : int, optional 

569 Number of support examples to generate for each class, by default 25. 

570 way : int, optional 

571 Number of classes to train on per episode, by default 3. 

572 episode_size : int, optional 

573 Number of examples to use per episode, by default 100. 

574 loss_fn : Callable, optional 

575 A pytorch loss function, eg., torch.nn.CrossEntropyLoss(), by default torch.nn.functional.cross_entropy. 

576 opt_class : Callable, optional 

577 A pytorch optimizer, e.g., torch.optim.Adam, by default torch.optim.Adam. 

578 num_calibration_epochs : int, optional 

579 The desired number of epochs to use for temperature scaling, by default 2. 

580 calibration_lr : float, optional 

581 Learning rate for temperature scaling, by default 0.01. 

582 

583 Returns 

584 ------- 

585 tuple[dict[str, Any], torch.Tensor, torch.Tensor] 

586 A tuple containing the model summary, the held out calibration data, and the calibration labels. 

587 """ 

588 self.train() 

589 

590 if self.use_temperature: 

591 self.temperature: torch.Tensor = torch.Tensor( 

592 self.init_temperature * torch.ones(1) 

593 ).type_as(self.temperature) 

594 

595 X, Y = dataset[:] 

596 

597 self.validate_feature_label_names(X.shape[-1], torch.unique(Y).shape[0]) 

598 

599 train_x, calib_x, train_y, calib_y = stratified_train_test_split( 

600 X, Y, test_size=calib_frac 

601 ) 

602 optimizer = opt_class(self.parameters()) 

603 

604 train_x.to(self.device) 

605 train_y.to(self.device) 

606 calib_x.to(self.device) 

607 calib_y.to(self.device) 

608 

609 for i in tqdm(range(num_episodes)): 

610 optimizer.zero_grad() 

611 

612 support, episode_x, episode_y = generate_episode( 

613 train_x, train_y, support_size, way, episode_size 

614 ) 

615 self.model.update_support(support) 

616 

617 _, dists = self.model(episode_x) 

618 loss_value = loss_fn( 

619 torch.neg(dists).to(self.device), episode_y.to(self.device) 

620 ) 

621 loss_value.backward() 

622 optimizer.step() 

623 

624 self.eval() 

625 full_support = generate_support( 

626 train_x, 

627 train_y, 

628 support_size, 

629 selected_labels=torch.unique(train_y).tolist(), 

630 ) 

631 

632 self.model.update_support( 

633 full_support 

634 ) # update support with final selected examples 

635 

636 X_embed = self.model.compute_embeddings(calib_x) 

637 pred_probs, dists = self.model(calib_x) 

638 ood_dists = self._compute_ood_dist(X_embed, pred_probs, dists) 

639 self._fit_outlier_scores(ood_dists, calib_y) 

640 

641 if self.use_temperature: 

642 self.calibrate_temperature( 

643 calib_x, calib_y, num_calibration_epochs, calibration_lr 

644 ) 

645 

646 date_trained = datetime.now().strftime("%m/%d/%Y, %H:%M:%S") 

647 self.train_summary: dict[str, Any] = generate_train_summary( 

648 self, train_y, date_trained 

649 ) 

650 return_dict: dict[str, Any] = dict() 

651 return_dict["train_summary"] = self.train_summary 

652 return_dict["calib_x"] = calib_x 

653 return_dict["calib_y"] = calib_y 

654 return return_dict 

655 

656 def calibrate_temperature( 

657 self, 

658 calib_x: torch.Tensor, 

659 calib_y: torch.Tensor, 

660 num_calibration_epochs: int = 1, 

661 calibration_lr: float = 0.01, 

662 ) -> None: 

663 """ 

664 Fine-tune the temperature after training. Note that this function is also run at the conclusion of train_model. 

665 

666 Parameters 

667 ---------- 

668 calib_x : torch.Tensor 

669 Training data to be used for temperature calibration. 

670 calib_y : torch.Tensor 

671 Labels corresponding to `calib_x`. 

672 num_calibration_epochs : int, optional 

673 Number of epochs to tune temperature, by default 1. 

674 calibration_lr : float, optional 

675 Learning rate for temperature optimization, by default 0.01. 

676 

677 Returns 

678 ------- 

679 None 

680 """ 

681 self.temperature.requires_grad = True 

682 optimizer = torch.optim.Adam([self.temperature], lr=calibration_lr) 

683 for t in range(num_calibration_epochs): 

684 optimizer.zero_grad() 

685 with torch.no_grad(): 

686 pred_probs, dists = self.model(calib_x) 

687 dists = dists.to(self.device) / self.temperature.to(self.device) 

688 loss = torch.nn.functional.cross_entropy( 

689 torch.neg(dists).to(self.device), calib_y.to(torch.long).to(self.device) 

690 ) 

691 loss.backward() 

692 optimizer.step() 

693 self.temperature.requires_grad = False 

694 

695 @icontract.ensure(lambda self: len(self.model.support_embeddings) > 0) 

696 def _fit_outlier_scores( 

697 self, ood_dists: torch.Tensor, calib_y: torch.Tensor 

698 ) -> None: 

699 """ 

700 Private function to fit outlier scores with a kernel density estimate (KDE). 

701 

702 Parameters 

703 ---------- 

704 ood_dists : torch.Tensor 

705 Tensor of computed OOD distances. 

706 calib_y : torch.Tensor 

707 Tensor of class labels for `ood_dists` examples. 

708 

709 Returns 

710 ------- 

711 None 

712 """ 

713 for label in self.model.support_embeddings.keys(): 

714 class_ood_dists = ood_dists[calib_y == int(label)].cpu().detach().numpy() 

715 class_kde = gaussian_kde(class_ood_dists) # TODO convert to torch func 

716 self.outlier_score_kde[label] = class_kde 

717 

718 def _compute_outlier_scores(self, ood_dists, predictions) -> torch.Tensor: 

719 """ 

720 Private function to compute OOD scores using the calculated kernel density estimate (KDE). 

721 

722 Parameters 

723 ---------- 

724 ood_dists : torch.Tensor 

725 Tensor of computed OOD distances. 

726 predictions : torch.Tensor 

727 Tensor of model protonet predictions. 

728 

729 Returns 

730 ------- 

731 torch.Tensor 

732 Tensor of OOD scores for the given examples. 

733 """ 

734 ood_scores = torch.zeros_like(ood_dists) 

735 for i in range(len(predictions)): 

736 # Use KDE and RMD corresponding to the predicted class 

737 predicted_class = int(torch.argmax(predictions[i, :])) 

738 p_value = self.outlier_score_kde[int(predicted_class)].integrate_box_1d( 

739 ood_dists[i].detach().numpy(), np.inf 

740 ) 

741 ood_scores[i] = 1.0 - np.clip(p_value, 0.0, 1.0) 

742 

743 return ood_scores 

744 

745 @icontract.ensure(lambda result: len(result) > 0) 

746 def _compute_ood_dist( 

747 self, 

748 X_embeddings: torch.Tensor, 

749 predictions: torch.Tensor, 

750 distances: torch.Tensor, 

751 ) -> torch.Tensor: 

752 """ 

753 Private function to compute OOD distances using a distance function. 

754 

755 Parameters 

756 ---------- 

757 X_embeddings : torch.Tensor 

758 Tensor of example embeddings. 

759 predictions : torch.Tensor 

760 Tensor of model protonet predictions for the given embeddings. 

761 distances : torch.Tensor 

762 Tensor of calculated protonet distances for the given embeddings. 

763 

764 Returns 

765 ------- 

766 torch.Tensor 

767 Tensor of OOD distances for the given embeddings. 

768 """ 

769 preds = torch.argmax(predictions, dim=1) 

770 preds = preds.unsqueeze(dim=-1) 

771 # Calculate (Relative) Mahalanobis Distance: 

772 if self.relative_mahal: 

773 null_distance = self.model.compute_distance( 

774 X_embeddings, self.model.global_mean, self.model.global_covariance 

775 ) 

776 null_distance = null_distance.unsqueeze(dim=-1) 

777 ood_dist = distances.gather(1, preds) - null_distance 

778 else: 

779 ood_dist = distances.gather(1, preds) 

780 

781 ood_dist = torch.reshape(ood_dist, (-1,)) 

782 return ood_dist 

783 

784 def predict(self, X: torch.Tensor) -> EquineOutput: 

785 """Predict function for EquineProtonet, inherited and implemented from Equine. 

786 

787 Parameters 

788 ---------- 

789 X : torch.Tensor 

790 Input tensor. 

791 

792 Returns 

793 ------- 

794 EquineOutput 

795 Output object containing prediction probabilities and OOD scores. 

796 """ 

797 X_embed = self.model.compute_embeddings(X) 

798 if X_embed.shape == torch.Size([self.model.emb_out_dim]): 

799 X_embed = X_embed.unsqueeze(dim=0) # Handle single examples 

800 preds, dists = self.model(X) 

801 if self.use_temperature: 

802 dists = dists / self.temperature 

803 preds = torch.softmax(torch.negative(dists), dim=1) 

804 ood_dist = self._compute_ood_dist(X_embed, preds, dists) 

805 ood_scores = self._compute_outlier_scores(ood_dist, preds) 

806 

807 self.validate_feature_label_names(X.shape[-1], preds.shape[-1]) 

808 

809 return EquineOutput(classes=preds, ood_scores=ood_scores, embeddings=X_embed) 

810 

811 @icontract.require(lambda calib_frac: (calib_frac > 0.0) and (calib_frac < 1.0)) 

812 def update_support( 

813 self, 

814 support_x: torch.Tensor, 

815 support_y: torch.Tensor, 

816 calib_frac: float, 

817 label_names: Optional[list[str]] = None, 

818 ) -> None: 

819 """Function to update protonet support examples with given examples. 

820 

821 Parameters 

822 ---------- 

823 support_x : torch.Tensor 

824 Tensor containing support examples for protonet. 

825 support_y : torch.Tensor 

826 Tensor containing labels for given support examples. 

827 calib_frac : float 

828 Fraction of given support data to use for OOD calibration. 

829 label_names : list[str], optional 

830 List of strings of the names of the labels (ex ["streaming", "voip", ...]) 

831 

832 Returns 

833 ------- 

834 None 

835 """ 

836 

837 support_x, calib_x, support_y, calib_y = stratified_train_test_split( 

838 support_x, support_y, test_size=calib_frac 

839 ) 

840 labels, counts = torch.unique(support_y, return_counts=True) 

841 if label_names is not None: 841 ↛ 842line 841 didn't jump to line 842 because the condition on line 841 was never true

842 self.label_names = label_names 

843 self.validate_feature_label_names(support_x.shape[-1], labels.shape[0]) 

844 

845 support = OrderedDict() 

846 for label, count in list(zip(labels.tolist(), counts.tolist())): 

847 class_support = generate_support( 

848 support_x, 

849 support_y, 

850 support_size=count, 

851 selected_labels=[label], 

852 ) 

853 support.update(class_support) 

854 

855 self.model.update_support(support) 

856 

857 X_embed = self.model.compute_embeddings(calib_x) 

858 preds, dists = self.model(calib_x) 

859 ood_dists = self._compute_ood_dist(X_embed, preds, dists) 

860 

861 self._fit_outlier_scores(ood_dists, calib_y) 

862 

863 @icontract.require(lambda self: len(self.model.support) > 0) 

864 def get_support(self) -> OrderedDict[int, torch.Tensor]: 

865 """ 

866 Get the support examples for the model. 

867 

868 Returns 

869 ------- 

870 OrderedDict[int, torch.Tensor] 

871 The support examples for the model. 

872 """ 

873 return self.model.support 

874 

875 @icontract.require(lambda self: len(self.model.prototypes) > 0) 

876 def get_prototypes(self) -> torch.Tensor: 

877 """ 

878 Get the prototypes for the model (the class means of the support embeddings). 

879 

880 Returns 

881 ------- 

882 torch.Tensor 

883 The prototpes for the model. 

884 """ 

885 return self.model.prototypes 

886 

887 def save(self, path: str) -> None: 

888 """ 

889 Save all model parameters to a file. 

890 

891 Parameters 

892 ---------- 

893 path : str 

894 Filename to write the model. 

895 

896 Returns 

897 ------- 

898 None 

899 """ 

900 model_settings = { 

901 "cov_type": self.cov_type, 

902 "emb_out_dim": self.emb_out_dim, 

903 "use_temperature": self.use_temperature, 

904 "init_temperature": self.temperature.item(), 

905 "relative_mahal": self.relative_mahal, 

906 } 

907 

908 jit_model = torch.jit.script(self.model.embedding_model) 

909 buffer = io.BytesIO() 

910 torch.jit.save(jit_model, buffer) 

911 buffer.seek(0) 

912 

913 save_data = { 

914 "embed_jit_save": buffer, 

915 "feature_names": self.feature_names, 

916 "label_names": self.label_names, 

917 "model_head_save": self.model.model_head.state_dict(), 

918 "outlier_kde": self.outlier_score_kde, 

919 "settings": model_settings, 

920 "support": self.model.support, 

921 "train_summary": self.train_summary, 

922 } 

923 

924 torch.save(save_data, path) # TODO allow model checkpointing 

925 

926 @classmethod 

927 def load(cls, path: str) -> Equine: # noqa: F821 

928 """ 

929 Load a previously saved EquineProtonet model. 

930 

931 Parameters 

932 ---------- 

933 path : str 

934 The filename of the saved model. 

935 

936 Returns 

937 ------- 

938 EquineProtonet 

939 The reconstituted EquineProtonet object. 

940 """ 

941 model_save = torch.load(path, weights_only=False) 

942 support = model_save.get("support") 

943 jit_model = torch.jit.load(model_save.get("embed_jit_save")) 

944 eq_model = cls(jit_model, **model_save.get("settings")) 

945 

946 eq_model.model.model_head.load_state_dict(model_save.get("model_head_save")) 

947 eq_model.eval() 

948 eq_model.model.update_support(support) 

949 

950 eq_model.feature_names = model_save.get("feature_names") 

951 eq_model.label_names = model_save.get("label_names") 

952 eq_model.outlier_score_kde = model_save.get("outlier_kde") 

953 eq_model.train_summary = model_save.get("train_summary") 

954 

955 return eq_model