rai_toolbox.mushin.HydraDDP#
- class rai_toolbox.mushin.HydraDDP(accelerator=None, parallel_devices=None, cluster_environment=None, checkpoint_io=None, precision_plugin=None, ddp_comm_state=None, ddp_comm_hook=None, ddp_comm_wrapper=None, model_averaging_period=None, process_group_backend=None, timeout=datetime.timedelta(seconds=1800), start_method='popen', **kwargs)[source]#
DDP Strategy that supports Hydra run and multirun jobs.
This strategy assumes a PyTorch Lightning
Trainer.fit
orTrainer.test
has been configured to execute via Hydra. It requires that Hydra saves aconfig.yaml
in the current working directory with the following keys/properties set:├── Config │ ├── trainer: A `pytorch_lightning.Trainer` configuration │ ├── module: A `pytorch_lightning.LightningModule` configuration │ ├── datamodule: [OPTIONAL] A `pytorch_lightning.LightningDataModule` configuration
This strategy will launch a child subprocesses for additional GPU beyond the first using the following base command:
python -m rai_toolbox.mushin.lightning._pl_main -cp <path to config.yaml> -cn config.yaml
Examples
First define a Hydra configuration using hydra-zen:
>>> import pytorch_lightning as pl ... from hydra_zen import builds, make_config, ... from rai_toolbox.mushin import HydraDDP ... from rai_toolbox.mushin.testing.lightning import SimpleLightningModule ... ... TrainerConfig = builds( ... pl.Trainer, ... accelerator="auto", ... gpus=2, ... max_epochs=1, ... fast_dev_run=True, ... strategy=builds(HydraDDP), ... populate_full_signature=True ... ) ... ... ModuleConfig = builds(SimpleLightningModule) ... ... Config = make_config( ... trainer=TrainerConfig, ... module=ModuleConfig ... )
Next, define a task function to execute the Hydra job:
>>> from hydra_zen import instantiate >>> def task_function(cfg): ... obj = instantiate(cfg) ... obj.trainer.fit(obj.module)
Launch the Hydra+Lightning DDP job
>>> from hydra_zen import launch >>> job = launch(Config, task_function)
HydraDDP
also supportsLightningDataModule
configuration.>>> DataModuleConfig = ... # A LightningDataModule config >>> Config = make_config( ... trainer=TrainerConfig, ... module=ModuleConfig ... datamodule=DataModuleconfig ... )
Next define a task function to execute the Hydra job:
>>> from hydra_zen import instantiate >>> def task_function(cfg): ... obj = instantiate(cfg) ... obj.trainer.fit(obj.module, datamodule=obj.datamodule)
Launch the Hydra+Lightning DDP job:
>>> from hydra_zen import launch >>> job = launch(Config, task_function)
- __init__(accelerator=None, parallel_devices=None, cluster_environment=None, checkpoint_io=None, precision_plugin=None, ddp_comm_state=None, ddp_comm_hook=None, ddp_comm_wrapper=None, model_averaging_period=None, process_group_backend=None, timeout=datetime.timedelta(seconds=1800), start_method='popen', **kwargs)#
Methods
__init__
([accelerator, parallel_devices, ...])