# Copyright (c) 2024 Massachusetts Institute of Technology
# SPDX-License-Identifier: MIT
import warnings
from collections import UserList
from dataclasses import fields, is_dataclass
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Generic,
List,
Mapping,
Optional,
Tuple,
Type,
TypeVar,
Union,
cast,
overload,
)
from hydra import initialize
from hydra._internal.callbacks import Callbacks
from hydra.core.config_store import ConfigStore
from hydra.core.global_hydra import GlobalHydra
from hydra.core.utils import JobReturn, run_job
from hydra.plugins.sweeper import Sweeper
from hydra.types import HydraContext, RunMode
from omegaconf import DictConfig, ListConfig, OmegaConf
from typing_extensions import Literal, TypeAlias
from hydra_zen._hydra_overloads import instantiate
from hydra_zen.typing._implementations import DataClass_, InstOrType
T = TypeVar("T", bound=Any)
HydraPrimitives: TypeAlias = Union[None, int, float, bool, str, Dict[str, str]]
if TYPE_CHECKING: # pragma: no cover
# branching needed to deal with pyright type-completeness complaints
TUserList: TypeAlias = UserList[Any]
else:
TUserList = UserList
class _NotSet: # pragma: no cover
pass
T1 = TypeVar("T1", bound=HydraPrimitives)
class hydra_list(TUserList, Generic[T1]):
"""Signals that a sequence is provided as a single configured value (i.e. it is not
to be iterated over during a multirun)"""
pass
T2 = TypeVar("T2", bound=Union[HydraPrimitives, hydra_list[HydraPrimitives]])
class multirun(TUserList, Generic[T2]):
"""Signals that a sequence is to be iterated over in a multirun"""
pass
def _safe_name(x: Any) -> str:
return getattr(x, "__name__", str(x))
def value_check(
name: str,
value: T,
type_: Union[type, Tuple[type, ...]],
) -> T:
"""
For internal use only.
Used to check the type of `value`. Numerical types can also be bound-checked.
Examples
--------
>>> value_check("x", 1, type_=str)
TypeError: `x` must be of type(s) `str`, got 1 (type: int)
Raises
------
TypeError"""
# check internal params
assert isinstance(name, str), name
if not isinstance(value, type_):
raise TypeError(
f"`{name}` must be of type(s) "
f"`{_safe_name(type_)}`, got {value} (type: {_safe_name(type(value))})"
)
return cast(T, value)
OverrideValues: TypeAlias = Union[
HydraPrimitives,
multirun[Union[HydraPrimitives, hydra_list[HydraPrimitives]]],
hydra_list[HydraPrimitives],
]
OverrideDict: TypeAlias = Mapping[str, OverrideValues]
def _process_dict_overrides(overrides: OverrideDict) -> List[str]:
"""Convert dict overrides to a list of Hydra CLI compatible args"""
launch_overrides = []
for k, v in overrides.items():
if v is None:
v = "null"
value_check(
k,
v,
type_=(int, float, bool, str, dict, multirun, hydra_list),
)
if isinstance(v, multirun):
v = ",".join(str(item) for item in v)
launch_overrides.append(f"{k}={v}")
return launch_overrides
def _store_config(
cfg: Union[DataClass_, Type[DataClass_], DictConfig, ListConfig, Mapping[Any, Any]],
config_name: str = "hydra_launch",
) -> str:
"""Stores configuration object in Hydra's ConfigStore.
Parameters
----------
cfg : Union[DataClass_, DictConfig, Mapping]
A configuration as a dataclass, configuration object, or a dictionary.
config_name : str (default: hydra_launch)
The configuration name used to store the configuration.
Returns
-------
config_name : str
The configuration name used to store the default configuration.
Notes
-----
The input configuration is registered in the Hydra ConfigStore [1]_ using a
user-provided config name.
References
----------
.. [1] https://hydra.cc/docs/tutorials/structured_config/config_store
"""
cs = ConfigStore().instance()
cs.store(name=config_name, node=cfg)
return config_name
@overload
def launch(
config: Union[InstOrType[DataClass_], Mapping[str, Any]],
task_function: Callable[[Any], Any],
overrides: Optional[Union[OverrideDict, List[str]]] = ...,
multirun: Literal[False] = ...,
version_base: Optional[Union[str, Type[_NotSet]]] = ...,
to_dictconfig: bool = ...,
config_name: str = ...,
job_name: str = ...,
with_log_configuration: bool = ...,
**override_kwargs: OverrideValues,
) -> JobReturn: ...
@overload
def launch(
config: Union[InstOrType[DataClass_], Mapping[str, Any]],
task_function: Callable[[Any], Any],
overrides: Optional[Union[OverrideDict, List[str]]] = ...,
multirun: Literal[True] = ...,
version_base: Optional[Union[str, Type[_NotSet]]] = ...,
to_dictconfig: bool = ...,
config_name: str = ...,
job_name: str = ...,
with_log_configuration: bool = ...,
**override_kwargs: OverrideValues,
) -> Any: ...
[docs]
def launch(
config: Union[InstOrType[DataClass_], Mapping[str, Any]],
task_function: Callable[[Any], Any],
overrides: Optional[Union[OverrideDict, List[str]]] = None,
multirun: bool = False,
version_base: Optional[Union[str, Type[_NotSet]]] = _NotSet,
to_dictconfig: bool = False,
config_name: str = "zen_launch",
job_name: str = "zen_launch",
with_log_configuration: bool = True,
**override_kwargs: OverrideValues,
) -> Union[JobReturn, Any]:
r"""
Launches a Hydra job from a Python function rather than a CLI.
`launch` is designed to closely match the interface of the standard Hydra CLI.
For example, launching a Hydra job from the CLI via::
$ python my_task.py job/group=group_name job.group.param=1
corresponds to the following usage of `launch`:
>>> job = launch(config, task_function, overrides=["job/group=group_name", "job.group.param=1"])
Parameters
----------
config : DataClass_ | Type[DataClass_] | Mapping[str, Any]
A config that will be passed to ``task_function``.
task_function : Callable[[DictConfig], Any]
The function that Hydra will execute. Its input will be ``config``, which
has been modified via the specified ``overrides``
overrides : Optional[Union[OverrideMapping, List[str]]] (default: None)
If provided, sets/overrides values in ``config``. See [1]_ and [2]_
for a detailed discussion of the "grammar" supported by ``overrides``.
multirun : bool (default: False)
Launch a Hydra multi-run ([3]_).
version_base : Optional[str], optional (default=not-specified)
Available starting with Hydra 1.2.0.
- If the `version_base parameter` is not specified, Hydra 1.x will use defaults compatible with version 1.1. Also in this case, a warning is issued to indicate an explicit version_base is preferred.
- If the `version_base parameter` is `None`, then the defaults are chosen for the current minor Hydra version. For example for Hydra 1.2, then would imply `config_path=None` and `hydra.job.chdir=False`.
- If the `version_base` parameter is an explicit version string like "1.1", then the defaults appropriate to that version are used.
to_dictconfig : bool (default: False)
If ``True``, convert a ``dataclasses.dataclass`` to a ``omegaconf.DictConfig``.
Note, this will remove Hydra's cabability for validation with structured
configurations.
config_name : str (default: "zen_launch")
Name of the stored configuration in Hydra's ConfigStore API.
job_name : str (default: "zen_launch")
with_log_configuration : bool (default: True)
If ``True``, enables the configuration of the logging subsystem from the loaded
config.
**override_kwargs : OverrideValues
Keyword arguments to override existing configuration values. Note, this only
works when the configuration value name is a valid Python identifier; e.g.,
this does not support adding (`+param`) values.
Returns
-------
result : hydra.core.utils.JobReturn | Any
If ``multirun is False``:
A ``JobReturn`` object storing the results of the Hydra experiment via the
following attributes
- ``cfg``: Reflects ``config``
- ``overrides``: Reflects ``overrides``
- ``return_value``: The return value of the task function
- ``hydra_cfg``: The Hydra configuration object
- ``working_dir``: The experiment working directory
- ``task_name``: The task name of the Hydra job
- ``status``: A ``JobStatus`` enum reporting whether or not the job completed successfully
Else:
Return values of all launched jobs (depends on the Sweeper implementation).
References
----------
.. [1] https://hydra.cc/docs/advanced/override_grammar/basic
.. [2] https://hydra.cc/docs/configure_hydra/intro
.. [3] https://hydra.cc/docs/tutorials/basic/running_your_app/multi-run
Examples
--------
**Basic usage**
Let's define and launch a trivial Hydra app.
>>> from hydra_zen import make_config, launch, to_yaml
First, we will define a config, which determines the configurable interface to our
"app". For the purpose of example, we'll design the "interface" of this config to accept
two configurable parameters: ``a`` and ``b``.
>>> Conf = make_config("a", "b")
Our task function accepts the config as an input and uses it to run some generic functionality.
For simplicity's sake, let's design this task function to: convert the job's config to a
yaml-formatted string, print it, and then return the string.
>>> def task_fn(cfg):
... out = to_yaml(cfg) # task's input config, converted to yaml-string
... print(out)
... return out
Now, let's use `launch` to run this task function via Hydra, using particular configured
values (or, "overrides") for ``a`` and ``b``.
>>> job_out = launch(Conf, task_fn, a=1, b='foo')
a: 1
b: foo
Let's inspect ``job_out`` to see the ways that it summarizes the results of this job.
>>> job_out.return_value # the value returned by `task_fn`
'a: 1\nb: foo\n'
>>> job_out.working_dir # where the job's outputs, logs, and configs are saved
'outputs/2021-10-19/15-27-11'
>>> job_out.cfg # the particular config used to run our task-function
{'a': 1, 'b': 'foo'}
>>> job_out.overrides # the overrides that we provides
['a=1', "b='foo'"]
>>> job_out.status # the job's completion status
<JobStatus.COMPLETED: 1>
**Launching a multirun job**
We can launch multiple runs of our task-function, using various configured values.
Let's launch a multirun that sweeps over three configurations
>>> (outputs,) = launch(
... Conf,
... task_fn,
... a="1,2,3",
... b="bar",
... multirun=True,
... )
[2021-10-19 17:50:07,334][HYDRA] Launching 3 jobs locally
[2021-10-19 17:50:07,334][HYDRA] #0 : a=1 b='bar'
a: 1
b: bar
[2021-10-19 17:50:07,434][HYDRA] #1 : a=2 b='bar'
a: 2
b: bar
[2021-10-19 17:50:07,535][HYDRA] #2 : a=3 b='bar'
a: 3
b: bar
``outputs`` contains three corresponding ``JobReturns`` instances.
>>> len(outputs)
3
>>> [j.cfg for j in outputs]
[{'a': 1, 'b': 'bar'}, {'a': 2, 'b': 'bar'}, {'a': 3, 'b': 'bar'}]
Each run's outputs, logs, and configs are saved to separate working directories
>>> [j.working_dir for j in outputs]
['multirun/2021-10-19/17-50-07\\0',
'multirun/2021-10-19/17-50-07\\1',
'multirun/2021-10-19/17-50-07\\2']
**Launching with quoted overrides**
Some of the Hydra CLI override syntax cannot be specified as keyword arguments. In such cases we can instead provide a list or a dict with quoted overrides.
>>> job_out = launch(Conf, task_fn, a=1, b="foo", overrides={"+c": 22})
a: 1
b: foo
c: 22
>>> job_out.overrides # the overrides that we provides
['a=1', 'b=foo', '+c=22']
>>> launch(Conf, task_fn, overrides=["a=1", "b='foo'", "+c=22"])
a: 1
b: foo
c: 22
**Using hydra_zen.multirun**
Multi-run values can be specified directly, without having to form a quoted multi-run string, by using the `hydra_zen.multi_run` list to store the values.
>>> import random
>>> from hydra_zen import launch, instantiate, make_config, multirun
>>> values_for_experiment = [random.uniform(0, 1) for i in range(10)]
>>> jobs = launch(
... make_config(),
... instantiate,
... overrides={
... "+param": multirun(values_for_experiment)
... },
... multirun=True
... )
If, instead, you want to configure a list as a single value - not to be iterated
over in a multirun - you can instead use `hydra_zen.hydra_list`.
"""
# used for check below
_num_dataclass_fields = 0
if is_dataclass(config):
_num_dataclass_fields = len(fields(config))
# store config in ConfigStore
if to_dictconfig and is_dataclass(config):
# convert Dataclass to a DictConfig
dictconfig = OmegaConf.create(
OmegaConf.to_container(OmegaConf.structured(config))
)
config_name = _store_config(dictconfig, config_name)
else:
config_name = _store_config(config, config_name)
# allow user to provide a dictionary of override values
# instead of just a list of strings
overrides = overrides if overrides is not None else []
if isinstance(overrides, Mapping):
overrides = _process_dict_overrides(overrides)
override_kwargs_list = _process_dict_overrides(override_kwargs)
overrides += override_kwargs_list
# Initializes Hydra and add the config_path to the config search path
with initialize(
config_path=None,
job_name=job_name,
**({} if version_base is _NotSet else {"version_base": version_base}), # type: ignore
):
# taken from hydra.compose with support for MULTIRUN
gh = GlobalHydra.instance()
assert gh.hydra is not None
# Load configuration
cfg = gh.hydra.compose_config(
config_name=config_name,
overrides=overrides,
run_mode=RunMode.RUN if not multirun else RunMode.MULTIRUN,
from_shell=False,
with_log_configuration=with_log_configuration,
)
callbacks = Callbacks(cfg)
run_start = (
callbacks.on_run_start if not multirun else callbacks.on_multirun_start
)
run_start(config=cfg, config_name=config_name)
hydra_context = HydraContext(
config_loader=gh.config_loader(), callbacks=callbacks
)
if not multirun:
job = run_job(
hydra_context=hydra_context,
task_function=task_function,
config=cfg,
job_dir_key="hydra.run.dir",
job_subdir_key=None,
configure_logging=with_log_configuration,
)
callbacks.on_run_end(config=cfg, config_name=config_name, job_return=job)
# access the result to trigger an exception in case the job failed.
_ = job.return_value
else:
# Instantiate sweeper without using Hydra's Plugin discovery (Zen!)
sweeper = instantiate(cfg.hydra.sweeper)
assert isinstance(sweeper, Sweeper)
sweeper.setup(
config=cfg,
hydra_context=hydra_context,
task_function=task_function,
)
task_overrides = OmegaConf.to_container(
cfg.hydra.overrides.task, resolve=False
)
assert isinstance(task_overrides, list)
job = sweeper.sweep(arguments=task_overrides)
callbacks.on_multirun_end(config=cfg, config_name=config_name)
if is_dataclass(config): # pragma: no cover
_num_dataclass_fields_after = len(fields(config))
if (
_num_dataclass_fields_after == 0
and _num_dataclass_fields_after < _num_dataclass_fields
):
warnings.warn(
"Your dataclass-based config was mutated by this run. If you just "
"executed with a `hydra/launcher` that utilizes cloudpickle (e.g., "
"hydra-submitit-launcher), there is a known issue with dataclasses "
"(see: https://github.com/cloudpipe/cloudpickle/issues/386). You will "
"have to restart your interactive environment to run `launch` again. "
"To avoid this issue you can use the `launch` option: "
"`to_dictconfig=True`."
)
return job