Mushin#

Important

The features provided by rai_toolbox.mushin are in early-beta phase and are subject to compatibility-breaking changes in the future.

Note

rai_toolbox.mushin requires additional installation dependencies. These can be installed via .. console:

$ pip install rai-toolbox[mushin]

Mushin means, roughly, “no-mind”; rai_toolbox.mushin provides utilities that greatly reduce the “boilerplate” code and overall complexity of conducting machine learning experiments, tests, and analyses. Unlike the rest of the toolbox, which adheres solely to essential PyTorch APIs, mushin is intentionally opinionated and specialized in its design; it reflects our (the toolbox dev team’s) shared workflows, best practices, and favorite tools for doing research and development.

As such, mushin is designed around PyTorch-Lightning, which facilitates boilerplate-free and performant machine learning work, and around hydra-zen, which makes it easy to design configurable and reproducible workflows that leverage the Hydra framework.

Workflows#

Workflows are designed to simplify and automate the process of configuring, running, and reproducing various data science and machine learning workflows. In part, these serve to greatly simplify the process of organizing and running jobs using the Hydra framework, and aggregating the results of those jobs for analysis.

BaseWorkflow([eval_task_cfg])

Provides an interface for creating a reusable workflow: encapsulated "boilerplate" for running, aggregating, and analyzing one or more Hydra jobs.

MultiRunMetricsWorkflow([eval_task_cfg, ...])

Abstract class for workflows that record metrics using Hydra multirun.

RobustnessCurve([eval_task_cfg, working_dir])

Abstract class for workflows that measure performance for different perturbation values.

PyTorch-Lightning Utilities#

Tools and utilities that make PyTorch-Lightning easy to use for our work. Some of these utilities also enable much-needed compatibility between PyTorch-Lightning and Hydra.

MetricsCallback([save_dir, filename])

Saves validation and test metrics stored in trainer.callback_metrics.

HydraDDP([accelerator, parallel_devices, ...])

DDP Strategy that supports Hydra run and multirun jobs.