rai_toolbox.mushin.HydraDDP#

class rai_toolbox.mushin.HydraDDP(accelerator=None, parallel_devices=None, cluster_environment=None, checkpoint_io=None, precision_plugin=None, ddp_comm_state=None, ddp_comm_hook=None, ddp_comm_wrapper=None, model_averaging_period=None, process_group_backend=None, timeout=datetime.timedelta(seconds=1800), start_method='popen', **kwargs)[source]#

DDP Strategy that supports Hydra run and multirun jobs.

This strategy assumes a PyTorch Lightning Trainer.fit or Trainer.test has been configured to execute via Hydra. It requires that Hydra saves a config.yaml in the current working directory with the following keys/properties set:

├── Config
│    ├── trainer: A `pytorch_lightning.Trainer` configuration
│    ├── module: A `pytorch_lightning.LightningModule` configuration
│    ├── datamodule: [OPTIONAL] A `pytorch_lightning.LightningDataModule` configuration

This strategy will launch a child subprocesses for additional GPU beyond the first using the following base command:

python -m rai_toolbox.mushin.lightning._pl_main -cp <path to config.yaml> -cn config.yaml

Examples

First define a Hydra configuration using hydra-zen:

>>> import pytorch_lightning as pl
... from hydra_zen import builds, make_config,
... from rai_toolbox.mushin import HydraDDP
... from rai_toolbox.mushin.testing.lightning import SimpleLightningModule
...
... TrainerConfig = builds(
...     pl.Trainer,
...     accelerator="auto",
...     gpus=2,
...     max_epochs=1,
...     fast_dev_run=True,
...     strategy=builds(HydraDDP),
...     populate_full_signature=True
... )
...
... ModuleConfig = builds(SimpleLightningModule)
...
... Config = make_config(
...     trainer=TrainerConfig,
...     module=ModuleConfig
... )

Next, define a task function to execute the Hydra job:

>>> from hydra_zen import instantiate
>>> def task_function(cfg):
...     obj = instantiate(cfg)
...     obj.trainer.fit(obj.module)

Launch the Hydra+Lightning DDP job

>>> from hydra_zen import launch
>>> job = launch(Config, task_function)

HydraDDP also supports LightningDataModule configuration.

>>> DataModuleConfig = ... # A LightningDataModule config
>>> Config = make_config(
...     trainer=TrainerConfig,
...     module=ModuleConfig
...     datamodule=DataModuleconfig
... )

Next define a task function to execute the Hydra job:

>>> from hydra_zen import instantiate
>>> def task_function(cfg):
...     obj = instantiate(cfg)
...     obj.trainer.fit(obj.module, datamodule=obj.datamodule)

Launch the Hydra+Lightning DDP job:

>>> from hydra_zen import launch
>>> job = launch(Config, task_function)
__init__(accelerator=None, parallel_devices=None, cluster_environment=None, checkpoint_io=None, precision_plugin=None, ddp_comm_state=None, ddp_comm_hook=None, ddp_comm_wrapper=None, model_averaging_period=None, process_group_backend=None, timeout=datetime.timedelta(seconds=1800), start_method='popen', **kwargs)#

Methods

__init__([accelerator, parallel_devices, ...])