Coverage for /opt/hostedtoolcache/Python/3.10.18/x64/lib/python3.10/site-packages/equine/equine_protonet.py: 96%
298 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-29 04:12 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-29 04:12 +0000
1# Copyright 2024, MASSACHUSETTS INSTITUTE OF TECHNOLOGY
2# Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014).
3# SPDX-License-Identifier: MIT
4from __future__ import annotations
6from typing import Any, Optional
8import icontract
9import io
10import numpy as np
11import torch
12import warnings
13from beartype import beartype
14from collections import OrderedDict
15from collections.abc import Callable
16from datetime import datetime
17from enum import Enum
18from scipy.stats import gaussian_kde
19from torch.utils.data import TensorDataset
20from tqdm import tqdm
22from .equine import Equine, EquineOutput
23from .utils import (
24 generate_episode,
25 generate_support,
26 generate_train_summary,
27 mahalanobis_distance_nosq,
28 stratified_train_test_split,
29)
32#####################################
33class CovType(Enum):
34 """
35 Enum class for covariance types used in EQUINE.
36 """
38 UNIT = "unit"
39 DIAGONAL = "diag"
40 FULL = "full"
43PRED_COV_TYPE = CovType.DIAGONAL
44OOD_COV_TYPE = CovType.DIAGONAL
45DEFAULT_EPSILON = 1e-5
46COV_REG_TYPE = "epsilon"
49###############################################
52@beartype
53class Protonet(torch.nn.Module):
54 """
55 Private class that implements a prototypical neural network for use in EQUINE.
56 """
58 def __init__(
59 self,
60 embedding_model: torch.nn.Module,
61 emb_out_dim: int,
62 cov_type: CovType,
63 cov_reg_type: str,
64 epsilon: float,
65 device: str = "cpu",
66 ) -> None:
67 """
68 Protonet class constructor.
70 Parameters
71 ----------
72 embedding_model : torch.nn.Module
73 The PyTorch embedding model to generate logits with.
74 emb_out_dim : int
75 Dimension size of given embedding model's output.
76 cov_type : CovType
77 Type of covariance to use when computing distances [unit, diag, full].
78 cov_reg_type : str
79 Type of regularization to use when generating the covariance matrix [epsilon, shared].
80 epsilon : float
81 Epsilon value to use for covariance regularization.
82 device : str, optional
83 The device to train the protonet model on (defaults to cpu).
84 """
85 super().__init__()
86 self.embedding_model = embedding_model
87 self.cov_type = cov_type
88 self.cov_reg_type = cov_reg_type
89 self.epsilon = epsilon
90 self.emb_out_dim = emb_out_dim
91 self.to(device)
92 self.device = device
94 self.support: OrderedDict[int, torch.Tensor] = OrderedDict()
95 self.support_embeddings: OrderedDict[int, torch.Tensor] = OrderedDict()
96 self.model_head: torch.nn.Module = self.create_model_head(emb_out_dim)
97 self.model_head.to(device)
99 def create_model_head(self, emb_out_dim: int) -> torch.nn.Linear:
100 """
101 Method for adding a PyTorch layer on top of the given embedding model. This layer
102 is intended to offer extra degrees of freedom for distance learning in the embedding space.
104 Parameters
105 ----------
106 emb_out_dim : int
107 Dimension size of the embedding model output.
109 Returns
110 -------
111 torch.nn.Linear
112 The created PyTorch model layer.
113 """
114 return torch.nn.Linear(emb_out_dim, emb_out_dim)
116 def compute_embeddings(self, X: torch.Tensor) -> torch.Tensor:
117 """
118 Method for calculating model embeddings using both the given embedding model and the added model head.
120 Parameters
121 ----------
122 X : torch.Tensor
123 Input tensor to compute embeddings on.
125 Returns
126 -------
127 torch.Tensor
128 Fully computed embedding tensors for the given X tensor.
129 """
130 model_embeddings = self.embedding_model(X.to(self.device))
131 head_embeddings = self.model_head(model_embeddings)
132 return head_embeddings
134 @icontract.require(lambda self: len(self.support_embeddings) > 0)
135 def compute_prototypes(self) -> torch.Tensor:
136 """
137 Method for computing class prototypes based on given support examples.
138 ``Prototypes'' in this context are the means of the support embeddings for each class.
140 Returns
141 -------
142 torch.Tensor
143 Tensors of prototypes for each of the given classes in the support.
144 """
145 # Compute prototype for each class
146 proto_list = []
147 for label in self.support_embeddings: # look at doing functorch
148 class_prototype = torch.mean(self.support_embeddings[label], dim=0)
149 proto_list.append(class_prototype)
151 prototypes = torch.stack(proto_list)
153 return prototypes
155 @icontract.require(lambda self: len(self.support_embeddings) > 0)
156 def compute_covariance(self, cov_type: CovType) -> torch.Tensor:
157 """
158 Method for generating the (regularized) support example covariance matrix(es) used for calculating distances.
159 Note that this method is only called once per episode, and the resulting tensor is used for all queries.
161 Parameters
162 ----------
163 cov_type : CovType
164 Type of covariance to use [unit, diag, full].
166 Returns
167 -------
168 torch.Tensor
169 Tensor containing the generated regularized covariance matrix.
170 """
171 class_cov_dict = OrderedDict().fromkeys(
172 self.support_embeddings.keys(), torch.Tensor()
173 )
174 for label in self.support_embeddings.keys():
175 class_covariance = self.compute_covariance_by_type(
176 cov_type, self.support_embeddings[label]
177 )
178 class_cov_dict[label] = class_covariance
180 reg_covariance_dict = self.regularize_covariance(
181 class_cov_dict, cov_type, self.cov_reg_type
182 )
183 reg_covariance = torch.stack(list(reg_covariance_dict.values()))
185 return reg_covariance # TODO try putting everything on GPU with .to() and see if faster
187 def compute_covariance_by_type(
188 self, cov_type: CovType, embedding: torch.Tensor
189 ) -> torch.Tensor:
190 """
191 Select the appropriate covariance matrix type based on cov_type.
193 Parameters
194 ----------
195 cov_type : str
196 Type of covariance to use. Options are ['unit', 'diag', 'full'].
197 embedding : torch.Tensor
198 Embedding tensor to use when generating the covariance matrix.
200 Returns
201 -------
202 torch.Tensor
203 Tensor containing the requested covariance matrix.
204 """
205 if cov_type == CovType.FULL:
206 class_covariance = torch.cov(embedding.T)
207 elif cov_type == CovType.DIAGONAL:
208 class_covariance = torch.var(embedding, dim=0)
209 elif cov_type == CovType.UNIT:
210 class_covariance = torch.ones(self.emb_out_dim)
211 else:
212 raise ValueError
214 return class_covariance
216 def regularize_covariance(
217 self,
218 class_cov_dict: OrderedDict[int, torch.Tensor],
219 cov_type: CovType,
220 cov_reg_type: str,
221 ) -> OrderedDict[int, torch.Tensor]:
222 """
223 Method to add regularization to each class covariance matrix based on the selected regularization type.
225 Parameters
226 ----------
227 class_cov_dict : OrderedDict[int, torch.Tensor]
228 A dictionary containing each class and the corresponding covariance matrix.
229 cov_type : CovType
230 Type of covariance to use [unit, diag, full].
232 Returns
233 -------
234 dict[float, torch.Tensor]
235 Dictionary containing the regularized class covariance matrices.
236 """
238 if cov_type == CovType.FULL:
239 regularization = torch.diag(self.epsilon * torch.ones(self.emb_out_dim)).to(
240 self.device
241 )
242 elif cov_type == CovType.DIAGONAL:
243 regularization = self.epsilon * torch.ones(self.emb_out_dim).to(self.device)
244 elif cov_type == CovType.UNIT: 244 ↛ 247line 244 didn't jump to line 247 because the condition on line 244 was always true
245 regularization = torch.zeros(self.emb_out_dim).to(self.device)
247 if cov_reg_type == "shared":
248 if cov_type != CovType.FULL and cov_type != CovType.DIAGONAL: 248 ↛ 249line 248 didn't jump to line 249 because the condition on line 248 was never true
249 for label in self.support_embeddings:
250 class_cov_dict[label] = class_cov_dict[label] + regularization
251 warnings.warn(
252 "Covariance type UNIT is incompatible with shared regularization, \
253 reverting to epsilon regularization"
254 )
255 return class_cov_dict
257 shared_covariance = self.compute_shared_covariance(class_cov_dict, cov_type)
259 for label in self.support_embeddings:
260 num_class_support = self.support_embeddings[label].shape[0]
261 lamb = num_class_support / (num_class_support + 1)
263 class_cov_dict[label] = (
264 lamb * class_cov_dict[label]
265 + (1 - lamb) * shared_covariance
266 + regularization
267 )
269 elif cov_reg_type == "epsilon": 269 ↛ 275line 269 didn't jump to line 275 because the condition on line 269 was always true
270 for label in class_cov_dict.keys():
271 class_cov_dict[label] = (
272 class_cov_dict[label].to(self.device) + regularization
273 )
275 return class_cov_dict
277 def compute_shared_covariance(
278 self, class_cov_dict: OrderedDict[int, torch.Tensor], cov_type: CovType
279 ) -> torch.Tensor:
280 """
281 Method to calculate a shared covariance matrix.
283 The shared covariance matrix is calculated as the weighted average of the class covariance matrices,
284 where the weights are the number of support examples for each class. This is useful when the number of
285 support examples for each class is small.
287 Parameters
288 ----------
289 class_cov_dict : OrderedDict[int, torch.Tensor]
290 A dictionary containing each class and the corresponding covariance matrix.
291 cov_type : CovType
292 Type of covariance to use [unit, diag, full].
294 Returns
295 -------
296 torch.Tensor
297 Tensor containing the shared covariance matrix.
298 """
299 total_support = sum([x.shape[0] for x in class_cov_dict.values()])
301 if cov_type == CovType.FULL: 301 ↛ 302line 301 didn't jump to line 302 because the condition on line 301 was never true
302 shared_covariance = torch.zeros((self.emb_out_dim, self.emb_out_dim))
303 elif cov_type == CovType.DIAGONAL:
304 shared_covariance = torch.zeros(self.emb_out_dim)
305 else:
306 raise ValueError(
307 "Shared covariance can only be used with FULL or DIAGONAL (not UNIT) covariance types"
308 )
310 for label in class_cov_dict:
311 num_class_support = class_cov_dict[label].shape[0]
312 shared_covariance = (
313 shared_covariance + (num_class_support - 1) * class_cov_dict[label]
314 ) # undo N-1 div from cov
316 shared_covariance = shared_covariance / (
317 total_support - 1
318 ) # redo N-1 div for shared cov
320 return shared_covariance
322 @icontract.require(lambda X_embed, mu: X_embed.shape[-1] == mu.shape[-1])
323 @icontract.ensure(lambda result: torch.all(result >= 0))
324 def compute_distance(
325 self, X_embed: torch.Tensor, mu: torch.Tensor, cov: torch.Tensor
326 ) -> torch.Tensor:
327 """
328 Method to compute the distances to class prototypes for the given embeddings.
330 Parameters
331 ----------
332 X_embed : torch.Tensor
333 The embeddings of the query examples.
334 mu : torch.Tensor
335 The class prototypes (means of the support embeddings).
336 cov : torch.Tensor
337 The support covariance matrix.
339 Returns
340 -------
341 torch.Tensor
342 The calculated distances from each of the class prototypes for the given embeddings.
343 """
344 _queries = torch.unsqueeze(X_embed, 1) # examples x 1 x dimension
345 diff = torch.sub(mu, _queries)
347 if len(cov.shape) == 2: # (diagonal covariance)
348 # examples x classes x dimension
349 sq_diff = diff**2
350 div = torch.div(sq_diff.to(self.device), cov.to(self.device))
351 dist = torch.nan_to_num(div)
352 dist = torch.sum(dist, dim=2) # examples x classes
353 dist = dist.squeeze(dim=1)
354 dist = torch.sqrt(dist + self.epsilon) # examples x classes
355 else: # len(cov.shape) == 3: (full covariance)
356 diff = diff.permute(1, 2, 0) # classes x dimension x examples
357 dist = mahalanobis_distance_nosq(diff, cov)
358 dist = torch.sqrt(dist.permute(1, 0) + self.epsilon) # examples x classes
359 dist = dist.squeeze(dim=1)
360 return dist
362 def compute_classes(self, distances: torch.Tensor) -> torch.Tensor:
363 """
364 Method to compute predicted classes from distances via a softmax function.
366 Parameters
367 ----------
368 distances : torch.Tensor
369 The distances of embeddings to class prototypes.
371 Returns
372 -------
373 torch.Tensor
374 Tensor of class predictions.
375 """
376 softmax = torch.nn.functional.softmax(torch.neg(distances), dim=-1)
377 return softmax
379 def forward(self, X: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
380 """
381 Protonet forward function, generates class probability predictions and distances from prototypes.
383 Parameters
384 ----------
385 X : torch.Tensor
386 Input tensor of queries for generating predictions.
388 Returns
389 -------
390 tuple[torch.Tensor, torch.Tensor]
391 tuple containing class probability predictions, and class distances from prototypes.
392 """
393 if len(self.support) == 0 or len(self.support_embeddings) == 0:
394 raise ValueError(
395 "No support examples found. Protonet Model requires model support to \
396 be set with the 'update_support()' method before calling forward."
397 )
399 X_embed = self.compute_embeddings(X)
400 if X_embed.shape == torch.Size([self.emb_out_dim]):
401 X_embed = X_embed.unsqueeze(dim=0) # handle single examples
402 distances = self.compute_distance(X_embed, self.prototypes, self.covariance)
403 classes = self.compute_classes(distances)
405 return classes, distances
407 def update_support(self, support: OrderedDict[int, torch.Tensor]) -> None:
408 """
409 Method to update the support examples, and all the calculations that rely on them.
411 Parameters
412 ----------
413 support : OrderedDict
414 Ordered dict containing class labels and their associated support examples.
415 """
416 self.support = support # TODO torch.nn.ParameterDict(support)
418 support_embs = OrderedDict().fromkeys(support.keys(), torch.Tensor())
419 for label in support:
420 support_embs[label] = self.compute_embeddings(support[label])
422 self.support_embeddings = (
423 support_embs # TODO torch.nn.ParameterDict(support_embs)
424 )
426 self.prototypes: torch.Tensor = self.compute_prototypes()
428 if self.training is False:
429 self.compute_global_moments()
430 self.covariance: torch.Tensor = self.compute_covariance(
431 cov_type=PRED_COV_TYPE
432 )
433 else:
434 self.covariance: torch.Tensor = self.compute_covariance(
435 cov_type=self.cov_type
436 )
438 @icontract.require(lambda self: len(self.support_embeddings) > 0)
439 def compute_global_moments(self) -> None:
440 """Method to calculate the global moments of the support embeddings for use in OOD score generation"""
441 embeddings = torch.cat(list(self.support_embeddings.values()))
442 self.global_covariance = torch.unsqueeze(
443 self.compute_covariance_by_type(OOD_COV_TYPE, embeddings), dim=0
444 )
445 global_reg_input = OrderedDict().fromkeys([0], torch.Tensor())
446 global_reg_input[0] = self.global_covariance
447 self.global_covariance: torch.Tensor = self.regularize_covariance(
448 global_reg_input, OOD_COV_TYPE, "epsilon"
449 )[0]
450 self.global_mean: torch.Tensor = torch.mean(embeddings, dim=0)
453###############################################
454@beartype
455class EquineProtonet(Equine):
456 """
457 A class representing an EQUINE model that utilizes protonets and (optionally) relative Mahalanobis distances
458 to generate OOD and model confidence scores. This wraps any pytorch embedding neural network
459 and provides the `forward`, `predict`, `save`, and `load` methods required by Equine.
460 """
462 def __init__(
463 self,
464 embedding_model: torch.nn.Module,
465 emb_out_dim: int,
466 cov_type: CovType = CovType.UNIT,
467 relative_mahal: bool = True,
468 use_temperature: bool = False,
469 init_temperature: float = 1.0,
470 device: str = "cpu",
471 feature_names: Optional[list[str]] = None,
472 label_names: Optional[list[str]] = None,
473 ) -> None:
474 """
475 EquineProtonet class constructor
477 Parameters
478 ----------
479 embedding_model : torch.nn.Module
480 Neural Network feature embedding model.
481 emb_out_dim : int
482 The number of output features from the embedding model.
483 cov_type : CovType, optional
484 The type of covariance to use when training the protonet [UNIT, DIAG, FULL], by default CovType.UNIT.
485 relative_mahal : bool, optional
486 Use relative mahalanobis distance for OOD calculations. If false, uses standard mahalanobis distance instead, by default True.
487 use_temperature : bool, optional
488 Whether to use temperature scaling after training, by default False.
489 init_temperature : float, optional
490 What to use as the initial temperature (1.0 has no effect), by default 1.0.
491 device : str, optional
492 The device to train the equine model on (defaults to cpu).
493 feature_names : list[str], optional
494 List of strings of the names of the tabular features (ex ["duration", "fiat_mean", ...])
495 label_names : list[str], optional
496 List of strings of the names of the labels (ex ["streaming", "voip", ...])
497 """
498 super().__init__(
499 embedding_model,
500 device=device,
501 feature_names=feature_names,
502 label_names=label_names,
503 )
504 self.cov_type = cov_type
505 self.cov_reg_type = COV_REG_TYPE
506 self.relative_mahal = relative_mahal
507 self.emb_out_dim = emb_out_dim
508 self.epsilon = DEFAULT_EPSILON
509 self.outlier_score_kde: OrderedDict[int, gaussian_kde] = OrderedDict()
510 self.model_summary: dict[str, Any] = dict()
511 self.use_temperature = use_temperature
512 self.init_temperature = init_temperature
513 self.register_buffer(
514 "temperature", torch.Tensor(self.init_temperature * torch.ones(1))
515 )
517 self.model: torch.nn.Module = Protonet(
518 embedding_model,
519 self.emb_out_dim,
520 self.cov_type,
521 self.cov_reg_type,
522 self.epsilon,
523 device=device,
524 )
526 def forward(self, X: torch.Tensor) -> torch.Tensor:
527 """
528 Generates logits for classification based on the input tensor.
530 Parameters
531 ----------
532 X : torch.Tensor
533 The input tensor for generating predictions.
535 Returns
536 -------
537 torch.Tensor
538 The output class predictions.
539 """
540 preds, _ = self.model(X)
541 return preds
543 @icontract.require(lambda calib_frac: calib_frac > 0 and calib_frac < 1)
544 def train_model(
545 self,
546 dataset: TensorDataset,
547 num_episodes: int,
548 calib_frac: float = 0.2,
549 support_size: int = 25,
550 way: int = 3,
551 episode_size: int = 100,
552 loss_fn: Callable = torch.nn.functional.cross_entropy,
553 opt_class: Callable = torch.optim.Adam,
554 num_calibration_epochs: int = 2,
555 calibration_lr: float = 0.01,
556 ) -> dict[str, Any]:
557 """
558 Train or fine-tune an EquineProtonet model.
560 Parameters
561 ----------
562 dataset : TensorDataset
563 Input pytorch TensorDataset of training data for model.
564 num_episodes : int
565 The desired number of episodes to use for training.
566 calib_frac : float, optional
567 Fraction of given training data to reserve for model calibration, by default 0.2.
568 support_size : int, optional
569 Number of support examples to generate for each class, by default 25.
570 way : int, optional
571 Number of classes to train on per episode, by default 3.
572 episode_size : int, optional
573 Number of examples to use per episode, by default 100.
574 loss_fn : Callable, optional
575 A pytorch loss function, eg., torch.nn.CrossEntropyLoss(), by default torch.nn.functional.cross_entropy.
576 opt_class : Callable, optional
577 A pytorch optimizer, e.g., torch.optim.Adam, by default torch.optim.Adam.
578 num_calibration_epochs : int, optional
579 The desired number of epochs to use for temperature scaling, by default 2.
580 calibration_lr : float, optional
581 Learning rate for temperature scaling, by default 0.01.
583 Returns
584 -------
585 tuple[dict[str, Any], torch.Tensor, torch.Tensor]
586 A tuple containing the model summary, the held out calibration data, and the calibration labels.
587 """
588 self.train()
590 if self.use_temperature:
591 self.temperature: torch.Tensor = torch.Tensor(
592 self.init_temperature * torch.ones(1)
593 ).type_as(self.temperature)
595 X, Y = dataset[:]
597 self.validate_feature_label_names(X.shape[-1], torch.unique(Y).shape[0])
599 train_x, calib_x, train_y, calib_y = stratified_train_test_split(
600 X, Y, test_size=calib_frac
601 )
602 optimizer = opt_class(self.parameters())
604 train_x.to(self.device)
605 train_y.to(self.device)
606 calib_x.to(self.device)
607 calib_y.to(self.device)
609 for i in tqdm(range(num_episodes)):
610 optimizer.zero_grad()
612 support, episode_x, episode_y = generate_episode(
613 train_x, train_y, support_size, way, episode_size
614 )
615 self.model.update_support(support)
617 _, dists = self.model(episode_x)
618 loss_value = loss_fn(
619 torch.neg(dists).to(self.device), episode_y.to(self.device)
620 )
621 loss_value.backward()
622 optimizer.step()
624 self.eval()
625 full_support = generate_support(
626 train_x,
627 train_y,
628 support_size,
629 selected_labels=torch.unique(train_y).tolist(),
630 )
632 self.model.update_support(
633 full_support
634 ) # update support with final selected examples
636 X_embed = self.model.compute_embeddings(calib_x)
637 pred_probs, dists = self.model(calib_x)
638 ood_dists = self._compute_ood_dist(X_embed, pred_probs, dists)
639 self._fit_outlier_scores(ood_dists, calib_y)
641 if self.use_temperature:
642 self.calibrate_temperature(
643 calib_x, calib_y, num_calibration_epochs, calibration_lr
644 )
646 date_trained = datetime.now().strftime("%m/%d/%Y, %H:%M:%S")
647 self.train_summary: dict[str, Any] = generate_train_summary(
648 self, train_y, date_trained
649 )
650 return_dict: dict[str, Any] = dict()
651 return_dict["train_summary"] = self.train_summary
652 return_dict["calib_x"] = calib_x
653 return_dict["calib_y"] = calib_y
654 return return_dict
656 def calibrate_temperature(
657 self,
658 calib_x: torch.Tensor,
659 calib_y: torch.Tensor,
660 num_calibration_epochs: int = 1,
661 calibration_lr: float = 0.01,
662 ) -> None:
663 """
664 Fine-tune the temperature after training. Note that this function is also run at the conclusion of train_model.
666 Parameters
667 ----------
668 calib_x : torch.Tensor
669 Training data to be used for temperature calibration.
670 calib_y : torch.Tensor
671 Labels corresponding to `calib_x`.
672 num_calibration_epochs : int, optional
673 Number of epochs to tune temperature, by default 1.
674 calibration_lr : float, optional
675 Learning rate for temperature optimization, by default 0.01.
677 Returns
678 -------
679 None
680 """
681 self.temperature.requires_grad = True
682 optimizer = torch.optim.Adam([self.temperature], lr=calibration_lr)
683 for t in range(num_calibration_epochs):
684 optimizer.zero_grad()
685 with torch.no_grad():
686 pred_probs, dists = self.model(calib_x)
687 dists = dists.to(self.device) / self.temperature.to(self.device)
688 loss = torch.nn.functional.cross_entropy(
689 torch.neg(dists).to(self.device), calib_y.to(torch.long).to(self.device)
690 )
691 loss.backward()
692 optimizer.step()
693 self.temperature.requires_grad = False
695 @icontract.ensure(lambda self: len(self.model.support_embeddings) > 0)
696 def _fit_outlier_scores(
697 self, ood_dists: torch.Tensor, calib_y: torch.Tensor
698 ) -> None:
699 """
700 Private function to fit outlier scores with a kernel density estimate (KDE).
702 Parameters
703 ----------
704 ood_dists : torch.Tensor
705 Tensor of computed OOD distances.
706 calib_y : torch.Tensor
707 Tensor of class labels for `ood_dists` examples.
709 Returns
710 -------
711 None
712 """
713 for label in self.model.support_embeddings.keys():
714 class_ood_dists = ood_dists[calib_y == int(label)].cpu().detach().numpy()
715 class_kde = gaussian_kde(class_ood_dists) # TODO convert to torch func
716 self.outlier_score_kde[label] = class_kde
718 def _compute_outlier_scores(self, ood_dists, predictions) -> torch.Tensor:
719 """
720 Private function to compute OOD scores using the calculated kernel density estimate (KDE).
722 Parameters
723 ----------
724 ood_dists : torch.Tensor
725 Tensor of computed OOD distances.
726 predictions : torch.Tensor
727 Tensor of model protonet predictions.
729 Returns
730 -------
731 torch.Tensor
732 Tensor of OOD scores for the given examples.
733 """
734 ood_scores = torch.zeros_like(ood_dists)
735 for i in range(len(predictions)):
736 # Use KDE and RMD corresponding to the predicted class
737 predicted_class = int(torch.argmax(predictions[i, :]))
738 p_value = self.outlier_score_kde[int(predicted_class)].integrate_box_1d(
739 ood_dists[i].detach().numpy(), np.inf
740 )
741 ood_scores[i] = 1.0 - np.clip(p_value, 0.0, 1.0)
743 return ood_scores
745 @icontract.ensure(lambda result: len(result) > 0)
746 def _compute_ood_dist(
747 self,
748 X_embeddings: torch.Tensor,
749 predictions: torch.Tensor,
750 distances: torch.Tensor,
751 ) -> torch.Tensor:
752 """
753 Private function to compute OOD distances using a distance function.
755 Parameters
756 ----------
757 X_embeddings : torch.Tensor
758 Tensor of example embeddings.
759 predictions : torch.Tensor
760 Tensor of model protonet predictions for the given embeddings.
761 distances : torch.Tensor
762 Tensor of calculated protonet distances for the given embeddings.
764 Returns
765 -------
766 torch.Tensor
767 Tensor of OOD distances for the given embeddings.
768 """
769 preds = torch.argmax(predictions, dim=1)
770 preds = preds.unsqueeze(dim=-1)
771 # Calculate (Relative) Mahalanobis Distance:
772 if self.relative_mahal:
773 null_distance = self.model.compute_distance(
774 X_embeddings, self.model.global_mean, self.model.global_covariance
775 )
776 null_distance = null_distance.unsqueeze(dim=-1)
777 ood_dist = distances.gather(1, preds) - null_distance
778 else:
779 ood_dist = distances.gather(1, preds)
781 ood_dist = torch.reshape(ood_dist, (-1,))
782 return ood_dist
784 def predict(self, X: torch.Tensor) -> EquineOutput:
785 """Predict function for EquineProtonet, inherited and implemented from Equine.
787 Parameters
788 ----------
789 X : torch.Tensor
790 Input tensor.
792 Returns
793 -------
794 EquineOutput
795 Output object containing prediction probabilities and OOD scores.
796 """
797 X_embed = self.model.compute_embeddings(X)
798 if X_embed.shape == torch.Size([self.model.emb_out_dim]):
799 X_embed = X_embed.unsqueeze(dim=0) # Handle single examples
800 preds, dists = self.model(X)
801 if self.use_temperature:
802 dists = dists / self.temperature
803 preds = torch.softmax(torch.negative(dists), dim=1)
804 ood_dist = self._compute_ood_dist(X_embed, preds, dists)
805 ood_scores = self._compute_outlier_scores(ood_dist, preds)
807 self.validate_feature_label_names(X.shape[-1], preds.shape[-1])
809 return EquineOutput(classes=preds, ood_scores=ood_scores, embeddings=X_embed)
811 @icontract.require(lambda calib_frac: (calib_frac > 0.0) and (calib_frac < 1.0))
812 def update_support(
813 self,
814 support_x: torch.Tensor,
815 support_y: torch.Tensor,
816 calib_frac: float,
817 label_names: Optional[list[str]] = None,
818 ) -> None:
819 """Function to update protonet support examples with given examples.
821 Parameters
822 ----------
823 support_x : torch.Tensor
824 Tensor containing support examples for protonet.
825 support_y : torch.Tensor
826 Tensor containing labels for given support examples.
827 calib_frac : float
828 Fraction of given support data to use for OOD calibration.
829 label_names : list[str], optional
830 List of strings of the names of the labels (ex ["streaming", "voip", ...])
832 Returns
833 -------
834 None
835 """
837 support_x, calib_x, support_y, calib_y = stratified_train_test_split(
838 support_x, support_y, test_size=calib_frac
839 )
840 labels, counts = torch.unique(support_y, return_counts=True)
841 if label_names is not None: 841 ↛ 842line 841 didn't jump to line 842 because the condition on line 841 was never true
842 self.label_names = label_names
843 self.validate_feature_label_names(support_x.shape[-1], labels.shape[0])
845 support = OrderedDict()
846 for label, count in list(zip(labels.tolist(), counts.tolist())):
847 class_support = generate_support(
848 support_x,
849 support_y,
850 support_size=count,
851 selected_labels=[label],
852 )
853 support.update(class_support)
855 self.model.update_support(support)
857 X_embed = self.model.compute_embeddings(calib_x)
858 preds, dists = self.model(calib_x)
859 ood_dists = self._compute_ood_dist(X_embed, preds, dists)
861 self._fit_outlier_scores(ood_dists, calib_y)
863 @icontract.require(lambda self: len(self.model.support) > 0)
864 def get_support(self) -> OrderedDict[int, torch.Tensor]:
865 """
866 Get the support examples for the model.
868 Returns
869 -------
870 OrderedDict[int, torch.Tensor]
871 The support examples for the model.
872 """
873 return self.model.support
875 @icontract.require(lambda self: len(self.model.prototypes) > 0)
876 def get_prototypes(self) -> torch.Tensor:
877 """
878 Get the prototypes for the model (the class means of the support embeddings).
880 Returns
881 -------
882 torch.Tensor
883 The prototpes for the model.
884 """
885 return self.model.prototypes
887 def save(self, path: str) -> None:
888 """
889 Save all model parameters to a file.
891 Parameters
892 ----------
893 path : str
894 Filename to write the model.
896 Returns
897 -------
898 None
899 """
900 model_settings = {
901 "cov_type": self.cov_type,
902 "emb_out_dim": self.emb_out_dim,
903 "use_temperature": self.use_temperature,
904 "init_temperature": self.temperature.item(),
905 "relative_mahal": self.relative_mahal,
906 }
908 jit_model = torch.jit.script(self.model.embedding_model)
909 buffer = io.BytesIO()
910 torch.jit.save(jit_model, buffer)
911 buffer.seek(0)
913 save_data = {
914 "embed_jit_save": buffer,
915 "feature_names": self.feature_names,
916 "label_names": self.label_names,
917 "model_head_save": self.model.model_head.state_dict(),
918 "outlier_kde": self.outlier_score_kde,
919 "settings": model_settings,
920 "support": self.model.support,
921 "train_summary": self.train_summary,
922 }
924 torch.save(save_data, path) # TODO allow model checkpointing
926 @classmethod
927 def load(cls, path: str) -> Equine: # noqa: F821
928 """
929 Load a previously saved EquineProtonet model.
931 Parameters
932 ----------
933 path : str
934 The filename of the saved model.
936 Returns
937 -------
938 EquineProtonet
939 The reconstituted EquineProtonet object.
940 """
941 model_save = torch.load(path, weights_only=False)
942 support = model_save.get("support")
943 jit_model = torch.jit.load(model_save.get("embed_jit_save"))
944 eq_model = cls(jit_model, **model_save.get("settings"))
946 eq_model.model.model_head.load_state_dict(model_save.get("model_head_save"))
947 eq_model.eval()
948 eq_model.model.update_support(support)
950 eq_model.feature_names = model_save.get("feature_names")
951 eq_model.label_names = model_save.get("label_names")
952 eq_model.outlier_score_kde = model_save.get("outlier_kde")
953 eq_model.train_summary = model_save.get("train_summary")
955 return eq_model