Skip to content

load_equine_model

load_equine_model(model_path) ยค

Attempt to load an EQUINE model from a file

Parameters:

Name Type Description Default
model_path str

The path to the model file

required

Returns:

Type Description
Equine

The loaded EQUINE model

Raises:

Type Description
ValueError

If the model type is unknown

Source code in src/equine/load_equine_model.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def load_equine_model(model_path: str) -> Equine:
    """
    Attempt to load an EQUINE model from a file

    Parameters
    ----------
    model_path : str
        The path to the model file

    Returns
    -------
    Equine
        The loaded EQUINE model

    Raises
    ------
    ValueError
        If the model type is unknown
    """
    model_type = torch.load(model_path)["train_summary"]["modelType"]

    if model_type == "EquineProtonet":
        model = EquineProtonet.load(model_path)
    elif model_type == "EquineGP":
        model = EquineGP.load(model_path)
    else:
        raise ValueError(f"Unknown model type '{model_type}'")
    return model