Coverage for /opt/hostedtoolcache/Python/3.10.18/x64/lib/python3.10/site-packages/equine/load_equine_model.py: 100%
11 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-29 04:12 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-29 04:12 +0000
1# Copyright 2024, MASSACHUSETTS INSTITUTE OF TECHNOLOGY
2# Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014).
3# SPDX-License-Identifier: MIT
5import torch
7from .equine import Equine
8from .equine_gp import EquineGP
9from .equine_protonet import EquineProtonet
12def load_equine_model(model_path: str) -> Equine:
13 """
14 Attempt to load an EQUINE model from a file
16 Parameters
17 ----------
18 model_path : str
19 The path to the model file
21 Returns
22 -------
23 Equine
24 The loaded EQUINE model
26 Raises
27 ------
28 ValueError
29 If the model type is unknown
30 """
31 model_type = torch.load(model_path, weights_only=False)["train_summary"][
32 "modelType"
33 ]
35 if model_type == "EquineProtonet":
36 model = EquineProtonet.load(model_path)
37 elif model_type == "EquineGP":
38 model = EquineGP.load(model_path)
39 else:
40 raise ValueError(f"Unknown model type '{model_type}'")
41 return model