Coverage for /opt/hostedtoolcache/Python/3.10.18/x64/lib/python3.10/site-packages/equine/equine.py: 100%

38 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-29 04:12 +0000

1# Copyright 2024, MASSACHUSETTS INSTITUTE OF TECHNOLOGY 

2# Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014). 

3# SPDX-License-Identifier: MIT 

4 

5from typing import Any, Optional, TypeVar 

6 

7import icontract 

8import torch 

9from abc import ABC, abstractmethod 

10from collections import OrderedDict 

11from torch.utils.data import TensorDataset 

12 

13from .equine_output import EquineOutput 

14 

15# A type variable for Equine objects 

16AnyEquine = TypeVar("AnyEquine", bound="Equine") 

17 

18 

19class Equine(torch.nn.Module, ABC): 

20 """EQUINE Abstract Base Class (ABC): 

21 EQUINE is set up to extend torch's nn.Module to enrich it with 

22 a method that enables uncertainty quantification and visualization. Most 

23 importantly, the `.predict()` method must be outfitted to return an 

24 EquineOutput object that contains both the class probabilities 

25 *and* an out-of-distribution (ood) score. 

26 

27 Parameters 

28 ---------- 

29 embedding_model : torch.nn.Module 

30 The embedding model to use. 

31 head_layers : int, optional 

32 The number of layers to use in the model head, by default 1. 

33 device : str, optional 

34 The device to train the equine model on (defaults to cpu). 

35 feature_names : list[str], optional 

36 List of strings of the names of the tabular features (ex ["duration", "fiat_mean", ...]) 

37 label_names : list[str], optional 

38 List of strings of the names of the labels (ex ["streaming", "voip", ...]) 

39 

40 Attributes 

41 ---------- 

42 device : str 

43 The device to train the equine model on (defaults to cpu). 

44 embedding_model : torch.nn.Module 

45 The neural embedding model to enrich with uncertainty quantification. 

46 feature_names : list[str], optional 

47 List of strings of the names of the tabular features (ex ["duration", "fiat_mean", ...]) 

48 head_layers : int 

49 The number of linear layers to append to the embedding model (default 1, not always used). 

50 label_names : list[str], optional 

51 List of strings of the names of the labels (ex ["streaming", "voip", ...]) 

52 train_summary : dict[str, Any] 

53 A dictionary containing information about the model training. 

54 

55 Raises 

56 ------ 

57 NotImplementedError 

58 If any of the abstract methods are not implemented. 

59 """ 

60 

61 def __init__( 

62 self, 

63 embedding_model: torch.nn.Module, 

64 head_layers: int = 1, 

65 device: str = "cpu", 

66 feature_names: Optional[list[str]] = None, 

67 label_names: Optional[list[str]] = None, 

68 ) -> None: 

69 super().__init__() 

70 self.embedding_model = embedding_model 

71 self.head_layers = head_layers 

72 self.train_summary: dict[str, Any] = { 

73 "numTrainExamples": 0, 

74 "dateTrained": "", 

75 "modelType": "", 

76 } 

77 self.device = device 

78 self.to(device) 

79 self.embedding_model.to(device) 

80 self.feature_names = feature_names 

81 self.label_names = label_names 

82 

83 self.support: OrderedDict[int, torch.Tensor] = OrderedDict() 

84 self.support_embeddings: OrderedDict[int, torch.Tensor] = OrderedDict() 

85 self.prototypes: torch.Tensor = torch.Tensor() 

86 

87 @abstractmethod 

88 def forward(self, X: torch.Tensor) -> torch.Tensor: 

89 """ 

90 Forward pass of the model. This is to preserve the usual behavior 

91 of torch.nn.Module. 

92 

93 Parameters 

94 ---------- 

95 X : torch.Tensor 

96 The input data. 

97 

98 Returns 

99 ------- 

100 torch.Tensor 

101 The output of the model. 

102 

103 """ 

104 raise NotImplementedError 

105 

106 @abstractmethod 

107 def predict(self, X: torch.Tensor) -> EquineOutput: 

108 """ 

109 Upon implementation, predicts the class logits and out-of-distribution (ood) scores for the 

110 given input data. 

111 

112 Parameters 

113 ---------- 

114 X : torch.Tensor 

115 The input data. 

116 

117 Returns 

118 ------- 

119 EquineOutput 

120 An EquineOutput object containing the class probabilities and OOD scores. 

121 """ 

122 raise NotImplementedError 

123 

124 @abstractmethod 

125 def train_model( 

126 self, dataset: TensorDataset, *args: Any, **kwargs: Any 

127 ) -> dict[str, Any]: 

128 """ 

129 Upon implementation, train the model on the given dataset. 

130 

131 Parameters 

132 ---------- 

133 dataset : TensorDataset 

134 TensorDataset containing the training data. 

135 **kwargs 

136 Additional keyword arguments to pass to the training function. 

137 

138 Returns 

139 ------- 

140 dict[str, Any] 

141 Dictionary containing summary training information and any other data 

142 Note that at least one key should be 'train_summary' 

143 """ 

144 raise NotImplementedError 

145 

146 @abstractmethod 

147 def get_prototypes(self) -> torch.Tensor: 

148 """ 

149 Upon implementation, returns the prototype embeddings 

150 

151 Returns 

152 ------- 

153 torch.Tensor 

154 A torch tensor of the prototype embeddings 

155 """ 

156 raise NotImplementedError 

157 

158 @abstractmethod 

159 def save(self, path: str) -> None: 

160 """ 

161 Upon implementation, save the model to the given file path. 

162 

163 Parameters 

164 ---------- 

165 path : str 

166 File path to save the model to. 

167 """ 

168 raise NotImplementedError 

169 

170 @classmethod # type: ignore 

171 def load(cls: AnyEquine, path: str) -> AnyEquine: # noqa: F821 # type: ignore 

172 """ 

173 Upon implementation, load the model from the given file path. 

174 

175 Parameters 

176 ---------- 

177 path : str 

178 File path to load the model from. 

179 

180 Returns 

181 ------- 

182 Equine 

183 Loaded model object. 

184 """ 

185 raise NotImplementedError 

186 

187 def get_label_names(self) -> Optional[list[str]]: 

188 """ 

189 Retrieve the label names used in the model. 

190 

191 Returns 

192 ------- 

193 Optional[list[str]] 

194 A list of label names if available; otherwise, None. 

195 """ 

196 if hasattr(self, "label_names"): 

197 return self.label_names 

198 return None 

199 

200 def get_feature_names(self) -> Optional[list[str]]: 

201 """ 

202 Retrieve the feature names used in the model. 

203 

204 Returns 

205 ------- 

206 Optional[list[str]] 

207 A list of feature names if available; otherwise, None. 

208 """ 

209 if hasattr(self, "feature_names"): 

210 return self.feature_names 

211 return None 

212 

213 @icontract.require( 

214 lambda num_features, num_classes: num_features > 0 and num_classes > 0 

215 ) 

216 def validate_feature_label_names(self, num_features: int, num_classes: int) -> None: 

217 """ 

218 Validate that the feature names and label names, if provided, match the expected counts. 

219 

220 Parameters 

221 ---------- 

222 num_features : int 

223 The expected number of features. 

224 num_classes : int 

225 The expected number of classes. 

226 

227 Raises 

228 ------ 

229 ValueError 

230 If the length of `feature_names` does not match `num_features`, or 

231 if the length of `label_names` does not match `num_classes`. 

232 """ 

233 feature_names = self.get_feature_names() 

234 if feature_names is not None and len(feature_names) != num_features: 

235 raise ValueError( 

236 f"The length of feature_names ({len(feature_names)}) does not match the number of data features ({num_features}). Update feature_names or set feature_names to None." 

237 ) 

238 

239 label_names = self.get_label_names() 

240 if label_names is not None and len(label_names) != num_classes: 

241 raise ValueError( 

242 f"The length of label_names ({len(label_names)}) does not match the number of classes ({num_classes}). Update label_names or set label_names to None." 

243 )