Coverage for /opt/hostedtoolcache/Python/3.10.18/x64/lib/python3.10/site-packages/equine/equine_gp.py: 96%

271 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-29 04:12 +0000

1# Copyright 2024, MASSACHUSETTS INSTITUTE OF TECHNOLOGY 

2# Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014). 

3# SPDX-License-Identifier: MIT 

4 

5from typing import Any, Optional, Union 

6 

7import icontract 

8import io 

9import math 

10import torch 

11from beartype import beartype 

12from collections import OrderedDict 

13from collections.abc import Callable 

14from datetime import datetime 

15from torch.utils.data import DataLoader, TensorDataset 

16from tqdm import tqdm 

17 

18from .equine import Equine, EquineOutput 

19from .utils import generate_support, generate_train_summary, stratified_train_test_split 

20 

21BatchType = tuple[torch.Tensor, ...] 

22# ------------------------------------------------------------------------------- 

23# Note that the below code for 

24# * `_random_ortho`, 

25# * `_RandomFourierFeatures``, and 

26# * `_Laplace` 

27# is copied and modified from https://github.com/y0ast/DUE/blob/main/due/sngp.py 

28# under its original MIT license, redisplayed here: 

29# ------------------------------------------------------------------------------- 

30# MIT License 

31# 

32# Copyright (c) 2021 Joost van Amersfoort 

33# 

34# Permission is hereby granted, free of charge, to any person obtaining a copy 

35# of this software and associated documentation files (the "Software"), to deal 

36# in the Software without restriction, including without limitation the rights 

37# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

38# copies of the Software, and to permit persons to whom the Software is 

39# furnished to do so, subject to the following conditions: 

40# 

41# The above copyright notice and this permission notice shall be included in all 

42# copies or substantial portions of the Software. 

43# 

44# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

45# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

46# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

47# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

48# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

49# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

50# SOFTWARE. 

51# ------------------------------------------------------------------------------ 

52# Following the recommendation of their README at https://github.com/y0ast/DUE 

53# we encourage anyone using this code in their research to cite the following papers: 

54# 

55# @article{van2021on, 

56# title={On Feature Collapse and Deep Kernel Learning for Single Forward Pass Uncertainty}, 

57# author={van Amersfoort, Joost and Smith, Lewis and Jesson, Andrew and Key, Oscar and Gal, Yarin}, 

58# journal={arXiv preprint arXiv:2102.11409}, 

59# year={2021} 

60# } 

61# 

62# @article{liu2020simple, 

63# title={Simple and principled uncertainty estimation with deterministic deep learning via distance awareness}, 

64# author={Liu, Jeremiah and Lin, Zi and Padhy, Shreyas and Tran, Dustin and Bedrax Weiss, Tania and Lakshminarayanan, Balaji}, 

65# journal={Advances in Neural Information Processing Systems}, 

66# volume={33}, 

67# pages={7498--7512}, 

68# year={2020} 

69# } 

70 

71 

72@beartype 

73def _random_ortho(n: int, m: int) -> torch.Tensor: 

74 """ 

75 Generate a random orthonormal matrix. 

76 

77 Parameters 

78 ---------- 

79 n : int 

80 The number of rows. 

81 m : int 

82 The number of columns. 

83 

84 Returns 

85 ------- 

86 torch.Tensor 

87 The random orthonormal matrix. 

88 """ 

89 q, _ = torch.linalg.qr(torch.randn(n, m)) 

90 return q 

91 

92 

93@beartype 

94class _RandomFourierFeatures(torch.nn.Module): 

95 """ 

96 A private class to generate random Fourier features for the embedding model. 

97 """ 

98 

99 def __init__( 

100 self, in_dim: int, num_random_features: int, feature_scale: Optional[float] 

101 ) -> None: 

102 """ 

103 Initialize the _RandomFourierFeatures module, which generates random Fourier features 

104 for the embedding model. 

105 

106 Parameters 

107 ---------- 

108 in_dim : int 

109 The input dimensionality. 

110 num_random_features : int 

111 The number of random Fourier features to generate. 

112 feature_scale : Optional[float] 

113 The scaling factor for the random Fourier features. If None, defaults to sqrt(num_random_features / 2). 

114 """ 

115 super().__init__() 

116 if feature_scale is None: 116 ↛ 117line 116 didn't jump to line 117 because the condition on line 116 was never true

117 feature_scale = math.sqrt(num_random_features / 2) 

118 

119 self.register_buffer("feature_scale", torch.tensor(feature_scale)) 

120 

121 if num_random_features <= in_dim: 121 ↛ 122line 121 didn't jump to line 122 because the condition on line 121 was never true

122 W = _random_ortho(in_dim, num_random_features) 

123 else: 

124 # generate blocks of orthonormal rows which are not necessarily orthonormal 

125 # to each other. 

126 dim_left = num_random_features 

127 ws = [] 

128 while dim_left > in_dim: 

129 ws.append(_random_ortho(in_dim, in_dim)) 

130 dim_left -= in_dim 

131 ws.append(_random_ortho(in_dim, dim_left)) 

132 W = torch.cat(ws, 1) 

133 

134 feature_norm = torch.randn(W.shape) ** 2 

135 W = W * feature_norm.sum(0).sqrt() 

136 self.register_buffer("W", W) 

137 

138 b = torch.empty(num_random_features).uniform_(0, 2 * math.pi) 

139 self.register_buffer("b", b) 

140 

141 def forward(self, x: torch.Tensor) -> torch.Tensor: 

142 """ 

143 Compute the forward pass of the _RandomFourierFeatures module. 

144 

145 Parameters 

146 ---------- 

147 x : torch.Tensor 

148 The input tensor of shape (batch_size, in_dim). 

149 

150 Returns 

151 ------- 

152 torch.Tensor 

153 The output tensor of shape (batch_size, num_random_features). 

154 """ 

155 k = torch.cos(x @ self.W + self.b) 

156 k = k / self.feature_scale 

157 

158 return k 

159 

160 

161class _Laplace(torch.nn.Module): 

162 """ 

163 A private class to compute a Laplace approximation to a Gaussian Process (GP) 

164 """ 

165 

166 def __init__( 

167 self, 

168 feature_extractor: torch.nn.Module, 

169 num_deep_features: int, 

170 num_gp_features: int, 

171 normalize_gp_features: bool, 

172 num_random_features: int, 

173 num_outputs: int, 

174 feature_scale: Optional[float], 

175 mean_field_factor: Optional[float], # required for classification problems 

176 ridge_penalty: float = 1.0, 

177 ) -> None: 

178 """ 

179 Initialize the _Laplace module. 

180 

181 Parameters 

182 ---------- 

183 feature_extractor : torch.nn.Module 

184 The feature extractor module. 

185 num_deep_features : int 

186 The number of features output by the feature extractor. 

187 num_gp_features : int 

188 The number of features to use in the Gaussian process. 

189 normalize_gp_features : bool 

190 Whether to normalize the GP features. 

191 num_random_features : int 

192 The number of random Fourier features to use. 

193 num_outputs : int 

194 The number of outputs of the model. 

195 feature_scale : Optional[float] 

196 The scaling factor for the random Fourier features. 

197 mean_field_factor : Optional[float] 

198 The mean-field factor for the Gaussian-Softmax approximation. 

199 Required for classification problems. 

200 ridge_penalty : float, optional 

201 The ridge penalty for the Laplace approximation. 

202 """ 

203 super().__init__() 

204 self.feature_extractor = feature_extractor 

205 self.mean_field_factor = mean_field_factor 

206 self.ridge_penalty = ridge_penalty 

207 self.train_batch_size = 0 # to be set later 

208 

209 if num_gp_features > 0: 209 ↛ 219line 209 didn't jump to line 219 because the condition on line 209 was always true

210 self.num_gp_features = num_gp_features 

211 self.register_buffer( 

212 "random_matrix", 

213 torch.normal(0, 0.05, (num_gp_features, num_deep_features)), 

214 ) 

215 self.jl: Callable = lambda x: torch.nn.functional.linear( 

216 x, self.random_matrix 

217 ) 

218 else: 

219 self.num_gp_features: int = num_deep_features 

220 self.jl: Callable = lambda x: x # Identity 

221 

222 self.normalize_gp_features = normalize_gp_features 

223 if normalize_gp_features: 223 ↛ 226line 223 didn't jump to line 226 because the condition on line 223 was always true

224 self.normalize: torch.nn.LayerNorm = torch.nn.LayerNorm(num_gp_features) 

225 

226 self.rff: _RandomFourierFeatures = _RandomFourierFeatures( 

227 num_gp_features, num_random_features, feature_scale 

228 ) 

229 self.beta: torch.nn.Linear = torch.nn.Linear(num_random_features, num_outputs) 

230 

231 self.num_data = 0 # to be set later 

232 self.register_buffer("seen_data", torch.tensor(0)) 

233 

234 precision = torch.eye(num_random_features) * self.ridge_penalty 

235 self.register_buffer("precision", precision) 

236 

237 self.recompute_covariance = True 

238 self.register_buffer("covariance", torch.eye(num_random_features)) 

239 self.training_parameters_set = False 

240 

241 def reset_precision_matrix(self) -> None: 

242 """ 

243 Reset the precision matrix to the identity matrix times the ridge penalty. 

244 """ 

245 identity = torch.eye(self.precision.shape[0], device=self.precision.device) 

246 self.precision: torch.Tensor = identity * self.ridge_penalty 

247 self.seen_data: torch.Tensor = torch.tensor(0) 

248 self.recompute_covariance = True 

249 

250 @icontract.require(lambda num_data: num_data > 0) 

251 @icontract.require( 

252 lambda num_data, batch_size: (0 < batch_size) & (batch_size <= num_data) 

253 ) 

254 def set_training_params(self, num_data: int, batch_size: int) -> None: 

255 """ 

256 Set the training parameters for the Laplace approximation. 

257 

258 Parameters 

259 ---------- 

260 num_data : int 

261 The total number of data points. 

262 batch_size : int 

263 The batch size to use during training. 

264 """ 

265 self.num_data: int = num_data 

266 self.train_batch_size: int = batch_size 

267 self.training_parameters_set: bool = True 

268 

269 @icontract.require(lambda self: self.mean_field_factor is not None) 

270 def mean_field_logits( 

271 self, logits: torch.Tensor, pred_cov: torch.Tensor 

272 ) -> torch.Tensor: 

273 """ 

274 Compute the mean-field logits for the Gaussian-Softmax approximation. 

275 

276 Parameters 

277 ---------- 

278 logits : torch.Tensor 

279 The logits tensor of shape (batch_size, num_outputs). 

280 pred_cov : torch.Tensor 

281 The predicted covariance matrix of shape (batch_size, batch_size). 

282 

283 Returns 

284 ------- 

285 torch.Tensor 

286 The mean-field logits tensor of shape (batch_size, num_outputs). 

287 """ 

288 # Mean-Field approximation as alternative to MC integration of Gaussian-Softmax 

289 # Based on: https://arxiv.org/abs/2006.07584 

290 

291 logits_scale = torch.sqrt(1.0 + torch.diag(pred_cov) * self.mean_field_factor) 

292 if self.mean_field_factor > 0: # type: ignore 292 ↛ 295line 292 didn't jump to line 295 because the condition on line 292 was always true

293 logits = logits / logits_scale.unsqueeze(-1) 

294 

295 return logits 

296 

297 @icontract.require(lambda self: self.training_parameters_set) 

298 def forward( 

299 self, x: torch.Tensor 

300 ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: 

301 """ 

302 Compute the forward pass of the Laplace approximation to the Gaussian Process. 

303 

304 Parameters 

305 ---------- 

306 x : torch.Tensor 

307 The input tensor of shape (batch_size, num_features). 

308 

309 Returns 

310 ------- 

311 Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]] 

312 If the model is in training mode, returns the predicted mean of shape (batch_size, 1). 

313 If the model is in evaluation mode, returns a tuple containing the predicted mean of shape (batch_size, 1) 

314 and the predicted covariance matrix of shape (batch_size, batch_size). 

315 """ 

316 f = self.feature_extractor(x) 

317 f_reduc = self.jl(f) 

318 if self.normalize_gp_features: 318 ↛ 321line 318 didn't jump to line 321 because the condition on line 318 was always true

319 f_reduc = self.normalize(f_reduc) 

320 

321 k = self.rff(f_reduc) 

322 

323 pred = self.beta(k) 

324 

325 if self.training: 

326 precision_minibatch = k.t() @ k 

327 self.precision += precision_minibatch 

328 self.seen_data += x.shape[0] 

329 

330 assert ( 

331 self.seen_data <= self.num_data 

332 ), "Did not reset precision matrix at start of epoch" 

333 else: 

334 assert self.seen_data > ( 

335 self.num_data - self.train_batch_size 

336 ), "Not seen sufficient data for precision matrix" 

337 

338 if self.recompute_covariance: 

339 with torch.no_grad(): 

340 eps = 1e-7 

341 jitter = eps * torch.eye( 

342 self.precision.shape[1], 

343 device=self.precision.device, 

344 ) 

345 u, info = torch.linalg.cholesky_ex(self.precision + jitter) 

346 assert (info == 0).all(), "Precision matrix inversion failed!" 

347 torch.cholesky_inverse(u, out=self.covariance) 

348 

349 self.recompute_covariance: bool = False 

350 

351 with torch.no_grad(): 

352 pred_cov = k @ ((self.covariance @ k.t()) * self.ridge_penalty) 

353 

354 if self.mean_field_factor is None: 354 ↛ 355line 354 didn't jump to line 355 because the condition on line 354 was never true

355 return pred, pred_cov 

356 else: 

357 pred = self.mean_field_logits(pred, pred_cov) 

358 

359 return pred 

360 

361 

362# ------------------------------------------------------------------------------- 

363# EquineGP, below, demonstrates how to adapt that approach in EQUINE 

364@beartype 

365class EquineGP(Equine): 

366 """ 

367 An example of an EQUINE model that builds upon the approach in "Spectral Norm 

368 Gaussian Processes" (SNGP). This wraps any pytorch embedding neural network and provides 

369 the `forward`, `predict`, `save`, and `load` methods required by Equine. 

370 

371 Notes 

372 ----- 

373 Although this model build upon the approach in SNGP, it does not enforce the spectral normalization 

374 and ResNet architecture required for SNGP. Instead, it is a simple wrapper around 

375 any pytorch embedding neural network. Your mileage may vary. 

376 """ 

377 

378 def __init__( 

379 self, 

380 embedding_model: torch.nn.Module, 

381 emb_out_dim: int, 

382 num_classes: int, 

383 use_temperature: bool = False, 

384 init_temperature: float = 1.0, 

385 device: str = "cpu", 

386 feature_names: Optional[list[str]] = None, 

387 label_names: Optional[list[str]] = None, 

388 ) -> None: 

389 """ 

390 Initialize the EquineGP model. 

391 

392 Parameters 

393 ---------- 

394 embedding_model : torch.nn.Module 

395 Neural Network feature embedding. 

396 emb_out_dim : int 

397 The number of deep features from the feature embedding. 

398 num_classes : int 

399 The number of output classes this model predicts. 

400 use_temperature : bool, optional 

401 Whether to use temperature scaling after training. 

402 init_temperature : float, optional 

403 What to use as the initial temperature (1.0 has no effect). 

404 device : str, optional 

405 Either 'cuda' or 'cpu'. 

406 feature_names : list[str], optional 

407 List of strings of the names of the tabular features (ex ["duration", "fiat_mean", ...]) 

408 label_names : list[str], optional 

409 List of strings of the names of the labels (ex ["streaming", "voip", ...]) 

410 """ 

411 super().__init__( 

412 embedding_model, feature_names=feature_names, label_names=label_names 

413 ) 

414 self.num_deep_features = emb_out_dim 

415 self.num_gp_features = emb_out_dim 

416 self.normalize_gp_features = True 

417 self.num_random_features = 1024 

418 self.num_outputs = num_classes 

419 self.mean_field_factor = 25 

420 self.ridge_penalty = 1 

421 self.feature_scale: float = 2.0 

422 self.use_temperature = use_temperature 

423 self.init_temperature = init_temperature 

424 self.register_buffer( 

425 "temperature", torch.Tensor(self.init_temperature * torch.ones(1)) 

426 ) 

427 self.model: _Laplace = _Laplace( 

428 self.embedding_model, 

429 self.num_deep_features, 

430 self.num_gp_features, 

431 self.normalize_gp_features, 

432 self.num_random_features, 

433 self.num_outputs, 

434 self.feature_scale, 

435 self.mean_field_factor, 

436 self.ridge_penalty, 

437 ) 

438 self.device_type = device 

439 self.device: torch.device = torch.device(self.device_type) 

440 self.model.to(self.device) 

441 

442 def train_model( 

443 self, 

444 dataset: TensorDataset, 

445 loss_fn: Callable, 

446 opt: torch.optim.Optimizer, 

447 num_epochs: int, 

448 batch_size: int = 64, 

449 calib_frac: float = 0.1, 

450 num_calibration_epochs: int = 2, 

451 calibration_lr: float = 0.01, 

452 vis_support: bool = False, 

453 support_size: int = 25, 

454 ) -> dict[str, Any]: 

455 """ 

456 Train or fine-tune an EquineGP model. 

457 

458 Parameters 

459 ---------- 

460 dataset : TensorDataset 

461 An iterable, pytorch TensorDataset. 

462 loss_fn : Callable 

463 A pytorch loss function, e.g., torch.nn.CrossEntropyLoss(). 

464 opt : torch.optim.Optimizer 

465 A pytorch optimizer, e.g., torch.optim.Adam(). 

466 num_epochs : int 

467 The desired number of epochs to use for training. 

468 batch_size : int, optional 

469 The number of samples to use per batch. 

470 calib_frac : float, optional 

471 Fraction of training data to use in temperature scaling. 

472 num_calibration_epochs : int, optional 

473 The desired number of epochs to use for temperature scaling. 

474 calibration_lr : float, optional 

475 Learning rate for temperature scaling. 

476 

477 Returns 

478 ------- 

479 dict[str, Any] 

480 A dict containing a dict of summary stats and a dataloader for the calibration data. 

481 

482 Notes 

483 ------- 

484 - If `use_temperature` is True, temperature scaling will be used after training. 

485 - The calibration data is used to calibrate the temperature scaling. 

486 """ 

487 X, Y = dataset[:] 

488 calib_x = torch.Tensor() 

489 calib_y = torch.Tensor() 

490 if self.use_temperature: 

491 train_x, calib_x, train_y, calib_y = stratified_train_test_split( 

492 X, Y, test_size=calib_frac 

493 ) 

494 dataset = TensorDataset(torch.Tensor(train_x), torch.Tensor(train_y)) 

495 self.temperature: torch.Tensor = torch.Tensor( 

496 self.init_temperature * torch.ones(1) 

497 ).type_as(self.temperature) 

498 

499 self.validate_feature_label_names(X.shape[-1], self.num_outputs) 

500 

501 train_loader = DataLoader( 

502 dataset, batch_size=batch_size, shuffle=True, drop_last=True 

503 ) 

504 self.model.set_training_params(len(dataset), batch_size) 

505 self.model.train() 

506 for _ in tqdm(range(num_epochs)): 

507 self.model.reset_precision_matrix() 

508 epoch_loss = 0.0 

509 for i, (xs, labels) in enumerate(train_loader): 

510 opt.zero_grad() 

511 xs = xs.to(self.device) 

512 labels = labels.to(self.device) 

513 yhats = self.model(xs) 

514 loss = loss_fn(yhats, labels.to(torch.long)) 

515 loss.backward() 

516 opt.step() 

517 epoch_loss += loss.item() 

518 self.model.eval() 

519 

520 if vis_support: 

521 self.update_support(dataset.tensors[0], dataset.tensors[1], support_size) 

522 

523 calibration_loader = None 

524 if self.use_temperature: 

525 dataset_calibration = TensorDataset( 

526 torch.Tensor(calib_x), torch.Tensor(calib_y) 

527 ) 

528 calibration_loader = DataLoader( 

529 dataset_calibration, 

530 batch_size=batch_size, 

531 shuffle=True, 

532 drop_last=False, 

533 ) 

534 self.calibrate_temperature( 

535 calibration_loader, num_calibration_epochs, calibration_lr 

536 ) 

537 

538 _, train_y = dataset[:] 

539 date_trained = datetime.now().strftime("%m/%d/%Y, %H:%M:%S") 

540 self.train_summary: dict[str, Any] = generate_train_summary( 

541 self, train_y, date_trained 

542 ) 

543 

544 return_dict: dict[str, Any] = dict() 

545 return_dict["train_summary"] = self.train_summary 

546 return_dict["calibration_loader"] = calibration_loader 

547 

548 return return_dict 

549 

550 def update_support( 

551 self, support_x: torch.Tensor, support_y: torch.Tensor, support_size: int 

552 ) -> None: 

553 """Function to update protonet support examples with given examples. 

554 

555 Parameters 

556 ---------- 

557 support_x : torch.Tensor 

558 Tensor containing support examples for protonet. 

559 support_y : torch.Tensor 

560 Tensor containing labels for given support examples. 

561 

562 Returns 

563 ------- 

564 None 

565 """ 

566 

567 labels, counts = torch.unique(support_y, return_counts=True) 

568 support = OrderedDict() 

569 for label, count in list(zip(labels.tolist(), counts.tolist())): 

570 class_support = generate_support( 

571 support_x, 

572 support_y, 

573 support_size=min(count, support_size), 

574 selected_labels=[label], 

575 ) 

576 support.update(class_support) 

577 

578 self.support = support 

579 

580 support_embeddings = OrderedDict().fromkeys(self.support.keys(), torch.Tensor()) 

581 for label in support: 

582 support_embeddings[label] = self.compute_embeddings(support[label]) 

583 

584 self.support_embeddings = support_embeddings 

585 self.prototypes: torch.Tensor = self.compute_prototypes() 

586 

587 def compute_embeddings(self, x: torch.Tensor) -> torch.Tensor: 

588 """ 

589 Method for computing deep embeddings for given input tensor. 

590 

591 Parameters 

592 ---------- 

593 x : torch.Tensor 

594 Input tensor for generating embeddings. 

595 

596 Returns 

597 ------- 

598 torch.Tensor 

599 Output embeddings . 

600 """ 

601 f = self.model.feature_extractor(x) 

602 f_reduc = self.model.jl(f) 

603 if self.model.normalize_gp_features: 603 ↛ 606line 603 didn't jump to line 606 because the condition on line 603 was always true

604 f_reduc = self.model.normalize(f_reduc) 

605 

606 return self.model.rff(f_reduc) 

607 

608 @icontract.require(lambda self: len(self.support) > 0) 

609 def compute_prototypes(self) -> torch.Tensor: 

610 """ 

611 Method for computing class prototypes based on given support examples. 

612 ``Prototypes'' in this context are the means of the support embeddings for each class. 

613 

614 Returns 

615 ------- 

616 torch.Tensor 

617 Tensors of prototypes for each of the given classes in the support. 

618 """ 

619 # Compute support embeddings 

620 support_embeddings = OrderedDict().fromkeys(self.support.keys()) 

621 for label in self.support: 

622 support_embeddings[label] = self.compute_embeddings(self.support[label]) 

623 

624 # Compute prototype for each class 

625 proto_list = [] 

626 for label in self.support: # look at doing functorch 

627 class_prototype = torch.mean(support_embeddings[label], dim=0) # type: ignore 

628 proto_list.append(class_prototype) 

629 

630 prototypes = torch.stack(proto_list) 

631 

632 return prototypes 

633 

634 @icontract.require(lambda self: len(self.support) > 0) 

635 def get_support(self) -> OrderedDict[int, torch.Tensor]: 

636 """ 

637 Method for returning support examples used in training. 

638 

639 Returns 

640 ------- 

641 OrderedDict[int, torch.Tensor] 

642 Dictionary containing support examples for each class. 

643 """ 

644 return self.support 

645 

646 @icontract.require(lambda self: len(self.prototypes) > 0) 

647 def get_prototypes(self) -> torch.Tensor: 

648 """ 

649 Method for returning class prototypes. 

650 

651 Returns 

652 ------- 

653 torch.Tensor 

654 Tensors of prototypes for each of the given classes in the support. 

655 """ 

656 return self.prototypes 

657 

658 @icontract.require(lambda num_calibration_epochs: 0 < num_calibration_epochs) 

659 @icontract.require(lambda calibration_lr: calibration_lr > 0.0) 

660 def calibrate_temperature( 

661 self, 

662 calibration_loader: DataLoader[BatchType], 

663 num_calibration_epochs: int = 1, 

664 calibration_lr: float = 0.01, 

665 ) -> None: 

666 """ 

667 Fine-tune the temperature after training. Note this function is also run at the conclusion of train_model. 

668 

669 Parameters 

670 ---------- 

671 calibration_loader : DataLoader 

672 Data loader returned by train_model. 

673 num_calibration_epochs : int, optional 

674 Number of epochs to tune temperature. 

675 calibration_lr : float, optional 

676 Learning rate for temperature optimization. 

677 """ 

678 self.temperature.requires_grad = True 

679 loss_fn = torch.nn.functional.cross_entropy 

680 optimizer = torch.optim.Adam([self.temperature], lr=calibration_lr) 

681 for _ in range(num_calibration_epochs): 

682 for xs, labels in calibration_loader: 

683 optimizer.zero_grad() 

684 xs = xs.to(self.device) 

685 labels = labels.to(self.device) 

686 with torch.no_grad(): 

687 logits = self.model(xs) 

688 logits = logits / self.temperature 

689 loss = loss_fn(logits, labels.to(torch.long)) 

690 loss.backward() 

691 optimizer.step() 

692 self.temperature.requires_grad = False 

693 

694 def forward(self, X: torch.Tensor) -> torch.Tensor: 

695 """ 

696 EquineGP forward function, generates logits for classification. 

697 

698 Parameters 

699 ---------- 

700 X : torch.Tensor 

701 Input tensor for generating predictions. 

702 

703 Returns 

704 ------- 

705 torch.Tensor 

706 Output probabilities computed. 

707 """ 

708 X = X.to(self.device) 

709 preds = self.model(X) 

710 return preds / self.temperature.to(self.device) 

711 

712 @icontract.ensure( 

713 lambda result: all((0 <= result.ood_scores) & (result.ood_scores <= 1.0)) 

714 ) 

715 def predict(self, X: torch.Tensor) -> EquineOutput: 

716 """ 

717 Predict function for EquineGP, inherited and implemented from Equine. 

718 

719 Parameters 

720 ---------- 

721 X : torch.Tensor 

722 Input tensor. 

723 

724 Returns 

725 ------- 

726 EquineOutput 

727 Output object containing prediction probabilities and OOD scores. 

728 """ 

729 logits = self(X) 

730 preds = torch.softmax(logits, dim=1) 

731 equiprobable = torch.ones(self.num_outputs) / self.num_outputs 

732 max_entropy = torch.sum(torch.special.entr(equiprobable)) 

733 ood_score = torch.sum(torch.special.entr(preds), dim=1) / max_entropy 

734 embeddings = self.compute_embeddings(X) 

735 eq_out = EquineOutput( 

736 classes=preds, ood_scores=ood_score, embeddings=embeddings 

737 ) # TODO return embeddings 

738 

739 self.validate_feature_label_names(X.shape[-1], self.num_outputs) 

740 

741 return eq_out 

742 

743 def save(self, path: str) -> None: 

744 """ 

745 Function to save all model parameters to a file. 

746 

747 Parameters 

748 ---------- 

749 path : str 

750 Filename to write the model. 

751 """ 

752 model_settings = { 

753 "emb_out_dim": self.num_deep_features, 

754 "num_classes": self.num_outputs, 

755 "use_temperature": self.use_temperature, 

756 "init_temperature": self.temperature.item(), 

757 "device": self.device_type, 

758 } 

759 

760 jit_model = torch.jit.script(self.model.feature_extractor) 

761 buffer = io.BytesIO() 

762 torch.jit.save(jit_model, buffer) 

763 buffer.seek(0) 

764 

765 laplace_sd = self.model.state_dict() 

766 keys_to_delete = [] 

767 for key in laplace_sd: 

768 if "feature_extractor" in key: 

769 keys_to_delete.append(key) 

770 for key in keys_to_delete: 

771 del laplace_sd[key] 

772 

773 save_data = { 

774 "embed_jit_save": buffer, 

775 "feature_names": self.feature_names, 

776 "label_names": self.label_names, 

777 "laplace_model_save": laplace_sd, 

778 "num_data": self.model.num_data, 

779 "settings": model_settings, 

780 "support": self.support, 

781 "train_batch_size": self.model.train_batch_size, 

782 "train_summary": self.train_summary, 

783 } 

784 

785 torch.save(save_data, path) # TODO allow model checkpointing 

786 

787 @classmethod 

788 def load(cls, path: str) -> Equine: 

789 """ 

790 Function to load previously saved EquineGP model. 

791 

792 Parameters 

793 ---------- 

794 path : str 

795 Input filename. 

796 

797 Returns 

798 ------- 

799 EquineGP 

800 The reconstituted EquineGP object. 

801 """ 

802 model_save = torch.load(path, weights_only=False) 

803 jit_model = torch.jit.load(model_save.get("embed_jit_save")) 

804 eq_model = cls(jit_model, **model_save.get("settings")) 

805 

806 eq_model.feature_names = model_save.get("feature_names") 

807 eq_model.label_names = model_save.get("label_names") 

808 eq_model.train_summary = model_save.get("train_summary") 

809 

810 eq_model.model.load_state_dict( 

811 model_save.get("laplace_model_save"), strict=False 

812 ) 

813 eq_model.model.seen_data = model_save.get("laplace_model_save").get("seen_data") 

814 

815 eq_model.model.set_training_params( 

816 model_save.get("num_data"), model_save.get("train_batch_size") 

817 ) 

818 eq_model.eval() 

819 

820 support = model_save.get("support") 

821 if len(support) > 0: 

822 eq_model.support = support 

823 eq_model.prototypes = eq_model.compute_prototypes() 

824 

825 return eq_model