Coverage for  / opt / hostedtoolcache / Python / 3.10.19 / x64 / lib / python3.10 / site-packages / equine / equine_gp.py: 96%

279 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-18 23:02 +0000

1# Copyright 2024, MASSACHUSETTS INSTITUTE OF TECHNOLOGY 

2# Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014). 

3# SPDX-License-Identifier: MIT 

4 

5from typing import Any, Optional, Union 

6 

7import icontract 

8import io 

9import math 

10import torch 

11from beartype import beartype 

12from collections import OrderedDict 

13from collections.abc import Callable, Iterable 

14from datetime import datetime 

15from torch.utils.data import DataLoader, Dataset 

16from torchmetrics.metric import Metric 

17from tqdm import tqdm 

18 

19from .equine import Equine, EquineOutput 

20from .utils import generate_support, generate_train_summary 

21 

22BatchType = tuple[torch.Tensor, ...] 

23# ------------------------------------------------------------------------------- 

24# Note that the below code for 

25# * `_random_ortho`, 

26# * `_RandomFourierFeatures``, and 

27# * `_Laplace` 

28# is copied and modified from https://github.com/y0ast/DUE/blob/main/due/sngp.py 

29# under its original MIT license, redisplayed here: 

30# ------------------------------------------------------------------------------- 

31# MIT License 

32# 

33# Copyright (c) 2021 Joost van Amersfoort 

34# 

35# Permission is hereby granted, free of charge, to any person obtaining a copy 

36# of this software and associated documentation files (the "Software"), to deal 

37# in the Software without restriction, including without limitation the rights 

38# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

39# copies of the Software, and to permit persons to whom the Software is 

40# furnished to do so, subject to the following conditions: 

41# 

42# The above copyright notice and this permission notice shall be included in all 

43# copies or substantial portions of the Software. 

44# 

45# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

46# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

47# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

48# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

49# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

50# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

51# SOFTWARE. 

52# ------------------------------------------------------------------------------ 

53# Following the recommendation of their README at https://github.com/y0ast/DUE 

54# we encourage anyone using this code in their research to cite the following papers: 

55# 

56# @article{van2021on, 

57# title={On Feature Collapse and Deep Kernel Learning for Single Forward Pass Uncertainty}, 

58# author={van Amersfoort, Joost and Smith, Lewis and Jesson, Andrew and Key, Oscar and Gal, Yarin}, 

59# journal={arXiv preprint arXiv:2102.11409}, 

60# year={2021} 

61# } 

62# 

63# @article{liu2020simple, 

64# title={Simple and principled uncertainty estimation with deterministic deep learning via distance awareness}, 

65# author={Liu, Jeremiah and Lin, Zi and Padhy, Shreyas and Tran, Dustin and Bedrax Weiss, Tania and Lakshminarayanan, Balaji}, 

66# journal={Advances in Neural Information Processing Systems}, 

67# volume={33}, 

68# pages={7498--7512}, 

69# year={2020} 

70# } 

71 

72 

73@beartype 

74def _random_ortho(n: int, m: int) -> torch.Tensor: 

75 """ 

76 Generate a random orthonormal matrix. 

77 

78 Parameters 

79 ---------- 

80 n : int 

81 The number of rows. 

82 m : int 

83 The number of columns. 

84 

85 Returns 

86 ------- 

87 torch.Tensor 

88 The random orthonormal matrix. 

89 """ 

90 q, _ = torch.linalg.qr(torch.randn(n, m)) 

91 return q 

92 

93 

94@beartype 

95class _RandomFourierFeatures(torch.nn.Module): 

96 """ 

97 A private class to generate random Fourier features for the embedding model. 

98 """ 

99 

100 def __init__( 

101 self, in_dim: int, num_random_features: int, feature_scale: Optional[float] 

102 ) -> None: 

103 """ 

104 Initialize the _RandomFourierFeatures module, which generates random Fourier features 

105 for the embedding model. 

106 

107 Parameters 

108 ---------- 

109 in_dim : int 

110 The input dimensionality. 

111 num_random_features : int 

112 The number of random Fourier features to generate. 

113 feature_scale : Optional[float] 

114 The scaling factor for the random Fourier features. If None, defaults to sqrt(num_random_features / 2). 

115 """ 

116 super().__init__() 

117 if feature_scale is None: 117 ↛ 118line 117 didn't jump to line 118 because the condition on line 117 was never true

118 feature_scale = math.sqrt(num_random_features / 2) 

119 

120 self.register_buffer("feature_scale", torch.tensor(feature_scale)) 

121 

122 if num_random_features <= in_dim: 122 ↛ 123line 122 didn't jump to line 123 because the condition on line 122 was never true

123 W: torch.Tensor = _random_ortho(in_dim, num_random_features) 

124 else: 

125 # generate blocks of orthonormal rows which are not necessarily orthonormal 

126 # to each other. 

127 dim_left = num_random_features 

128 ws = [] 

129 while dim_left > in_dim: 

130 ws.append(_random_ortho(in_dim, in_dim)) 

131 dim_left -= in_dim 

132 ws.append(_random_ortho(in_dim, dim_left)) 

133 W: torch.Tensor = torch.cat(ws, 1) 

134 

135 feature_norm = torch.randn(W.shape) ** 2 

136 

137 W = W * feature_norm.sum(0).sqrt() 

138 self.register_buffer("W", W) 

139 

140 b: torch.Tensor = torch.empty(num_random_features).uniform_(0, 2 * math.pi) 

141 self.register_buffer("b", b) 

142 

143 def forward(self, x: torch.Tensor) -> torch.Tensor: 

144 """ 

145 Compute the forward pass of the _RandomFourierFeatures module. 

146 

147 Parameters 

148 ---------- 

149 x : torch.Tensor 

150 The input tensor of shape (batch_size, in_dim). 

151 

152 Returns 

153 ------- 

154 torch.Tensor 

155 The output tensor of shape (batch_size, num_random_features). 

156 """ 

157 k = torch.cos(x @ self.W + self.b) 

158 k = k / self.feature_scale 

159 

160 return k 

161 

162 

163class _Laplace(torch.nn.Module): 

164 """ 

165 A private class to compute a Laplace approximation to a Gaussian Process (GP) 

166 """ 

167 

168 def __init__( 

169 self, 

170 feature_extractor: torch.nn.Module, 

171 num_deep_features: int, 

172 num_gp_features: int, 

173 normalize_gp_features: bool, 

174 num_random_features: int, 

175 num_outputs: int, 

176 feature_scale: Optional[float], 

177 mean_field_factor: Optional[float], # required for classification problems 

178 ridge_penalty: float = 1.0, 

179 ) -> None: 

180 """ 

181 Initialize the _Laplace module. 

182 

183 Parameters 

184 ---------- 

185 feature_extractor : torch.nn.Module 

186 The feature extractor module. 

187 num_deep_features : int 

188 The number of features output by the feature extractor. 

189 num_gp_features : int 

190 The number of features to use in the Gaussian process. 

191 normalize_gp_features : bool 

192 Whether to normalize the GP features. 

193 num_random_features : int 

194 The number of random Fourier features to use. 

195 num_outputs : int 

196 The number of outputs of the model. 

197 feature_scale : Optional[float] 

198 The scaling factor for the random Fourier features. 

199 mean_field_factor : Optional[float] 

200 The mean-field factor for the Gaussian-Softmax approximation. 

201 Required for classification problems. 

202 ridge_penalty : float, optional 

203 The ridge penalty for the Laplace approximation. 

204 """ 

205 super().__init__() 

206 self.feature_extractor = feature_extractor 

207 self.mean_field_factor = mean_field_factor 

208 self.ridge_penalty = ridge_penalty 

209 self.train_batch_size = 0 # to be set later 

210 

211 if num_gp_features > 0: 211 ↛ 221line 211 didn't jump to line 221 because the condition on line 211 was always true

212 self.num_gp_features = num_gp_features 

213 random_matrix: torch.Tensor = torch.normal( 

214 0, 0.05, (num_gp_features, num_deep_features) 

215 ) 

216 self.register_buffer("random_matrix", random_matrix) 

217 self.jl: Callable = lambda x: torch.nn.functional.linear( 

218 x, self.random_matrix 

219 ) 

220 else: 

221 self.num_gp_features: int = num_deep_features 

222 self.jl: Callable = lambda x: x # Identity 

223 

224 self.normalize_gp_features = normalize_gp_features 

225 if normalize_gp_features: 225 ↛ 228line 225 didn't jump to line 228 because the condition on line 225 was always true

226 self.normalize: torch.nn.LayerNorm = torch.nn.LayerNorm(num_gp_features) 

227 

228 self.rff: _RandomFourierFeatures = _RandomFourierFeatures( 

229 num_gp_features, num_random_features, feature_scale 

230 ) 

231 self.beta: torch.nn.Linear = torch.nn.Linear(num_random_features, num_outputs) 

232 

233 self.num_data = 0 # to be set later 

234 self.register_buffer("seen_data", torch.tensor(0)) 

235 

236 precision = torch.eye(num_random_features) * self.ridge_penalty 

237 self.register_buffer("precision", precision) 

238 

239 self.recompute_covariance = True 

240 self.register_buffer("covariance", torch.eye(num_random_features)) 

241 self.training_parameters_set = False 

242 

243 def reset_precision_matrix(self) -> None: 

244 """ 

245 Reset the precision matrix to the identity matrix times the ridge penalty. 

246 """ 

247 identity = torch.eye(self.precision.shape[0], device=self.precision.device) 

248 self.precision: torch.Tensor = identity * self.ridge_penalty 

249 self.seen_data: torch.Tensor = torch.tensor(0) 

250 self.recompute_covariance = True 

251 

252 @icontract.require(lambda num_data: num_data > 0) 

253 @icontract.require( 

254 lambda num_data, batch_size: (0 < batch_size) & (batch_size <= num_data) 

255 ) 

256 def set_training_params(self, num_data: int, batch_size: int) -> None: 

257 """ 

258 Set the training parameters for the Laplace approximation. 

259 

260 Parameters 

261 ---------- 

262 num_data : int 

263 The total number of data points. 

264 batch_size : int 

265 The batch size to use during training. 

266 """ 

267 self.num_data: int = num_data 

268 self.train_batch_size: int = batch_size 

269 self.training_parameters_set: bool = True 

270 

271 @icontract.require(lambda mean_field_factor: mean_field_factor is not None) 

272 def mean_field_logits( 

273 self, logits: torch.Tensor, pred_cov: torch.Tensor, mean_field_factor: float 

274 ) -> torch.Tensor: 

275 """ 

276 Compute the mean-field logits for the Gaussian-Softmax approximation. 

277 

278 Parameters 

279 ---------- 

280 logits : torch.Tensor 

281 The logits tensor of shape (batch_size, num_outputs). 

282 pred_cov : torch.Tensor 

283 The predicted covariance matrix of shape (batch_size, batch_size). 

284 mean_field_factor : float 

285 Diagonal scaling factor 

286 

287 Returns 

288 ------- 

289 torch.Tensor 

290 The mean-field logits tensor of shape (batch_size, num_outputs). 

291 """ 

292 # Mean-Field approximation as alternative to MC integration of Gaussian-Softmax 

293 # Based on: https://arxiv.org/abs/2006.07584 

294 

295 logits_scale = torch.sqrt(1.0 + torch.diag(pred_cov) * mean_field_factor) 

296 if mean_field_factor > 0: 296 ↛ 299line 296 didn't jump to line 299 because the condition on line 296 was always true

297 logits = logits / logits_scale.unsqueeze(-1) 

298 

299 return logits 

300 

301 @icontract.require(lambda self: self.training_parameters_set) 

302 def forward( 

303 self, x: torch.Tensor 

304 ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: 

305 """ 

306 Compute the forward pass of the Laplace approximation to the Gaussian Process. 

307 

308 Parameters 

309 ---------- 

310 x : torch.Tensor 

311 The input tensor of shape (batch_size, num_features). 

312 

313 Returns 

314 ------- 

315 Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]] 

316 If the model is in training mode, returns the predicted mean of shape (batch_size, 1). 

317 If the model is in evaluation mode, returns a tuple containing the predicted mean of shape (batch_size, 1) 

318 and the predicted covariance matrix of shape (batch_size, batch_size). 

319 """ 

320 f = self.feature_extractor(x) 

321 f_reduc = self.jl(f) 

322 if self.normalize_gp_features: 322 ↛ 325line 322 didn't jump to line 325 because the condition on line 322 was always true

323 f_reduc = self.normalize(f_reduc) 

324 

325 k = self.rff(f_reduc) 

326 

327 pred = self.beta(k) 

328 

329 if self.training: 

330 precision_minibatch = k.t() @ k 

331 self.precision += precision_minibatch 

332 self.seen_data += x.shape[0] 

333 

334 assert ( 

335 self.seen_data <= self.num_data 

336 ), "Did not reset precision matrix at start of epoch" 

337 else: 

338 assert self.seen_data > ( 

339 self.num_data - self.train_batch_size 

340 ), "Not seen sufficient data for precision matrix" 

341 

342 if self.recompute_covariance: 

343 with torch.no_grad(): 

344 eps = 1e-7 

345 jitter = eps * torch.eye( 

346 self.precision.shape[1], 

347 device=self.precision.device, 

348 ) 

349 u, info = torch.linalg.cholesky_ex(self.precision + jitter) 

350 assert (info == 0).all(), "Precision matrix inversion failed!" 

351 torch.cholesky_inverse(u, out=self.covariance) 

352 

353 self.recompute_covariance: bool = False 

354 

355 with torch.no_grad(): 

356 pred_cov = k @ ((self.covariance @ k.t()) * self.ridge_penalty) 

357 

358 if self.mean_field_factor is None: 358 ↛ 359line 358 didn't jump to line 359 because the condition on line 358 was never true

359 return pred, pred_cov 

360 else: 

361 pred = self.mean_field_logits(pred, pred_cov, self.mean_field_factor) 

362 

363 return pred 

364 

365 

366# ------------------------------------------------------------------------------- 

367# EquineGP, below, demonstrates how to adapt that approach in EQUINE 

368@beartype 

369class EquineGP(Equine): 

370 """ 

371 An example of an EQUINE model that builds upon the approach in "Spectral Norm 

372 Gaussian Processes" (SNGP). This wraps any pytorch embedding neural network and provides 

373 the `forward`, `predict`, `save`, and `load` methods required by Equine. 

374 

375 Notes 

376 ----- 

377 Although this model build upon the approach in SNGP, it does not enforce the spectral normalization 

378 and ResNet architecture required for SNGP. Instead, it is a simple wrapper around 

379 any pytorch embedding neural network. Your mileage may vary. 

380 """ 

381 

382 def __init__( 

383 self, 

384 embedding_model: torch.nn.Module, 

385 emb_out_dim: int, 

386 num_classes: int, 

387 num_random_features: int = 1024, 

388 init_temperature: float = 1.0, 

389 device: str = "cpu", 

390 feature_names: Optional[list[str]] = None, 

391 label_names: Optional[list[str]] = None, 

392 ) -> None: 

393 """ 

394 Initialize the EquineGP model. 

395 

396 Parameters 

397 ---------- 

398 embedding_model : torch.nn.Module 

399 Neural Network feature embedding. 

400 emb_out_dim : int 

401 The number of deep features from the feature embedding. 

402 num_classes : int 

403 The number of output classes this model predicts. 

404 num_random_features : int 

405 The dimension of the output of the RandomFourierFeatures operation 

406 init_temperature : float, optional 

407 What to use as the initial temperature (1.0 has no effect). 

408 device : str, optional 

409 Either 'cuda' or 'cpu'. 

410 feature_names : list[str], optional 

411 List of strings of the names of the tabular features (ex ["duration", "fiat_mean", ...]) 

412 label_names : list[str], optional 

413 List of strings of the names of the labels (ex ["streaming", "voip", ...]) 

414 """ 

415 super().__init__( 

416 embedding_model, feature_names=feature_names, label_names=label_names 

417 ) 

418 self.num_deep_features = emb_out_dim 

419 self.num_gp_features = emb_out_dim 

420 self.normalize_gp_features = True 

421 self.num_random_features = num_random_features 

422 self.num_outputs = num_classes 

423 self.mean_field_factor = 25 

424 self.ridge_penalty = 1 

425 self.feature_scale: float = 2.0 

426 self.init_temperature = init_temperature 

427 self.register_buffer( 

428 "temperature", torch.Tensor(self.init_temperature * torch.ones(1)) 

429 ) 

430 self.model: _Laplace = _Laplace( 

431 self.embedding_model, 

432 self.num_deep_features, 

433 self.num_gp_features, 

434 self.normalize_gp_features, 

435 self.num_random_features, 

436 self.num_outputs, 

437 self.feature_scale, 

438 self.mean_field_factor, 

439 self.ridge_penalty, 

440 ) 

441 self.device_type = device 

442 self.device: torch.device = torch.device(self.device_type) 

443 self.model.to(self.device) 

444 

445 def train_model( 

446 self, 

447 dataset: Dataset, 

448 loss_fn: Callable, 

449 opt: torch.optim.Optimizer, 

450 num_epochs: int, 

451 scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None, 

452 batch_size: int = 64, 

453 validation_dataset: Optional[Dataset] = None, 

454 val_metrics: Optional[Iterable[Metric]] = None, 

455 vis_support: bool = False, 

456 support_size: int = 25, 

457 ) -> dict[str, Any]: 

458 """ 

459 Train or fine-tune an EquineGP model. 

460 

461 Parameters 

462 ---------- 

463 dataset : TensorDataset 

464 An iterable, pytorch TensorDataset. 

465 loss_fn : Callable 

466 A pytorch loss function, e.g., torch.nn.CrossEntropyLoss(). 

467 opt : torch.optim.Optimizer 

468 A pytorch optimizer, e.g., torch.optim.Adam(). 

469 num_epochs : int 

470 The desired number of epochs to use for training. 

471 scheduler : torch.optim.LRScheduler 

472 A pytorch scheduler, if one is desired 

473 validation_dataset: Dataset 

474 If provided, will compute validation metrics on this dataset after each epoch of training 

475 batch_size : int, optional 

476 The number of samples to use per batch. 

477 

478 Returns 

479 ------- 

480 dict[str, Any] 

481 A dict containing a dict of summary stats and a dataloader for the calibration data. 

482 

483 """ 

484 

485 self.validate_feature_label_names(dataset[0][0].shape[-1], self.num_outputs) 

486 

487 train_loader = DataLoader( 

488 dataset, batch_size=batch_size, shuffle=True, drop_last=False 

489 ) 

490 

491 val_loader: Optional[DataLoader] = None 

492 if validation_dataset is not None: 

493 val_loader = DataLoader( 

494 validation_dataset, 

495 batch_size=batch_size, 

496 shuffle=False, 

497 drop_last=False, 

498 ) 

499 

500 self.model.set_training_params(len(dataset), batch_size) 

501 val_metrics_outputs: Optional[list[list[float]]] = None 

502 

503 if validation_dataset is not None and val_metrics is not None: 

504 val_metrics_outputs = [[] for i in range(len(list(val_metrics)))] 

505 

506 for _ in tqdm(range(num_epochs)): 

507 self.model.train() 

508 self.model.reset_precision_matrix() 

509 epoch_loss = 0.0 

510 for i, (xs, labels) in enumerate(train_loader): 

511 opt.zero_grad() 

512 xs = xs.to(self.device) 

513 labels = labels.to(self.device) 

514 yhats = self.model(xs) 

515 loss = loss_fn(yhats, labels.to(torch.long)) 

516 loss.backward() 

517 opt.step() 

518 epoch_loss += loss.item() 

519 if scheduler is not None: 

520 scheduler.step() 

521 self.model.eval() 

522 # compute the validation metrics 

523 if ( 

524 validation_dataset is not None 

525 and val_loader is not None 

526 and val_metrics is not None 

527 and val_metrics_outputs is not None 

528 ): 

529 for _, (xs_val, labels_val) in enumerate(val_loader): 

530 xs_val = xs_val.to(self.device) 

531 labels_val = labels_val.to(self.device) 

532 yhats_val = self.model(xs_val) 

533 for metric in val_metrics: 

534 metric.update(yhats_val, labels_val) 

535 for i, metric in enumerate(val_metrics): 

536 val_metrics_outputs[i].append(metric.compute()) 

537 if vis_support: 

538 self.update_support(dataset.tensors[0], dataset.tensors[1], support_size) 

539 

540 _, train_y = dataset[:] 

541 date_trained = datetime.now().strftime("%m/%d/%Y, %H:%M:%S") 

542 self.train_summary: dict[str, Any] = generate_train_summary( 

543 self, train_y, date_trained 

544 ) 

545 

546 return_dict: dict[str, Any] = dict() 

547 return_dict["train_summary"] = self.train_summary 

548 if validation_dataset is not None: 

549 return_dict["val_metrics"] = val_metrics_outputs 

550 

551 return return_dict 

552 

553 def update_support( 

554 self, support_x: torch.Tensor, support_y: torch.Tensor, support_size: int 

555 ) -> None: 

556 """Function to update protonet support examples with given examples. 

557 

558 Parameters 

559 ---------- 

560 support_x : torch.Tensor 

561 Tensor containing support examples for protonet. 

562 support_y : torch.Tensor 

563 Tensor containing labels for given support examples. 

564 

565 Returns 

566 ------- 

567 None 

568 """ 

569 

570 labels, counts = torch.unique(support_y, return_counts=True) 

571 support = OrderedDict() 

572 for label, count in list(zip(labels.tolist(), counts.tolist())): 

573 class_support = generate_support( 

574 support_x, 

575 support_y, 

576 support_size=min(count, support_size), 

577 selected_labels=[label], 

578 ) 

579 support.update(class_support) 

580 

581 self.support = support 

582 

583 support_embeddings = OrderedDict().fromkeys(self.support.keys(), torch.Tensor()) 

584 for label in support: 

585 support_embeddings[label] = self.compute_embeddings(support[label]) 

586 

587 self.support_embeddings = support_embeddings 

588 self.prototypes: torch.Tensor = self.compute_prototypes() 

589 

590 def compute_embeddings(self, x: torch.Tensor) -> torch.Tensor: 

591 """ 

592 Method for computing deep embeddings for given input tensor. 

593 

594 Parameters 

595 ---------- 

596 x : torch.Tensor 

597 Input tensor for generating embeddings. 

598 

599 Returns 

600 ------- 

601 torch.Tensor 

602 Output embeddings . 

603 """ 

604 f = self.model.feature_extractor(x) 

605 f_reduc = self.model.jl(f) 

606 if self.model.normalize_gp_features: 606 ↛ 609line 606 didn't jump to line 609 because the condition on line 606 was always true

607 f_reduc = self.model.normalize(f_reduc) 

608 

609 return self.model.rff(f_reduc) 

610 

611 @icontract.require(lambda self: len(self.support) > 0) 

612 def compute_prototypes(self) -> torch.Tensor: 

613 """ 

614 Method for computing class prototypes based on given support examples. 

615 ``Prototypes'' in this context are the means of the support embeddings for each class. 

616 

617 Returns 

618 ------- 

619 torch.Tensor 

620 Tensors of prototypes for each of the given classes in the support. 

621 """ 

622 # Compute support embeddings 

623 support_embeddings = OrderedDict().fromkeys(self.support.keys()) 

624 for label in self.support: 

625 support_embeddings[label] = self.compute_embeddings(self.support[label]) 

626 

627 # Compute prototype for each class 

628 proto_list = [] 

629 for label in self.support: # look at doing functorch 

630 class_prototype = torch.mean(support_embeddings[label], dim=0) # type: ignore 

631 proto_list.append(class_prototype) 

632 

633 prototypes = torch.stack(proto_list) 

634 

635 return prototypes 

636 

637 @icontract.require(lambda self: len(self.support) > 0) 

638 def get_support(self) -> OrderedDict[int, torch.Tensor]: 

639 """ 

640 Method for returning support examples used in training. 

641 

642 Returns 

643 ------- 

644 OrderedDict[int, torch.Tensor] 

645 Dictionary containing support examples for each class. 

646 """ 

647 return self.support 

648 

649 @icontract.require(lambda self: len(self.prototypes) > 0) 

650 def get_prototypes(self) -> torch.Tensor: 

651 """ 

652 Method for returning class prototypes. 

653 

654 Returns 

655 ------- 

656 torch.Tensor 

657 Tensors of prototypes for each of the given classes in the support. 

658 """ 

659 return self.prototypes 

660 

661 @icontract.require(lambda num_calibration_epochs: 0 < num_calibration_epochs) 

662 @icontract.require(lambda calibration_lr: calibration_lr > 0.0) 

663 def calibrate_model( 

664 self, 

665 dataset: torch.utils.data.Dataset, 

666 num_calibration_epochs: int = 1, 

667 calibration_lr: float = 0.01, 

668 calibration_batch_size: int = 256, 

669 ) -> None: 

670 """ 

671 Fine-tune the temperature after training. Note this function is also run at the conclusion of train_model. 

672 

673 Parameters 

674 ---------- 

675 dataset : TensorDataset 

676 An iterable, pytorch TensorDataset. 

677 num_calibration_epochs : int, optional 

678 Number of epochs to tune temperature. 

679 calibration_lr : float, optional 

680 Learning rate for temperature optimization. 

681 """ 

682 

683 calibration_loader = DataLoader( 

684 dataset, 

685 batch_size=calibration_batch_size, 

686 shuffle=True, 

687 drop_last=False, 

688 ) 

689 

690 self.temperature.requires_grad = True 

691 loss_fn = torch.nn.functional.cross_entropy 

692 optimizer = torch.optim.Adam([self.temperature], lr=calibration_lr) 

693 for _ in range(num_calibration_epochs): 

694 for xs, labels in calibration_loader: 

695 optimizer.zero_grad() 

696 xs = xs.to(self.device) 

697 labels = labels.to(self.device) 

698 with torch.no_grad(): 

699 logits = self.model(xs) 

700 logits = logits / self.temperature 

701 loss = loss_fn(logits, labels.to(torch.long)) 

702 loss.backward() 

703 optimizer.step() 

704 self.temperature.requires_grad = False 

705 

706 def forward(self, X: torch.Tensor) -> torch.Tensor: 

707 """ 

708 EquineGP forward function, generates logits for classification. 

709 

710 Parameters 

711 ---------- 

712 X : torch.Tensor 

713 Input tensor for generating predictions. 

714 

715 Returns 

716 ------- 

717 torch.Tensor 

718 Output probabilities computed. 

719 """ 

720 X = X.to(self.device) 

721 preds = self.model(X) 

722 return preds / self.temperature.to(self.device) 

723 

724 @icontract.ensure( 

725 lambda result: all((0 <= result.ood_scores) & (result.ood_scores <= 1.0)) 

726 ) 

727 def predict(self, X: torch.Tensor) -> EquineOutput: 

728 """ 

729 Predict function for EquineGP, inherited and implemented from Equine. 

730 

731 Parameters 

732 ---------- 

733 X : torch.Tensor 

734 Input tensor. 

735 

736 Returns 

737 ------- 

738 EquineOutput 

739 Output object containing prediction probabilities and OOD scores. 

740 """ 

741 logits = self(X) 

742 preds = torch.softmax(logits, dim=1) 

743 equiprobable = torch.ones(self.num_outputs) / self.num_outputs 

744 max_entropy = torch.sum(torch.special.entr(equiprobable)) 

745 ood_score = torch.sum(torch.special.entr(preds), dim=1) / max_entropy 

746 embeddings = self.compute_embeddings(X) 

747 eq_out = EquineOutput( 

748 classes=preds, ood_scores=ood_score, embeddings=embeddings 

749 ) # TODO return embeddings 

750 

751 self.validate_feature_label_names(X.shape[-1], self.num_outputs) 

752 

753 return eq_out 

754 

755 def save(self, path: str) -> None: 

756 """ 

757 Function to save all model parameters to a file. 

758 

759 Parameters 

760 ---------- 

761 path : str 

762 Filename to write the model. 

763 """ 

764 model_settings = { 

765 "emb_out_dim": self.num_deep_features, 

766 "num_classes": self.num_outputs, 

767 "init_temperature": self.temperature.item(), 

768 "device": self.device_type, 

769 } 

770 

771 jit_model = torch.jit.script(self.model.feature_extractor) 

772 buffer = io.BytesIO() 

773 torch.jit.save(jit_model, buffer) 

774 buffer.seek(0) 

775 

776 laplace_sd = self.model.state_dict() 

777 keys_to_delete = [] 

778 for key in laplace_sd: 

779 if "feature_extractor" in key: 

780 keys_to_delete.append(key) 

781 for key in keys_to_delete: 

782 del laplace_sd[key] 

783 

784 save_data = { 

785 "embed_jit_save": buffer, 

786 "feature_names": self.feature_names, 

787 "label_names": self.label_names, 

788 "laplace_model_save": laplace_sd, 

789 "num_data": self.model.num_data, 

790 "settings": model_settings, 

791 "support": self.support, 

792 "train_batch_size": self.model.train_batch_size, 

793 "train_summary": self.train_summary, 

794 } 

795 

796 torch.save(save_data, path) # TODO allow model checkpointing 

797 

798 @classmethod 

799 def load(cls, path: str) -> Equine: 

800 """ 

801 Function to load previously saved EquineGP model. 

802 

803 Parameters 

804 ---------- 

805 path : str 

806 Input filename. 

807 

808 Returns 

809 ------- 

810 EquineGP 

811 The reconstituted EquineGP object. 

812 """ 

813 model_save = torch.load(path, weights_only=False) 

814 jit_model = torch.jit.load(model_save.get("embed_jit_save")) 

815 eq_model = cls(jit_model, **model_save.get("settings")) 

816 

817 eq_model.feature_names = model_save.get("feature_names") 

818 eq_model.label_names = model_save.get("label_names") 

819 eq_model.train_summary = model_save.get("train_summary") 

820 

821 eq_model.model.load_state_dict( 

822 model_save.get("laplace_model_save"), strict=False 

823 ) 

824 eq_model.model.seen_data = model_save.get("laplace_model_save").get("seen_data") 

825 

826 eq_model.model.set_training_params( 

827 model_save.get("num_data"), model_save.get("train_batch_size") 

828 ) 

829 eq_model.eval() 

830 

831 support = model_save.get("support") 

832 if len(support) > 0: 

833 eq_model.support = support 

834 eq_model.prototypes = eq_model.compute_prototypes() 

835 

836 return eq_model