Source code for hydra_zen._hydra_overloads

# Copyright (c) 2024 Massachusetts Institute of Technology
# SPDX-License-Identifier: MIT
"""
Provides annotation overloads for various hydra functions, using the types defined in `hydra_zen.typing`.
This enables tools like IDEs to be more incisive during static analysis and to provide users with additional
context about their code.

E.g.

.. code::

   from hydra_zen import builds, instantiate
   DictConfig = builds(dict, a=1, b=2)  # type: Type[Builds[Type[dict]]]

   # static analysis tools can provide useful type information
   # about the object that is instantiated from the config
   out = instantiate(DictConfig)  # type: dict

"""

# pyright: strict
# pyright: reportPrivateUsage=false

import pathlib
from dataclasses import is_dataclass
from functools import partial, wraps
from typing import (
    IO,
    Any,
    Callable,
    Dict,
    List,
    Tuple,
    Type,
    TypeVar,
    Union,
    cast,
    overload,
)

from hydra.utils import instantiate as hydra_instantiate
from omegaconf import MISSING, DictConfig, ListConfig, OmegaConf

from .structured_configs._implementations import (
    ConfigComplex,
    ConfigPath,
    DefaultBuilds,
)
from .typing import Builds, Just, Partial
from .typing._implementations import DataClass_, HasTarget, InstOrType, IsPartial

__all__ = ["instantiate", "to_yaml", "save_as_yaml", "load_from_yaml", "MISSING"]


T = TypeVar("T")
F = TypeVar("F", bound=Callable[..., Any])


def _call_target(
    _target_: F,
    _partial_: bool,
    args: Tuple[Any, ...],
    kwargs: Dict[str, Any],
    full_key: str,
    *,
    target_wrapper: Callable[[F], F],
) -> Any:  # pragma: no cover
    """Call target (type) with args and kwargs."""
    import functools

    from hydra._internal.instantiate._instantiate2 import (
        _convert_target_to_string,
        _extract_pos_args,
    )
    from hydra.errors import InstantiationException
    from omegaconf import OmegaConf

    from hydra_zen.funcs import zen_processing

    try:
        args, kwargs = _extract_pos_args(args, kwargs)
        # detaching configs from parent.
        # At this time, everything is resolved and the parent link can cause
        # issues when serializing objects in some scenarios.
        for arg in args:
            if OmegaConf.is_config(arg):
                arg._set_parent(None)
        for v in kwargs.values():
            if OmegaConf.is_config(v):
                v._set_parent(None)
    except Exception as e:
        msg = (
            f"Error in collecting args and kwargs for '{_convert_target_to_string(_target_)}':"
            + f"\n{repr(e)}"
        )
        if full_key:
            msg += f"\nfull_key: {full_key}"

        raise InstantiationException(msg) from e

    orig_target = _target_
    if _target_ is zen_processing:
        kwargs["_zen_target_wrapper"] = target_wrapper
    else:
        _target_ = target_wrapper(_target_)

    if _partial_:
        try:
            return functools.partial(_target_, *args, **kwargs)
        except Exception as e:
            msg = (
                f"Error in creating partial({_convert_target_to_string(orig_target)}, ...) object:"
                + f"\n{repr(e)}"
            )
            if full_key:
                msg += f"\nfull_key: {full_key}"
            raise InstantiationException(msg) from e
    else:
        try:
            return _target_(*args, **kwargs)
        except Exception as e:
            msg = f"Error in call to target '{_convert_target_to_string(orig_target)}':\n{repr(e)}"
            if full_key:
                msg += f"\nfull_key: {full_key}"
            raise InstantiationException(msg) from e


class _TightBind:  # pragma: no cover
    ...


@overload
def instantiate(
    config: _TightBind,
    *args: Any,
    _target_wrapper_: Union[Callable[[F], F], None] = ...,
    **kwargs: Any,
) -> Any: ...


@overload
def instantiate(
    config: InstOrType[ConfigPath],
    *args: Any,
    _target_wrapper_: Union[Callable[[F], F], None] = ...,
    **kwargs: Any,
) -> pathlib.Path: ...


@overload
def instantiate(
    config: InstOrType[ConfigComplex],
    *args: Any,
    _target_wrapper_: Union[Callable[[F], F], None] = ...,
    **kwargs: Any,
) -> complex: ...


@overload
def instantiate(
    config: InstOrType[Just[T]],
    *args: Any,
    _target_wrapper_: Union[Callable[[F], F], None] = ...,
    **kwargs: Any,
) -> T: ...


@overload
def instantiate(
    config: InstOrType[IsPartial[Callable[..., T]]],
    *args: Any,
    _target_wrapper_: Union[Callable[[F], F], None] = ...,
    **kwargs: Any,
) -> Partial[T]: ...


@overload
def instantiate(
    config: InstOrType[Builds[Callable[..., T]]],
    *args: Any,
    _target_wrapper_: Union[Callable[[F], F], None] = ...,
    **kwargs: Any,
) -> T: ...


@overload
def instantiate(
    config: Union[
        HasTarget,
        ListConfig,
        DictConfig,
        DataClass_,
        Type[DataClass_],
        Dict[Any, Any],
        List[Any],
    ],
    *args: Any,
    _target_wrapper_: Union[Callable[[F], F], None] = ...,
    **kwargs: Any,
) -> Any: ...


[docs] def instantiate( config: Any, *args: Any, _target_wrapper_: Union[Callable[[F], F], None] = None, **kwargs: Any, ) -> Any: """ Instantiates the target of a targeted config. This is an alias of :func:`hydra.utils.instantiate` [1]_. By default, `instantiate` will recursively instantiate nested configurations [1]_. Parameters ---------- config : Builds[Type[T] | Callable[..., T]] The targeted config whose target will be instantiated/called. *args: Any Override values, specified by-position. Take priority over the positional values provided by ``config``. **kwargs : Any Override values, specified by-name. Take priority over the named values provided by ``config``. _target_wrapper_ : Callable[[F], F] | None, optional (default=None) If specified, this wrapper is applied to _all_ targets during instantiation. This can be used to add custom validation/parsing to the config-instantiation process. I.e., For any target reached during recursive instantiation, `_target_wrapper_(target)(*args, **kwargs)` will be called rather than `target(*args, **kwargs)`. Returns ------- instantiated : T The instantiated target. Instantiated using the values provided by ``config`` and/or overridden via ``*args`` and ``**kwargs``. See Also -------- builds: Returns a config, which describes how to instantiate/call ``<hydra_target>``. just: Produces a config that, when instantiated by Hydra, "just" returns the un-instantiated target-object Notes ----- This is an alias for ``hydra.utils.instantiate``, but adds additional static type information. During instantiation, Hydra performs runtime validation of data based on a limited set of type-annotations that can be associated with the fields of the provided config [2]_ [3]_. Hydra supports a string-based syntax for variable interpolation, which enables configured values to be set in a self-referential and dynamic manner [4]_. References ---------- .. [1] https://hydra.cc/docs/advanced/instantiate_objects/overview .. [2] https://omegaconf.readthedocs.io/en/latest/structured_config.html#simple-types .. [3] https://omegaconf.readthedocs.io/en/latest/structured_config.html#runtime-type-validation-and-conversion .. [4] https://omegaconf.readthedocs.io/en/latest/usage.html#variable-interpolation Examples -------- >>> from hydra_zen import builds, instantiate, just **Basic Usage** Instantiating a config that targets a class/type. >>> ConfDict = builds(dict, x=1) # a targeted config >>> instantiate(ConfDict) # calls `dict(x=1)` {'x': 1} Instantiating a config that targets a function. >>> def f(z): return z >>> ConfF = builds(f, z=22) # a targeted config >>> instantiate(ConfF) # calls `f(z=22)` 22 Providing a manual override, via ``instantiate(..., **kwargs)`` >>> instantiate(ConfF, z='foo') # calls `f(z='foo')` 'foo' Recursive instantiation through nested configs. >>> inner = builds(dict, b="hi") >>> outer = builds(dict, a=inner) >>> instantiate(outer) # calls `dict(a=dict(b='hi))` {'a': {'b': 'hi'}} **Leveraging Variable Interpolation** Hydra provides a powerful language for absolute and relative interpolated variables among configs [4]_. Let's make a config where multiple fields reference the field ``name`` via absolute interpolation. >>> from hydra_zen import make_config >>> Conf = make_config("name", a="${name}", b=builds(dict, x="${name}")) Resolving the interpolation key: ``name`` >>> instantiate(Conf, name="Jeff") {'a': 'Jeff', 'b': {'x': 'Jeff'}, 'name': 'Jeff'} **Runtime Data Validation via Hydra** >>> def g(x: float): return x # note the annotation: float >>> Conf_g = builds(g, populate_full_signature=True) >>> instantiate(Conf_g, x=1.0) 1.0 Passing a non-float to ``x`` will produce a validation error upon instantiation >>> instantiate(Conf_g, x='hi') ValidationError: Value 'hi' could not be converted to Float full_key: x object_type=Builds_g Only a subset of primitive types are supported by Hydra's validation system [2]_. See :ref:`data-val` for more general data validation capabilities via hydra-zen. """ if _target_wrapper_ is None: return hydra_instantiate(config, *args, **kwargs) from hydra._internal.instantiate import _instantiate2 as inst old = inst._call_target try: new_call_target = cast( F, partial(_call_target, target_wrapper=_target_wrapper_) ) inst._call_target = new_call_target return hydra_instantiate(config, *args, **kwargs) finally: inst._call_target = old
def _apply_just(fn: F) -> F: @wraps(fn) def wrapper(cfg: Any, *args: Any, **kwargs: Any): if not is_dataclass(cfg): cfg = DefaultBuilds.just(cfg) return fn(cfg, *args, **kwargs) return cast(F, wrapper)
[docs] @_apply_just def to_yaml(cfg: Any, *, resolve: bool = False, sort_keys: bool = False) -> str: """ Serialize a config as a yaml-formatted string. This is an alias of ``omegaconf.Omegaconf.to_yaml``. Parameters ---------- cfg : Any A valid configuration object, supported either by Hydra or hydra-zen resolve : bool, optional (default=False) If `True`, interpolated fields in `cfg` will be resolved in the yaml. sort_keys : bool, optional (default=False) If `True`, the yaml's entries will alphabetically ordered. Returns ------- yaml : str See Also -------- save_as_yaml: Save a config to a yaml-format file. load_from_yaml: Load a config from a yaml-format file. Examples -------- >>> from hydra_zen import builds, make_config, to_yaml **Basic usage** The yaml of a config with both an un-configured field and a configured field: >>> c1 = make_config("a", b=1) >>> print(to_yaml(c1)) a: ??? b: 1 The yaml of a targeted config: >>> c2 = builds(dict, y=10) >>> print(to_yaml(c2)) _target_: builtins.dict 'y': 10 hydra-zen's additional supported types can be specified as well >>> print(to_yaml(1+2j)) real: 1.0 imag: 2.0 _target_: builtins.complex **Specifying resolve** The following is a config with interpolated fields. >>> c3 = make_config(a=builds(dict, b="${c}"), c=1) >>> print(to_yaml(c3, resolve=False)) a: _target_: builtins.dict b: ${c} c: 1 >>> print(to_yaml(c3, resolve=True)) a: _target_: builtins.dict b: 1 c: 1 **Specifying sort_keys** >>> c4 = make_config("b", "a") # field order: b then a >>> print(to_yaml(c4, sort_keys=False)) b: ??? a: ??? >>> print(to_yaml(c4, sort_keys=True)) a: ??? b: ??? """ return OmegaConf.to_yaml(cfg=cfg, resolve=resolve, sort_keys=sort_keys)
[docs] @_apply_just def save_as_yaml( config: Any, f: Union[str, pathlib.Path, IO[Any]], resolve: bool = False ) -> None: """ Save a config to a yaml-format file This is an alias of ``omegaconf.Omegaconf.save`` [1]_. Parameters ---------- config : Any A config object. f : str | pathlib.Path | IO[Any] The path of the file file, or a file object, to be written to. resolve : bool, optional (default=None) If ``True`` interpolations will be resolved in the config prior to serialization [2]_. See Examples section of `to_yaml` for details. See Also -------- to_yaml: Serialize a config as a yaml-formatted string. load_from_yaml: Load a config from a yaml-format file. References ---------- .. [1] https://omegaconf.readthedocs.io/en/2.0_branch/usage.html#save-load-yaml-file .. [2] https://omegaconf.readthedocs.io/en/2.0_branch/usage.html#variable-interpolation Examples -------- >>> from hydra_zen import make_config, save_as_yaml, load_from_yaml **Basic usage** >>> Conf = make_config(a=1, b="foo") >>> save_as_yaml(Conf, "test.yaml") # file written to: test.yaml >>> load_from_yaml("test.yaml") {'a': 1, 'b': 'foo'} """ return OmegaConf.save(config=config, f=f, resolve=resolve)
[docs] def load_from_yaml( file_: Union[str, pathlib.Path, IO[Any]] ) -> Union[DictConfig, ListConfig]: """ Load a config from a yaml-format file This is an alias of ``omegaconf.OmegaConf.load``. Parameters ---------- file_ : str | pathlib.Path | IO[Any] The path to the yaml-formatted file, or the file object, that the config will be loaded from. Returns ------- loaded_conf : DictConfig | ListConfig See Also -------- save_as_yaml: Save a config to a yaml-format file. to_yaml: Serialize a config as a yaml-formatted string. References ---------- .. [1] https://omegaconf.readthedocs.io/en/2.0_branch/usage.html#save-load-yaml-file Examples -------- >>> from hydra_zen import make_config, save_as_yaml, load_from_yaml **Basic usage** >>> Conf = make_config(a=1, b="foo") >>> save_as_yaml(Conf, "test.yaml") # file written to: test.yaml >>> load_from_yaml("test.yaml") {'a': 1, 'b': 'foo'} """ return OmegaConf.load(file_)