Coverage for / opt / hostedtoolcache / Python / 3.10.19 / x64 / lib / python3.10 / site-packages / equine / equine_gp.py: 96%
279 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-18 23:02 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-18 23:02 +0000
1# Copyright 2024, MASSACHUSETTS INSTITUTE OF TECHNOLOGY
2# Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014).
3# SPDX-License-Identifier: MIT
5from typing import Any, Optional, Union
7import icontract
8import io
9import math
10import torch
11from beartype import beartype
12from collections import OrderedDict
13from collections.abc import Callable, Iterable
14from datetime import datetime
15from torch.utils.data import DataLoader, Dataset
16from torchmetrics.metric import Metric
17from tqdm import tqdm
19from .equine import Equine, EquineOutput
20from .utils import generate_support, generate_train_summary
22BatchType = tuple[torch.Tensor, ...]
23# -------------------------------------------------------------------------------
24# Note that the below code for
25# * `_random_ortho`,
26# * `_RandomFourierFeatures``, and
27# * `_Laplace`
28# is copied and modified from https://github.com/y0ast/DUE/blob/main/due/sngp.py
29# under its original MIT license, redisplayed here:
30# -------------------------------------------------------------------------------
31# MIT License
32#
33# Copyright (c) 2021 Joost van Amersfoort
34#
35# Permission is hereby granted, free of charge, to any person obtaining a copy
36# of this software and associated documentation files (the "Software"), to deal
37# in the Software without restriction, including without limitation the rights
38# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
39# copies of the Software, and to permit persons to whom the Software is
40# furnished to do so, subject to the following conditions:
41#
42# The above copyright notice and this permission notice shall be included in all
43# copies or substantial portions of the Software.
44#
45# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
46# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
47# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
48# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
49# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
50# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
51# SOFTWARE.
52# ------------------------------------------------------------------------------
53# Following the recommendation of their README at https://github.com/y0ast/DUE
54# we encourage anyone using this code in their research to cite the following papers:
55#
56# @article{van2021on,
57# title={On Feature Collapse and Deep Kernel Learning for Single Forward Pass Uncertainty},
58# author={van Amersfoort, Joost and Smith, Lewis and Jesson, Andrew and Key, Oscar and Gal, Yarin},
59# journal={arXiv preprint arXiv:2102.11409},
60# year={2021}
61# }
62#
63# @article{liu2020simple,
64# title={Simple and principled uncertainty estimation with deterministic deep learning via distance awareness},
65# author={Liu, Jeremiah and Lin, Zi and Padhy, Shreyas and Tran, Dustin and Bedrax Weiss, Tania and Lakshminarayanan, Balaji},
66# journal={Advances in Neural Information Processing Systems},
67# volume={33},
68# pages={7498--7512},
69# year={2020}
70# }
73@beartype
74def _random_ortho(n: int, m: int) -> torch.Tensor:
75 """
76 Generate a random orthonormal matrix.
78 Parameters
79 ----------
80 n : int
81 The number of rows.
82 m : int
83 The number of columns.
85 Returns
86 -------
87 torch.Tensor
88 The random orthonormal matrix.
89 """
90 q, _ = torch.linalg.qr(torch.randn(n, m))
91 return q
94@beartype
95class _RandomFourierFeatures(torch.nn.Module):
96 """
97 A private class to generate random Fourier features for the embedding model.
98 """
100 def __init__(
101 self, in_dim: int, num_random_features: int, feature_scale: Optional[float]
102 ) -> None:
103 """
104 Initialize the _RandomFourierFeatures module, which generates random Fourier features
105 for the embedding model.
107 Parameters
108 ----------
109 in_dim : int
110 The input dimensionality.
111 num_random_features : int
112 The number of random Fourier features to generate.
113 feature_scale : Optional[float]
114 The scaling factor for the random Fourier features. If None, defaults to sqrt(num_random_features / 2).
115 """
116 super().__init__()
117 if feature_scale is None: 117 ↛ 118line 117 didn't jump to line 118 because the condition on line 117 was never true
118 feature_scale = math.sqrt(num_random_features / 2)
120 self.register_buffer("feature_scale", torch.tensor(feature_scale))
122 if num_random_features <= in_dim: 122 ↛ 123line 122 didn't jump to line 123 because the condition on line 122 was never true
123 W: torch.Tensor = _random_ortho(in_dim, num_random_features)
124 else:
125 # generate blocks of orthonormal rows which are not necessarily orthonormal
126 # to each other.
127 dim_left = num_random_features
128 ws = []
129 while dim_left > in_dim:
130 ws.append(_random_ortho(in_dim, in_dim))
131 dim_left -= in_dim
132 ws.append(_random_ortho(in_dim, dim_left))
133 W: torch.Tensor = torch.cat(ws, 1)
135 feature_norm = torch.randn(W.shape) ** 2
137 W = W * feature_norm.sum(0).sqrt()
138 self.register_buffer("W", W)
140 b: torch.Tensor = torch.empty(num_random_features).uniform_(0, 2 * math.pi)
141 self.register_buffer("b", b)
143 def forward(self, x: torch.Tensor) -> torch.Tensor:
144 """
145 Compute the forward pass of the _RandomFourierFeatures module.
147 Parameters
148 ----------
149 x : torch.Tensor
150 The input tensor of shape (batch_size, in_dim).
152 Returns
153 -------
154 torch.Tensor
155 The output tensor of shape (batch_size, num_random_features).
156 """
157 k = torch.cos(x @ self.W + self.b)
158 k = k / self.feature_scale
160 return k
163class _Laplace(torch.nn.Module):
164 """
165 A private class to compute a Laplace approximation to a Gaussian Process (GP)
166 """
168 def __init__(
169 self,
170 feature_extractor: torch.nn.Module,
171 num_deep_features: int,
172 num_gp_features: int,
173 normalize_gp_features: bool,
174 num_random_features: int,
175 num_outputs: int,
176 feature_scale: Optional[float],
177 mean_field_factor: Optional[float], # required for classification problems
178 ridge_penalty: float = 1.0,
179 ) -> None:
180 """
181 Initialize the _Laplace module.
183 Parameters
184 ----------
185 feature_extractor : torch.nn.Module
186 The feature extractor module.
187 num_deep_features : int
188 The number of features output by the feature extractor.
189 num_gp_features : int
190 The number of features to use in the Gaussian process.
191 normalize_gp_features : bool
192 Whether to normalize the GP features.
193 num_random_features : int
194 The number of random Fourier features to use.
195 num_outputs : int
196 The number of outputs of the model.
197 feature_scale : Optional[float]
198 The scaling factor for the random Fourier features.
199 mean_field_factor : Optional[float]
200 The mean-field factor for the Gaussian-Softmax approximation.
201 Required for classification problems.
202 ridge_penalty : float, optional
203 The ridge penalty for the Laplace approximation.
204 """
205 super().__init__()
206 self.feature_extractor = feature_extractor
207 self.mean_field_factor = mean_field_factor
208 self.ridge_penalty = ridge_penalty
209 self.train_batch_size = 0 # to be set later
211 if num_gp_features > 0: 211 ↛ 221line 211 didn't jump to line 221 because the condition on line 211 was always true
212 self.num_gp_features = num_gp_features
213 random_matrix: torch.Tensor = torch.normal(
214 0, 0.05, (num_gp_features, num_deep_features)
215 )
216 self.register_buffer("random_matrix", random_matrix)
217 self.jl: Callable = lambda x: torch.nn.functional.linear(
218 x, self.random_matrix
219 )
220 else:
221 self.num_gp_features: int = num_deep_features
222 self.jl: Callable = lambda x: x # Identity
224 self.normalize_gp_features = normalize_gp_features
225 if normalize_gp_features: 225 ↛ 228line 225 didn't jump to line 228 because the condition on line 225 was always true
226 self.normalize: torch.nn.LayerNorm = torch.nn.LayerNorm(num_gp_features)
228 self.rff: _RandomFourierFeatures = _RandomFourierFeatures(
229 num_gp_features, num_random_features, feature_scale
230 )
231 self.beta: torch.nn.Linear = torch.nn.Linear(num_random_features, num_outputs)
233 self.num_data = 0 # to be set later
234 self.register_buffer("seen_data", torch.tensor(0))
236 precision = torch.eye(num_random_features) * self.ridge_penalty
237 self.register_buffer("precision", precision)
239 self.recompute_covariance = True
240 self.register_buffer("covariance", torch.eye(num_random_features))
241 self.training_parameters_set = False
243 def reset_precision_matrix(self) -> None:
244 """
245 Reset the precision matrix to the identity matrix times the ridge penalty.
246 """
247 identity = torch.eye(self.precision.shape[0], device=self.precision.device)
248 self.precision: torch.Tensor = identity * self.ridge_penalty
249 self.seen_data: torch.Tensor = torch.tensor(0)
250 self.recompute_covariance = True
252 @icontract.require(lambda num_data: num_data > 0)
253 @icontract.require(
254 lambda num_data, batch_size: (0 < batch_size) & (batch_size <= num_data)
255 )
256 def set_training_params(self, num_data: int, batch_size: int) -> None:
257 """
258 Set the training parameters for the Laplace approximation.
260 Parameters
261 ----------
262 num_data : int
263 The total number of data points.
264 batch_size : int
265 The batch size to use during training.
266 """
267 self.num_data: int = num_data
268 self.train_batch_size: int = batch_size
269 self.training_parameters_set: bool = True
271 @icontract.require(lambda mean_field_factor: mean_field_factor is not None)
272 def mean_field_logits(
273 self, logits: torch.Tensor, pred_cov: torch.Tensor, mean_field_factor: float
274 ) -> torch.Tensor:
275 """
276 Compute the mean-field logits for the Gaussian-Softmax approximation.
278 Parameters
279 ----------
280 logits : torch.Tensor
281 The logits tensor of shape (batch_size, num_outputs).
282 pred_cov : torch.Tensor
283 The predicted covariance matrix of shape (batch_size, batch_size).
284 mean_field_factor : float
285 Diagonal scaling factor
287 Returns
288 -------
289 torch.Tensor
290 The mean-field logits tensor of shape (batch_size, num_outputs).
291 """
292 # Mean-Field approximation as alternative to MC integration of Gaussian-Softmax
293 # Based on: https://arxiv.org/abs/2006.07584
295 logits_scale = torch.sqrt(1.0 + torch.diag(pred_cov) * mean_field_factor)
296 if mean_field_factor > 0: 296 ↛ 299line 296 didn't jump to line 299 because the condition on line 296 was always true
297 logits = logits / logits_scale.unsqueeze(-1)
299 return logits
301 @icontract.require(lambda self: self.training_parameters_set)
302 def forward(
303 self, x: torch.Tensor
304 ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
305 """
306 Compute the forward pass of the Laplace approximation to the Gaussian Process.
308 Parameters
309 ----------
310 x : torch.Tensor
311 The input tensor of shape (batch_size, num_features).
313 Returns
314 -------
315 Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]
316 If the model is in training mode, returns the predicted mean of shape (batch_size, 1).
317 If the model is in evaluation mode, returns a tuple containing the predicted mean of shape (batch_size, 1)
318 and the predicted covariance matrix of shape (batch_size, batch_size).
319 """
320 f = self.feature_extractor(x)
321 f_reduc = self.jl(f)
322 if self.normalize_gp_features: 322 ↛ 325line 322 didn't jump to line 325 because the condition on line 322 was always true
323 f_reduc = self.normalize(f_reduc)
325 k = self.rff(f_reduc)
327 pred = self.beta(k)
329 if self.training:
330 precision_minibatch = k.t() @ k
331 self.precision += precision_minibatch
332 self.seen_data += x.shape[0]
334 assert (
335 self.seen_data <= self.num_data
336 ), "Did not reset precision matrix at start of epoch"
337 else:
338 assert self.seen_data > (
339 self.num_data - self.train_batch_size
340 ), "Not seen sufficient data for precision matrix"
342 if self.recompute_covariance:
343 with torch.no_grad():
344 eps = 1e-7
345 jitter = eps * torch.eye(
346 self.precision.shape[1],
347 device=self.precision.device,
348 )
349 u, info = torch.linalg.cholesky_ex(self.precision + jitter)
350 assert (info == 0).all(), "Precision matrix inversion failed!"
351 torch.cholesky_inverse(u, out=self.covariance)
353 self.recompute_covariance: bool = False
355 with torch.no_grad():
356 pred_cov = k @ ((self.covariance @ k.t()) * self.ridge_penalty)
358 if self.mean_field_factor is None: 358 ↛ 359line 358 didn't jump to line 359 because the condition on line 358 was never true
359 return pred, pred_cov
360 else:
361 pred = self.mean_field_logits(pred, pred_cov, self.mean_field_factor)
363 return pred
366# -------------------------------------------------------------------------------
367# EquineGP, below, demonstrates how to adapt that approach in EQUINE
368@beartype
369class EquineGP(Equine):
370 """
371 An example of an EQUINE model that builds upon the approach in "Spectral Norm
372 Gaussian Processes" (SNGP). This wraps any pytorch embedding neural network and provides
373 the `forward`, `predict`, `save`, and `load` methods required by Equine.
375 Notes
376 -----
377 Although this model build upon the approach in SNGP, it does not enforce the spectral normalization
378 and ResNet architecture required for SNGP. Instead, it is a simple wrapper around
379 any pytorch embedding neural network. Your mileage may vary.
380 """
382 def __init__(
383 self,
384 embedding_model: torch.nn.Module,
385 emb_out_dim: int,
386 num_classes: int,
387 num_random_features: int = 1024,
388 init_temperature: float = 1.0,
389 device: str = "cpu",
390 feature_names: Optional[list[str]] = None,
391 label_names: Optional[list[str]] = None,
392 ) -> None:
393 """
394 Initialize the EquineGP model.
396 Parameters
397 ----------
398 embedding_model : torch.nn.Module
399 Neural Network feature embedding.
400 emb_out_dim : int
401 The number of deep features from the feature embedding.
402 num_classes : int
403 The number of output classes this model predicts.
404 num_random_features : int
405 The dimension of the output of the RandomFourierFeatures operation
406 init_temperature : float, optional
407 What to use as the initial temperature (1.0 has no effect).
408 device : str, optional
409 Either 'cuda' or 'cpu'.
410 feature_names : list[str], optional
411 List of strings of the names of the tabular features (ex ["duration", "fiat_mean", ...])
412 label_names : list[str], optional
413 List of strings of the names of the labels (ex ["streaming", "voip", ...])
414 """
415 super().__init__(
416 embedding_model, feature_names=feature_names, label_names=label_names
417 )
418 self.num_deep_features = emb_out_dim
419 self.num_gp_features = emb_out_dim
420 self.normalize_gp_features = True
421 self.num_random_features = num_random_features
422 self.num_outputs = num_classes
423 self.mean_field_factor = 25
424 self.ridge_penalty = 1
425 self.feature_scale: float = 2.0
426 self.init_temperature = init_temperature
427 self.register_buffer(
428 "temperature", torch.Tensor(self.init_temperature * torch.ones(1))
429 )
430 self.model: _Laplace = _Laplace(
431 self.embedding_model,
432 self.num_deep_features,
433 self.num_gp_features,
434 self.normalize_gp_features,
435 self.num_random_features,
436 self.num_outputs,
437 self.feature_scale,
438 self.mean_field_factor,
439 self.ridge_penalty,
440 )
441 self.device_type = device
442 self.device: torch.device = torch.device(self.device_type)
443 self.model.to(self.device)
445 def train_model(
446 self,
447 dataset: Dataset,
448 loss_fn: Callable,
449 opt: torch.optim.Optimizer,
450 num_epochs: int,
451 scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None,
452 batch_size: int = 64,
453 validation_dataset: Optional[Dataset] = None,
454 val_metrics: Optional[Iterable[Metric]] = None,
455 vis_support: bool = False,
456 support_size: int = 25,
457 ) -> dict[str, Any]:
458 """
459 Train or fine-tune an EquineGP model.
461 Parameters
462 ----------
463 dataset : TensorDataset
464 An iterable, pytorch TensorDataset.
465 loss_fn : Callable
466 A pytorch loss function, e.g., torch.nn.CrossEntropyLoss().
467 opt : torch.optim.Optimizer
468 A pytorch optimizer, e.g., torch.optim.Adam().
469 num_epochs : int
470 The desired number of epochs to use for training.
471 scheduler : torch.optim.LRScheduler
472 A pytorch scheduler, if one is desired
473 validation_dataset: Dataset
474 If provided, will compute validation metrics on this dataset after each epoch of training
475 batch_size : int, optional
476 The number of samples to use per batch.
478 Returns
479 -------
480 dict[str, Any]
481 A dict containing a dict of summary stats and a dataloader for the calibration data.
483 """
485 self.validate_feature_label_names(dataset[0][0].shape[-1], self.num_outputs)
487 train_loader = DataLoader(
488 dataset, batch_size=batch_size, shuffle=True, drop_last=False
489 )
491 val_loader: Optional[DataLoader] = None
492 if validation_dataset is not None:
493 val_loader = DataLoader(
494 validation_dataset,
495 batch_size=batch_size,
496 shuffle=False,
497 drop_last=False,
498 )
500 self.model.set_training_params(len(dataset), batch_size)
501 val_metrics_outputs: Optional[list[list[float]]] = None
503 if validation_dataset is not None and val_metrics is not None:
504 val_metrics_outputs = [[] for i in range(len(list(val_metrics)))]
506 for _ in tqdm(range(num_epochs)):
507 self.model.train()
508 self.model.reset_precision_matrix()
509 epoch_loss = 0.0
510 for i, (xs, labels) in enumerate(train_loader):
511 opt.zero_grad()
512 xs = xs.to(self.device)
513 labels = labels.to(self.device)
514 yhats = self.model(xs)
515 loss = loss_fn(yhats, labels.to(torch.long))
516 loss.backward()
517 opt.step()
518 epoch_loss += loss.item()
519 if scheduler is not None:
520 scheduler.step()
521 self.model.eval()
522 # compute the validation metrics
523 if (
524 validation_dataset is not None
525 and val_loader is not None
526 and val_metrics is not None
527 and val_metrics_outputs is not None
528 ):
529 for _, (xs_val, labels_val) in enumerate(val_loader):
530 xs_val = xs_val.to(self.device)
531 labels_val = labels_val.to(self.device)
532 yhats_val = self.model(xs_val)
533 for metric in val_metrics:
534 metric.update(yhats_val, labels_val)
535 for i, metric in enumerate(val_metrics):
536 val_metrics_outputs[i].append(metric.compute())
537 if vis_support:
538 self.update_support(dataset.tensors[0], dataset.tensors[1], support_size)
540 _, train_y = dataset[:]
541 date_trained = datetime.now().strftime("%m/%d/%Y, %H:%M:%S")
542 self.train_summary: dict[str, Any] = generate_train_summary(
543 self, train_y, date_trained
544 )
546 return_dict: dict[str, Any] = dict()
547 return_dict["train_summary"] = self.train_summary
548 if validation_dataset is not None:
549 return_dict["val_metrics"] = val_metrics_outputs
551 return return_dict
553 def update_support(
554 self, support_x: torch.Tensor, support_y: torch.Tensor, support_size: int
555 ) -> None:
556 """Function to update protonet support examples with given examples.
558 Parameters
559 ----------
560 support_x : torch.Tensor
561 Tensor containing support examples for protonet.
562 support_y : torch.Tensor
563 Tensor containing labels for given support examples.
565 Returns
566 -------
567 None
568 """
570 labels, counts = torch.unique(support_y, return_counts=True)
571 support = OrderedDict()
572 for label, count in list(zip(labels.tolist(), counts.tolist())):
573 class_support = generate_support(
574 support_x,
575 support_y,
576 support_size=min(count, support_size),
577 selected_labels=[label],
578 )
579 support.update(class_support)
581 self.support = support
583 support_embeddings = OrderedDict().fromkeys(self.support.keys(), torch.Tensor())
584 for label in support:
585 support_embeddings[label] = self.compute_embeddings(support[label])
587 self.support_embeddings = support_embeddings
588 self.prototypes: torch.Tensor = self.compute_prototypes()
590 def compute_embeddings(self, x: torch.Tensor) -> torch.Tensor:
591 """
592 Method for computing deep embeddings for given input tensor.
594 Parameters
595 ----------
596 x : torch.Tensor
597 Input tensor for generating embeddings.
599 Returns
600 -------
601 torch.Tensor
602 Output embeddings .
603 """
604 f = self.model.feature_extractor(x)
605 f_reduc = self.model.jl(f)
606 if self.model.normalize_gp_features: 606 ↛ 609line 606 didn't jump to line 609 because the condition on line 606 was always true
607 f_reduc = self.model.normalize(f_reduc)
609 return self.model.rff(f_reduc)
611 @icontract.require(lambda self: len(self.support) > 0)
612 def compute_prototypes(self) -> torch.Tensor:
613 """
614 Method for computing class prototypes based on given support examples.
615 ``Prototypes'' in this context are the means of the support embeddings for each class.
617 Returns
618 -------
619 torch.Tensor
620 Tensors of prototypes for each of the given classes in the support.
621 """
622 # Compute support embeddings
623 support_embeddings = OrderedDict().fromkeys(self.support.keys())
624 for label in self.support:
625 support_embeddings[label] = self.compute_embeddings(self.support[label])
627 # Compute prototype for each class
628 proto_list = []
629 for label in self.support: # look at doing functorch
630 class_prototype = torch.mean(support_embeddings[label], dim=0) # type: ignore
631 proto_list.append(class_prototype)
633 prototypes = torch.stack(proto_list)
635 return prototypes
637 @icontract.require(lambda self: len(self.support) > 0)
638 def get_support(self) -> OrderedDict[int, torch.Tensor]:
639 """
640 Method for returning support examples used in training.
642 Returns
643 -------
644 OrderedDict[int, torch.Tensor]
645 Dictionary containing support examples for each class.
646 """
647 return self.support
649 @icontract.require(lambda self: len(self.prototypes) > 0)
650 def get_prototypes(self) -> torch.Tensor:
651 """
652 Method for returning class prototypes.
654 Returns
655 -------
656 torch.Tensor
657 Tensors of prototypes for each of the given classes in the support.
658 """
659 return self.prototypes
661 @icontract.require(lambda num_calibration_epochs: 0 < num_calibration_epochs)
662 @icontract.require(lambda calibration_lr: calibration_lr > 0.0)
663 def calibrate_model(
664 self,
665 dataset: torch.utils.data.Dataset,
666 num_calibration_epochs: int = 1,
667 calibration_lr: float = 0.01,
668 calibration_batch_size: int = 256,
669 ) -> None:
670 """
671 Fine-tune the temperature after training. Note this function is also run at the conclusion of train_model.
673 Parameters
674 ----------
675 dataset : TensorDataset
676 An iterable, pytorch TensorDataset.
677 num_calibration_epochs : int, optional
678 Number of epochs to tune temperature.
679 calibration_lr : float, optional
680 Learning rate for temperature optimization.
681 """
683 calibration_loader = DataLoader(
684 dataset,
685 batch_size=calibration_batch_size,
686 shuffle=True,
687 drop_last=False,
688 )
690 self.temperature.requires_grad = True
691 loss_fn = torch.nn.functional.cross_entropy
692 optimizer = torch.optim.Adam([self.temperature], lr=calibration_lr)
693 for _ in range(num_calibration_epochs):
694 for xs, labels in calibration_loader:
695 optimizer.zero_grad()
696 xs = xs.to(self.device)
697 labels = labels.to(self.device)
698 with torch.no_grad():
699 logits = self.model(xs)
700 logits = logits / self.temperature
701 loss = loss_fn(logits, labels.to(torch.long))
702 loss.backward()
703 optimizer.step()
704 self.temperature.requires_grad = False
706 def forward(self, X: torch.Tensor) -> torch.Tensor:
707 """
708 EquineGP forward function, generates logits for classification.
710 Parameters
711 ----------
712 X : torch.Tensor
713 Input tensor for generating predictions.
715 Returns
716 -------
717 torch.Tensor
718 Output probabilities computed.
719 """
720 X = X.to(self.device)
721 preds = self.model(X)
722 return preds / self.temperature.to(self.device)
724 @icontract.ensure(
725 lambda result: all((0 <= result.ood_scores) & (result.ood_scores <= 1.0))
726 )
727 def predict(self, X: torch.Tensor) -> EquineOutput:
728 """
729 Predict function for EquineGP, inherited and implemented from Equine.
731 Parameters
732 ----------
733 X : torch.Tensor
734 Input tensor.
736 Returns
737 -------
738 EquineOutput
739 Output object containing prediction probabilities and OOD scores.
740 """
741 logits = self(X)
742 preds = torch.softmax(logits, dim=1)
743 equiprobable = torch.ones(self.num_outputs) / self.num_outputs
744 max_entropy = torch.sum(torch.special.entr(equiprobable))
745 ood_score = torch.sum(torch.special.entr(preds), dim=1) / max_entropy
746 embeddings = self.compute_embeddings(X)
747 eq_out = EquineOutput(
748 classes=preds, ood_scores=ood_score, embeddings=embeddings
749 ) # TODO return embeddings
751 self.validate_feature_label_names(X.shape[-1], self.num_outputs)
753 return eq_out
755 def save(self, path: str) -> None:
756 """
757 Function to save all model parameters to a file.
759 Parameters
760 ----------
761 path : str
762 Filename to write the model.
763 """
764 model_settings = {
765 "emb_out_dim": self.num_deep_features,
766 "num_classes": self.num_outputs,
767 "init_temperature": self.temperature.item(),
768 "device": self.device_type,
769 }
771 jit_model = torch.jit.script(self.model.feature_extractor)
772 buffer = io.BytesIO()
773 torch.jit.save(jit_model, buffer)
774 buffer.seek(0)
776 laplace_sd = self.model.state_dict()
777 keys_to_delete = []
778 for key in laplace_sd:
779 if "feature_extractor" in key:
780 keys_to_delete.append(key)
781 for key in keys_to_delete:
782 del laplace_sd[key]
784 save_data = {
785 "embed_jit_save": buffer,
786 "feature_names": self.feature_names,
787 "label_names": self.label_names,
788 "laplace_model_save": laplace_sd,
789 "num_data": self.model.num_data,
790 "settings": model_settings,
791 "support": self.support,
792 "train_batch_size": self.model.train_batch_size,
793 "train_summary": self.train_summary,
794 }
796 torch.save(save_data, path) # TODO allow model checkpointing
798 @classmethod
799 def load(cls, path: str) -> Equine:
800 """
801 Function to load previously saved EquineGP model.
803 Parameters
804 ----------
805 path : str
806 Input filename.
808 Returns
809 -------
810 EquineGP
811 The reconstituted EquineGP object.
812 """
813 model_save = torch.load(path, weights_only=False)
814 jit_model = torch.jit.load(model_save.get("embed_jit_save"))
815 eq_model = cls(jit_model, **model_save.get("settings"))
817 eq_model.feature_names = model_save.get("feature_names")
818 eq_model.label_names = model_save.get("label_names")
819 eq_model.train_summary = model_save.get("train_summary")
821 eq_model.model.load_state_dict(
822 model_save.get("laplace_model_save"), strict=False
823 )
824 eq_model.model.seen_data = model_save.get("laplace_model_save").get("seen_data")
826 eq_model.model.set_training_params(
827 model_save.get("num_data"), model_save.get("train_batch_size")
828 )
829 eq_model.eval()
831 support = model_save.get("support")
832 if len(support) > 0:
833 eq_model.support = support
834 eq_model.prototypes = eq_model.compute_prototypes()
836 return eq_model