Source code for rai_toolbox.mushin.lightning.callbacks
# Copyright 2023, MASSACHUSETTS INSTITUTE OF TECHNOLOGY
# Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014).
# SPDX-License-Identifier: MIT
from collections import defaultdict
from pathlib import Path
from typing import Union
import torch
from pytorch_lightning import Callback, LightningModule, Trainer
from rai_toolbox._utils import value_check
[docs]
class MetricsCallback(Callback):
"""Saves validation and test metrics stored in `trainer.callback_metrics`.
Parameters
----------
save_dir : str, optional (default=".")
filename : str, optional (default="metrics.pt")
The base filename used to store metrics. For `FITTING` the file is prepended
with "fit_" and for `TESTING` the file is prepended with `test_`.
Notes
-----
No metrics will be saved during `FITTING` if no validation metrics are calculated.
This is a limitation of PyTorch Lightning. Future versions will save the training
step metrics when no validation metrics are calculated.
Examples
--------
>>> from pytorch_lightning import Trainer
>>> from rai_toolbox.mushin import MetricsCallback
>>> metrics_callback = MetricsCallback()
>>> trainer = Trainer(callbacks=[metrics_callback])
"""
[docs]
def __init__(
self,
save_dir: Union[Path, str] = ".",
filename: Union[Path, str] = "metrics.pt",
):
super().__init__()
self.save_dir = Path(save_dir)
self.filename = value_check("filename", filename, type_=(str, Path))
self.train_metrics = defaultdict(list)
self.val_metrics = defaultdict(list)
self.test_metrics = defaultdict(list)
def _get_filename(self, stage: str):
return self.save_dir / f"{stage}_{self.filename}"
def _process_metrics(self, stored_metrics, metrics):
for k, v in metrics.items():
if isinstance(v, torch.Tensor):
if v.ndim == 0:
v = v.item()
else:
v = v.cpu().numpy()
stored_metrics[k].append(v)
def on_validation_end(self, trainer: Trainer, pl_module: LightningModule):
# Make sure PL is not doing it's sanity check run
if trainer.sanity_checking:
return self.val_metrics
metrics = trainer.callback_metrics
self.val_metrics["epoch"].append(pl_module.current_epoch)
self._process_metrics(self.val_metrics, metrics)
torch.save(self.val_metrics, self._get_filename("fit"))
return self.val_metrics
def on_test_end(self, trainer: Trainer, pl_module: LightningModule):
metrics = trainer.callback_metrics
self._process_metrics(self.test_metrics, metrics)
torch.save(self.test_metrics, self._get_filename("test"))
return self.test_metrics