Coverage for /opt/hostedtoolcache/Python/3.10.18/x64/lib/python3.10/site-packages/equine/equine.py: 100%
38 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-29 04:12 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-29 04:12 +0000
1# Copyright 2024, MASSACHUSETTS INSTITUTE OF TECHNOLOGY
2# Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014).
3# SPDX-License-Identifier: MIT
5from typing import Any, Optional, TypeVar
7import icontract
8import torch
9from abc import ABC, abstractmethod
10from collections import OrderedDict
11from torch.utils.data import TensorDataset
13from .equine_output import EquineOutput
15# A type variable for Equine objects
16AnyEquine = TypeVar("AnyEquine", bound="Equine")
19class Equine(torch.nn.Module, ABC):
20 """EQUINE Abstract Base Class (ABC):
21 EQUINE is set up to extend torch's nn.Module to enrich it with
22 a method that enables uncertainty quantification and visualization. Most
23 importantly, the `.predict()` method must be outfitted to return an
24 EquineOutput object that contains both the class probabilities
25 *and* an out-of-distribution (ood) score.
27 Parameters
28 ----------
29 embedding_model : torch.nn.Module
30 The embedding model to use.
31 head_layers : int, optional
32 The number of layers to use in the model head, by default 1.
33 device : str, optional
34 The device to train the equine model on (defaults to cpu).
35 feature_names : list[str], optional
36 List of strings of the names of the tabular features (ex ["duration", "fiat_mean", ...])
37 label_names : list[str], optional
38 List of strings of the names of the labels (ex ["streaming", "voip", ...])
40 Attributes
41 ----------
42 device : str
43 The device to train the equine model on (defaults to cpu).
44 embedding_model : torch.nn.Module
45 The neural embedding model to enrich with uncertainty quantification.
46 feature_names : list[str], optional
47 List of strings of the names of the tabular features (ex ["duration", "fiat_mean", ...])
48 head_layers : int
49 The number of linear layers to append to the embedding model (default 1, not always used).
50 label_names : list[str], optional
51 List of strings of the names of the labels (ex ["streaming", "voip", ...])
52 train_summary : dict[str, Any]
53 A dictionary containing information about the model training.
55 Raises
56 ------
57 NotImplementedError
58 If any of the abstract methods are not implemented.
59 """
61 def __init__(
62 self,
63 embedding_model: torch.nn.Module,
64 head_layers: int = 1,
65 device: str = "cpu",
66 feature_names: Optional[list[str]] = None,
67 label_names: Optional[list[str]] = None,
68 ) -> None:
69 super().__init__()
70 self.embedding_model = embedding_model
71 self.head_layers = head_layers
72 self.train_summary: dict[str, Any] = {
73 "numTrainExamples": 0,
74 "dateTrained": "",
75 "modelType": "",
76 }
77 self.device = device
78 self.to(device)
79 self.embedding_model.to(device)
80 self.feature_names = feature_names
81 self.label_names = label_names
83 self.support: OrderedDict[int, torch.Tensor] = OrderedDict()
84 self.support_embeddings: OrderedDict[int, torch.Tensor] = OrderedDict()
85 self.prototypes: torch.Tensor = torch.Tensor()
87 @abstractmethod
88 def forward(self, X: torch.Tensor) -> torch.Tensor:
89 """
90 Forward pass of the model. This is to preserve the usual behavior
91 of torch.nn.Module.
93 Parameters
94 ----------
95 X : torch.Tensor
96 The input data.
98 Returns
99 -------
100 torch.Tensor
101 The output of the model.
103 """
104 raise NotImplementedError
106 @abstractmethod
107 def predict(self, X: torch.Tensor) -> EquineOutput:
108 """
109 Upon implementation, predicts the class logits and out-of-distribution (ood) scores for the
110 given input data.
112 Parameters
113 ----------
114 X : torch.Tensor
115 The input data.
117 Returns
118 -------
119 EquineOutput
120 An EquineOutput object containing the class probabilities and OOD scores.
121 """
122 raise NotImplementedError
124 @abstractmethod
125 def train_model(
126 self, dataset: TensorDataset, *args: Any, **kwargs: Any
127 ) -> dict[str, Any]:
128 """
129 Upon implementation, train the model on the given dataset.
131 Parameters
132 ----------
133 dataset : TensorDataset
134 TensorDataset containing the training data.
135 **kwargs
136 Additional keyword arguments to pass to the training function.
138 Returns
139 -------
140 dict[str, Any]
141 Dictionary containing summary training information and any other data
142 Note that at least one key should be 'train_summary'
143 """
144 raise NotImplementedError
146 @abstractmethod
147 def get_prototypes(self) -> torch.Tensor:
148 """
149 Upon implementation, returns the prototype embeddings
151 Returns
152 -------
153 torch.Tensor
154 A torch tensor of the prototype embeddings
155 """
156 raise NotImplementedError
158 @abstractmethod
159 def save(self, path: str) -> None:
160 """
161 Upon implementation, save the model to the given file path.
163 Parameters
164 ----------
165 path : str
166 File path to save the model to.
167 """
168 raise NotImplementedError
170 @classmethod # type: ignore
171 def load(cls: AnyEquine, path: str) -> AnyEquine: # noqa: F821 # type: ignore
172 """
173 Upon implementation, load the model from the given file path.
175 Parameters
176 ----------
177 path : str
178 File path to load the model from.
180 Returns
181 -------
182 Equine
183 Loaded model object.
184 """
185 raise NotImplementedError
187 def get_label_names(self) -> Optional[list[str]]:
188 """
189 Retrieve the label names used in the model.
191 Returns
192 -------
193 Optional[list[str]]
194 A list of label names if available; otherwise, None.
195 """
196 if hasattr(self, "label_names"):
197 return self.label_names
198 return None
200 def get_feature_names(self) -> Optional[list[str]]:
201 """
202 Retrieve the feature names used in the model.
204 Returns
205 -------
206 Optional[list[str]]
207 A list of feature names if available; otherwise, None.
208 """
209 if hasattr(self, "feature_names"):
210 return self.feature_names
211 return None
213 @icontract.require(
214 lambda num_features, num_classes: num_features > 0 and num_classes > 0
215 )
216 def validate_feature_label_names(self, num_features: int, num_classes: int) -> None:
217 """
218 Validate that the feature names and label names, if provided, match the expected counts.
220 Parameters
221 ----------
222 num_features : int
223 The expected number of features.
224 num_classes : int
225 The expected number of classes.
227 Raises
228 ------
229 ValueError
230 If the length of `feature_names` does not match `num_features`, or
231 if the length of `label_names` does not match `num_classes`.
232 """
233 feature_names = self.get_feature_names()
234 if feature_names is not None and len(feature_names) != num_features:
235 raise ValueError(
236 f"The length of feature_names ({len(feature_names)}) does not match the number of data features ({num_features}). Update feature_names or set feature_names to None."
237 )
239 label_names = self.get_label_names()
240 if label_names is not None and len(label_names) != num_classes:
241 raise ValueError(
242 f"The length of label_names ({len(label_names)}) does not match the number of classes ({num_classes}). Update label_names or set label_names to None."
243 )