Coverage for /opt/hostedtoolcache/Python/3.10.18/x64/lib/python3.10/site-packages/equine/equine_gp.py: 96%
271 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-29 04:12 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-29 04:12 +0000
1# Copyright 2024, MASSACHUSETTS INSTITUTE OF TECHNOLOGY
2# Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014).
3# SPDX-License-Identifier: MIT
5from typing import Any, Optional, Union
7import icontract
8import io
9import math
10import torch
11from beartype import beartype
12from collections import OrderedDict
13from collections.abc import Callable
14from datetime import datetime
15from torch.utils.data import DataLoader, TensorDataset
16from tqdm import tqdm
18from .equine import Equine, EquineOutput
19from .utils import generate_support, generate_train_summary, stratified_train_test_split
21BatchType = tuple[torch.Tensor, ...]
22# -------------------------------------------------------------------------------
23# Note that the below code for
24# * `_random_ortho`,
25# * `_RandomFourierFeatures``, and
26# * `_Laplace`
27# is copied and modified from https://github.com/y0ast/DUE/blob/main/due/sngp.py
28# under its original MIT license, redisplayed here:
29# -------------------------------------------------------------------------------
30# MIT License
31#
32# Copyright (c) 2021 Joost van Amersfoort
33#
34# Permission is hereby granted, free of charge, to any person obtaining a copy
35# of this software and associated documentation files (the "Software"), to deal
36# in the Software without restriction, including without limitation the rights
37# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
38# copies of the Software, and to permit persons to whom the Software is
39# furnished to do so, subject to the following conditions:
40#
41# The above copyright notice and this permission notice shall be included in all
42# copies or substantial portions of the Software.
43#
44# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
45# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
46# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
47# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
48# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
49# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
50# SOFTWARE.
51# ------------------------------------------------------------------------------
52# Following the recommendation of their README at https://github.com/y0ast/DUE
53# we encourage anyone using this code in their research to cite the following papers:
54#
55# @article{van2021on,
56# title={On Feature Collapse and Deep Kernel Learning for Single Forward Pass Uncertainty},
57# author={van Amersfoort, Joost and Smith, Lewis and Jesson, Andrew and Key, Oscar and Gal, Yarin},
58# journal={arXiv preprint arXiv:2102.11409},
59# year={2021}
60# }
61#
62# @article{liu2020simple,
63# title={Simple and principled uncertainty estimation with deterministic deep learning via distance awareness},
64# author={Liu, Jeremiah and Lin, Zi and Padhy, Shreyas and Tran, Dustin and Bedrax Weiss, Tania and Lakshminarayanan, Balaji},
65# journal={Advances in Neural Information Processing Systems},
66# volume={33},
67# pages={7498--7512},
68# year={2020}
69# }
72@beartype
73def _random_ortho(n: int, m: int) -> torch.Tensor:
74 """
75 Generate a random orthonormal matrix.
77 Parameters
78 ----------
79 n : int
80 The number of rows.
81 m : int
82 The number of columns.
84 Returns
85 -------
86 torch.Tensor
87 The random orthonormal matrix.
88 """
89 q, _ = torch.linalg.qr(torch.randn(n, m))
90 return q
93@beartype
94class _RandomFourierFeatures(torch.nn.Module):
95 """
96 A private class to generate random Fourier features for the embedding model.
97 """
99 def __init__(
100 self, in_dim: int, num_random_features: int, feature_scale: Optional[float]
101 ) -> None:
102 """
103 Initialize the _RandomFourierFeatures module, which generates random Fourier features
104 for the embedding model.
106 Parameters
107 ----------
108 in_dim : int
109 The input dimensionality.
110 num_random_features : int
111 The number of random Fourier features to generate.
112 feature_scale : Optional[float]
113 The scaling factor for the random Fourier features. If None, defaults to sqrt(num_random_features / 2).
114 """
115 super().__init__()
116 if feature_scale is None: 116 ↛ 117line 116 didn't jump to line 117 because the condition on line 116 was never true
117 feature_scale = math.sqrt(num_random_features / 2)
119 self.register_buffer("feature_scale", torch.tensor(feature_scale))
121 if num_random_features <= in_dim: 121 ↛ 122line 121 didn't jump to line 122 because the condition on line 121 was never true
122 W = _random_ortho(in_dim, num_random_features)
123 else:
124 # generate blocks of orthonormal rows which are not necessarily orthonormal
125 # to each other.
126 dim_left = num_random_features
127 ws = []
128 while dim_left > in_dim:
129 ws.append(_random_ortho(in_dim, in_dim))
130 dim_left -= in_dim
131 ws.append(_random_ortho(in_dim, dim_left))
132 W = torch.cat(ws, 1)
134 feature_norm = torch.randn(W.shape) ** 2
135 W = W * feature_norm.sum(0).sqrt()
136 self.register_buffer("W", W)
138 b = torch.empty(num_random_features).uniform_(0, 2 * math.pi)
139 self.register_buffer("b", b)
141 def forward(self, x: torch.Tensor) -> torch.Tensor:
142 """
143 Compute the forward pass of the _RandomFourierFeatures module.
145 Parameters
146 ----------
147 x : torch.Tensor
148 The input tensor of shape (batch_size, in_dim).
150 Returns
151 -------
152 torch.Tensor
153 The output tensor of shape (batch_size, num_random_features).
154 """
155 k = torch.cos(x @ self.W + self.b)
156 k = k / self.feature_scale
158 return k
161class _Laplace(torch.nn.Module):
162 """
163 A private class to compute a Laplace approximation to a Gaussian Process (GP)
164 """
166 def __init__(
167 self,
168 feature_extractor: torch.nn.Module,
169 num_deep_features: int,
170 num_gp_features: int,
171 normalize_gp_features: bool,
172 num_random_features: int,
173 num_outputs: int,
174 feature_scale: Optional[float],
175 mean_field_factor: Optional[float], # required for classification problems
176 ridge_penalty: float = 1.0,
177 ) -> None:
178 """
179 Initialize the _Laplace module.
181 Parameters
182 ----------
183 feature_extractor : torch.nn.Module
184 The feature extractor module.
185 num_deep_features : int
186 The number of features output by the feature extractor.
187 num_gp_features : int
188 The number of features to use in the Gaussian process.
189 normalize_gp_features : bool
190 Whether to normalize the GP features.
191 num_random_features : int
192 The number of random Fourier features to use.
193 num_outputs : int
194 The number of outputs of the model.
195 feature_scale : Optional[float]
196 The scaling factor for the random Fourier features.
197 mean_field_factor : Optional[float]
198 The mean-field factor for the Gaussian-Softmax approximation.
199 Required for classification problems.
200 ridge_penalty : float, optional
201 The ridge penalty for the Laplace approximation.
202 """
203 super().__init__()
204 self.feature_extractor = feature_extractor
205 self.mean_field_factor = mean_field_factor
206 self.ridge_penalty = ridge_penalty
207 self.train_batch_size = 0 # to be set later
209 if num_gp_features > 0: 209 ↛ 219line 209 didn't jump to line 219 because the condition on line 209 was always true
210 self.num_gp_features = num_gp_features
211 self.register_buffer(
212 "random_matrix",
213 torch.normal(0, 0.05, (num_gp_features, num_deep_features)),
214 )
215 self.jl: Callable = lambda x: torch.nn.functional.linear(
216 x, self.random_matrix
217 )
218 else:
219 self.num_gp_features: int = num_deep_features
220 self.jl: Callable = lambda x: x # Identity
222 self.normalize_gp_features = normalize_gp_features
223 if normalize_gp_features: 223 ↛ 226line 223 didn't jump to line 226 because the condition on line 223 was always true
224 self.normalize: torch.nn.LayerNorm = torch.nn.LayerNorm(num_gp_features)
226 self.rff: _RandomFourierFeatures = _RandomFourierFeatures(
227 num_gp_features, num_random_features, feature_scale
228 )
229 self.beta: torch.nn.Linear = torch.nn.Linear(num_random_features, num_outputs)
231 self.num_data = 0 # to be set later
232 self.register_buffer("seen_data", torch.tensor(0))
234 precision = torch.eye(num_random_features) * self.ridge_penalty
235 self.register_buffer("precision", precision)
237 self.recompute_covariance = True
238 self.register_buffer("covariance", torch.eye(num_random_features))
239 self.training_parameters_set = False
241 def reset_precision_matrix(self) -> None:
242 """
243 Reset the precision matrix to the identity matrix times the ridge penalty.
244 """
245 identity = torch.eye(self.precision.shape[0], device=self.precision.device)
246 self.precision: torch.Tensor = identity * self.ridge_penalty
247 self.seen_data: torch.Tensor = torch.tensor(0)
248 self.recompute_covariance = True
250 @icontract.require(lambda num_data: num_data > 0)
251 @icontract.require(
252 lambda num_data, batch_size: (0 < batch_size) & (batch_size <= num_data)
253 )
254 def set_training_params(self, num_data: int, batch_size: int) -> None:
255 """
256 Set the training parameters for the Laplace approximation.
258 Parameters
259 ----------
260 num_data : int
261 The total number of data points.
262 batch_size : int
263 The batch size to use during training.
264 """
265 self.num_data: int = num_data
266 self.train_batch_size: int = batch_size
267 self.training_parameters_set: bool = True
269 @icontract.require(lambda self: self.mean_field_factor is not None)
270 def mean_field_logits(
271 self, logits: torch.Tensor, pred_cov: torch.Tensor
272 ) -> torch.Tensor:
273 """
274 Compute the mean-field logits for the Gaussian-Softmax approximation.
276 Parameters
277 ----------
278 logits : torch.Tensor
279 The logits tensor of shape (batch_size, num_outputs).
280 pred_cov : torch.Tensor
281 The predicted covariance matrix of shape (batch_size, batch_size).
283 Returns
284 -------
285 torch.Tensor
286 The mean-field logits tensor of shape (batch_size, num_outputs).
287 """
288 # Mean-Field approximation as alternative to MC integration of Gaussian-Softmax
289 # Based on: https://arxiv.org/abs/2006.07584
291 logits_scale = torch.sqrt(1.0 + torch.diag(pred_cov) * self.mean_field_factor)
292 if self.mean_field_factor > 0: # type: ignore 292 ↛ 295line 292 didn't jump to line 295 because the condition on line 292 was always true
293 logits = logits / logits_scale.unsqueeze(-1)
295 return logits
297 @icontract.require(lambda self: self.training_parameters_set)
298 def forward(
299 self, x: torch.Tensor
300 ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
301 """
302 Compute the forward pass of the Laplace approximation to the Gaussian Process.
304 Parameters
305 ----------
306 x : torch.Tensor
307 The input tensor of shape (batch_size, num_features).
309 Returns
310 -------
311 Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]
312 If the model is in training mode, returns the predicted mean of shape (batch_size, 1).
313 If the model is in evaluation mode, returns a tuple containing the predicted mean of shape (batch_size, 1)
314 and the predicted covariance matrix of shape (batch_size, batch_size).
315 """
316 f = self.feature_extractor(x)
317 f_reduc = self.jl(f)
318 if self.normalize_gp_features: 318 ↛ 321line 318 didn't jump to line 321 because the condition on line 318 was always true
319 f_reduc = self.normalize(f_reduc)
321 k = self.rff(f_reduc)
323 pred = self.beta(k)
325 if self.training:
326 precision_minibatch = k.t() @ k
327 self.precision += precision_minibatch
328 self.seen_data += x.shape[0]
330 assert (
331 self.seen_data <= self.num_data
332 ), "Did not reset precision matrix at start of epoch"
333 else:
334 assert self.seen_data > (
335 self.num_data - self.train_batch_size
336 ), "Not seen sufficient data for precision matrix"
338 if self.recompute_covariance:
339 with torch.no_grad():
340 eps = 1e-7
341 jitter = eps * torch.eye(
342 self.precision.shape[1],
343 device=self.precision.device,
344 )
345 u, info = torch.linalg.cholesky_ex(self.precision + jitter)
346 assert (info == 0).all(), "Precision matrix inversion failed!"
347 torch.cholesky_inverse(u, out=self.covariance)
349 self.recompute_covariance: bool = False
351 with torch.no_grad():
352 pred_cov = k @ ((self.covariance @ k.t()) * self.ridge_penalty)
354 if self.mean_field_factor is None: 354 ↛ 355line 354 didn't jump to line 355 because the condition on line 354 was never true
355 return pred, pred_cov
356 else:
357 pred = self.mean_field_logits(pred, pred_cov)
359 return pred
362# -------------------------------------------------------------------------------
363# EquineGP, below, demonstrates how to adapt that approach in EQUINE
364@beartype
365class EquineGP(Equine):
366 """
367 An example of an EQUINE model that builds upon the approach in "Spectral Norm
368 Gaussian Processes" (SNGP). This wraps any pytorch embedding neural network and provides
369 the `forward`, `predict`, `save`, and `load` methods required by Equine.
371 Notes
372 -----
373 Although this model build upon the approach in SNGP, it does not enforce the spectral normalization
374 and ResNet architecture required for SNGP. Instead, it is a simple wrapper around
375 any pytorch embedding neural network. Your mileage may vary.
376 """
378 def __init__(
379 self,
380 embedding_model: torch.nn.Module,
381 emb_out_dim: int,
382 num_classes: int,
383 use_temperature: bool = False,
384 init_temperature: float = 1.0,
385 device: str = "cpu",
386 feature_names: Optional[list[str]] = None,
387 label_names: Optional[list[str]] = None,
388 ) -> None:
389 """
390 Initialize the EquineGP model.
392 Parameters
393 ----------
394 embedding_model : torch.nn.Module
395 Neural Network feature embedding.
396 emb_out_dim : int
397 The number of deep features from the feature embedding.
398 num_classes : int
399 The number of output classes this model predicts.
400 use_temperature : bool, optional
401 Whether to use temperature scaling after training.
402 init_temperature : float, optional
403 What to use as the initial temperature (1.0 has no effect).
404 device : str, optional
405 Either 'cuda' or 'cpu'.
406 feature_names : list[str], optional
407 List of strings of the names of the tabular features (ex ["duration", "fiat_mean", ...])
408 label_names : list[str], optional
409 List of strings of the names of the labels (ex ["streaming", "voip", ...])
410 """
411 super().__init__(
412 embedding_model, feature_names=feature_names, label_names=label_names
413 )
414 self.num_deep_features = emb_out_dim
415 self.num_gp_features = emb_out_dim
416 self.normalize_gp_features = True
417 self.num_random_features = 1024
418 self.num_outputs = num_classes
419 self.mean_field_factor = 25
420 self.ridge_penalty = 1
421 self.feature_scale: float = 2.0
422 self.use_temperature = use_temperature
423 self.init_temperature = init_temperature
424 self.register_buffer(
425 "temperature", torch.Tensor(self.init_temperature * torch.ones(1))
426 )
427 self.model: _Laplace = _Laplace(
428 self.embedding_model,
429 self.num_deep_features,
430 self.num_gp_features,
431 self.normalize_gp_features,
432 self.num_random_features,
433 self.num_outputs,
434 self.feature_scale,
435 self.mean_field_factor,
436 self.ridge_penalty,
437 )
438 self.device_type = device
439 self.device: torch.device = torch.device(self.device_type)
440 self.model.to(self.device)
442 def train_model(
443 self,
444 dataset: TensorDataset,
445 loss_fn: Callable,
446 opt: torch.optim.Optimizer,
447 num_epochs: int,
448 batch_size: int = 64,
449 calib_frac: float = 0.1,
450 num_calibration_epochs: int = 2,
451 calibration_lr: float = 0.01,
452 vis_support: bool = False,
453 support_size: int = 25,
454 ) -> dict[str, Any]:
455 """
456 Train or fine-tune an EquineGP model.
458 Parameters
459 ----------
460 dataset : TensorDataset
461 An iterable, pytorch TensorDataset.
462 loss_fn : Callable
463 A pytorch loss function, e.g., torch.nn.CrossEntropyLoss().
464 opt : torch.optim.Optimizer
465 A pytorch optimizer, e.g., torch.optim.Adam().
466 num_epochs : int
467 The desired number of epochs to use for training.
468 batch_size : int, optional
469 The number of samples to use per batch.
470 calib_frac : float, optional
471 Fraction of training data to use in temperature scaling.
472 num_calibration_epochs : int, optional
473 The desired number of epochs to use for temperature scaling.
474 calibration_lr : float, optional
475 Learning rate for temperature scaling.
477 Returns
478 -------
479 dict[str, Any]
480 A dict containing a dict of summary stats and a dataloader for the calibration data.
482 Notes
483 -------
484 - If `use_temperature` is True, temperature scaling will be used after training.
485 - The calibration data is used to calibrate the temperature scaling.
486 """
487 X, Y = dataset[:]
488 calib_x = torch.Tensor()
489 calib_y = torch.Tensor()
490 if self.use_temperature:
491 train_x, calib_x, train_y, calib_y = stratified_train_test_split(
492 X, Y, test_size=calib_frac
493 )
494 dataset = TensorDataset(torch.Tensor(train_x), torch.Tensor(train_y))
495 self.temperature: torch.Tensor = torch.Tensor(
496 self.init_temperature * torch.ones(1)
497 ).type_as(self.temperature)
499 self.validate_feature_label_names(X.shape[-1], self.num_outputs)
501 train_loader = DataLoader(
502 dataset, batch_size=batch_size, shuffle=True, drop_last=True
503 )
504 self.model.set_training_params(len(dataset), batch_size)
505 self.model.train()
506 for _ in tqdm(range(num_epochs)):
507 self.model.reset_precision_matrix()
508 epoch_loss = 0.0
509 for i, (xs, labels) in enumerate(train_loader):
510 opt.zero_grad()
511 xs = xs.to(self.device)
512 labels = labels.to(self.device)
513 yhats = self.model(xs)
514 loss = loss_fn(yhats, labels.to(torch.long))
515 loss.backward()
516 opt.step()
517 epoch_loss += loss.item()
518 self.model.eval()
520 if vis_support:
521 self.update_support(dataset.tensors[0], dataset.tensors[1], support_size)
523 calibration_loader = None
524 if self.use_temperature:
525 dataset_calibration = TensorDataset(
526 torch.Tensor(calib_x), torch.Tensor(calib_y)
527 )
528 calibration_loader = DataLoader(
529 dataset_calibration,
530 batch_size=batch_size,
531 shuffle=True,
532 drop_last=False,
533 )
534 self.calibrate_temperature(
535 calibration_loader, num_calibration_epochs, calibration_lr
536 )
538 _, train_y = dataset[:]
539 date_trained = datetime.now().strftime("%m/%d/%Y, %H:%M:%S")
540 self.train_summary: dict[str, Any] = generate_train_summary(
541 self, train_y, date_trained
542 )
544 return_dict: dict[str, Any] = dict()
545 return_dict["train_summary"] = self.train_summary
546 return_dict["calibration_loader"] = calibration_loader
548 return return_dict
550 def update_support(
551 self, support_x: torch.Tensor, support_y: torch.Tensor, support_size: int
552 ) -> None:
553 """Function to update protonet support examples with given examples.
555 Parameters
556 ----------
557 support_x : torch.Tensor
558 Tensor containing support examples for protonet.
559 support_y : torch.Tensor
560 Tensor containing labels for given support examples.
562 Returns
563 -------
564 None
565 """
567 labels, counts = torch.unique(support_y, return_counts=True)
568 support = OrderedDict()
569 for label, count in list(zip(labels.tolist(), counts.tolist())):
570 class_support = generate_support(
571 support_x,
572 support_y,
573 support_size=min(count, support_size),
574 selected_labels=[label],
575 )
576 support.update(class_support)
578 self.support = support
580 support_embeddings = OrderedDict().fromkeys(self.support.keys(), torch.Tensor())
581 for label in support:
582 support_embeddings[label] = self.compute_embeddings(support[label])
584 self.support_embeddings = support_embeddings
585 self.prototypes: torch.Tensor = self.compute_prototypes()
587 def compute_embeddings(self, x: torch.Tensor) -> torch.Tensor:
588 """
589 Method for computing deep embeddings for given input tensor.
591 Parameters
592 ----------
593 x : torch.Tensor
594 Input tensor for generating embeddings.
596 Returns
597 -------
598 torch.Tensor
599 Output embeddings .
600 """
601 f = self.model.feature_extractor(x)
602 f_reduc = self.model.jl(f)
603 if self.model.normalize_gp_features: 603 ↛ 606line 603 didn't jump to line 606 because the condition on line 603 was always true
604 f_reduc = self.model.normalize(f_reduc)
606 return self.model.rff(f_reduc)
608 @icontract.require(lambda self: len(self.support) > 0)
609 def compute_prototypes(self) -> torch.Tensor:
610 """
611 Method for computing class prototypes based on given support examples.
612 ``Prototypes'' in this context are the means of the support embeddings for each class.
614 Returns
615 -------
616 torch.Tensor
617 Tensors of prototypes for each of the given classes in the support.
618 """
619 # Compute support embeddings
620 support_embeddings = OrderedDict().fromkeys(self.support.keys())
621 for label in self.support:
622 support_embeddings[label] = self.compute_embeddings(self.support[label])
624 # Compute prototype for each class
625 proto_list = []
626 for label in self.support: # look at doing functorch
627 class_prototype = torch.mean(support_embeddings[label], dim=0) # type: ignore
628 proto_list.append(class_prototype)
630 prototypes = torch.stack(proto_list)
632 return prototypes
634 @icontract.require(lambda self: len(self.support) > 0)
635 def get_support(self) -> OrderedDict[int, torch.Tensor]:
636 """
637 Method for returning support examples used in training.
639 Returns
640 -------
641 OrderedDict[int, torch.Tensor]
642 Dictionary containing support examples for each class.
643 """
644 return self.support
646 @icontract.require(lambda self: len(self.prototypes) > 0)
647 def get_prototypes(self) -> torch.Tensor:
648 """
649 Method for returning class prototypes.
651 Returns
652 -------
653 torch.Tensor
654 Tensors of prototypes for each of the given classes in the support.
655 """
656 return self.prototypes
658 @icontract.require(lambda num_calibration_epochs: 0 < num_calibration_epochs)
659 @icontract.require(lambda calibration_lr: calibration_lr > 0.0)
660 def calibrate_temperature(
661 self,
662 calibration_loader: DataLoader[BatchType],
663 num_calibration_epochs: int = 1,
664 calibration_lr: float = 0.01,
665 ) -> None:
666 """
667 Fine-tune the temperature after training. Note this function is also run at the conclusion of train_model.
669 Parameters
670 ----------
671 calibration_loader : DataLoader
672 Data loader returned by train_model.
673 num_calibration_epochs : int, optional
674 Number of epochs to tune temperature.
675 calibration_lr : float, optional
676 Learning rate for temperature optimization.
677 """
678 self.temperature.requires_grad = True
679 loss_fn = torch.nn.functional.cross_entropy
680 optimizer = torch.optim.Adam([self.temperature], lr=calibration_lr)
681 for _ in range(num_calibration_epochs):
682 for xs, labels in calibration_loader:
683 optimizer.zero_grad()
684 xs = xs.to(self.device)
685 labels = labels.to(self.device)
686 with torch.no_grad():
687 logits = self.model(xs)
688 logits = logits / self.temperature
689 loss = loss_fn(logits, labels.to(torch.long))
690 loss.backward()
691 optimizer.step()
692 self.temperature.requires_grad = False
694 def forward(self, X: torch.Tensor) -> torch.Tensor:
695 """
696 EquineGP forward function, generates logits for classification.
698 Parameters
699 ----------
700 X : torch.Tensor
701 Input tensor for generating predictions.
703 Returns
704 -------
705 torch.Tensor
706 Output probabilities computed.
707 """
708 X = X.to(self.device)
709 preds = self.model(X)
710 return preds / self.temperature.to(self.device)
712 @icontract.ensure(
713 lambda result: all((0 <= result.ood_scores) & (result.ood_scores <= 1.0))
714 )
715 def predict(self, X: torch.Tensor) -> EquineOutput:
716 """
717 Predict function for EquineGP, inherited and implemented from Equine.
719 Parameters
720 ----------
721 X : torch.Tensor
722 Input tensor.
724 Returns
725 -------
726 EquineOutput
727 Output object containing prediction probabilities and OOD scores.
728 """
729 logits = self(X)
730 preds = torch.softmax(logits, dim=1)
731 equiprobable = torch.ones(self.num_outputs) / self.num_outputs
732 max_entropy = torch.sum(torch.special.entr(equiprobable))
733 ood_score = torch.sum(torch.special.entr(preds), dim=1) / max_entropy
734 embeddings = self.compute_embeddings(X)
735 eq_out = EquineOutput(
736 classes=preds, ood_scores=ood_score, embeddings=embeddings
737 ) # TODO return embeddings
739 self.validate_feature_label_names(X.shape[-1], self.num_outputs)
741 return eq_out
743 def save(self, path: str) -> None:
744 """
745 Function to save all model parameters to a file.
747 Parameters
748 ----------
749 path : str
750 Filename to write the model.
751 """
752 model_settings = {
753 "emb_out_dim": self.num_deep_features,
754 "num_classes": self.num_outputs,
755 "use_temperature": self.use_temperature,
756 "init_temperature": self.temperature.item(),
757 "device": self.device_type,
758 }
760 jit_model = torch.jit.script(self.model.feature_extractor)
761 buffer = io.BytesIO()
762 torch.jit.save(jit_model, buffer)
763 buffer.seek(0)
765 laplace_sd = self.model.state_dict()
766 keys_to_delete = []
767 for key in laplace_sd:
768 if "feature_extractor" in key:
769 keys_to_delete.append(key)
770 for key in keys_to_delete:
771 del laplace_sd[key]
773 save_data = {
774 "embed_jit_save": buffer,
775 "feature_names": self.feature_names,
776 "label_names": self.label_names,
777 "laplace_model_save": laplace_sd,
778 "num_data": self.model.num_data,
779 "settings": model_settings,
780 "support": self.support,
781 "train_batch_size": self.model.train_batch_size,
782 "train_summary": self.train_summary,
783 }
785 torch.save(save_data, path) # TODO allow model checkpointing
787 @classmethod
788 def load(cls, path: str) -> Equine:
789 """
790 Function to load previously saved EquineGP model.
792 Parameters
793 ----------
794 path : str
795 Input filename.
797 Returns
798 -------
799 EquineGP
800 The reconstituted EquineGP object.
801 """
802 model_save = torch.load(path, weights_only=False)
803 jit_model = torch.jit.load(model_save.get("embed_jit_save"))
804 eq_model = cls(jit_model, **model_save.get("settings"))
806 eq_model.feature_names = model_save.get("feature_names")
807 eq_model.label_names = model_save.get("label_names")
808 eq_model.train_summary = model_save.get("train_summary")
810 eq_model.model.load_state_dict(
811 model_save.get("laplace_model_save"), strict=False
812 )
813 eq_model.model.seen_data = model_save.get("laplace_model_save").get("seen_data")
815 eq_model.model.set_training_params(
816 model_save.get("num_data"), model_save.get("train_batch_size")
817 )
818 eq_model.eval()
820 support = model_save.get("support")
821 if len(support) > 0:
822 eq_model.support = support
823 eq_model.prototypes = eq_model.compute_prototypes()
825 return eq_model