# Copyright 2023, MASSACHUSETTS INSTITUTE OF TECHNOLOGY
# Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014).
# SPDX-License-Identifier: MIT
from collections import defaultdict
from inspect import getattr_static
from pathlib import Path
from typing import (
Any,
Callable,
DefaultDict,
Dict,
Iterable,
List,
Mapping,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
)
import numpy as np
import torch as tr
from hydra.core.override_parser.overrides_parser import OverridesParser
from hydra.core.utils import JobReturn
from hydra_zen import hydra_list, launch, load_from_yaml, make_config, multirun, zen
from hydra_zen._compatibility import HYDRA_VERSION
from hydra_zen._launch import _NotSet
from typing_extensions import Self, TypeAlias, TypeGuard
from rai_toolbox._utils import value_check
LoadedValue: TypeAlias = Union[str, int, float, bool, List[Any], Dict[str, Any]]
__all__ = [
"BaseWorkflow",
"RobustnessCurve",
"MultiRunMetricsWorkflow",
]
T = TypeVar("T", List[Any], Tuple[Any])
T1 = TypeVar("T1")
_VERSION_BASE_DEFAULT = _NotSet if HYDRA_VERSION < (1, 2, 0) else "1.1"
def _sort_x_by_k(x: T, k: Iterable[Any]) -> T:
k = tuple(k)
assert len(x) == len(k)
sorted_, _ = zip(*sorted(zip(x, k), key=lambda x: x[1]))
return type(x)(sorted_)
def _identity(x: T1) -> T1:
return x
def _task_calls(
pre_task: Callable[[Any], None], task: Callable[[Any], T1]
) -> Callable[[Any], T1]:
def wrapped(cfg: Any) -> T1:
pre_task(cfg)
return task(cfg)
return wrapped
[docs]
class BaseWorkflow:
"""Provides an interface for creating a reusable workflow: encapsulated
"boilerplate" for running, aggregating, and analyzing one or more Hydra jobs.
Attributes
----------
cfgs : List[Any]
List of configurations for each Hydra job.
metrics : Dict[str, List[Any]]
Dictionary of metrics for across all jobs.
workflow_overrides : Dict[str, Any]
Workflow parameters defined as additional arguments to `run`.
jobs : List[Any]
List of jobs returned for each experiment within the workflow.
working_dir: Optional[pathlib.Path]
The working directory of the experiment defined by Hydra's sweep directory
(`hydra.sweep.dir`).
"""
_REQUIRED_STATIC_METHODS = ("task", "pre_task")
cfgs: List[Any]
metrics: Dict[str, List[Any]]
workflow_overrides: Dict[str, Any]
jobs: Union[List[JobReturn], List[Any], JobReturn]
[docs]
def __init__(self, eval_task_cfg=None) -> None:
"""Workflows and experiments using Hydra.
Parameters
----------
eval_task_cfg: Mapping | None (default: None)
The workflow configuration object.
"""
# we can do validation checks here
self.eval_task_cfg = (
eval_task_cfg if eval_task_cfg is not None else make_config()
)
# initialize attributes
self.cfgs = []
self.metrics = {}
self.workflow_overrides = {}
self._multirun_task_overrides = {}
self.jobs = []
self._working_dir = None
@property
def working_dir(self) -> Path:
if self._working_dir is None:
raise ValueError("`self.working_dir` must be set.")
return self._working_dir
@working_dir.setter
def working_dir(self, path: Union[str, Path]):
if isinstance(path, str):
path = Path(path)
value_check("path", path, type_=Path)
path = path.resolve()
if not path.is_dir():
raise FileNotFoundError(
f"`path` point to an existing directory, got {path}"
)
self._working_dir = path
@staticmethod
def _parse_overrides(
overrides,
) -> Dict[str, Union[LoadedValue, Sequence[LoadedValue]]]:
parser = OverridesParser.create()
parsed_overrides = parser.parse_overrides(overrides=overrides)
output = {}
for override in parsed_overrides:
param_name = override.get_key_element()
val = override.value()
if override.is_sweep_override():
val = multirun(val.list) # type: ignore
param_name = param_name.split("+")[-1]
output[param_name] = val
return output
@property
def multirun_task_overrides(
self,
) -> Dict[str, Union[LoadedValue, Sequence[LoadedValue]]]:
"""Returns override param-name -> value.
A sequence of overrides associated with a multirun will
be stored in a `rai_toolbox.mushin.multirun` list. This
enables one to distinguish this from an override whose sole
value was a list of values.
Returns
-------
multirun_task_overrides: Dict[str, LoadedValue | Sequence[LoadedValue]]
Examples
--------
>>> from rai_toolbox.mushin import multirun, hydra_list
>>>
>>> class WorkFlow(MultiRunMetricsWorkflow):
... @staticmethod
... def task(*args, **kwargs):
... return None
>>>
>>> wf = WorkFlow()
>>> wf.run(foo=hydra_list(["val"]), bar=multirun(["a", "b"]), apple=1)
>>> wf.multirun_task_overrides
{'foo': ['val'], 'bar': multirun(['a', 'b']), 'apple': 1}
"""
if not self._multirun_task_overrides:
overrides = load_from_yaml(
self.working_dir / "multirun.yaml"
).hydra.overrides.task
output = self._parse_overrides(overrides)
self._multirun_task_overrides = output
return self._multirun_task_overrides
@staticmethod
def pre_task(*args: Any, **kwargs: Any) -> None:
"""Called prior to `task`
This can be useful for doing things like setting random seeds,
which must occur prior to instantiating objects for the evaluation
task.
Notes
-----
This function is automatically wrapped by `zen`, which is responsible
for parsing the function's signature and then extracting and instantiating
the corresponding fields from a Hydra config object – passing them to the
function. This behavior can be modified by `self.run(pre_task_fn_wrapper=...)`
"""
[docs]
@staticmethod
def task(*args: Any, **kwargs: Any) -> Any:
"""User-defined task that is run by the workflow. This should be
a static method.
Arguments will be instantiated configuration variables. For example,
if the the workflow configuration is structured as::
├── eval_task_cfg
│ ├── trainer
| ├── module
| ├── another_config
The inputs to `task` can be any of the three configurations:
`trainer`, `module`, or `another_config` such as::
@staticmethod
def task(trainer: Trainer, module: LightningModule) -> None:
trainer.fit(module)
Notes
-----
This function is automatically wrapped by `zen`, which is responsible
for parsing the function's signature and then extracting and instantiating
the corresponding fields from a Hydra config object – passing them to the
function. This behavior can be modified by `self.run(task_fn_wrapper=...)`
"""
raise NotImplementedError()
def validate(self, include_pre_task: bool = True):
"""Validates that the configuration will execute with the user-defined evaluation task"""
if include_pre_task:
zen(self.pre_task).validate(self.eval_task_cfg)
zen(self.task).validate(self.eval_task_cfg)
[docs]
def run(
self,
*,
working_dir: Optional[str] = None,
sweeper: Optional[str] = None,
launcher: Optional[str] = None,
overrides: Optional[List[str]] = None,
task_fn_wrapper: Union[
Callable[[Callable[..., T1]], Callable[[Any], T1]], None
] = zen,
pre_task_fn_wrapper: Union[
Callable[[Callable[..., None]], Callable[[Any], None]], None
] = zen,
version_base: Optional[Union[str, Type[_NotSet]]] = _VERSION_BASE_DEFAULT,
to_dictconfig: bool = False,
config_name: str = "rai_workflow",
job_name: str = "rai_workflow",
with_log_configuration: bool = True,
**workflow_overrides: Union[str, int, float, bool, dict, multirun, hydra_list],
):
"""Run the experiment.
Individual workflows can explicitly define `workflow_overrides` to improve
readability and undstanding of what parameters are expected for a particular
workflow.
Parameters
----------
task_fn_wrapper: Callable[[Callable[..., T1]], Callable[[Any], T1]] | None, optional (default=rai_toolbox.mushin.zen)
A wrapper applied to `self.task` prior to launching the task.
The default wrapper is `rai_toolbox.mushin.zen`. Specify `None` for no
wrapper to be applied.
working_dir: str (default: None, the Hydra default will be used)
The directory to run the experiment in. This value is used for
setting `hydra.sweep.dir`.
sweeper: str | None (default: None)
The configuration name of the Hydra Sweeper to use (i.e., the override for
`hydra/sweeper=sweeper`)
launcher: str | None (default: None)
The configuration name of the Hydra Launcher to use (i.e., the override for
`hydra/launcher=launcher`)
overrides: List[str] | None (default: None)
Parameter overrides not considered part of the workflow parameter set.
This is helpful for filtering out parameters stored in
`self.workflow_overrides`.
version_base : Optional[str], optional (default=1.1)
Available starting with Hydra 1.2.0.
- If the `version_base parameter` is not specified, Hydra 1.x will use defaults compatible with version 1.1. Also in this case, a warning is issued to indicate an explicit version_base is preferred.
- If the `version_base parameter` is `None`, then the defaults are chosen for the current minor Hydra version. For example for Hydra 1.2, then would imply `config_path=None` and `hydra.job.chdir=False`.
- If the `version_base` parameter is an explicit version string like "1.1", then the defaults appropriate to that version are used.
to_dictconfig: bool (default: False)
If ``True``, convert a ``dataclasses.dataclass`` to a ``omegaconf.DictConfig``. Note, this
will remove Hydra's cabability for validation with structured configurations.
config_name : str (default: "rai_workflow")
Name of the stored configuration in Hydra's ConfigStore API.
job_name : str (default: "rai_workflow")
Name of job for logging.
with_log_configuration : bool (default: True)
If ``True``, enables the configuration of the logging subsystem from the loaded config.
**workflow_overrides: str | int | float | bool | multirun | hydra_list | dict
These parameters represent the values for configurations to use for the
experiment.
Passing `param=multirun([1, 2, 3])` will perform a multirun over those
three param values, whereas passing `param=hydra_list([1, 2, 3])` will
pass the entire list as a single input.
These values will be appended to the `overrides` for the Hydra job.
"""
launch_overrides = []
if overrides is not None:
launch_overrides.extend(overrides)
if working_dir is not None:
launch_overrides.append(f"hydra.sweep.dir={working_dir}")
if sweeper is not None:
launch_overrides.append(f"hydra/sweeper={sweeper}")
if launcher is not None:
launch_overrides.append(f"hydra/launcher={launcher}")
for k, v in workflow_overrides.items():
value_check(k, v, type_=(int, float, bool, str, dict, multirun, hydra_list))
if isinstance(v, multirun):
v = ",".join(str(item) for item in v)
prefix = ""
if (
not hasattr(self.eval_task_cfg, k)
or getattr(self.eval_task_cfg, k) is None
):
prefix = "+"
launch_overrides.append(f"{prefix}{k}={v}")
for _name in self._REQUIRED_STATIC_METHODS:
if _name == "task" and hasattr(self, "evaluation_task"):
# TODO: remove when evaluation_task support is removed
_name = "evaluation_task"
if not isinstance(getattr_static(self, _name), staticmethod):
raise TypeError(
f"{type(self).__name__}.{_name} must be a static method"
)
if task_fn_wrapper is None:
task_fn_wrapper = _identity
if pre_task_fn_wrapper is None:
pre_task_fn_wrapper = _identity
# Run a Multirun over epsilons
jobs = launch(
self.eval_task_cfg,
_task_calls(
pre_task=pre_task_fn_wrapper(self.pre_task),
task=task_fn_wrapper(self.task),
),
overrides=launch_overrides,
multirun=True,
version_base=version_base,
to_dictconfig=to_dictconfig,
config_name=config_name,
job_name=job_name,
with_log_configuration=with_log_configuration,
)
if isinstance(jobs, List) and len(jobs) == 1:
# hydra returns [jobs]
jobs = jobs[0]
_job_nums = [j.hydra_cfg.hydra.job.num for j in jobs]
# ensure jobs are always sorted by job-num
jobs = _sort_x_by_k(jobs, _job_nums)
self.jobs = jobs
self.jobs_post_process()
[docs]
def jobs_post_process(self): # pragma: no cover
"""Method to extract attributes and metrics relevant to the workflow."""
raise NotImplementedError()
def plot(self, **kwargs) -> None: # pragma: no cover
"""Plot workflow metrics."""
raise NotImplementedError()
def to_xarray(self): # pragma: no cover
"""Convert workflow data to xArray Dataset or DataArray."""
raise NotImplementedError()
def _non_str_sequence(x: Any) -> TypeGuard[Sequence[Any]]:
return isinstance(x, Sequence) and not isinstance(x, str)
def _coerce_list_of_arraylikes(v: List[Any]):
if v and hasattr(v[0], "__array__"):
return [np.asarray(i) for i in v]
return v
[docs]
class MultiRunMetricsWorkflow(BaseWorkflow):
"""Abstract class for workflows that record metrics using Hydra multirun.
This workflow creates subdirectories of multirun experiments using Hydra. These directories
contain the Hydra YAML configuration and any saved metrics file (defined by the evaluationf task)::
├── working_dir
│ ├── <experiment directory name: 0>
│ | ├── <hydra output subdirectory: (default: .hydra)>
| | | ├── config.yaml
| | | ├── hydra.yaml
| | | ├── overrides.yaml
│ | ├── <metrics_filename>
│ ├── <experiment directory name: 1>
| | ...
The evaluation task is expected to return a dictionary that maps
`metric-name (str) -> value (number | Sequence[number])`
Examples
--------
Let's create a simple workflow where we perform a multirun over a parameter,
`epsilon`, and evaluate a task function that computes an accuracy and loss based on
that `epsilon` value and a specified `scale`.
>>> from rai_toolbox.mushin.workflows import MultiRunMetricsWorkflow
>>> from rai_toolbox.mushin import multirun
>>> class LocalRobustness(MultiRunMetricsWorkflow):
... @staticmethod
... def task(epsilon: float, scale: float) -> dict:
... epsilon *= scale
... val = 100 - epsilon**2
... result = dict(accuracies=val+2, loss=epsilon**2)
... tr.save(result, "test_metrics.pt")
... return result
We'll run this workflow for six total configurations of three `epsilon` values and
two `scale` values. This will launch a Hydra multirun job and aggregate the results.
>>> wf = LocalRobustness()
>>> wf.run(epsilon=multirun([1.0, 2.0, 3.0]), scale=multirun([0.1, 1.0]))
[2022-05-02 11:57:59,219][HYDRA] Launching 6 jobs locally
[2022-05-02 11:57:59,220][HYDRA] #0 : +epsilon=1.0 +scale=0.1
[2022-05-02 11:57:59,312][HYDRA] #1 : +epsilon=1.0 +scale=1.0
[2022-05-02 11:57:59,405][HYDRA] #2 : +epsilon=2.0 +scale=0.1
[2022-05-02 11:57:59,498][HYDRA] #3 : +epsilon=2.0 +scale=1.0
[2022-05-02 11:57:59,590][HYDRA] #4 : +epsilon=3.0 +scale=0.1
[2022-05-02 11:57:59,683][HYDRA] #5 : +epsilon=3.0 +scale=1.0
Now that this workflow has run, we can view the results as an xarray-dataset whose
coordinates reflect the multirun parameters that were varied, and whose
data-variables are our recorded metrics: "accuracies" and "loss".
>>> ds = wf.to_xarray()
>>> ds
<xarray.Dataset>
Dimensions: (epsilon: 3, scale: 2)
Coordinates:
* epsilon (epsilon) float64 1.0 2.0 3.0
* scale (scale) float64 0.1 1.0
Data variables:
accuracies (epsilon, scale) float64 102.0 101.0 102.0 98.0 101.9 93.0
loss (epsilon, scale) float64 0.01 1.0 0.04 4.0 0.09 9.0
We can also load this workflow by providing the working directory where it was run.
>>> loaded = LocalRobustness().load_from_dir(wf.working_dir)
>>> loaded.to_xarray()
<xarray.Dataset>
Dimensions: (epsilon: 3, scale: 2)
Coordinates:
* epsilon (epsilon) float64 1.0 2.0 3.0
* scale (scale) float64 0.1 1.0
Data variables:
accuracies (epsilon, scale) float64 102.0 101.0 102.0 98.0 101.9 93.0
loss (epsilon, scale) float64 0.01 1.0 0.04 4.0 0.09 9.0
"""
[docs]
def __init__(self, eval_task_cfg=None, working_dir: Optional[Path] = None) -> None:
super().__init__(eval_task_cfg)
self._working_dir = working_dir
if self._working_dir is not None:
self.load_from_dir(self.working_dir, metrics_filename=None)
# TODO: add target_job_dirs example
# Document .swap_dims({"job_dir": <...>}) and .set_index(job_dir=[...]).unstack("job_dir")
# for re-indexing based on overrides values
_JOBDIR_NAME: str = "job_dir"
_target_dir_multirun_overrides: Optional[DefaultDict[str, List[Any]]] = None
output_subdir: Optional[str] = None
# List of all the dirs that the multirun writes to; sorted by job-num
multirun_working_dirs: Optional[List[Path]] = None
[docs]
@staticmethod
def task(*args: Any, **kwargs: Any) -> Mapping[str, Any]: # pragma: no cover
"""Abstract `staticmethod` for users to define the task that is configured and
launched by the workflow"""
raise NotImplementedError()
[docs]
@staticmethod
def metric_load_fn(file_path: Path) -> Mapping[str, Any]:
"""Loads a metric file and returns a dictionary of metric-name -> metric-value
mappings.
The default metric load function is `torch.load`.
Parameters
----------
file_path : Path
Returns
-------
named_metrics : Mapping[str, Any]
metric-name -> metric-value(s)
Examples
--------
Designing a workflow that uses the `pickle` module to save and load
metrics
>>> from rai_toolbox.mushin import MultiRunMetricsWorkflow, multirun
>>> import pickle
>>>
>>> class PickledWorkFlow(MultiRunMetricsWorkflow):
... @staticmethod
... def metric_load_fn(file_path: Path):
... with file_path.open("rb") as f:
... return pickle.load(f)
...
... @staticmethod
... def task(a, b):
... with open("./metrics.pkl", "wb") as f:
... pickle.dump(dict(a=a, b=b), f)
>>>
>>> wf = PickleWorkFlow()
>>> wf.run(a=multirun([1, 2, 3]), b=False)
>>> wf.load_metrics("metrics.pkl")
>>> wf.metrics
dict(a=[1, 2, 3], b=[False, False, False])"""
return tr.load(file_path)
[docs]
def run(
self,
*,
task_fn_wrapper: Union[
Callable[[Callable[..., T1]], Callable[[Any], T1]], None
] = zen,
working_dir: Optional[str] = None,
sweeper: Optional[str] = None,
launcher: Optional[str] = None,
overrides: Optional[List[str]] = None,
version_base: Optional[Union[str, Type[_NotSet]]] = _VERSION_BASE_DEFAULT,
target_job_dirs: Optional[Sequence[Union[str, Path]]] = None,
to_dictconfig: bool = False,
config_name: str = "rai_workflow",
job_name: str = "rai_workflow",
with_log_configuration: bool = True,
**workflow_overrides: Union[str, int, float, bool, dict, multirun, hydra_list],
):
# TODO: add docs
if target_job_dirs is not None:
if isinstance(target_job_dirs, str):
raise TypeError(
f"`target_job_dirs` must be a sequence of pathlike objects, got: {target_job_dirs}"
)
value_check("target_job_dirs", target_job_dirs, type_=Sequence)
target_job_dirs = [Path(s).resolve() for s in target_job_dirs]
for d in target_job_dirs:
if not d.is_dir() or not d.exists():
raise FileNotFoundError(
f"The specified target directory – {d} – does not exist."
)
target_job_dirs = multirun([str(s) for s in target_job_dirs])
workflow_overrides[self._JOBDIR_NAME] = target_job_dirs
return super().run(
task_fn_wrapper=task_fn_wrapper,
working_dir=working_dir,
sweeper=sweeper,
launcher=launcher,
overrides=overrides,
version_base=version_base,
to_dictconfig=to_dictconfig,
config_name=config_name,
job_name=job_name,
with_log_configuration=with_log_configuration,
**workflow_overrides,
)
@property
def target_dir_multirun_overrides(self) -> Dict[str, List[Any]]:
"""
For a multirun that sweeps over the target directories of a
previous multirun, `target_dir_multirun_overrides` provides
the flattened overrides for that previous run.
Examples
--------
>>> class A(MultiRunMetricsWorkflow):
... @staticmethod
... def task(value: float, scale: float):
... pass
...
>>> class B(MultiRunMetricsWorkflow):
... @staticmethod
... def task():
... pass
>>> a = A()
>>> a.run(value=multirun([-1.0, 0.0, 1.0]), scale=multirun([11.0, 9.0]))
[2022-05-13 17:19:51,497][HYDRA] Launching 6 jobs locally
[2022-05-13 17:19:51,497][HYDRA] #0 : +value=-1.0 +scale=11.0
[2022-05-13 17:19:51,555][HYDRA] #1 : +value=-1.0 +scale=9.0
[2022-05-13 17:19:51,729][HYDRA] #2 : +value=1.0 +scale=11.0
[2022-05-13 17:19:51,787][HYDRA] #3 : +value=1.0 +scale=9.0
>>> b = B()
>>> b.run(target_job_dirs=a.multirun_working_dirs)
[2022-05-13 17:19:59,900][HYDRA] Launching 6 jobs locally
[2022-05-13 17:19:59,900][HYDRA] #0 : +job_dir=/home/scratch/multirun/0
[2022-05-13 17:19:59,958][HYDRA] #1 : +job_dir=/home/scratch/multirun/1
[2022-05-13 17:20:00,015][HYDRA] #2 : +job_dir=/home/scratch/multirun/2
[2022-05-13 17:20:00,073][HYDRA] #3 : +job_dir=/home/scratch/multirun/3
>>> b.target_dir_multirun_overrides
{'value': [-1.0, -1.0, 1.0, 1.0],
'scale': [11.0, 9.0, 11.0, 9.0]}"""
if self._target_dir_multirun_overrides is not None:
return dict(self._target_dir_multirun_overrides)
assert self.output_subdir is not None
multirun_cfg = self.working_dir / "multirun.yaml"
self._target_dir_multirun_overrides = defaultdict(list)
overrides = load_from_yaml(multirun_cfg).hydra.overrides.task
self.overrides = overrides
dirs = []
for o in overrides:
k, v = o.split("=")
k = k.replace("+", "")
if k == self._JOBDIR_NAME:
dirs = v.split(",")
break
for d in dirs:
overrides: List[str] = list(
load_from_yaml(Path(d) / f"{self.output_subdir}/overrides.yaml")
)
output = self._parse_overrides(overrides)
for ko, vo in output.items():
self._target_dir_multirun_overrides[ko].append(vo)
return dict(self._target_dir_multirun_overrides)
def jobs_post_process(self):
assert len(self.jobs) > 0
# TODO: Make protocol type for JobReturn
assert isinstance(self.jobs[0], JobReturn)
self.jobs: List[JobReturn]
self.multirun_working_dirs = []
for job in self.jobs:
_hydra_cfg = job.hydra_cfg
assert _hydra_cfg is not None
assert job.working_dir is not None
_cwd = _hydra_cfg.hydra.runtime.cwd
working_dir = Path(_cwd) / job.working_dir
self.multirun_working_dirs.append(working_dir)
# set working directory of this workflow
self.working_dir = self.multirun_working_dirs[0].parent
hydra_cfg = self.jobs[0].hydra_cfg
assert hydra_cfg is not None
self.output_subdir = hydra_cfg.hydra.output_subdir
# extract configs, overrides, and metrics
self.cfgs = [j.cfg for j in self.jobs]
job_metrics = [j.return_value for j in self.jobs]
self.metrics = self._process_metrics(job_metrics)
@staticmethod
def _process_metrics(job_metrics: List[Dict[str, Any]]) -> Dict[str, Any]:
metrics = defaultdict(list)
for task_metrics in job_metrics:
if task_metrics is None:
continue
for k, v in task_metrics.items():
# get item if it's a single element array
if isinstance(v, list) and len(v) == 1:
v = v[0]
metrics[k].append(v)
return metrics
[docs]
def load_from_dir(
self: Self,
working_dir: Union[Path, str],
metrics_filename: Union[str, Sequence[str], None],
) -> Self:
"""Loading workflow job data from a given working directory. The workflow
is loaded in-place and "self" is returned by this method.
Parameters
----------
working_dir: str | Path
The base working directory of the experiment. It is expected
that subdirectories within this working directory will contain
individual Hydra jobs data (yaml configurations) and saved metrics files.
metrics_filename: str | Sequence[str] | None
The filename(s) or glob-pattern(s) uses to load the metrics.
If `None`, the metrics stored in `self.metrics` is used.
Returns
-------
loaded_workflow : Self
"""
self.working_dir = Path(working_dir)
self.output_subdir = load_from_yaml(
self.working_dir / "multirun.yaml"
).hydra.output_subdir
self.multirun_working_dirs = list(
(x.parent for x in self.working_dir.glob(f"**/*/{self.output_subdir}"))
)
# ensure working dirs are sorted by job num
_job_nums = (
load_from_yaml(dir_ / f"{self.output_subdir}/hydra.yaml").hydra.job.num
for dir_ in self.multirun_working_dirs
)
self.multirun_working_dirs = _sort_x_by_k(self.multirun_working_dirs, _job_nums)
self.cfgs = []
for dir_ in self.multirun_working_dirs:
# Ensure we load saved YAML configurations for each job (in hydra.job.output_subdir)
cfg_file = dir_ / f"{self.output_subdir}/config.yaml"
assert cfg_file.exists(), cfg_file
self.cfgs.append(load_from_yaml(cfg_file))
if metrics_filename is not None:
self.load_metrics(metrics_filename)
return self
[docs]
def load_metrics(
self, metrics_filename: Union[str, Sequence[str]]
) -> Dict[str, List[Any]]:
"""Loads and aggregates across all multirun working dirs, and stores
the metrics in `self.metrics`.
`self.metric_load_fn` is used to load each job's metric file(s).
Parameters
----------
metrics_filename : str | Sequence[str]
The filename(s) or glob-pattern(s) uses to load the metrics.
If `None`, the metrics stored in `self.metrics` is used.
Returns
-------
metrics : Dict[str, List[Any]]
Examples
--------
Creating a workflow that saves named metrics using `torch.save`
>>> from rai_toolbox.mushin.workflows import MultiRunMetricsWorkflow, multirun
>>> import torch as tr
>>>
... class TorchWorkFlow(MultiRunMetricsWorkflow):
... @staticmethod
... def task(a, b):
... tr.save(dict(a=a, b=b), "metrics.pt")
...
>>> wf = TorchWorkFlow()
>>> wf.run(a=multirun([1, 2, 3]), b=False)
[2022-06-01 12:35:51,650][HYDRA] Launching 3 jobs locally
[2022-06-01 12:35:51,650][HYDRA] #0 : +a=1 +b=False
[2022-06-01 12:35:51,715][HYDRA] #1 : +a=2 +b=False
[2022-06-01 12:35:51,780][HYDRA] #2 : +a=3 +b=False
`~MultiRunMetricsWorkflow` uses `torch.load` by default to load metrics files
(refer to `~MultiRunMetricsWorkflow.metric_load_fn` to change this behavior).
>>> wf.load_metrics("metrics.pt")
defaultdict(list, {'a': [1, 2, 3], 'b': [False, False, False]})
>>> wf.metrics
defaultdict(list, {'a': [1, 2, 3], 'b': [False, False, False]})
"""
if self.multirun_working_dirs is None:
self.load_from_dir(self.working_dir, metrics_filename=None)
assert self.multirun_working_dirs is not None
if isinstance(metrics_filename, str):
metrics_filename = [metrics_filename]
job_metrics = []
for dir_ in self.multirun_working_dirs:
_metrics = {}
for name in metrics_filename:
files = sorted(dir_.glob(name))
if not files:
raise FileNotFoundError(
f"No files with the path/pattern {dir_/name} were found"
)
for f_ in files:
_metrics.update(self.metric_load_fn(f_))
job_metrics.append(_metrics)
self.metrics = self._process_metrics(job_metrics)
return self.metrics
@staticmethod
def _sanitize_coordinate_for_xarray(
value: Union[LoadedValue, Sequence[LoadedValue]]
) -> Union[str, int, float, bool, List[Union[str, int, float, bool]]]:
"""Nested sequences are not permitted for xarray coordinates. This
Returns a list of scalars when `value` is a multi-run or a scalar.
Inner sequences are converted to strings"""
if _non_str_sequence(value):
if isinstance(value, multirun):
_seq: Sequence[LoadedValue] = value
return [str(_v) if _non_str_sequence(_v) else _v for _v in _seq]
return str(value)
return value # type: ignore
[docs]
def to_xarray(
self,
include_working_subdirs_as_data_var: bool = False,
coord_from_metrics: Optional[str] = None,
non_multirun_params_as_singleton_dims: bool = False,
metrics_filename: Union[str, Sequence[str], None] = None,
):
"""Convert workflow data to xarray Dataset.
Parameters
----------
include_working_subdirs_as_data_var : bool, optional (default=False)
If `True` then the data-variable "working_subdir" will be included in the
xarray. This data variable is used to lookup the working sub-dir path
(a string) by multirun coordinate.
coord_from_metrics : str | None (default: None)
If not `None` defines the metric key to use as a coordinate
in the `Dataset`. This function assumes that this coordinate
represents the leading dimension for all data-variables.
non_multirun_params_as_singleton_dims : bool, optional (default=False)
If `True` then non-multirun entries from `workflow_overrides` will be
included as length-1 dimensions in the xarray. Useful for merging/
concatenation with other Datasets
metrics_filename: Optional[str]
The filename or glob-pattern uses to load the metrics.
If `None`, the metrics stored in `self.metrics` is used.
Returns
-------
results : xarray.Dataset
A dataset whose dimensions and coordinate-values are determined by the
quantities over which the multi-run was performed. The data variables
correspond to the named results returned by the jobs."""
import xarray as xr
if metrics_filename is not None:
if self.multirun_working_dirs is None:
self.load_from_dir(self.working_dir, metrics_filename=metrics_filename)
else:
self.load_metrics(metrics_filename)
# all overrides containing non-multirun lists must be converted to
# strings so that xarray treats that list value as a "scalar"
#
# stores: override-name -> value
# where value is either a scalar (i.e. int|float|bool|str) or a list of scalars
# A list of scalars indicates a multirun
cast_overrides = {
k: self._sanitize_coordinate_for_xarray(value)
for k, value in self.multirun_task_overrides.items()
}
orig_coords = {
k: (v if _non_str_sequence(v) else [v])
for k, v in cast_overrides.items()
if non_multirun_params_as_singleton_dims or _non_str_sequence(v)
}
metric_coords = {}
if coord_from_metrics:
if coord_from_metrics not in self.metrics:
raise ValueError(
f"key `{coord_from_metrics}` not in metrics (available: "
f"{list(self.metrics.keys())})"
)
v = _coerce_list_of_arraylikes(self.metrics[coord_from_metrics])
v = np.asarray(v)
if v.ndim > 1: # pragma: no cover
# assume this coord was repeated across experiments, e.g., "epochs"
v = v[0]
metric_coords[coord_from_metrics] = v
attrs = {k: v for k, v in cast_overrides.items() if not _non_str_sequence(v)}
# we will add additional coordinates as-needed for multi-dim metrics
coords: Dict[str, Any] = orig_coords.copy()
shape = tuple(len(v) for v in coords.values())
metrics_to_add = self.metrics.copy()
if (
include_working_subdirs_as_data_var
and self.multirun_working_dirs is not None
):
metrics_to_add["working_subdir"] = [
str(p) for p in self.multirun_working_dirs
]
data = {}
for k, v in metrics_to_add.items():
if coord_from_metrics and k == coord_from_metrics:
continue
v = _coerce_list_of_arraylikes(v)
datum = np.asarray(v).reshape(shape + np.asarray(v[0]).shape)
k_coords = list(orig_coords)
for n in range(datum.ndim - len(orig_coords)):
if coord_from_metrics and n < len(metric_coords):
# Assume the first coordinate of the metric is the metric coordinate dimension
k_coords += list(metric_coords.keys())
for mk, mv in metric_coords.items():
coords[mk] = mv
else:
# Create additional arbitrary coordinates as-needed for non-scalar
# metrics
k_coords += [f"{k}_dim{n}"]
coords[f"{k}_dim{n}"] = np.arange(datum.shape[len(orig_coords) + n])
data[k] = (k_coords, datum)
coords.update(metric_coords)
out = xr.Dataset(coords=coords, data_vars=data, attrs=attrs)
if self._JOBDIR_NAME in set(out.coords):
exp_dir = out.coords[self._JOBDIR_NAME]
coords = {}
for k, v in self.target_dir_multirun_overrides.items():
if len(v) == len(exp_dir):
if (
len(set(np.unique(v))) > 1
or non_multirun_params_as_singleton_dims
):
coords[k] = (
[self._JOBDIR_NAME],
[self._sanitize_coordinate_for_xarray(item) for item in v],
)
out = out.assign_coords(coords)
return out
[docs]
class RobustnessCurve(MultiRunMetricsWorkflow):
"""Abstract class for workflows that measure performance for different perturbation
values.
This workflow requires and uses parameter `epsilon` as the configuration option for
varying the perturbation.
See Also
--------
MultiRunMetricsWorkflow
"""
[docs]
def run(
self,
*,
epsilon: Union[str, Sequence[float]],
task_fn_wrapper: Union[
Callable[[Callable[..., T1]], Callable[[Any], T1]], None
] = zen,
target_job_dirs: Optional[Sequence[Union[str, Path]]] = None, # TODO: add docs
version_base: Optional[Union[str, Type[_NotSet]]] = _VERSION_BASE_DEFAULT,
working_dir: Optional[str] = None,
sweeper: Optional[str] = None,
launcher: Optional[str] = None,
overrides: Optional[List[str]] = None,
**workflow_overrides: Union[str, int, float, bool, multirun, hydra_list],
):
"""Run the experiment for varying value `epsilon`.
Parameters
----------
epsilon: str | Sequence[float]
The configuration parameter for the perturbation. Unlike Hydra overrides,
this parameter can be a list of floats that will be converted into a
multirun sequence override for Hydra.
task_fn_wrapper: Callable[[Callable[..., T1]], Callable[[Any], T1]] | None, optional (default=rai_toolbox.mushin.zen)
A wrapper applied to `self.task` prior to launching the task.
The default wrapper is `rai_toolbox.mushin.zen`. Specify `None` for no
wrapper to be applied.
working_dir: str (default: None, the Hydra default will be used)
The directory to run the experiment in. This value is used for
setting `hydra.sweep.dir`.
sweeper: str | None (default: None)
The configuration name of the Hydra Sweeper to use (i.e., the override for
`hydra/sweeper=sweeper`)
launcher: str | None (default: None)
The configuration name of the Hydra Launcher to use (i.e., the override for
`hydra/launcher=launcher`)
overrides: List[str] | None (default: None)
Parameter overrides not considered part of the workflow parameter set.
This is helpful for filtering out parameters stored in
`self.workflow_overrides`.
**workflow_overrides: dict | str | int | float | bool | multirun | hydra_list
These parameters represent the values for configurations to use for the
experiment.
These values will be appended to the `overrides` for the Hydra job.
"""
if not isinstance(epsilon, str):
epsilon = multirun(epsilon)
return super().run(
task_fn_wrapper=task_fn_wrapper,
working_dir=working_dir,
sweeper=sweeper,
launcher=launcher,
version_base=version_base,
overrides=overrides,
**workflow_overrides,
# for multiple multi-run params, epsilon should fastest-varying param;
# i.e. epsilon should be the trailing dim in the multi-dim array of results
target_job_dirs=target_job_dirs,
epsilon=epsilon,
)
def to_xarray(
self,
include_working_subdirs_as_data_var: bool = False,
coord_from_metrics: Optional[str] = None,
non_multirun_params_as_singleton_dims: bool = False,
metrics_filename: Union[str, Sequence[str], None] = None,
):
"""Convert workflow data to xarray Dataset.
Parameters
----------
include_working_subdirs_as_data_var : bool, optional (default=False)
If `True` then the data-variable "working_subdir" will be included in the
xarray. This data variable is used to lookup the working sub-dir path
(a string) by multirun coordinate.
coord_from_metrics : str | None (default: None)
If not `None` defines the metric key to use as a coordinate
in the `Dataset`. This function assumes that this coordinate
represents the leading dimension for all data-variables.
non_multirun_params_as_singleton_dims : bool, optional (default=False)
If `True` then non-multirun entries from `workflow_overrides` will be
included as length-1 dimensions in the xarray. Useful for merging/
concatenation with other Datasets
metrics_filename: Optional[str]
The filename or glob-pattern uses to load the metrics.
If `None`, the metrics stored in `self.metrics` is used.
Returns
-------
results : xarray.Dataset
A dataset whose dimensions and coordinate-values are determined by the
quantities over which the multi-run was performed. The data variables
correspond to the named results returned by the jobs."""
return (
super()
.to_xarray(
include_working_subdirs_as_data_var=include_working_subdirs_as_data_var,
coord_from_metrics=coord_from_metrics,
non_multirun_params_as_singleton_dims=non_multirun_params_as_singleton_dims,
metrics_filename=metrics_filename,
)
.sortby("epsilon")
)
[docs]
def plot(
self,
metric: str,
ax: Any = None,
group: Optional[str] = None,
save_filename: Optional[str] = None,
non_multirun_params_as_singleton_dims: bool = False,
**kwargs,
) -> Any:
"""Plot metrics versus `epsilon`.
Using the `xarray.Dataset` from `to_xarray`, plot the metrics
against the workflow perturbation parameters.
Parameters
----------
metric: str
The metric saved
ax: Axes | None (default: None)
If not `None`, the matplotlib.Axes to use for plotting.
group: str | None (default: None)
Needed if other parameters besides `epsilon` were varied.
save_filename: str | None (default: None)
If not `None` save figure to the filename provided.
non_multirun_params_as_singleton_dims : bool, optional (default=False)
If `True` then non-multirun entries from `workflow_overrides` will be
included as length-1 dimensions in the xarray. Useful for merging/
concatenation with other Datasets
**kwargs: Any
Additional arguments passed to `xarray.plot`.
"""
import matplotlib.pyplot as plt
if ax is None:
_, ax = plt.subplots()
xdata = self.to_xarray(
non_multirun_params_as_singleton_dims=non_multirun_params_as_singleton_dims
)
if group is None:
plots = xdata[metric].plot.line(x="epsilon", ax=ax, **kwargs)
else:
# TODO: xarray.groupby doesn't support multidimensional grouping
dg = xdata.groupby(group)
plots = [
grp[metric].plot(x="epsilon", label=name, ax=ax, **kwargs)
for name, grp in dg
]
if save_filename is not None:
plt.savefig(save_filename)
return plots