# Copyright 2023, MASSACHUSETTS INSTITUTE OF TECHNOLOGY
# Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014).
# SPDX-License-Identifier: MIT
import os
import subprocess
import sys
from pathlib import Path
from time import sleep
from typing import Any, Callable, TypeVar
import numpy as np
from hydra.core.hydra_config import HydraConfig
from hydra_zen import load_from_yaml
from omegaconf.errors import ConfigAttributeError
from pytorch_lightning import Trainer
from pytorch_lightning.trainer.states import TrainerFn
from torch import distributed
from .._compatibility import PL_VERSION, Version
R = TypeVar("R")
def _setup_environment() -> None:
if distributed.is_initialized():
distributed.destroy_process_group()
def _teardown() -> None:
# Remove PL environments so next multirun starts fresh
envs = (
"LOCAL_RANK",
"NODE_RANK",
"WORLD_SIZE",
"MASTER_ADDR",
"MASTER_PORT",
"PL_GLOBAL_SEED",
)
for name in envs:
os.environ.pop(name, None)
def _subprocess_call(local_rank: int, testing: bool, predicting: bool) -> None:
env_copy = os.environ.copy()
env_copy["LOCAL_RANK"] = f"{local_rank}"
# CWD is the Hydra working directory
cwd = os.getcwd()
os_cwd = (
f'"{cwd}"' # this is needed to handle characters like `=` in the directory name
)
command = [
sys.executable,
"-m",
"rai_toolbox.mushin.lightning._pl_main",
]
hydra_cfg = HydraConfig.get()
hydra_output = (
os.path.join(cwd, hydra_cfg.output_subdir)
if hydra_cfg.output_subdir is not None
else cwd
)
# Validate that minimal configuration requirements
config = Path(hydra_output) / "config.yaml"
assert config.exists()
cfg = load_from_yaml(config)
if "trainer" not in cfg or "module" not in cfg:
raise ConfigAttributeError(
"Missing configurations `trainer` and `module` are required for use with HydraDDP. See documentation for further details."
)
# create the command for CLI
command += ["-cp", hydra_output, "-cn", "config.yaml"]
# Set flag to run Trainer.fit or Trainer.test in `_pl_main.py`
command += ["++pl_testing=" + ("false" if not testing else "true")]
# Set flag to run Trainer.fit or Trainer.test in `_pl_main.py`
command += ["++pl_predicting=" + ("false" if not predicting else "true")]
# Set flag for local rank
command += [f"++pl_local_rank={local_rank}"]
command += [
f"hydra.run.dir={os_cwd}",
f"hydra.output_subdir=.pl_hydra_rank_{local_rank}",
f"hydra.job.name={hydra_cfg.job.name}",
]
subprocess.Popen(command, env=env_copy, cwd=cwd)
if PL_VERSION >= Version(1, 6, 0):
from pytorch_lightning.strategies.ddp import DDPStrategy
from pytorch_lightning.strategies.launchers.subprocess_script import (
_SubprocessScriptLauncher,
)
class HydraDDP(DDPStrategy): # type: ignore
"""DDP Strategy that supports Hydra run and multirun jobs.
This strategy assumes a PyTorch Lightning `Trainer.fit` or `Trainer.test` has been configured
to execute via Hydra. It requires that Hydra saves a `config.yaml` in the current working directory with the following keys/properties set::
├── Config
│ ├── trainer: A `pytorch_lightning.Trainer` configuration
│ ├── module: A `pytorch_lightning.LightningModule` configuration
│ ├── datamodule: [OPTIONAL] A `pytorch_lightning.LightningDataModule` configuration
This strategy will launch a child subprocesses for additional GPU beyond the first using the following base command::
python -m rai_toolbox.mushin.lightning._pl_main -cp <path to config.yaml> -cn config.yaml
Examples
--------
First define a Hydra configuration using hydra-zen:
>>> import pytorch_lightning as pl
... from hydra_zen import builds, make_config,
... from rai_toolbox.mushin import HydraDDP
... from rai_toolbox.mushin.testing.lightning import SimpleLightningModule
...
... TrainerConfig = builds(
... pl.Trainer,
... accelerator="auto",
... gpus=2,
... max_epochs=1,
... fast_dev_run=True,
... strategy=builds(HydraDDP),
... populate_full_signature=True
... )
...
... ModuleConfig = builds(SimpleLightningModule)
...
... Config = make_config(
... trainer=TrainerConfig,
... module=ModuleConfig
... )
Next, define a task function to execute the Hydra job:
>>> from hydra_zen import instantiate
>>> def task_function(cfg):
... obj = instantiate(cfg)
... obj.trainer.fit(obj.module)
Launch the Hydra+Lightning DDP job
>>> from hydra_zen import launch
>>> job = launch(Config, task_function)
``HydraDDP`` also supports ``LightningDataModule`` configuration.
>>> DataModuleConfig = ... # A LightningDataModule config
>>> Config = make_config(
... trainer=TrainerConfig,
... module=ModuleConfig
... datamodule=DataModuleconfig
... )
Next define a task function to execute the Hydra job:
>>> from hydra_zen import instantiate
>>> def task_function(cfg):
... obj = instantiate(cfg)
... obj.trainer.fit(obj.module, datamodule=obj.datamodule)
Launch the Hydra+Lightning DDP job:
>>> from hydra_zen import launch
>>> job = launch(Config, task_function)
"""
def setup_environment(self) -> None:
_setup_environment()
super().setup_environment()
def _configure_launcher(self) -> None:
if self.cluster_environment is None: # pragma: no cover
raise TypeError("HydraDDP.cluster_environment is None")
if not self.cluster_environment.creates_processes_externally:
self._launcher = _HydraDDPLauncher(
self.cluster_environment, self.num_processes, self.num_nodes
)
self._rank_0_will_call_children_scripts = True
def teardown(self) -> None:
"""Performs additional teardown steps for PL to allow for Hydra multirun jobs."""
super().teardown()
_teardown()
class _HydraDDPLauncher(_SubprocessScriptLauncher):
@property
def is_interactive_compatible(self) -> bool: # pragma: no cover
return True
def launch(
self,
function: Callable[..., R],
*args: Any,
trainer: Trainer,
**kwargs: Any,
) -> R:
"""Creates new processes, then calls the given function.
Parameters
----------
function : Callable[[...], ReturnType]
A callback function to execute after all processes have been created.
It is up to the implementation of this function to synchronize the processes, e.g., with barriers.
*args : Any
Optional positional arguments to be passed to the given function.
trainer : pytorch_lightning.Trainer
Optional reference to the pytorch_lightning.Trainer`.
**kwargs : Any
Optional keyword arguments to be passed to the given function.
Returns
-------
ReturnType
"""
del trainer # unused
if (
not self.cluster_environment.creates_processes_externally
): # pragma: no cover
testing = function.__name__ == "_test_impl"
predicting = function.__name__ == "_predict_impl"
self._call_children_scripts(testing=testing, predicting=predicting)
return function(*args, **kwargs)
def _call_children_scripts(self, testing: bool, predicting: bool):
# bookkeeping of spawned processes
self._check_can_spawn_children()
# DDP Environment variables
os.environ["MASTER_ADDR"] = self.cluster_environment.main_address
os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port)
# allow the user to pass the node rank
os.environ["NODE_RANK"] = str(self.cluster_environment.node_rank())
os.environ["LOCAL_RANK"] = str(self.cluster_environment.local_rank())
os.environ["WORLD_SIZE"] = f"{self.num_processes * self.num_nodes}"
for local_rank in range(1, self.num_processes):
_subprocess_call(local_rank, testing, predicting)
# starting all processes at once can cause issues
# with dataloaders delay between 1-10 seconds
delay = np.random.uniform(1, 5, 1)[0]
sleep(delay)
else: # pragma: no cover
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin # type: ignore
[docs]
class HydraDDP(DDPPlugin):
"""DDP Strategy that supports Hydra run and multirun jobs.
This strategy assumes a PyTorch Lightning `Trainer.fit` or `Trainer.test` has been configured
to execute via Hydra. It requires that Hydra saves a `config.yaml` in the current working directory with the following keys/properties set::
├── Config
│ ├── trainer: A `pytorch_lightning.Trainer` configuration
│ ├── module: A `pytorch_lightning.LightningModule` configuration
│ ├── datamodule: [OPTIONAL] A `pytorch_lightning.LightningDataModule` configuration
This strategy will launch a child subprocesses for additional GPU beyond the first using the following base command::
python -m rai_toolbox.mushin.lightning._pl_main -cp <path to config.yaml> -cn config.yaml
Examples
--------
First define a Hydra configuration using hydra-zen:
>>> import pytorch_lightning as pl
... from hydra_zen import builds, make_config,
... from rai_toolbox.mushin import HydraDDP
... from rai_toolbox.mushin.testing.lightning import SimpleLightningModule
...
... TrainerConfig = builds(
... pl.Trainer,
... accelerator="auto",
... gpus=2,
... max_epochs=1,
... fast_dev_run=True,
... strategy=builds(HydraDDP),
... populate_full_signature=True
... )
...
... ModuleConfig = builds(SimpleLightningModule)
...
... Config = make_config(
... trainer=TrainerConfig,
... module=ModuleConfig
... )
Next define a task function to execute the Hydra job:
>>> from hydra_zen import instantiate
>>> def task_function(cfg):
... obj = instantiate(cfg)
... obj.trainer.fit(obj.module)
Launch the Hydra+Lightning DDP job:
>>> from hydra_zen import launch
>>> job = launch(Config, task_function)
``HydraDDP`` also supports ``LightningDataModule`` configuration.
>>> DataModuleConfig = ... # A LightningDataModule config
>>> Config = make_config(
... trainer=TrainerConfig,
... module=ModuleConfig
... datamodule=DataModuleconfig
... )
Next, define a task function to execute the Hydra job:
>>> from hydra_zen import instantiate
>>> def task_function(cfg):
... obj = instantiate(cfg)
... obj.trainer.fit(obj.module, datamodule=obj.datamodule)
Launch the Hydra+Lightning DDP job:
>>> from hydra_zen import launch
>>> job = launch(Config, task_function)
"""
def setup_environment(self) -> None:
_setup_environment()
super().setup_environment()
def _call_children_scripts(self):
if self.lightning_module is None: # pragma: no cover
raise TypeError("HydraDDP.lightning_module is None")
if self.lightning_module.trainer is None: # pragma: no cover
raise TypeError("HydraDDP.lightning_module.trainer is None")
if self.cluster_environment is None: # pragma: no cover
raise TypeError("HydraDDP.cluster_environment is None")
# bookkeeping of spawned processes
self._check_can_spawn_children()
# DDP Environment variables
os.environ["MASTER_ADDR"] = self.cluster_environment.master_address()
os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())
# allow the user to pass the node rank
os.environ["NODE_RANK"] = str(self.cluster_environment.node_rank())
os.environ["LOCAL_RANK"] = str(self.cluster_environment.local_rank())
os.environ["WORLD_SIZE"] = f"{self.num_processes * self.num_nodes}"
self.interactive_ddp_procs = []
for local_rank in range(1, self.num_processes):
testing = self.lightning_module.trainer.state.fn == TrainerFn.TESTING
predicting = (
self.lightning_module.trainer.state.fn == TrainerFn.PREDICTING
)
_subprocess_call(local_rank, testing=testing, predicting=predicting)
# starting all processes at once can cause issues
# with dataloaders delay between 1-10 seconds
delay = np.random.uniform(1, 5, 1)[0]
sleep(delay)
self._rank_0_has_called_call_children_scripts = True
def teardown(self) -> None:
"""Performs additional teardown steps for PL to allow for Hydra multirun jobs."""
super().teardown()
_teardown()