# Copyright 2023, MASSACHUSETTS INSTITUTE OF TECHNOLOGY
# Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014).
# SPDX-License-Identifier: MIT
import os
import subprocess
import sys
from pathlib import Path
from time import sleep
from typing import Any, Callable, TypeVar
import numpy as np
from hydra.core.hydra_config import HydraConfig
from hydra_zen import load_from_yaml
from omegaconf.errors import ConfigAttributeError
from pytorch_lightning import Trainer
from pytorch_lightning.trainer.states import TrainerFn
from torch import distributed
from .._compatibility import PL_VERSION, Version
R = TypeVar("R")
def _setup_environment() -> None:
    if distributed.is_initialized():
        distributed.destroy_process_group()
def _teardown() -> None:
    # Remove PL environments so next multirun starts fresh
    envs = (
        "LOCAL_RANK",
        "NODE_RANK",
        "WORLD_SIZE",
        "MASTER_ADDR",
        "MASTER_PORT",
        "PL_GLOBAL_SEED",
    )
    for name in envs:
        os.environ.pop(name, None)
def _subprocess_call(local_rank: int, testing: bool, predicting: bool) -> None:
    env_copy = os.environ.copy()
    env_copy["LOCAL_RANK"] = f"{local_rank}"
    # CWD is the Hydra working directory
    cwd = os.getcwd()
    os_cwd = (
        f'"{cwd}"'  # this is needed to handle characters like `=` in the directory name
    )
    command = [
        sys.executable,
        "-m",
        "rai_toolbox.mushin.lightning._pl_main",
    ]
    hydra_cfg = HydraConfig.get()
    hydra_output = (
        os.path.join(cwd, hydra_cfg.output_subdir)
        if hydra_cfg.output_subdir is not None
        else cwd
    )
    # Validate that minimal configuration requirements
    config = Path(hydra_output) / "config.yaml"
    assert config.exists()
    cfg = load_from_yaml(config)
    if "trainer" not in cfg or "module" not in cfg:
        raise ConfigAttributeError(
            "Missing configurations `trainer` and `module` are required for use with HydraDDP.  See documentation for further details."
        )
    # create the command for CLI
    command += ["-cp", hydra_output, "-cn", "config.yaml"]
    # Set flag to run Trainer.fit or Trainer.test in `_pl_main.py`
    command += ["++pl_testing=" + ("false" if not testing else "true")]
    # Set flag to run Trainer.fit or Trainer.test in `_pl_main.py`
    command += ["++pl_predicting=" + ("false" if not predicting else "true")]
    # Set flag for local rank
    command += [f"++pl_local_rank={local_rank}"]
    command += [
        f"hydra.run.dir={os_cwd}",
        f"hydra.output_subdir=.pl_hydra_rank_{local_rank}",
        f"hydra.job.name={hydra_cfg.job.name}",
    ]
    subprocess.Popen(command, env=env_copy, cwd=cwd)
if PL_VERSION >= Version(1, 6, 0):
    from pytorch_lightning.strategies.ddp import DDPStrategy
    from pytorch_lightning.strategies.launchers.subprocess_script import (
        _SubprocessScriptLauncher,
    )
    class HydraDDP(DDPStrategy):  # type: ignore
        """DDP Strategy that supports Hydra run and multirun jobs.
        This strategy assumes a PyTorch Lightning `Trainer.fit` or `Trainer.test` has been configured
        to execute via Hydra.  It requires that Hydra saves a `config.yaml` in the current working directory with the following keys/properties set::
           ├── Config
           │    ├── trainer: A `pytorch_lightning.Trainer` configuration
           │    ├── module: A `pytorch_lightning.LightningModule` configuration
           │    ├── datamodule: [OPTIONAL] A `pytorch_lightning.LightningDataModule` configuration
        This strategy will launch a child subprocesses for additional GPU beyond the first using the following base command::
           python -m rai_toolbox.mushin.lightning._pl_main -cp <path to config.yaml> -cn config.yaml
        Examples
        --------
        First define a Hydra configuration using hydra-zen:
        >>> import pytorch_lightning as pl
        ... from hydra_zen import builds, make_config,
        ... from rai_toolbox.mushin import HydraDDP
        ... from rai_toolbox.mushin.testing.lightning import SimpleLightningModule
        ...
        ... TrainerConfig = builds(
        ...     pl.Trainer,
        ...     accelerator="auto",
        ...     gpus=2,
        ...     max_epochs=1,
        ...     fast_dev_run=True,
        ...     strategy=builds(HydraDDP),
        ...     populate_full_signature=True
        ... )
        ...
        ... ModuleConfig = builds(SimpleLightningModule)
        ...
        ... Config = make_config(
        ...     trainer=TrainerConfig,
        ...     module=ModuleConfig
        ... )
        Next, define a task function to execute the Hydra job:
        >>> from hydra_zen import instantiate
        >>> def task_function(cfg):
        ...     obj = instantiate(cfg)
        ...     obj.trainer.fit(obj.module)
        Launch the Hydra+Lightning DDP job
        >>> from hydra_zen import launch
        >>> job = launch(Config, task_function)
        ``HydraDDP`` also supports ``LightningDataModule`` configuration.
        >>> DataModuleConfig = ... # A LightningDataModule config
        >>> Config = make_config(
        ...     trainer=TrainerConfig,
        ...     module=ModuleConfig
        ...     datamodule=DataModuleconfig
        ... )
        Next define a task function to execute the Hydra job:
        >>> from hydra_zen import instantiate
        >>> def task_function(cfg):
        ...     obj = instantiate(cfg)
        ...     obj.trainer.fit(obj.module, datamodule=obj.datamodule)
        Launch the Hydra+Lightning DDP job:
        >>> from hydra_zen import launch
        >>> job = launch(Config, task_function)
        """
        def setup_environment(self) -> None:
            _setup_environment()
            super().setup_environment()
        def _configure_launcher(self) -> None:
            if self.cluster_environment is None:  # pragma: no cover
                raise TypeError("HydraDDP.cluster_environment is None")
            if not self.cluster_environment.creates_processes_externally:
                self._launcher = _HydraDDPLauncher(
                    self.cluster_environment, self.num_processes, self.num_nodes
                )
                self._rank_0_will_call_children_scripts = True
        def teardown(self) -> None:
            """Performs additional teardown steps for PL to allow for Hydra multirun jobs."""
            super().teardown()
            _teardown()
    class _HydraDDPLauncher(_SubprocessScriptLauncher):
        @property
        def is_interactive_compatible(self) -> bool:  # pragma: no cover
            return True
        def launch(
            self,
            function: Callable[..., R],
            *args: Any,
            trainer: Trainer,
            **kwargs: Any,
        ) -> R:
            """Creates new processes, then calls the given function.
            Parameters
            ----------
            function : Callable[[...], ReturnType]
                A callback function to execute after all processes have been created.
                It is up to the implementation of this function to synchronize the processes, e.g., with barriers.
            *args : Any
                Optional positional arguments to be passed to the given function.
            trainer : pytorch_lightning.Trainer
                Optional reference to the pytorch_lightning.Trainer`.
            **kwargs : Any
                Optional keyword arguments to be passed to the given function.
            Returns
            -------
            ReturnType
            """
            del trainer  # unused
            if (
                not self.cluster_environment.creates_processes_externally
            ):  # pragma: no cover
                testing = function.__name__ == "_test_impl"
                predicting = function.__name__ == "_predict_impl"
                self._call_children_scripts(testing=testing, predicting=predicting)
            return function(*args, **kwargs)
        def _call_children_scripts(self, testing: bool, predicting: bool):
            # bookkeeping of spawned processes
            self._check_can_spawn_children()
            # DDP Environment variables
            os.environ["MASTER_ADDR"] = self.cluster_environment.main_address
            os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port)
            # allow the user to pass the node rank
            os.environ["NODE_RANK"] = str(self.cluster_environment.node_rank())
            os.environ["LOCAL_RANK"] = str(self.cluster_environment.local_rank())
            os.environ["WORLD_SIZE"] = f"{self.num_processes * self.num_nodes}"
            for local_rank in range(1, self.num_processes):
                _subprocess_call(local_rank, testing, predicting)
                # starting all processes at once can cause issues
                # with dataloaders delay between 1-10 seconds
                delay = np.random.uniform(1, 5, 1)[0]
                sleep(delay)
else:  # pragma: no cover
    from pytorch_lightning.plugins.training_type.ddp import DDPPlugin  # type: ignore
[docs]
    class HydraDDP(DDPPlugin):
        """DDP Strategy that supports Hydra run and multirun jobs.
        This strategy assumes a PyTorch Lightning `Trainer.fit` or `Trainer.test` has been configured
        to execute via Hydra.  It requires that Hydra saves a `config.yaml` in the current working directory with the following keys/properties set::
           ├── Config
           │    ├── trainer: A `pytorch_lightning.Trainer` configuration
           │    ├── module: A `pytorch_lightning.LightningModule` configuration
           │    ├── datamodule: [OPTIONAL] A `pytorch_lightning.LightningDataModule` configuration
        This strategy will launch a child subprocesses for additional GPU beyond the first using the following base command::
           python -m rai_toolbox.mushin.lightning._pl_main -cp <path to config.yaml> -cn config.yaml
        Examples
        --------
        First define a Hydra configuration using hydra-zen:
        >>> import pytorch_lightning as pl
        ... from hydra_zen import builds, make_config,
        ... from rai_toolbox.mushin import HydraDDP
        ... from rai_toolbox.mushin.testing.lightning import SimpleLightningModule
        ...
        ... TrainerConfig = builds(
        ...     pl.Trainer,
        ...     accelerator="auto",
        ...     gpus=2,
        ...     max_epochs=1,
        ...     fast_dev_run=True,
        ...     strategy=builds(HydraDDP),
        ...     populate_full_signature=True
        ... )
        ...
        ... ModuleConfig = builds(SimpleLightningModule)
        ...
        ... Config = make_config(
        ...     trainer=TrainerConfig,
        ...     module=ModuleConfig
        ... )
        Next define a task function to execute the Hydra job:
        >>> from hydra_zen import instantiate
        >>> def task_function(cfg):
        ...     obj = instantiate(cfg)
        ...     obj.trainer.fit(obj.module)
        Launch the Hydra+Lightning DDP job:
        >>> from hydra_zen import launch
        >>> job = launch(Config, task_function)
        ``HydraDDP`` also supports ``LightningDataModule`` configuration.
        >>> DataModuleConfig = ... # A LightningDataModule config
        >>> Config = make_config(
        ...     trainer=TrainerConfig,
        ...     module=ModuleConfig
        ...     datamodule=DataModuleconfig
        ... )
        Next, define a task function to execute the Hydra job:
        >>> from hydra_zen import instantiate
        >>> def task_function(cfg):
        ...     obj = instantiate(cfg)
        ...     obj.trainer.fit(obj.module, datamodule=obj.datamodule)
        Launch the Hydra+Lightning DDP job:
        >>> from hydra_zen import launch
        >>> job = launch(Config, task_function)
        """
        def setup_environment(self) -> None:
            _setup_environment()
            super().setup_environment()
        def _call_children_scripts(self):
            if self.lightning_module is None:  # pragma: no cover
                raise TypeError("HydraDDP.lightning_module is None")
            if self.lightning_module.trainer is None:  # pragma: no cover
                raise TypeError("HydraDDP.lightning_module.trainer is None")
            if self.cluster_environment is None:  # pragma: no cover
                raise TypeError("HydraDDP.cluster_environment is None")
            # bookkeeping of spawned processes
            self._check_can_spawn_children()
            # DDP Environment variables
            os.environ["MASTER_ADDR"] = self.cluster_environment.master_address()
            os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())
            # allow the user to pass the node rank
            os.environ["NODE_RANK"] = str(self.cluster_environment.node_rank())
            os.environ["LOCAL_RANK"] = str(self.cluster_environment.local_rank())
            os.environ["WORLD_SIZE"] = f"{self.num_processes * self.num_nodes}"
            self.interactive_ddp_procs = []
            for local_rank in range(1, self.num_processes):
                testing = self.lightning_module.trainer.state.fn == TrainerFn.TESTING
                predicting = (
                    self.lightning_module.trainer.state.fn == TrainerFn.PREDICTING
                )
                _subprocess_call(local_rank, testing=testing, predicting=predicting)
                # starting all processes at once can cause issues
                # with dataloaders delay between 1-10 seconds
                delay = np.random.uniform(1, 5, 1)[0]
                sleep(delay)
            self._rank_0_has_called_call_children_scripts = True
        def teardown(self) -> None:
            """Performs additional teardown steps for PL to allow for Hydra multirun jobs."""
            super().teardown()
            _teardown()