Coverage for /opt/hostedtoolcache/Python/3.10.18/x64/lib/python3.10/site-packages/equine/load_equine_model.py: 100%

11 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-29 04:12 +0000

1# Copyright 2024, MASSACHUSETTS INSTITUTE OF TECHNOLOGY 

2# Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014). 

3# SPDX-License-Identifier: MIT 

4 

5import torch 

6 

7from .equine import Equine 

8from .equine_gp import EquineGP 

9from .equine_protonet import EquineProtonet 

10 

11 

12def load_equine_model(model_path: str) -> Equine: 

13 """ 

14 Attempt to load an EQUINE model from a file 

15 

16 Parameters 

17 ---------- 

18 model_path : str 

19 The path to the model file 

20 

21 Returns 

22 ------- 

23 Equine 

24 The loaded EQUINE model 

25 

26 Raises 

27 ------ 

28 ValueError 

29 If the model type is unknown 

30 """ 

31 model_type = torch.load(model_path, weights_only=False)["train_summary"][ 

32 "modelType" 

33 ] 

34 

35 if model_type == "EquineProtonet": 

36 model = EquineProtonet.load(model_path) 

37 elif model_type == "EquineGP": 

38 model = EquineGP.load(model_path) 

39 else: 

40 raise ValueError(f"Unknown model type '{model_type}'") 

41 return model